Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/air/__init__.py +22 -0
- .venv/lib/python3.11/site-packages/ray/air/config.py +766 -0
- .venv/lib/python3.11/site-packages/ray/air/constants.py +94 -0
- .venv/lib/python3.11/site-packages/ray/air/data_batch_type.py +11 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/fixed.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/placement_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/request.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/resource_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/fixed.py +147 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/placement_group.py +214 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/request.py +255 -0
- .venv/lib/python3.11/site-packages/ray/air/execution/resources/resource_manager.py +155 -0
- .venv/lib/python3.11/site-packages/ray/air/result.py +283 -0
- .venv/lib/python3.11/site-packages/ray/air/session.py +1 -0
- .venv/lib/python3.11/site-packages/ray/air/util/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/air/util/tensor_extensions/pandas.py +1451 -0
- .venv/lib/python3.11/site-packages/ray/air/util/torch_dist.py +191 -0
- .venv/lib/python3.11/site-packages/ray/air/util/transform_pyarrow.py +39 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_noop_latency.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_throughput.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/http_noop_latency.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/microbenchmark.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/proxy_benchmark.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/common.py +276 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_noop_latency.py +34 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_throughput.py +62 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/http_noop_latency.py +32 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/microbenchmark.py +182 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/proxy_benchmark.py +294 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/serialization_benchmark.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/common.py +29 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/serialization_benchmark.py +163 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_core_throughput.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_grpc_throughput.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_handle_throughput.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_http_throughput.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -155,3 +155,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 155 |
.venv/lib/python3.11/site-packages/ray/data/__pycache__/dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 156 |
.venv/lib/python3.11/site-packages/ray/data/__pycache__/read_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 157 |
.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 155 |
.venv/lib/python3.11/site-packages/ray/data/__pycache__/dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 156 |
.venv/lib/python3.11/site-packages/ray/data/__pycache__/read_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 157 |
.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 158 |
+
.venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/air/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.config import (
|
| 2 |
+
CheckpointConfig,
|
| 3 |
+
DatasetConfig,
|
| 4 |
+
FailureConfig,
|
| 5 |
+
RunConfig,
|
| 6 |
+
ScalingConfig,
|
| 7 |
+
)
|
| 8 |
+
from ray.air.data_batch_type import DataBatchType
|
| 9 |
+
from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
|
| 10 |
+
from ray.air.result import Result
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"DataBatchType",
|
| 14 |
+
"RunConfig",
|
| 15 |
+
"Result",
|
| 16 |
+
"ScalingConfig",
|
| 17 |
+
"DatasetConfig",
|
| 18 |
+
"FailureConfig",
|
| 19 |
+
"CheckpointConfig",
|
| 20 |
+
"AcquiredResources",
|
| 21 |
+
"ResourceRequest",
|
| 22 |
+
]
|
.venv/lib/python3.11/site-packages/ray/air/config.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections import Counter, defaultdict
|
| 3 |
+
from dataclasses import _MISSING_TYPE, dataclass, fields
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import (
|
| 6 |
+
TYPE_CHECKING,
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Dict,
|
| 10 |
+
List,
|
| 11 |
+
Mapping,
|
| 12 |
+
Optional,
|
| 13 |
+
Tuple,
|
| 14 |
+
Union,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
import pyarrow.fs
|
| 18 |
+
|
| 19 |
+
from ray._private.ray_constants import RESOURCE_CONSTRAINT_PREFIX
|
| 20 |
+
from ray._private.storage import _get_storage_uri
|
| 21 |
+
from ray._private.thirdparty.tabulate.tabulate import tabulate
|
| 22 |
+
from ray.data.preprocessor import Preprocessor
|
| 23 |
+
from ray.util.annotations import Deprecated, PublicAPI
|
| 24 |
+
from ray.widgets import Template, make_table_html_repr
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ray.train import SyncConfig
|
| 28 |
+
from ray.tune.callback import Callback
|
| 29 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 30 |
+
from ray.tune.experimental.output import AirVerbosity
|
| 31 |
+
from ray.tune.search.sample import Domain
|
| 32 |
+
from ray.tune.stopper import Stopper
|
| 33 |
+
from ray.tune.utils.log import Verbosity
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Dict[str, List] is to support `tune.grid_search`:
|
| 37 |
+
# TODO(sumanthratna/matt): Upstream this to Tune.
|
| 38 |
+
SampleRange = Union["Domain", Dict[str, List]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
MAX = "max"
|
| 42 |
+
MIN = "min"
|
| 43 |
+
_DEPRECATED_VALUE = "DEPRECATED"
|
| 44 |
+
|
| 45 |
+
DATASET_CONFIG_DEPRECATION_MSG = """
|
| 46 |
+
Use `ray.train.DataConfig` instead of DatasetConfig to configure data ingest for training. See https://docs.ray.io/en/releases-2.6.3/ray-air/check-ingest.html#migrating-from-the-legacy-datasetconfig-api for more details.
|
| 47 |
+
""" # noqa: E501
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> str:
|
| 54 |
+
"""A utility function to elegantly represent dataclasses.
|
| 55 |
+
|
| 56 |
+
In contrast to the default dataclass `__repr__`, which shows all parameters, this
|
| 57 |
+
function only shows parameters with non-default values.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
obj: The dataclass to represent.
|
| 61 |
+
default_values: An optional dictionary that maps field names to default values.
|
| 62 |
+
Use this parameter to specify default values that are generated dynamically
|
| 63 |
+
(e.g., in `__post_init__` or by a `default_factory`). If a default value
|
| 64 |
+
isn't specified in `default_values`, then the default value is inferred from
|
| 65 |
+
the `dataclass`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
A representation of the dataclass.
|
| 69 |
+
"""
|
| 70 |
+
if default_values is None:
|
| 71 |
+
default_values = {}
|
| 72 |
+
|
| 73 |
+
non_default_values = {} # Maps field name to value.
|
| 74 |
+
|
| 75 |
+
def equals(value, default_value):
|
| 76 |
+
# We need to special case None because of a bug in pyarrow:
|
| 77 |
+
# https://github.com/apache/arrow/issues/38535
|
| 78 |
+
if value is None and default_value is None:
|
| 79 |
+
return True
|
| 80 |
+
if value is None or default_value is None:
|
| 81 |
+
return False
|
| 82 |
+
return value == default_value
|
| 83 |
+
|
| 84 |
+
for field in fields(obj):
|
| 85 |
+
value = getattr(obj, field.name)
|
| 86 |
+
default_value = default_values.get(field.name, field.default)
|
| 87 |
+
is_required = isinstance(field.default, _MISSING_TYPE)
|
| 88 |
+
if is_required or not equals(value, default_value):
|
| 89 |
+
non_default_values[field.name] = value
|
| 90 |
+
|
| 91 |
+
string = f"{obj.__class__.__name__}("
|
| 92 |
+
string += ", ".join(
|
| 93 |
+
f"{name}={value!r}" for name, value in non_default_values.items()
|
| 94 |
+
)
|
| 95 |
+
string += ")"
|
| 96 |
+
|
| 97 |
+
return string
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
@PublicAPI(stability="stable")
|
| 102 |
+
class ScalingConfig:
|
| 103 |
+
"""Configuration for scaling training.
|
| 104 |
+
|
| 105 |
+
For more details, see :ref:`train_scaling_config`.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
trainer_resources: Resources to allocate for the training coordinator.
|
| 109 |
+
The training coordinator launches the worker group and executes
|
| 110 |
+
the training function per worker, and this process does NOT require
|
| 111 |
+
GPUs. The coordinator is always scheduled on the same node as the
|
| 112 |
+
rank 0 worker, so one example use case is to set a minimum amount
|
| 113 |
+
of resources (e.g. CPU memory) required by the rank 0 node.
|
| 114 |
+
By default, this assigns 1 CPU to the training coordinator.
|
| 115 |
+
num_workers: The number of workers (Ray actors) to launch.
|
| 116 |
+
Each worker will reserve 1 CPU by default. The number of CPUs
|
| 117 |
+
reserved by each worker can be overridden with the
|
| 118 |
+
``resources_per_worker`` argument.
|
| 119 |
+
use_gpu: If True, training will be done on GPUs (1 per worker).
|
| 120 |
+
Defaults to False. The number of GPUs reserved by each
|
| 121 |
+
worker can be overridden with the ``resources_per_worker``
|
| 122 |
+
argument.
|
| 123 |
+
resources_per_worker: If specified, the resources
|
| 124 |
+
defined in this Dict is reserved for each worker.
|
| 125 |
+
Define the ``"CPU"`` key (case-sensitive) to
|
| 126 |
+
override the number of CPUs used by each worker.
|
| 127 |
+
This can also be used to request :ref:`custom resources <custom-resources>`.
|
| 128 |
+
placement_strategy: The placement strategy to use for the
|
| 129 |
+
placement group of the Ray actors. See :ref:`Placement Group
|
| 130 |
+
Strategies <pgroup-strategy>` for the possible options.
|
| 131 |
+
accelerator_type: [Experimental] If specified, Ray Train will launch the
|
| 132 |
+
training coordinator and workers on the nodes with the specified type
|
| 133 |
+
of accelerators.
|
| 134 |
+
See :ref:`the available accelerator types <accelerator_types>`.
|
| 135 |
+
Ensure that your cluster has instances with the specified accelerator type
|
| 136 |
+
or is able to autoscale to fulfill the request.
|
| 137 |
+
|
| 138 |
+
Example:
|
| 139 |
+
|
| 140 |
+
.. code-block:: python
|
| 141 |
+
|
| 142 |
+
from ray.train import ScalingConfig
|
| 143 |
+
scaling_config = ScalingConfig(
|
| 144 |
+
# Number of distributed workers.
|
| 145 |
+
num_workers=2,
|
| 146 |
+
# Turn on/off GPU.
|
| 147 |
+
use_gpu=True,
|
| 148 |
+
# Assign extra CPU/GPU/custom resources per worker.
|
| 149 |
+
resources_per_worker={"GPU": 1, "CPU": 1, "memory": 1e9, "custom": 1.0},
|
| 150 |
+
# Try to schedule workers on different nodes.
|
| 151 |
+
placement_strategy="SPREAD",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
trainer_resources: Optional[Union[Dict, SampleRange]] = None
|
| 157 |
+
num_workers: Union[int, SampleRange] = 1
|
| 158 |
+
use_gpu: Union[bool, SampleRange] = False
|
| 159 |
+
resources_per_worker: Optional[Union[Dict, SampleRange]] = None
|
| 160 |
+
placement_strategy: Union[str, SampleRange] = "PACK"
|
| 161 |
+
accelerator_type: Optional[str] = None
|
| 162 |
+
|
| 163 |
+
def __post_init__(self):
|
| 164 |
+
if self.resources_per_worker:
|
| 165 |
+
if not self.use_gpu and self.num_gpus_per_worker > 0:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
"`use_gpu` is False but `GPU` was found in "
|
| 168 |
+
"`resources_per_worker`. Either set `use_gpu` to True or "
|
| 169 |
+
"remove `GPU` from `resources_per_worker."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if self.use_gpu and self.num_gpus_per_worker == 0:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"`use_gpu` is True but `GPU` is set to 0 in "
|
| 175 |
+
"`resources_per_worker`. Either set `use_gpu` to False or "
|
| 176 |
+
"request a positive number of `GPU` in "
|
| 177 |
+
"`resources_per_worker."
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def __repr__(self):
|
| 181 |
+
return _repr_dataclass(self)
|
| 182 |
+
|
| 183 |
+
def _repr_html_(self) -> str:
|
| 184 |
+
return make_table_html_repr(obj=self, title=type(self).__name__)
|
| 185 |
+
|
| 186 |
+
def __eq__(self, o: "ScalingConfig") -> bool:
|
| 187 |
+
if not isinstance(o, type(self)):
|
| 188 |
+
return False
|
| 189 |
+
return self.as_placement_group_factory() == o.as_placement_group_factory()
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
def _resources_per_worker_not_none(self):
|
| 193 |
+
if self.resources_per_worker is None:
|
| 194 |
+
if self.use_gpu:
|
| 195 |
+
# Note that we don't request any CPUs, which avoids possible
|
| 196 |
+
# scheduling contention. Generally nodes have many more CPUs than
|
| 197 |
+
# GPUs, so not requesting a CPU does not lead to oversubscription.
|
| 198 |
+
resources_per_worker = {"GPU": 1}
|
| 199 |
+
else:
|
| 200 |
+
resources_per_worker = {"CPU": 1}
|
| 201 |
+
else:
|
| 202 |
+
resources_per_worker = {
|
| 203 |
+
k: v for k, v in self.resources_per_worker.items() if v != 0
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
if self.use_gpu:
|
| 207 |
+
resources_per_worker.setdefault("GPU", 1)
|
| 208 |
+
|
| 209 |
+
if self.accelerator_type:
|
| 210 |
+
accelerator = f"{RESOURCE_CONSTRAINT_PREFIX}{self.accelerator_type}"
|
| 211 |
+
resources_per_worker.setdefault(accelerator, 0.001)
|
| 212 |
+
return resources_per_worker
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def _trainer_resources_not_none(self):
|
| 216 |
+
if self.trainer_resources is None:
|
| 217 |
+
if self.num_workers:
|
| 218 |
+
# For Google Colab, don't allocate resources to the base Trainer.
|
| 219 |
+
# Colab only has 2 CPUs, and because of this resource scarcity,
|
| 220 |
+
# we have to be careful on where we allocate resources. Since Colab
|
| 221 |
+
# is not distributed, the concern about many parallel Ray Tune trials
|
| 222 |
+
# leading to all Trainers being scheduled on the head node if we set
|
| 223 |
+
# `trainer_resources` to 0 is no longer applicable.
|
| 224 |
+
try:
|
| 225 |
+
import google.colab # noqa: F401
|
| 226 |
+
|
| 227 |
+
trainer_num_cpus = 0
|
| 228 |
+
except ImportError:
|
| 229 |
+
trainer_num_cpus = 1
|
| 230 |
+
else:
|
| 231 |
+
# If there are no additional workers, then always reserve 1 CPU for
|
| 232 |
+
# the Trainer.
|
| 233 |
+
trainer_num_cpus = 1
|
| 234 |
+
|
| 235 |
+
trainer_resources = {"CPU": trainer_num_cpus}
|
| 236 |
+
else:
|
| 237 |
+
trainer_resources = {
|
| 238 |
+
k: v for k, v in self.trainer_resources.items() if v != 0
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
return trainer_resources
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def total_resources(self):
|
| 245 |
+
"""Map of total resources required for the trainer."""
|
| 246 |
+
total_resource_map = defaultdict(float, self._trainer_resources_not_none)
|
| 247 |
+
for k, value in self._resources_per_worker_not_none.items():
|
| 248 |
+
total_resource_map[k] += value * self.num_workers
|
| 249 |
+
return dict(total_resource_map)
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def num_cpus_per_worker(self):
|
| 253 |
+
"""The number of CPUs to set per worker."""
|
| 254 |
+
return self._resources_per_worker_not_none.get("CPU", 0)
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def num_gpus_per_worker(self):
|
| 258 |
+
"""The number of GPUs to set per worker."""
|
| 259 |
+
return self._resources_per_worker_not_none.get("GPU", 0)
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def additional_resources_per_worker(self):
|
| 263 |
+
"""Resources per worker, not including CPU or GPU resources."""
|
| 264 |
+
return {
|
| 265 |
+
k: v
|
| 266 |
+
for k, v in self._resources_per_worker_not_none.items()
|
| 267 |
+
if k not in ["CPU", "GPU"]
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
def as_placement_group_factory(self) -> "PlacementGroupFactory":
|
| 271 |
+
"""Returns a PlacementGroupFactory to specify resources for Tune."""
|
| 272 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 273 |
+
|
| 274 |
+
trainer_bundle = self._trainer_resources_not_none
|
| 275 |
+
worker_bundle = self._resources_per_worker_not_none
|
| 276 |
+
|
| 277 |
+
# Colocate Trainer and rank0 worker by merging their bundles
|
| 278 |
+
# Note: This empty bundle is required so that the Tune actor manager schedules
|
| 279 |
+
# the Trainable onto the combined bundle while taking none of its resources,
|
| 280 |
+
# rather than a non-empty head bundle.
|
| 281 |
+
combined_bundle = dict(Counter(trainer_bundle) + Counter(worker_bundle))
|
| 282 |
+
bundles = [{}, combined_bundle] + [worker_bundle] * (self.num_workers - 1)
|
| 283 |
+
return PlacementGroupFactory(bundles, strategy=self.placement_strategy)
|
| 284 |
+
|
| 285 |
+
@classmethod
|
| 286 |
+
def from_placement_group_factory(
|
| 287 |
+
cls, pgf: "PlacementGroupFactory"
|
| 288 |
+
) -> "ScalingConfig":
|
| 289 |
+
"""Create a ScalingConfig from a Tune's PlacementGroupFactory
|
| 290 |
+
|
| 291 |
+
Note that this is only needed for ResourceChangingScheduler, which
|
| 292 |
+
modifies a trial's PlacementGroupFactory but doesn't propagate
|
| 293 |
+
the changes to ScalingConfig. TrainTrainable needs to reconstruct
|
| 294 |
+
a ScalingConfig from on the trial's PlacementGroupFactory.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
# pgf.bundles = [{trainer + worker}, {worker}, ..., {worker}]
|
| 298 |
+
num_workers = len(pgf.bundles)
|
| 299 |
+
combined_resources = pgf.bundles[0]
|
| 300 |
+
resources_per_worker = pgf.bundles[-1]
|
| 301 |
+
use_gpu = bool(resources_per_worker.get("GPU", False))
|
| 302 |
+
placement_strategy = pgf.strategy
|
| 303 |
+
|
| 304 |
+
# In `as_placement_group_factory`, we merged the trainer resource into the
|
| 305 |
+
# first worker resources bundle. We need to calculate the resources diff to
|
| 306 |
+
# get the trainer resources.
|
| 307 |
+
# Note: If there's only one worker, we won't be able to calculate the diff.
|
| 308 |
+
# We'll have empty trainer bundle and assign all resources to the worker.
|
| 309 |
+
trainer_resources = dict(
|
| 310 |
+
Counter(combined_resources) - Counter(resources_per_worker)
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return ScalingConfig(
|
| 314 |
+
trainer_resources=trainer_resources,
|
| 315 |
+
num_workers=num_workers,
|
| 316 |
+
use_gpu=use_gpu,
|
| 317 |
+
resources_per_worker=resources_per_worker,
|
| 318 |
+
placement_strategy=placement_strategy,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@dataclass
|
| 323 |
+
@Deprecated(DATASET_CONFIG_DEPRECATION_MSG)
|
| 324 |
+
class DatasetConfig:
|
| 325 |
+
"""Configuration for ingest of a single Dataset.
|
| 326 |
+
|
| 327 |
+
See :ref:`the AIR Dataset configuration guide <data-ingest-torch>` for
|
| 328 |
+
usage examples.
|
| 329 |
+
|
| 330 |
+
This config defines how the Dataset should be read into the DataParallelTrainer.
|
| 331 |
+
It configures the preprocessing, splitting, and ingest strategy per-dataset.
|
| 332 |
+
|
| 333 |
+
DataParallelTrainers declare default DatasetConfigs for each dataset passed in the
|
| 334 |
+
``datasets`` argument. Users have the opportunity to selectively override these
|
| 335 |
+
configs by passing the ``dataset_config`` argument. Trainers can also define user
|
| 336 |
+
customizable values (e.g., XGBoostTrainer doesn't support streaming ingest).
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
fit: Whether to fit preprocessors on this dataset. This can be set on at most
|
| 340 |
+
one dataset at a time. True by default for the "train" dataset only.
|
| 341 |
+
split: Whether the dataset should be split across multiple workers.
|
| 342 |
+
True by default for the "train" dataset only.
|
| 343 |
+
required: Whether to raise an error if the Dataset isn't provided by the user.
|
| 344 |
+
False by default.
|
| 345 |
+
transform: Whether to transform the dataset with the fitted preprocessor.
|
| 346 |
+
This must be enabled at least for the dataset that is fit.
|
| 347 |
+
True by default.
|
| 348 |
+
max_object_store_memory_fraction [Experimental]: The maximum fraction
|
| 349 |
+
of Ray's shared-memory object store to use for the dataset. The
|
| 350 |
+
default value is -1, meaning that the preprocessed dataset should
|
| 351 |
+
be cached, which may cause spilling if its size is larger than the
|
| 352 |
+
object store's capacity. Pipelined ingest (all other values, 0 or
|
| 353 |
+
higher) is experimental. Note that the absolute memory capacity
|
| 354 |
+
used is based on the object store capacity at invocation time; this
|
| 355 |
+
does not currently cover autoscaling cases where the size of the
|
| 356 |
+
cluster may change.
|
| 357 |
+
global_shuffle: Whether to enable global shuffle (per pipeline window
|
| 358 |
+
in streaming mode). Note that this is an expensive all-to-all operation,
|
| 359 |
+
and most likely you want to use local shuffle instead.
|
| 360 |
+
See https://docs.ray.io/en/master/data/faq.html and
|
| 361 |
+
https://docs.ray.io/en/master/ray-air/check-ingest.html.
|
| 362 |
+
False by default.
|
| 363 |
+
randomize_block_order: Whether to randomize the iteration order over blocks.
|
| 364 |
+
The main purpose of this is to prevent data fetching hotspots in the
|
| 365 |
+
cluster when running many parallel workers / trials on the same data.
|
| 366 |
+
We recommend enabling it always. True by default.
|
| 367 |
+
per_epoch_preprocessor [Experimental]: A preprocessor to re-apply on
|
| 368 |
+
each pass of the dataset. The main use case for this is to apply a
|
| 369 |
+
random transform on a training dataset on each epoch. The
|
| 370 |
+
per-epoch preprocessor will be applied *after* all other
|
| 371 |
+
preprocessors and in parallel with the dataset consumer.
|
| 372 |
+
use_stream_api: Deprecated. Use max_object_store_memory_fraction instead.
|
| 373 |
+
stream_window_size: Deprecated. Use max_object_store_memory_fraction instead.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
# TODO(ekl) could we unify DataParallelTrainer and Trainer so the same data ingest
|
| 377 |
+
# strategy applies to all Trainers?
|
| 378 |
+
|
| 379 |
+
fit: Optional[bool] = None
|
| 380 |
+
split: Optional[bool] = None
|
| 381 |
+
required: Optional[bool] = None
|
| 382 |
+
transform: Optional[bool] = None
|
| 383 |
+
max_object_store_memory_fraction: Optional[float] = None
|
| 384 |
+
global_shuffle: Optional[bool] = None
|
| 385 |
+
randomize_block_order: Optional[bool] = None
|
| 386 |
+
per_epoch_preprocessor: Optional["Preprocessor"] = None
|
| 387 |
+
# Deprecated.
|
| 388 |
+
use_stream_api: Optional[int] = None
|
| 389 |
+
stream_window_size: Optional[int] = None
|
| 390 |
+
|
| 391 |
+
def __post_init__(self):
|
| 392 |
+
raise DeprecationWarning(DATASET_CONFIG_DEPRECATION_MSG)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@dataclass
|
| 396 |
+
@PublicAPI(stability="stable")
|
| 397 |
+
class FailureConfig:
|
| 398 |
+
"""Configuration related to failure handling of each training/tuning run.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
max_failures: Tries to recover a run at least this many times.
|
| 402 |
+
Will recover from the latest checkpoint if present.
|
| 403 |
+
Setting to -1 will lead to infinite recovery retries.
|
| 404 |
+
Setting to 0 will disable retries. Defaults to 0.
|
| 405 |
+
fail_fast: Whether to fail upon the first error.
|
| 406 |
+
If fail_fast='raise' provided, the original error during training will be
|
| 407 |
+
immediately raised. fail_fast='raise' can easily leak resources and
|
| 408 |
+
should be used with caution.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
max_failures: int = 0
|
| 412 |
+
fail_fast: Union[bool, str] = False
|
| 413 |
+
|
| 414 |
+
def __post_init__(self):
|
| 415 |
+
# Same check as in TuneController
|
| 416 |
+
if not (isinstance(self.fail_fast, bool) or self.fail_fast.upper() == "RAISE"):
|
| 417 |
+
raise ValueError(
|
| 418 |
+
"fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Same check as in tune.run
|
| 422 |
+
if self.fail_fast and self.max_failures != 0:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"max_failures must be 0 if fail_fast={repr(self.fail_fast)}."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
def __repr__(self):
|
| 428 |
+
return _repr_dataclass(self)
|
| 429 |
+
|
| 430 |
+
def _repr_html_(self):
|
| 431 |
+
return Template("scrollableTable.html.j2").render(
|
| 432 |
+
table=tabulate(
|
| 433 |
+
{
|
| 434 |
+
"Setting": ["Max failures", "Fail fast"],
|
| 435 |
+
"Value": [self.max_failures, self.fail_fast],
|
| 436 |
+
},
|
| 437 |
+
tablefmt="html",
|
| 438 |
+
showindex=False,
|
| 439 |
+
headers="keys",
|
| 440 |
+
),
|
| 441 |
+
max_height="none",
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
@dataclass
|
| 446 |
+
@PublicAPI(stability="stable")
|
| 447 |
+
class CheckpointConfig:
|
| 448 |
+
"""Configurable parameters for defining the checkpointing strategy.
|
| 449 |
+
|
| 450 |
+
Default behavior is to persist all checkpoints to disk. If
|
| 451 |
+
``num_to_keep`` is set, the default retention policy is to keep the
|
| 452 |
+
checkpoints with maximum timestamp, i.e. the most recent checkpoints.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
num_to_keep: The number of checkpoints to keep
|
| 456 |
+
on disk for this run. If a checkpoint is persisted to disk after
|
| 457 |
+
there are already this many checkpoints, then an existing
|
| 458 |
+
checkpoint will be deleted. If this is ``None`` then checkpoints
|
| 459 |
+
will not be deleted. Must be >= 1.
|
| 460 |
+
checkpoint_score_attribute: The attribute that will be used to
|
| 461 |
+
score checkpoints to determine which checkpoints should be kept
|
| 462 |
+
on disk when there are greater than ``num_to_keep`` checkpoints.
|
| 463 |
+
This attribute must be a key from the checkpoint
|
| 464 |
+
dictionary which has a numerical value. Per default, the last
|
| 465 |
+
checkpoints will be kept.
|
| 466 |
+
checkpoint_score_order: Either "max" or "min".
|
| 467 |
+
If "max", then checkpoints with highest values of
|
| 468 |
+
``checkpoint_score_attribute`` will be kept.
|
| 469 |
+
If "min", then checkpoints with lowest values of
|
| 470 |
+
``checkpoint_score_attribute`` will be kept.
|
| 471 |
+
checkpoint_frequency: Number of iterations between checkpoints. If 0
|
| 472 |
+
this will disable checkpointing.
|
| 473 |
+
Please note that most trainers will still save one checkpoint at
|
| 474 |
+
the end of training.
|
| 475 |
+
This attribute is only supported
|
| 476 |
+
by trainers that don't take in custom training loops.
|
| 477 |
+
checkpoint_at_end: If True, will save a checkpoint at the end of training.
|
| 478 |
+
This attribute is only supported by trainers that don't take in
|
| 479 |
+
custom training loops. Defaults to True for trainers that support it
|
| 480 |
+
and False for generic function trainables.
|
| 481 |
+
_checkpoint_keep_all_ranks: This experimental config is deprecated.
|
| 482 |
+
This behavior is now controlled by reporting `checkpoint=None`
|
| 483 |
+
in the workers that shouldn't persist a checkpoint.
|
| 484 |
+
For example, if you only want the rank 0 worker to persist a checkpoint
|
| 485 |
+
(e.g., in standard data parallel training), then you should save and
|
| 486 |
+
report a checkpoint if `ray.train.get_context().get_world_rank() == 0`
|
| 487 |
+
and `None` otherwise.
|
| 488 |
+
_checkpoint_upload_from_workers: This experimental config is deprecated.
|
| 489 |
+
Uploading checkpoint directly from the worker is now the default behavior.
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
num_to_keep: Optional[int] = None
|
| 493 |
+
checkpoint_score_attribute: Optional[str] = None
|
| 494 |
+
checkpoint_score_order: Optional[str] = MAX
|
| 495 |
+
checkpoint_frequency: Optional[int] = 0
|
| 496 |
+
checkpoint_at_end: Optional[bool] = None
|
| 497 |
+
_checkpoint_keep_all_ranks: Optional[bool] = _DEPRECATED_VALUE
|
| 498 |
+
_checkpoint_upload_from_workers: Optional[bool] = _DEPRECATED_VALUE
|
| 499 |
+
|
| 500 |
+
def __post_init__(self):
|
| 501 |
+
if self._checkpoint_keep_all_ranks != _DEPRECATED_VALUE:
|
| 502 |
+
raise DeprecationWarning(
|
| 503 |
+
"The experimental `_checkpoint_keep_all_ranks` config is deprecated. "
|
| 504 |
+
"This behavior is now controlled by reporting `checkpoint=None` "
|
| 505 |
+
"in the workers that shouldn't persist a checkpoint. "
|
| 506 |
+
"For example, if you only want the rank 0 worker to persist a "
|
| 507 |
+
"checkpoint (e.g., in standard data parallel training), "
|
| 508 |
+
"then you should save and report a checkpoint if "
|
| 509 |
+
"`ray.train.get_context().get_world_rank() == 0` "
|
| 510 |
+
"and `None` otherwise."
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if self._checkpoint_upload_from_workers != _DEPRECATED_VALUE:
|
| 514 |
+
raise DeprecationWarning(
|
| 515 |
+
"The experimental `_checkpoint_upload_from_workers` config is "
|
| 516 |
+
"deprecated. Uploading checkpoint directly from the worker is "
|
| 517 |
+
"now the default behavior."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if self.num_to_keep is not None and self.num_to_keep <= 0:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"Received invalid num_to_keep: "
|
| 523 |
+
f"{self.num_to_keep}. "
|
| 524 |
+
f"Must be None or an integer >= 1."
|
| 525 |
+
)
|
| 526 |
+
if self.checkpoint_score_order not in (MAX, MIN):
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".'
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if self.checkpoint_frequency < 0:
|
| 532 |
+
raise ValueError(
|
| 533 |
+
f"checkpoint_frequency must be >=0, got {self.checkpoint_frequency}"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def __repr__(self):
|
| 537 |
+
return _repr_dataclass(self)
|
| 538 |
+
|
| 539 |
+
def _repr_html_(self) -> str:
|
| 540 |
+
if self.num_to_keep is None:
|
| 541 |
+
num_to_keep_repr = "All"
|
| 542 |
+
else:
|
| 543 |
+
num_to_keep_repr = self.num_to_keep
|
| 544 |
+
|
| 545 |
+
if self.checkpoint_score_attribute is None:
|
| 546 |
+
checkpoint_score_attribute_repr = "Most recent"
|
| 547 |
+
else:
|
| 548 |
+
checkpoint_score_attribute_repr = self.checkpoint_score_attribute
|
| 549 |
+
|
| 550 |
+
if self.checkpoint_at_end is None:
|
| 551 |
+
checkpoint_at_end_repr = ""
|
| 552 |
+
else:
|
| 553 |
+
checkpoint_at_end_repr = self.checkpoint_at_end
|
| 554 |
+
|
| 555 |
+
return Template("scrollableTable.html.j2").render(
|
| 556 |
+
table=tabulate(
|
| 557 |
+
{
|
| 558 |
+
"Setting": [
|
| 559 |
+
"Number of checkpoints to keep",
|
| 560 |
+
"Checkpoint score attribute",
|
| 561 |
+
"Checkpoint score order",
|
| 562 |
+
"Checkpoint frequency",
|
| 563 |
+
"Checkpoint at end",
|
| 564 |
+
],
|
| 565 |
+
"Value": [
|
| 566 |
+
num_to_keep_repr,
|
| 567 |
+
checkpoint_score_attribute_repr,
|
| 568 |
+
self.checkpoint_score_order,
|
| 569 |
+
self.checkpoint_frequency,
|
| 570 |
+
checkpoint_at_end_repr,
|
| 571 |
+
],
|
| 572 |
+
},
|
| 573 |
+
tablefmt="html",
|
| 574 |
+
showindex=False,
|
| 575 |
+
headers="keys",
|
| 576 |
+
),
|
| 577 |
+
max_height="none",
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
@property
|
| 581 |
+
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
|
| 582 |
+
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
|
| 583 |
+
|
| 584 |
+
Only used for Legacy API compatibility.
|
| 585 |
+
"""
|
| 586 |
+
if self.checkpoint_score_attribute is None:
|
| 587 |
+
return self.checkpoint_score_attribute
|
| 588 |
+
prefix = ""
|
| 589 |
+
if self.checkpoint_score_order == MIN:
|
| 590 |
+
prefix = "min-"
|
| 591 |
+
return f"{prefix}{self.checkpoint_score_attribute}"
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
@dataclass
|
| 595 |
+
@PublicAPI(stability="stable")
|
| 596 |
+
class RunConfig:
|
| 597 |
+
"""Runtime configuration for training and tuning runs.
|
| 598 |
+
|
| 599 |
+
Upon resuming from a training or tuning run checkpoint,
|
| 600 |
+
Ray Train/Tune will automatically apply the RunConfig from
|
| 601 |
+
the previously checkpointed run.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
name: Name of the trial or experiment. If not provided, will be deduced
|
| 605 |
+
from the Trainable.
|
| 606 |
+
storage_path: [Beta] Path where all results and checkpoints are persisted.
|
| 607 |
+
Can be a local directory or a destination on cloud storage.
|
| 608 |
+
For multi-node training/tuning runs, this must be set to a
|
| 609 |
+
shared storage location (e.g., S3, NFS).
|
| 610 |
+
This defaults to the local ``~/ray_results`` directory.
|
| 611 |
+
storage_filesystem: [Beta] A custom filesystem to use for storage.
|
| 612 |
+
If this is provided, `storage_path` should be a path with its
|
| 613 |
+
prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
|
| 614 |
+
failure_config: Failure mode configuration.
|
| 615 |
+
checkpoint_config: Checkpointing configuration.
|
| 616 |
+
sync_config: Configuration object for syncing. See train.SyncConfig.
|
| 617 |
+
verbose: 0, 1, or 2. Verbosity mode.
|
| 618 |
+
0 = silent, 1 = default, 2 = verbose. Defaults to 1.
|
| 619 |
+
If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
|
| 620 |
+
uses the old verbosity settings:
|
| 621 |
+
0 = silent, 1 = only status updates, 2 = status and brief
|
| 622 |
+
results, 3 = status and detailed results.
|
| 623 |
+
stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
|
| 624 |
+
for more info. Stoppers should be serializable.
|
| 625 |
+
callbacks: [DeveloperAPI] Callbacks to invoke.
|
| 626 |
+
Refer to ray.tune.callback.Callback for more info.
|
| 627 |
+
Callbacks should be serializable.
|
| 628 |
+
Currently only stateless callbacks are supported for resumed runs.
|
| 629 |
+
(any state of the callback will not be checkpointed by Tune
|
| 630 |
+
and thus will not take effect in resumed runs).
|
| 631 |
+
progress_reporter: [DeveloperAPI] Progress reporter for reporting
|
| 632 |
+
intermediate experiment progress. Defaults to CLIReporter if
|
| 633 |
+
running in command-line, or JupyterNotebookReporter if running in
|
| 634 |
+
a Jupyter notebook.
|
| 635 |
+
log_to_file: [DeveloperAPI] Log stdout and stderr to files in
|
| 636 |
+
trial directories. If this is `False` (default), no files
|
| 637 |
+
are written. If `true`, outputs are written to `trialdir/stdout`
|
| 638 |
+
and `trialdir/stderr`, respectively. If this is a single string,
|
| 639 |
+
this is interpreted as a file relative to the trialdir, to which
|
| 640 |
+
both streams are written. If this is a Sequence (e.g. a Tuple),
|
| 641 |
+
it has to have length 2 and the elements indicate the files to
|
| 642 |
+
which stdout and stderr are written, respectively.
|
| 643 |
+
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
name: Optional[str] = None
|
| 647 |
+
storage_path: Optional[str] = None
|
| 648 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
|
| 649 |
+
failure_config: Optional[FailureConfig] = None
|
| 650 |
+
checkpoint_config: Optional[CheckpointConfig] = None
|
| 651 |
+
sync_config: Optional["SyncConfig"] = None
|
| 652 |
+
verbose: Optional[Union[int, "AirVerbosity", "Verbosity"]] = None
|
| 653 |
+
stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
|
| 654 |
+
callbacks: Optional[List["Callback"]] = None
|
| 655 |
+
progress_reporter: Optional[
|
| 656 |
+
"ray.tune.progress_reporter.ProgressReporter" # noqa: F821
|
| 657 |
+
] = None
|
| 658 |
+
log_to_file: Union[bool, str, Tuple[str, str]] = False
|
| 659 |
+
|
| 660 |
+
# Deprecated
|
| 661 |
+
local_dir: Optional[str] = None
|
| 662 |
+
|
| 663 |
+
def __post_init__(self):
|
| 664 |
+
from ray.train import SyncConfig
|
| 665 |
+
from ray.train.constants import DEFAULT_STORAGE_PATH
|
| 666 |
+
from ray.tune.experimental.output import AirVerbosity, get_air_verbosity
|
| 667 |
+
|
| 668 |
+
if self.local_dir is not None:
|
| 669 |
+
raise DeprecationWarning(
|
| 670 |
+
"The `RunConfig(local_dir)` argument is deprecated. "
|
| 671 |
+
"You should set the `RunConfig(storage_path)` instead."
|
| 672 |
+
"See the docs: https://docs.ray.io/en/latest/train/user-guides/"
|
| 673 |
+
"persistent-storage.html#setting-the-local-staging-directory"
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
if self.storage_path is None:
|
| 677 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.30
|
| 678 |
+
self.storage_path = DEFAULT_STORAGE_PATH
|
| 679 |
+
|
| 680 |
+
# If no remote path is set, try to get Ray Storage URI
|
| 681 |
+
ray_storage_uri: Optional[str] = _get_storage_uri()
|
| 682 |
+
if ray_storage_uri is not None:
|
| 683 |
+
logger.info(
|
| 684 |
+
"Using configured Ray Storage URI as the `storage_path`: "
|
| 685 |
+
f"{ray_storage_uri}"
|
| 686 |
+
)
|
| 687 |
+
self.storage_path = ray_storage_uri
|
| 688 |
+
|
| 689 |
+
if not self.failure_config:
|
| 690 |
+
self.failure_config = FailureConfig()
|
| 691 |
+
|
| 692 |
+
if not self.sync_config:
|
| 693 |
+
self.sync_config = SyncConfig()
|
| 694 |
+
|
| 695 |
+
if not self.checkpoint_config:
|
| 696 |
+
self.checkpoint_config = CheckpointConfig()
|
| 697 |
+
|
| 698 |
+
if self.verbose is None:
|
| 699 |
+
# Default `verbose` value. For new output engine,
|
| 700 |
+
# this is AirVerbosity.DEFAULT.
|
| 701 |
+
# For old output engine, this is Verbosity.V3_TRIAL_DETAILS
|
| 702 |
+
# Todo (krfricke): Currently uses number to pass test_configs::test_repr
|
| 703 |
+
self.verbose = get_air_verbosity(AirVerbosity.DEFAULT) or 3
|
| 704 |
+
|
| 705 |
+
if isinstance(self.storage_path, Path):
|
| 706 |
+
self.storage_path = self.storage_path.as_posix()
|
| 707 |
+
|
| 708 |
+
def __repr__(self):
|
| 709 |
+
from ray.train import SyncConfig
|
| 710 |
+
|
| 711 |
+
return _repr_dataclass(
|
| 712 |
+
self,
|
| 713 |
+
default_values={
|
| 714 |
+
"failure_config": FailureConfig(),
|
| 715 |
+
"sync_config": SyncConfig(),
|
| 716 |
+
"checkpoint_config": CheckpointConfig(),
|
| 717 |
+
},
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
def _repr_html_(self) -> str:
|
| 721 |
+
reprs = []
|
| 722 |
+
if self.failure_config is not None:
|
| 723 |
+
reprs.append(
|
| 724 |
+
Template("title_data_mini.html.j2").render(
|
| 725 |
+
title="Failure Config", data=self.failure_config._repr_html_()
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
if self.sync_config is not None:
|
| 729 |
+
reprs.append(
|
| 730 |
+
Template("title_data_mini.html.j2").render(
|
| 731 |
+
title="Sync Config", data=self.sync_config._repr_html_()
|
| 732 |
+
)
|
| 733 |
+
)
|
| 734 |
+
if self.checkpoint_config is not None:
|
| 735 |
+
reprs.append(
|
| 736 |
+
Template("title_data_mini.html.j2").render(
|
| 737 |
+
title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
|
| 738 |
+
)
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Create a divider between each displayed repr
|
| 742 |
+
subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
|
| 743 |
+
subconfigs[::2] = reprs
|
| 744 |
+
|
| 745 |
+
settings = Template("scrollableTable.html.j2").render(
|
| 746 |
+
table=tabulate(
|
| 747 |
+
{
|
| 748 |
+
"Name": self.name,
|
| 749 |
+
"Local results directory": self.local_dir,
|
| 750 |
+
"Verbosity": self.verbose,
|
| 751 |
+
"Log to file": self.log_to_file,
|
| 752 |
+
}.items(),
|
| 753 |
+
tablefmt="html",
|
| 754 |
+
headers=["Setting", "Value"],
|
| 755 |
+
showindex=False,
|
| 756 |
+
),
|
| 757 |
+
max_height="300px",
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
return Template("title_data.html.j2").render(
|
| 761 |
+
title="RunConfig",
|
| 762 |
+
data=Template("run_config.html.j2").render(
|
| 763 |
+
subconfigs=subconfigs,
|
| 764 |
+
settings=settings,
|
| 765 |
+
),
|
| 766 |
+
)
|
.venv/lib/python3.11/site-packages/ray/air/constants.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Key to denote the preprocessor in the checkpoint dict.
|
| 2 |
+
PREPROCESSOR_KEY = "_preprocessor"
|
| 3 |
+
|
| 4 |
+
# Key to denote the model in the checkpoint dict.
|
| 5 |
+
MODEL_KEY = "model"
|
| 6 |
+
|
| 7 |
+
# Key to denote which dataset is the evaluation dataset.
|
| 8 |
+
# Only used in trainers which do not support multiple
|
| 9 |
+
# evaluation datasets.
|
| 10 |
+
EVALUATION_DATASET_KEY = "evaluation"
|
| 11 |
+
|
| 12 |
+
# Key to denote which dataset is the training dataset.
|
| 13 |
+
# This is the dataset that the preprocessor is fit on.
|
| 14 |
+
TRAIN_DATASET_KEY = "train"
|
| 15 |
+
|
| 16 |
+
# Name to use for the column when representing tensors in table format.
|
| 17 |
+
TENSOR_COLUMN_NAME = "__value__"
|
| 18 |
+
|
| 19 |
+
# The maximum length of strings returned by `__repr__` for AIR objects constructed with
|
| 20 |
+
# default values.
|
| 21 |
+
MAX_REPR_LENGTH = int(80 * 1.5)
|
| 22 |
+
|
| 23 |
+
# Timeout used when putting exceptions raised by runner thread into the queue.
|
| 24 |
+
_ERROR_REPORT_TIMEOUT = 10
|
| 25 |
+
|
| 26 |
+
# Timeout when fetching new results after signaling the training function to continue.
|
| 27 |
+
_RESULT_FETCH_TIMEOUT = 0.2
|
| 28 |
+
|
| 29 |
+
# Timeout for fetching exceptions raised by the training function.
|
| 30 |
+
_ERROR_FETCH_TIMEOUT = 1
|
| 31 |
+
|
| 32 |
+
# The key used to identify whether we have already warned about ray.air.session
|
| 33 |
+
# functions being used outside of the session
|
| 34 |
+
SESSION_MISUSE_LOG_ONCE_KEY = "air_warn_session_misuse"
|
| 35 |
+
|
| 36 |
+
# Name of attribute in Checkpoint storing current Tune ID for restoring
|
| 37 |
+
# training with Ray Train
|
| 38 |
+
CHECKPOINT_ID_ATTR = "_current_checkpoint_id"
|
| 39 |
+
|
| 40 |
+
# Name of the marker dropped by the Trainable. If a worker detects
|
| 41 |
+
# the presence of the marker in the trial dir, it will use lazy
|
| 42 |
+
# checkpointing.
|
| 43 |
+
LAZY_CHECKPOINT_MARKER_FILE = ".lazy_checkpoint_marker"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# The timestamp of when the result is generated.
|
| 47 |
+
# Default to when the result is processed by tune.
|
| 48 |
+
TIMESTAMP = "timestamp"
|
| 49 |
+
|
| 50 |
+
# (Auto-filled) Time in seconds this iteration took to run.
|
| 51 |
+
# This may be overridden to override the system-computed time difference.
|
| 52 |
+
TIME_THIS_ITER_S = "time_this_iter_s"
|
| 53 |
+
|
| 54 |
+
# (Auto-filled) The index of this training iteration.
|
| 55 |
+
TRAINING_ITERATION = "training_iteration"
|
| 56 |
+
|
| 57 |
+
# File that stores parameters of the trial.
|
| 58 |
+
EXPR_PARAM_FILE = "params.json"
|
| 59 |
+
|
| 60 |
+
# Pickle File that stores parameters of the trial.
|
| 61 |
+
EXPR_PARAM_PICKLE_FILE = "params.pkl"
|
| 62 |
+
|
| 63 |
+
# File that stores the progress of the trial.
|
| 64 |
+
EXPR_PROGRESS_FILE = "progress.csv"
|
| 65 |
+
|
| 66 |
+
# File that stores results of the trial.
|
| 67 |
+
EXPR_RESULT_FILE = "result.json"
|
| 68 |
+
|
| 69 |
+
# File that stores the pickled error file
|
| 70 |
+
EXPR_ERROR_PICKLE_FILE = "error.pkl"
|
| 71 |
+
|
| 72 |
+
# File that stores the error file
|
| 73 |
+
EXPR_ERROR_FILE = "error.txt"
|
| 74 |
+
|
| 75 |
+
# File that stores the checkpoint metadata
|
| 76 |
+
CHECKPOINT_TUNE_METADATA_FILE = ".tune_metadata"
|
| 77 |
+
|
| 78 |
+
# ==================================================
|
| 79 |
+
# Environment Variables
|
| 80 |
+
# ==================================================
|
| 81 |
+
|
| 82 |
+
# Integer value which if set will copy files in reported AIR directory
|
| 83 |
+
# checkpoints instead of moving them (if worker is on the same node as Trainable)
|
| 84 |
+
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV = (
|
| 85 |
+
"TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# NOTE: When adding a new environment variable, please track it in this list.
|
| 89 |
+
# TODO(ml-team): Most env var constants should get moved here.
|
| 90 |
+
AIR_ENV_VARS = {
|
| 91 |
+
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
|
| 92 |
+
"RAY_AIR_FULL_TRACEBACKS",
|
| 93 |
+
"RAY_AIR_NEW_OUTPUT",
|
| 94 |
+
}
|
.venv/lib/python3.11/site-packages/ray/air/data_batch_type.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Dict, Union
|
| 2 |
+
|
| 3 |
+
if TYPE_CHECKING:
|
| 4 |
+
import numpy
|
| 5 |
+
import pandas # noqa: F401
|
| 6 |
+
import pyarrow
|
| 7 |
+
|
| 8 |
+
# TODO de-dup with ray.data.block.DataBatch
|
| 9 |
+
DataBatchType = Union[
|
| 10 |
+
"numpy.ndarray", "pyarrow.Table" "pandas.DataFrame", Dict[str, "numpy.ndarray"]
|
| 11 |
+
]
|
.venv/lib/python3.11/site-packages/ray/air/execution/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.execution.resources.fixed import FixedResourceManager
|
| 2 |
+
from ray.air.execution.resources.placement_group import PlacementGroupResourceManager
|
| 3 |
+
from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
|
| 4 |
+
from ray.air.execution.resources.resource_manager import ResourceManager
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ResourceRequest",
|
| 8 |
+
"AcquiredResources",
|
| 9 |
+
"ResourceManager",
|
| 10 |
+
"FixedResourceManager",
|
| 11 |
+
"PlacementGroupResourceManager",
|
| 12 |
+
]
|
.venv/lib/python3.11/site-packages/ray/air/execution/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (679 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.execution.resources.fixed import FixedResourceManager
|
| 2 |
+
from ray.air.execution.resources.placement_group import PlacementGroupResourceManager
|
| 3 |
+
from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
|
| 4 |
+
from ray.air.execution.resources.resource_manager import ResourceManager
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ResourceRequest",
|
| 8 |
+
"AcquiredResources",
|
| 9 |
+
"ResourceManager",
|
| 10 |
+
"FixedResourceManager",
|
| 11 |
+
"PlacementGroupResourceManager",
|
| 12 |
+
]
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/fixed.cpython-311.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/placement_group.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/request.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/resource_manager.cpython-311.pyc
ADDED
|
Binary file (7.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/fixed.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
from ray import LOCAL_MODE, SCRIPT_MODE
|
| 6 |
+
from ray.air.execution.resources.request import (
|
| 7 |
+
AcquiredResources,
|
| 8 |
+
RemoteRayEntity,
|
| 9 |
+
ResourceRequest,
|
| 10 |
+
)
|
| 11 |
+
from ray.air.execution.resources.resource_manager import ResourceManager
|
| 12 |
+
from ray.util.annotations import DeveloperAPI
|
| 13 |
+
|
| 14 |
+
# Avoid numerical errors by multiplying and subtracting with this number.
|
| 15 |
+
# Compare: 0.99 - 0.33 = 0.65999... vs (0.99 * 1000 - 0.33 * 1000) / 1000 = 0.66
|
| 16 |
+
_DIGITS = 100000
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@DeveloperAPI
|
| 20 |
+
@dataclass
|
| 21 |
+
class FixedAcquiredResources(AcquiredResources):
|
| 22 |
+
bundles: List[Dict[str, float]]
|
| 23 |
+
|
| 24 |
+
def _annotate_remote_entity(
|
| 25 |
+
self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
|
| 26 |
+
) -> RemoteRayEntity:
|
| 27 |
+
bundle = bundle.copy()
|
| 28 |
+
num_cpus = bundle.pop("CPU", 0)
|
| 29 |
+
num_gpus = bundle.pop("GPU", 0)
|
| 30 |
+
memory = bundle.pop("memory", 0.0)
|
| 31 |
+
|
| 32 |
+
return entity.options(
|
| 33 |
+
num_cpus=num_cpus,
|
| 34 |
+
num_gpus=num_gpus,
|
| 35 |
+
memory=memory,
|
| 36 |
+
resources=bundle,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@DeveloperAPI
|
| 41 |
+
class FixedResourceManager(ResourceManager):
|
| 42 |
+
"""Fixed budget based resource manager.
|
| 43 |
+
|
| 44 |
+
This resource manager keeps track of a fixed set of resources. When resources
|
| 45 |
+
are acquired, they are subtracted from the budget. When resources are freed,
|
| 46 |
+
they are added back to the budget.
|
| 47 |
+
|
| 48 |
+
The resource manager still requires resources to be requested before they become
|
| 49 |
+
available. However, because the resource requests are virtual, this will not
|
| 50 |
+
trigger autoscaling.
|
| 51 |
+
|
| 52 |
+
Additionally, resources are not reserved on request, only on acquisition. Thus,
|
| 53 |
+
acquiring a resource can change the availability of other requests. Note that
|
| 54 |
+
this behavior may be changed in future implementations.
|
| 55 |
+
|
| 56 |
+
The fixed resource manager does not support placement strategies. Using
|
| 57 |
+
``STRICT_SPREAD`` will result in an error. ``STRICT_PACK`` will succeed only
|
| 58 |
+
within a placement group bundle. All other placement group arguments will be
|
| 59 |
+
ignored.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
total_resources: Budget of resources to manage. Defaults to all available
|
| 63 |
+
resources in the current task or all cluster resources (if outside a task).
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
_resource_cls: AcquiredResources = FixedAcquiredResources
|
| 68 |
+
|
| 69 |
+
def __init__(self, total_resources: Optional[Dict[str, float]] = None):
|
| 70 |
+
rtc = ray.get_runtime_context()
|
| 71 |
+
|
| 72 |
+
if not total_resources:
|
| 73 |
+
if rtc.worker.mode in {None, SCRIPT_MODE, LOCAL_MODE}:
|
| 74 |
+
total_resources = ray.cluster_resources()
|
| 75 |
+
else:
|
| 76 |
+
total_resources = rtc.get_assigned_resources()
|
| 77 |
+
|
| 78 |
+
# If we are in a placement group, all of our resources will be in a bundle
|
| 79 |
+
# and thus fulfill requirements of STRICT_PACK - but only if child tasks
|
| 80 |
+
# are captured by the pg.
|
| 81 |
+
self._allow_strict_pack = (
|
| 82 |
+
ray.util.get_current_placement_group() is not None
|
| 83 |
+
and rtc.should_capture_child_tasks_in_placement_group
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self._total_resources = total_resources
|
| 87 |
+
self._requested_resources = []
|
| 88 |
+
self._used_resources = []
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def _available_resources(self) -> Dict[str, float]:
|
| 92 |
+
available_resources = self._total_resources.copy()
|
| 93 |
+
|
| 94 |
+
for used_resources in self._used_resources:
|
| 95 |
+
all_resources = used_resources.required_resources
|
| 96 |
+
for k, v in all_resources.items():
|
| 97 |
+
available_resources[k] = (
|
| 98 |
+
available_resources[k] * _DIGITS - v * _DIGITS
|
| 99 |
+
) / _DIGITS
|
| 100 |
+
return available_resources
|
| 101 |
+
|
| 102 |
+
def request_resources(self, resource_request: ResourceRequest):
|
| 103 |
+
if resource_request.strategy == "STRICT_SPREAD" or (
|
| 104 |
+
not self._allow_strict_pack and resource_request.strategy == "STRICT_PACK"
|
| 105 |
+
):
|
| 106 |
+
raise RuntimeError(
|
| 107 |
+
f"Requested a resource with placement strategy "
|
| 108 |
+
f"{resource_request.strategy}, but this cannot be fulfilled by a "
|
| 109 |
+
f"FixedResourceManager. In a nested setting, please set the inner "
|
| 110 |
+
f"placement strategy to be less restrictive (i.e. no STRICT_ strategy)."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self._requested_resources.append(resource_request)
|
| 114 |
+
|
| 115 |
+
def cancel_resource_request(self, resource_request: ResourceRequest):
|
| 116 |
+
self._requested_resources.remove(resource_request)
|
| 117 |
+
|
| 118 |
+
def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
|
| 119 |
+
if resource_request not in self._requested_resources:
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
available_resources = self._available_resources
|
| 123 |
+
all_resources = resource_request.required_resources
|
| 124 |
+
for k, v in all_resources.items():
|
| 125 |
+
if available_resources.get(k, 0.0) < v:
|
| 126 |
+
return False
|
| 127 |
+
return True
|
| 128 |
+
|
| 129 |
+
def acquire_resources(
|
| 130 |
+
self, resource_request: ResourceRequest
|
| 131 |
+
) -> Optional[AcquiredResources]:
|
| 132 |
+
if not self.has_resources_ready(resource_request):
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
self._used_resources.append(resource_request)
|
| 136 |
+
return self._resource_cls(
|
| 137 |
+
bundles=resource_request.bundles, resource_request=resource_request
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def free_resources(self, acquired_resource: AcquiredResources):
|
| 141 |
+
resources = acquired_resource.resource_request
|
| 142 |
+
self._used_resources.remove(resources)
|
| 143 |
+
|
| 144 |
+
def clear(self):
|
| 145 |
+
# Reset internal state
|
| 146 |
+
self._requested_resources = []
|
| 147 |
+
self._used_resources = []
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/placement_group.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Optional, Set
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.air.execution.resources.request import (
|
| 8 |
+
AcquiredResources,
|
| 9 |
+
RemoteRayEntity,
|
| 10 |
+
ResourceRequest,
|
| 11 |
+
)
|
| 12 |
+
from ray.air.execution.resources.resource_manager import ResourceManager
|
| 13 |
+
from ray.util.annotations import DeveloperAPI
|
| 14 |
+
from ray.util.placement_group import PlacementGroup, remove_placement_group
|
| 15 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@DeveloperAPI
|
| 19 |
+
@dataclass
|
| 20 |
+
class PlacementGroupAcquiredResources(AcquiredResources):
|
| 21 |
+
placement_group: PlacementGroup
|
| 22 |
+
|
| 23 |
+
def _annotate_remote_entity(
|
| 24 |
+
self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
|
| 25 |
+
) -> RemoteRayEntity:
|
| 26 |
+
bundle = bundle.copy()
|
| 27 |
+
num_cpus = bundle.pop("CPU", 0)
|
| 28 |
+
num_gpus = bundle.pop("GPU", 0)
|
| 29 |
+
memory = bundle.pop("memory", 0.0)
|
| 30 |
+
|
| 31 |
+
return entity.options(
|
| 32 |
+
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
| 33 |
+
placement_group=self.placement_group,
|
| 34 |
+
placement_group_bundle_index=bundle_index,
|
| 35 |
+
placement_group_capture_child_tasks=True,
|
| 36 |
+
),
|
| 37 |
+
num_cpus=num_cpus,
|
| 38 |
+
num_gpus=num_gpus,
|
| 39 |
+
memory=memory,
|
| 40 |
+
resources=bundle,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@DeveloperAPI
|
| 45 |
+
class PlacementGroupResourceManager(ResourceManager):
|
| 46 |
+
"""Resource manager using placement groups as the resource backend.
|
| 47 |
+
|
| 48 |
+
This manager will use placement groups to fulfill resource requests. Requesting
|
| 49 |
+
a resource will schedule the placement group. Acquiring a resource will
|
| 50 |
+
return a ``PlacementGroupAcquiredResources`` that can be used to schedule
|
| 51 |
+
Ray tasks and actors on the placement group. Freeing an acquired resource
|
| 52 |
+
will destroy the associated placement group.
|
| 53 |
+
|
| 54 |
+
Ray core does not emit events when resources are available. Instead, the
|
| 55 |
+
scheduling state has to be periodically updated.
|
| 56 |
+
|
| 57 |
+
Per default, placement group scheduling state is refreshed every time when
|
| 58 |
+
resource state is inquired, but not more often than once every ``update_interval_s``
|
| 59 |
+
seconds. Alternatively, staging futures can be retrieved (and awaited) with
|
| 60 |
+
``get_resource_futures()`` and state update can be force with ``update_state()``.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
update_interval_s: Minimum interval in seconds between updating scheduling
|
| 64 |
+
state of placement groups.
|
| 65 |
+
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
_resource_cls: AcquiredResources = PlacementGroupAcquiredResources
|
| 69 |
+
|
| 70 |
+
def __init__(self, update_interval_s: float = 0.1):
|
| 71 |
+
# Internally, the placement group lifecycle is like this:
|
| 72 |
+
# - Resources are requested with ``request_resources()``
|
| 73 |
+
# - A placement group is scheduled ("staged")
|
| 74 |
+
# - A ``PlacementGroup.ready()`` future is scheduled ("staging future")
|
| 75 |
+
# - We update the scheduling state when we need to
|
| 76 |
+
# (e.g. when ``has_resources_ready()`` is called)
|
| 77 |
+
# - When staging futures resolve, a placement group is moved from "staging"
|
| 78 |
+
# to "ready"
|
| 79 |
+
# - When a resource request is canceled, we remove a placement group from
|
| 80 |
+
# "staging". If there are not staged placement groups
|
| 81 |
+
# (because they are already "ready"), we remove one from "ready" instead.
|
| 82 |
+
# - When a resource is acquired, the pg is removed from "ready" and moved
|
| 83 |
+
# to "acquired"
|
| 84 |
+
# - When a resource is freed, the pg is removed from "acquired" and destroyed
|
| 85 |
+
|
| 86 |
+
# Mapping of placement group to request
|
| 87 |
+
self._pg_to_request: Dict[PlacementGroup, ResourceRequest] = {}
|
| 88 |
+
|
| 89 |
+
# PGs that are staged but not "ready", yet (i.e. not CREATED)
|
| 90 |
+
self._request_to_staged_pgs: Dict[
|
| 91 |
+
ResourceRequest, Set[PlacementGroup]
|
| 92 |
+
] = defaultdict(set)
|
| 93 |
+
|
| 94 |
+
# PGs that are CREATED and can be used by tasks and actors
|
| 95 |
+
self._request_to_ready_pgs: Dict[
|
| 96 |
+
ResourceRequest, Set[PlacementGroup]
|
| 97 |
+
] = defaultdict(set)
|
| 98 |
+
|
| 99 |
+
# Staging futures used to update internal state.
|
| 100 |
+
# We keep a double mapping here for better lookup efficiency.
|
| 101 |
+
self._staging_future_to_pg: Dict[ray.ObjectRef, PlacementGroup] = dict()
|
| 102 |
+
self._pg_to_staging_future: Dict[PlacementGroup, ray.ObjectRef] = dict()
|
| 103 |
+
|
| 104 |
+
# Set of acquired PGs. We keep track of these here to make sure we
|
| 105 |
+
# only free PGs that this manager managed.
|
| 106 |
+
self._acquired_pgs: Set[PlacementGroup] = set()
|
| 107 |
+
|
| 108 |
+
# Minimum time between updates of the internal state
|
| 109 |
+
self.update_interval_s = update_interval_s
|
| 110 |
+
self._last_update = time.monotonic() - self.update_interval_s - 1
|
| 111 |
+
|
| 112 |
+
def get_resource_futures(self) -> List[ray.ObjectRef]:
|
| 113 |
+
return list(self._staging_future_to_pg.keys())
|
| 114 |
+
|
| 115 |
+
def _maybe_update_state(self):
|
| 116 |
+
now = time.monotonic()
|
| 117 |
+
if now > self._last_update + self.update_interval_s:
|
| 118 |
+
self.update_state()
|
| 119 |
+
|
| 120 |
+
def update_state(self):
|
| 121 |
+
ready, not_ready = ray.wait(
|
| 122 |
+
list(self._staging_future_to_pg.keys()),
|
| 123 |
+
num_returns=len(self._staging_future_to_pg),
|
| 124 |
+
timeout=0,
|
| 125 |
+
)
|
| 126 |
+
for future in ready:
|
| 127 |
+
# Remove staging future
|
| 128 |
+
pg = self._staging_future_to_pg.pop(future)
|
| 129 |
+
self._pg_to_staging_future.pop(pg)
|
| 130 |
+
# Fetch resource request
|
| 131 |
+
request = self._pg_to_request[pg]
|
| 132 |
+
# Remove from staging, add to ready
|
| 133 |
+
self._request_to_staged_pgs[request].remove(pg)
|
| 134 |
+
self._request_to_ready_pgs[request].add(pg)
|
| 135 |
+
self._last_update = time.monotonic()
|
| 136 |
+
|
| 137 |
+
def request_resources(self, resource_request: ResourceRequest):
|
| 138 |
+
pg = resource_request.to_placement_group()
|
| 139 |
+
self._pg_to_request[pg] = resource_request
|
| 140 |
+
self._request_to_staged_pgs[resource_request].add(pg)
|
| 141 |
+
|
| 142 |
+
future = pg.ready()
|
| 143 |
+
self._staging_future_to_pg[future] = pg
|
| 144 |
+
self._pg_to_staging_future[pg] = future
|
| 145 |
+
|
| 146 |
+
def cancel_resource_request(self, resource_request: ResourceRequest):
|
| 147 |
+
if self._request_to_staged_pgs[resource_request]:
|
| 148 |
+
pg = self._request_to_staged_pgs[resource_request].pop()
|
| 149 |
+
|
| 150 |
+
# PG was staging
|
| 151 |
+
future = self._pg_to_staging_future.pop(pg)
|
| 152 |
+
self._staging_future_to_pg.pop(future)
|
| 153 |
+
|
| 154 |
+
# Cancel the pg.ready task.
|
| 155 |
+
# Otherwise, it will be pending node assignment forever.
|
| 156 |
+
ray.cancel(future)
|
| 157 |
+
else:
|
| 158 |
+
# PG might be ready
|
| 159 |
+
pg = self._request_to_ready_pgs[resource_request].pop()
|
| 160 |
+
if not pg:
|
| 161 |
+
raise RuntimeError(
|
| 162 |
+
"Cannot cancel resource request: No placement group was "
|
| 163 |
+
f"staged or is ready. Make sure to not cancel more resource "
|
| 164 |
+
f"requests than you've created. Request: {resource_request}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self._pg_to_request.pop(pg)
|
| 168 |
+
ray.util.remove_placement_group(pg)
|
| 169 |
+
|
| 170 |
+
def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
|
| 171 |
+
if not bool(len(self._request_to_ready_pgs[resource_request])):
|
| 172 |
+
# Only update state if needed
|
| 173 |
+
self._maybe_update_state()
|
| 174 |
+
|
| 175 |
+
return bool(len(self._request_to_ready_pgs[resource_request]))
|
| 176 |
+
|
| 177 |
+
def acquire_resources(
|
| 178 |
+
self, resource_request: ResourceRequest
|
| 179 |
+
) -> Optional[PlacementGroupAcquiredResources]:
|
| 180 |
+
if not self.has_resources_ready(resource_request):
|
| 181 |
+
return None
|
| 182 |
+
|
| 183 |
+
pg = self._request_to_ready_pgs[resource_request].pop()
|
| 184 |
+
self._acquired_pgs.add(pg)
|
| 185 |
+
|
| 186 |
+
return self._resource_cls(placement_group=pg, resource_request=resource_request)
|
| 187 |
+
|
| 188 |
+
def free_resources(self, acquired_resource: PlacementGroupAcquiredResources):
|
| 189 |
+
pg = acquired_resource.placement_group
|
| 190 |
+
|
| 191 |
+
self._acquired_pgs.remove(pg)
|
| 192 |
+
remove_placement_group(pg)
|
| 193 |
+
self._pg_to_request.pop(pg)
|
| 194 |
+
|
| 195 |
+
def clear(self):
|
| 196 |
+
if not ray.is_initialized():
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
for staged_pgs in self._request_to_staged_pgs.values():
|
| 200 |
+
for staged_pg in staged_pgs:
|
| 201 |
+
remove_placement_group(staged_pg)
|
| 202 |
+
|
| 203 |
+
for ready_pgs in self._request_to_ready_pgs.values():
|
| 204 |
+
for ready_pg in ready_pgs:
|
| 205 |
+
remove_placement_group(ready_pg)
|
| 206 |
+
|
| 207 |
+
for acquired_pg in self._acquired_pgs:
|
| 208 |
+
remove_placement_group(acquired_pg)
|
| 209 |
+
|
| 210 |
+
# Reset internal state
|
| 211 |
+
self.__init__(update_interval_s=self.update_interval_s)
|
| 212 |
+
|
| 213 |
+
def __del__(self):
|
| 214 |
+
self.clear()
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/request.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import json
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from inspect import signature
|
| 6 |
+
from typing import Dict, List, Union
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
from ray.util import placement_group
|
| 10 |
+
from ray.util.annotations import DeveloperAPI
|
| 11 |
+
|
| 12 |
+
RemoteRayEntity = Union[ray.remote_function.RemoteFunction, ray.actor.ActorClass]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _sum_bundles(bundles: List[Dict[str, float]]) -> Dict[str, float]:
|
| 16 |
+
"""Sum all resources in a list of resource bundles.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
bundles: List of resource bundles.
|
| 20 |
+
|
| 21 |
+
Returns: Dict containing all resources summed up.
|
| 22 |
+
"""
|
| 23 |
+
resources = {}
|
| 24 |
+
for bundle in bundles:
|
| 25 |
+
for k, v in bundle.items():
|
| 26 |
+
resources[k] = resources.get(k, 0) + v
|
| 27 |
+
return resources
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@DeveloperAPI
|
| 31 |
+
class ResourceRequest:
|
| 32 |
+
"""Request for resources.
|
| 33 |
+
|
| 34 |
+
This class is used to define a resource request. A resource request comprises one
|
| 35 |
+
or more bundles of resources and instructions on the scheduling behavior.
|
| 36 |
+
|
| 37 |
+
The resource request can be submitted to a resource manager, which will
|
| 38 |
+
schedule the resources. Depending on the resource backend, this may instruct
|
| 39 |
+
Ray to scale up (autoscaling).
|
| 40 |
+
|
| 41 |
+
Resource requests are compatible with the most fine-grained low-level resource
|
| 42 |
+
backend, which are Ray placement groups.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
bundles: A list of bundles which represent the resources requirements.
|
| 46 |
+
E.g. ``[{"CPU": 1, "GPU": 1}]``.
|
| 47 |
+
strategy: The scheduling strategy to acquire the bundles.
|
| 48 |
+
|
| 49 |
+
- "PACK": Packs Bundles into as few nodes as possible.
|
| 50 |
+
- "SPREAD": Places Bundles across distinct nodes as even as possible.
|
| 51 |
+
- "STRICT_PACK": Packs Bundles into one node. The group is
|
| 52 |
+
not allowed to span multiple nodes.
|
| 53 |
+
- "STRICT_SPREAD": Packs Bundles across distinct nodes.
|
| 54 |
+
*args: Passed to the call of ``placement_group()``, if applicable.
|
| 55 |
+
**kwargs: Passed to the call of ``placement_group()``, if applicable.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
bundles: List[Dict[str, Union[int, float]]],
|
| 62 |
+
strategy: str = "PACK",
|
| 63 |
+
*args,
|
| 64 |
+
**kwargs,
|
| 65 |
+
):
|
| 66 |
+
if not bundles:
|
| 67 |
+
raise ValueError("Cannot initialize a ResourceRequest with zero bundles.")
|
| 68 |
+
|
| 69 |
+
# Remove empty resource keys
|
| 70 |
+
self._bundles = [
|
| 71 |
+
{k: float(v) for k, v in bundle.items() if v != 0} for bundle in bundles
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Check if the head bundle is empty (no resources defined or all resources
|
| 75 |
+
# are 0 (and thus removed in the previous step)
|
| 76 |
+
if not self._bundles[0]:
|
| 77 |
+
# This is when the head bundle doesn't need resources.
|
| 78 |
+
self._head_bundle_is_empty = True
|
| 79 |
+
self._bundles.pop(0)
|
| 80 |
+
|
| 81 |
+
if not self._bundles:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"Cannot initialize a ResourceRequest with an empty head "
|
| 84 |
+
"and zero worker bundles."
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
self._head_bundle_is_empty = False
|
| 88 |
+
|
| 89 |
+
self._strategy = strategy
|
| 90 |
+
self._args = args
|
| 91 |
+
self._kwargs = kwargs
|
| 92 |
+
|
| 93 |
+
self._hash = None
|
| 94 |
+
self._bound = None
|
| 95 |
+
|
| 96 |
+
self._bind()
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def head_bundle_is_empty(self):
|
| 100 |
+
"""Returns True if head bundle is empty while child bundles
|
| 101 |
+
need resources.
|
| 102 |
+
|
| 103 |
+
This is considered an internal API within Tune.
|
| 104 |
+
"""
|
| 105 |
+
return self._head_bundle_is_empty
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
@DeveloperAPI
|
| 109 |
+
def head_cpus(self) -> float:
|
| 110 |
+
"""Returns the number of cpus in the head bundle."""
|
| 111 |
+
return 0.0 if self._head_bundle_is_empty else self._bundles[0].get("CPU", 0.0)
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
@DeveloperAPI
|
| 115 |
+
def bundles(self) -> List[Dict[str, float]]:
|
| 116 |
+
"""Returns a deep copy of resource bundles"""
|
| 117 |
+
return deepcopy(self._bundles)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def required_resources(self) -> Dict[str, float]:
|
| 121 |
+
"""Returns a dict containing the sums of all resources"""
|
| 122 |
+
return _sum_bundles(self._bundles)
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
@DeveloperAPI
|
| 126 |
+
def strategy(self) -> str:
|
| 127 |
+
"""Returns the placement strategy"""
|
| 128 |
+
return self._strategy
|
| 129 |
+
|
| 130 |
+
def _bind(self):
|
| 131 |
+
"""Bind the args and kwargs to the `placement_group()` signature.
|
| 132 |
+
|
| 133 |
+
We bind the args and kwargs, so we can compare equality of two resource
|
| 134 |
+
requests. The main reason for this is that the `placement_group()` API
|
| 135 |
+
can evolve independently from the ResourceRequest API (e.g. adding new
|
| 136 |
+
arguments). Then, `ResourceRequest(bundles, strategy, arg=arg)` should
|
| 137 |
+
be the same as `ResourceRequest(bundles, strategy, arg)`.
|
| 138 |
+
"""
|
| 139 |
+
sig = signature(placement_group)
|
| 140 |
+
try:
|
| 141 |
+
self._bound = sig.bind(
|
| 142 |
+
self._bundles, self._strategy, *self._args, **self._kwargs
|
| 143 |
+
)
|
| 144 |
+
except Exception as exc:
|
| 145 |
+
raise RuntimeError(
|
| 146 |
+
"Invalid definition for resource request. Please check "
|
| 147 |
+
"that you passed valid arguments to the ResourceRequest "
|
| 148 |
+
"object."
|
| 149 |
+
) from exc
|
| 150 |
+
|
| 151 |
+
def to_placement_group(self):
|
| 152 |
+
return placement_group(*self._bound.args, **self._bound.kwargs)
|
| 153 |
+
|
| 154 |
+
def __eq__(self, other: "ResourceRequest"):
|
| 155 |
+
return (
|
| 156 |
+
isinstance(other, ResourceRequest)
|
| 157 |
+
and self._bound == other._bound
|
| 158 |
+
and self.head_bundle_is_empty == other.head_bundle_is_empty
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def __hash__(self):
|
| 162 |
+
if not self._hash:
|
| 163 |
+
# Cache hash
|
| 164 |
+
self._hash = hash(
|
| 165 |
+
json.dumps(
|
| 166 |
+
{"args": self._bound.args, "kwargs": self._bound.kwargs},
|
| 167 |
+
sort_keys=True,
|
| 168 |
+
indent=0,
|
| 169 |
+
ensure_ascii=True,
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
return self._hash
|
| 173 |
+
|
| 174 |
+
def __getstate__(self):
|
| 175 |
+
state = self.__dict__.copy()
|
| 176 |
+
state.pop("_hash", None)
|
| 177 |
+
state.pop("_bound", None)
|
| 178 |
+
return state
|
| 179 |
+
|
| 180 |
+
def __setstate__(self, state):
|
| 181 |
+
self.__dict__.update(state)
|
| 182 |
+
self._hash = None
|
| 183 |
+
self._bound = None
|
| 184 |
+
self._bind()
|
| 185 |
+
|
| 186 |
+
def __repr__(self) -> str:
|
| 187 |
+
return (
|
| 188 |
+
f"<ResourceRequest (_bound={self._bound}, "
|
| 189 |
+
f"head_bundle_is_empty={self.head_bundle_is_empty})>"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@DeveloperAPI
|
| 194 |
+
@dataclass
|
| 195 |
+
class AcquiredResources(abc.ABC):
|
| 196 |
+
"""Base class for resources that have been acquired.
|
| 197 |
+
|
| 198 |
+
Acquired resources can be associated to Ray objects, which can then be
|
| 199 |
+
scheduled using these resources.
|
| 200 |
+
|
| 201 |
+
Internally this can point e.g. to a placement group, a placement
|
| 202 |
+
group bundle index, or just raw resources.
|
| 203 |
+
|
| 204 |
+
The main API is the `annotate_remote_entities` method. This will associate
|
| 205 |
+
remote Ray objects (tasks and actors) with the acquired resources by setting
|
| 206 |
+
the Ray remote options to use the acquired resources.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
resource_request: ResourceRequest
|
| 210 |
+
|
| 211 |
+
def annotate_remote_entities(
|
| 212 |
+
self, entities: List[RemoteRayEntity]
|
| 213 |
+
) -> List[Union[RemoteRayEntity]]:
|
| 214 |
+
"""Return remote ray entities (tasks/actors) to use the acquired resources.
|
| 215 |
+
|
| 216 |
+
The first entity will be associated with the first bundle, the second
|
| 217 |
+
entity will be associated with the second bundle, etc.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
entities: Remote Ray entities to annotate with the acquired resources.
|
| 221 |
+
"""
|
| 222 |
+
bundles = self.resource_request.bundles
|
| 223 |
+
|
| 224 |
+
# Also count the empty head bundle as a bundle
|
| 225 |
+
num_bundles = len(bundles) + int(self.resource_request.head_bundle_is_empty)
|
| 226 |
+
|
| 227 |
+
if len(entities) > num_bundles:
|
| 228 |
+
raise RuntimeError(
|
| 229 |
+
f"The number of callables to annotate ({len(entities)}) cannot "
|
| 230 |
+
f"exceed the number of available bundles ({num_bundles})."
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
annotated = []
|
| 234 |
+
|
| 235 |
+
if self.resource_request.head_bundle_is_empty:
|
| 236 |
+
# The empty head bundle is place on the first bundle index with empty
|
| 237 |
+
# resources.
|
| 238 |
+
annotated.append(
|
| 239 |
+
self._annotate_remote_entity(entities[0], {}, bundle_index=0)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Shift the remaining entities
|
| 243 |
+
entities = entities[1:]
|
| 244 |
+
|
| 245 |
+
for i, (entity, bundle) in enumerate(zip(entities, bundles)):
|
| 246 |
+
annotated.append(
|
| 247 |
+
self._annotate_remote_entity(entity, bundle, bundle_index=i)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return annotated
|
| 251 |
+
|
| 252 |
+
def _annotate_remote_entity(
|
| 253 |
+
self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
|
| 254 |
+
) -> RemoteRayEntity:
|
| 255 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/ray/air/execution/resources/resource_manager.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
|
| 6 |
+
from ray.util.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@DeveloperAPI
|
| 10 |
+
class ResourceManager(abc.ABC):
|
| 11 |
+
"""Resource manager interface.
|
| 12 |
+
|
| 13 |
+
A resource manager can be used to request resources from a Ray cluster and
|
| 14 |
+
allocate them to remote Ray tasks or actors.
|
| 15 |
+
|
| 16 |
+
Resources have to be requested before they can be acquired.
|
| 17 |
+
|
| 18 |
+
Resources managed by the resource manager can be in three states:
|
| 19 |
+
|
| 20 |
+
1. "Requested": The resources have been requested but are not yet available to
|
| 21 |
+
schedule remote Ray objects. The resource request may trigger autoscaling,
|
| 22 |
+
and can be cancelled if no longer needed.
|
| 23 |
+
2. "Ready": The requested resources are now available to schedule remote Ray
|
| 24 |
+
objects. They can be acquired and subsequently used remote Ray objects.
|
| 25 |
+
The resource request can still be cancelled if no longer needed.
|
| 26 |
+
3. "Acquired": The resources have been acquired by a caller to use for scheduling
|
| 27 |
+
remote Ray objects. Note that it is the responsibility of the caller to
|
| 28 |
+
schedule the Ray objects with these resources.
|
| 29 |
+
The associated resource request has been completed and can no longer be
|
| 30 |
+
cancelled. The acquired resources can be freed by the resource manager when
|
| 31 |
+
they are no longer used.
|
| 32 |
+
|
| 33 |
+
The flow is as follows:
|
| 34 |
+
|
| 35 |
+
.. code-block:: python
|
| 36 |
+
|
| 37 |
+
# Create resource manager
|
| 38 |
+
resource_manager = ResourceManager()
|
| 39 |
+
|
| 40 |
+
# Create resource request
|
| 41 |
+
resource_request = ResourceRequest([{"CPU": 4}])
|
| 42 |
+
|
| 43 |
+
# Pass to resource manager
|
| 44 |
+
resource_manager.request_resources(resource_request)
|
| 45 |
+
|
| 46 |
+
# Wait until ready
|
| 47 |
+
while not resource_manager.has_resources_ready(resource_request):
|
| 48 |
+
time.sleep(1)
|
| 49 |
+
|
| 50 |
+
# Once ready, acquire resources
|
| 51 |
+
acquired_resource = resource_manager.acquire_resources(resource_request)
|
| 52 |
+
|
| 53 |
+
# Bind to remote task or actor
|
| 54 |
+
annotated_remote_fn = acquired_resource.annotate_remote_entities(
|
| 55 |
+
[remote_fn])
|
| 56 |
+
|
| 57 |
+
# Run remote function. This will use the acquired resources
|
| 58 |
+
ray.get(annotated_remote_fn.remote())
|
| 59 |
+
|
| 60 |
+
# After using the resources, free
|
| 61 |
+
resource_manager.free_resources(annotated_resources)
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def request_resources(self, resource_request: ResourceRequest):
|
| 66 |
+
"""Request resources.
|
| 67 |
+
|
| 68 |
+
Depending on the backend, resources can trigger autoscaling. Requested
|
| 69 |
+
resources can be ready or not ready. Once they are "ready", they can
|
| 70 |
+
be acquired and used by remote Ray objects.
|
| 71 |
+
|
| 72 |
+
Resource requests can be cancelled anytime using ``cancel_resource_request()``.
|
| 73 |
+
Once acquired, the resource request is removed. Acquired resources can be
|
| 74 |
+
freed with ``free_resources()``.
|
| 75 |
+
"""
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
def cancel_resource_request(self, resource_request: ResourceRequest):
|
| 79 |
+
"""Cancel resource request.
|
| 80 |
+
|
| 81 |
+
Resource requests can be cancelled anytime before a resource is acquired.
|
| 82 |
+
Acquiring a resource will remove the associated resource request.
|
| 83 |
+
Acquired resources can be freed with ``free_resources()``.
|
| 84 |
+
"""
|
| 85 |
+
raise NotImplementedError
|
| 86 |
+
|
| 87 |
+
def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
|
| 88 |
+
"""Returns True if resources for the given request are ready to be acquired."""
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def acquire_resources(
|
| 92 |
+
self, resource_request: ResourceRequest
|
| 93 |
+
) -> Optional[AcquiredResources]:
|
| 94 |
+
"""Acquire resources. Returns None if resources are not ready to be acquired.
|
| 95 |
+
|
| 96 |
+
Acquiring resources will remove the associated resource request.
|
| 97 |
+
Acquired resources can be returned with ``free_resources()``.
|
| 98 |
+
"""
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
def free_resources(self, acquired_resource: AcquiredResources):
|
| 102 |
+
"""Free acquired resources from usage and return them to the resource manager.
|
| 103 |
+
|
| 104 |
+
Freeing resources will return the resources to the manager, but there are
|
| 105 |
+
no guarantees about the tasks and actors scheduled on the resources. The caller
|
| 106 |
+
should make sure that any references to tasks or actors scheduled on the
|
| 107 |
+
resources have been removed before calling ``free_resources()``.
|
| 108 |
+
"""
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
def get_resource_futures(self) -> List[ray.ObjectRef]:
|
| 112 |
+
"""Return futures for resources to await.
|
| 113 |
+
|
| 114 |
+
Depending on the backend, we use resource futures to determine availability
|
| 115 |
+
of resources (e.g. placement groups) or resolution of requests.
|
| 116 |
+
In this case, the futures can be awaited externally by the caller.
|
| 117 |
+
|
| 118 |
+
When a resource future resolved, the caller may call ``update_state()``
|
| 119 |
+
to force the resource manager to update its internal state immediately.
|
| 120 |
+
"""
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
def update_state(self):
|
| 124 |
+
"""Update internal state of the resource manager.
|
| 125 |
+
|
| 126 |
+
The resource manager may have internal state that needs periodic updating.
|
| 127 |
+
For instance, depending on the backend, resource futures can be awaited
|
| 128 |
+
externally (with ``get_resource_futures()``).
|
| 129 |
+
|
| 130 |
+
If such a future resolved, the caller can instruct the resource
|
| 131 |
+
manager to update its internal state immediately.
|
| 132 |
+
"""
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
def clear(self):
|
| 136 |
+
"""Reset internal state and clear all resources.
|
| 137 |
+
|
| 138 |
+
Calling this method will reset the resource manager to its initialization state.
|
| 139 |
+
All resources will be removed.
|
| 140 |
+
|
| 141 |
+
Clearing the state will remove tracked resources from the manager, but there are
|
| 142 |
+
no guarantees about the tasks and actors scheduled on the resources. The caller
|
| 143 |
+
should make sure that any references to tasks or actors scheduled on the
|
| 144 |
+
resources have been removed before calling ``clear()``.
|
| 145 |
+
"""
|
| 146 |
+
raise NotImplementedError
|
| 147 |
+
|
| 148 |
+
def __reduce__(self):
|
| 149 |
+
"""We disallow serialization.
|
| 150 |
+
|
| 151 |
+
Shared resource managers should live on an actor.
|
| 152 |
+
"""
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"Resource managers cannot be serialized. Resource manager: {str(self)}"
|
| 155 |
+
)
|
.venv/lib/python3.11/site-packages/ray/air/result.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import pyarrow
|
| 11 |
+
|
| 12 |
+
import ray
|
| 13 |
+
from ray.air.constants import (
|
| 14 |
+
EXPR_ERROR_PICKLE_FILE,
|
| 15 |
+
EXPR_PROGRESS_FILE,
|
| 16 |
+
EXPR_RESULT_FILE,
|
| 17 |
+
)
|
| 18 |
+
from ray.util.annotations import PublicAPI
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from ray.train import Checkpoint
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@PublicAPI(stability="stable")
|
| 27 |
+
@dataclass
|
| 28 |
+
class Result:
|
| 29 |
+
"""The final result of a ML training run or a Tune trial.
|
| 30 |
+
|
| 31 |
+
This is the output produced by ``Trainer.fit``.
|
| 32 |
+
``Tuner.fit`` outputs a :class:`~ray.tune.ResultGrid` that is a collection
|
| 33 |
+
of ``Result`` objects.
|
| 34 |
+
|
| 35 |
+
This API is the recommended way to access the outputs such as:
|
| 36 |
+
- checkpoints (``Result.checkpoint``)
|
| 37 |
+
- the history of reported metrics (``Result.metrics_dataframe``, ``Result.metrics``)
|
| 38 |
+
- errors encountered during a training run (``Result.error``)
|
| 39 |
+
|
| 40 |
+
The constructor is a private API -- use ``Result.from_path`` to create a result
|
| 41 |
+
object from a directory.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
metrics: The latest set of reported metrics.
|
| 45 |
+
checkpoint: The latest checkpoint.
|
| 46 |
+
error: The execution error of the Trainable run, if the trial finishes in error.
|
| 47 |
+
path: Path pointing to the result directory on persistent storage. This can
|
| 48 |
+
point to a remote storage location (e.g. S3) or to a local location (path
|
| 49 |
+
on the head node). The path is accessible via the result's associated
|
| 50 |
+
`filesystem`. For instance, for a result stored in S3 at
|
| 51 |
+
``s3://bucket/location``, ``path`` will have the value ``bucket/location``.
|
| 52 |
+
metrics_dataframe: The full result dataframe of the Trainable.
|
| 53 |
+
The dataframe is indexed by iterations and contains reported
|
| 54 |
+
metrics. Note that the dataframe columns are indexed with the
|
| 55 |
+
*flattened* keys of reported metrics, so the format of this dataframe
|
| 56 |
+
may be slightly different than ``Result.metrics``, which is an unflattened
|
| 57 |
+
dict of the latest set of reported metrics.
|
| 58 |
+
best_checkpoints: A list of tuples of the best checkpoints and
|
| 59 |
+
their associated metrics. The number of
|
| 60 |
+
saved checkpoints is determined by :class:`~ray.train.CheckpointConfig`
|
| 61 |
+
(by default, all checkpoints will be saved).
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
metrics: Optional[Dict[str, Any]]
|
| 65 |
+
checkpoint: Optional["Checkpoint"]
|
| 66 |
+
error: Optional[Exception]
|
| 67 |
+
path: str
|
| 68 |
+
metrics_dataframe: Optional["pd.DataFrame"] = None
|
| 69 |
+
best_checkpoints: Optional[List[Tuple["Checkpoint", Dict[str, Any]]]] = None
|
| 70 |
+
_storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
|
| 71 |
+
_items_to_repr = ["error", "metrics", "path", "filesystem", "checkpoint"]
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def config(self) -> Optional[Dict[str, Any]]:
|
| 75 |
+
"""The config associated with the result."""
|
| 76 |
+
if not self.metrics:
|
| 77 |
+
return None
|
| 78 |
+
return self.metrics.get("config", None)
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def filesystem(self) -> pyarrow.fs.FileSystem:
|
| 82 |
+
"""Return the filesystem that can be used to access the result path.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
pyarrow.fs.FileSystem implementation.
|
| 86 |
+
"""
|
| 87 |
+
return self._storage_filesystem or pyarrow.fs.LocalFileSystem()
|
| 88 |
+
|
| 89 |
+
def _repr(self, indent: int = 0) -> str:
|
| 90 |
+
"""Construct the representation with specified number of space indent."""
|
| 91 |
+
from ray.tune.experimental.output import BLACKLISTED_KEYS
|
| 92 |
+
from ray.tune.result import AUTO_RESULT_KEYS
|
| 93 |
+
|
| 94 |
+
shown_attributes = {k: getattr(self, k) for k in self._items_to_repr}
|
| 95 |
+
if self.error:
|
| 96 |
+
shown_attributes["error"] = type(self.error).__name__
|
| 97 |
+
else:
|
| 98 |
+
shown_attributes.pop("error")
|
| 99 |
+
|
| 100 |
+
shown_attributes["filesystem"] = shown_attributes["filesystem"].type_name
|
| 101 |
+
|
| 102 |
+
if self.metrics:
|
| 103 |
+
exclude = set(AUTO_RESULT_KEYS)
|
| 104 |
+
exclude.update(BLACKLISTED_KEYS)
|
| 105 |
+
shown_attributes["metrics"] = {
|
| 106 |
+
k: v for k, v in self.metrics.items() if k not in exclude
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
cls_indent = " " * indent
|
| 110 |
+
kws_indent = " " * (indent + 2)
|
| 111 |
+
|
| 112 |
+
kws = [
|
| 113 |
+
f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items()
|
| 114 |
+
]
|
| 115 |
+
kws_repr = ",\n".join(kws)
|
| 116 |
+
return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr)
|
| 117 |
+
|
| 118 |
+
def __repr__(self) -> str:
|
| 119 |
+
return self._repr(indent=0)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def _read_file_as_str(
|
| 123 |
+
storage_filesystem: pyarrow.fs.FileSystem,
|
| 124 |
+
storage_path: str,
|
| 125 |
+
) -> str:
|
| 126 |
+
"""Opens a file as an input stream reading all byte content sequentially and
|
| 127 |
+
decoding read bytes as utf-8 string.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
storage_filesystem: The filesystem to use.
|
| 131 |
+
storage_path: The source to open for reading.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
with storage_filesystem.open_input_stream(storage_path) as f:
|
| 135 |
+
return f.readall().decode()
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def from_path(
|
| 139 |
+
cls,
|
| 140 |
+
path: Union[str, os.PathLike],
|
| 141 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 142 |
+
) -> "Result":
|
| 143 |
+
"""Restore a Result object from local or remote trial directory.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
path: A path of a trial directory on local or remote storage
|
| 147 |
+
(ex: s3://bucket/path or /tmp/ray_results).
|
| 148 |
+
storage_filesystem: A custom filesystem to use. If not provided,
|
| 149 |
+
this will be auto-resolved by pyarrow. If provided, the path
|
| 150 |
+
is assumed to be prefix-stripped already, and must be a valid path
|
| 151 |
+
on the filesystem.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
A :py:class:`Result` object of that trial.
|
| 155 |
+
"""
|
| 156 |
+
# TODO(justinvyu): Fix circular dependency.
|
| 157 |
+
from ray.train import Checkpoint
|
| 158 |
+
from ray.train._internal.storage import (
|
| 159 |
+
_exists_at_fs_path,
|
| 160 |
+
_list_at_fs_path,
|
| 161 |
+
get_fs_and_path,
|
| 162 |
+
)
|
| 163 |
+
from ray.train.constants import CHECKPOINT_DIR_NAME
|
| 164 |
+
|
| 165 |
+
fs, fs_path = get_fs_and_path(path, storage_filesystem)
|
| 166 |
+
if not _exists_at_fs_path(fs, fs_path):
|
| 167 |
+
raise RuntimeError(f"Trial folder {fs_path} doesn't exist!")
|
| 168 |
+
|
| 169 |
+
# Restore metrics from result.json
|
| 170 |
+
result_json_file = Path(fs_path, EXPR_RESULT_FILE).as_posix()
|
| 171 |
+
progress_csv_file = Path(fs_path, EXPR_PROGRESS_FILE).as_posix()
|
| 172 |
+
if _exists_at_fs_path(fs, result_json_file):
|
| 173 |
+
lines = cls._read_file_as_str(fs, result_json_file).split("\n")
|
| 174 |
+
json_list = [json.loads(line) for line in lines if line]
|
| 175 |
+
metrics_df = pd.json_normalize(json_list, sep="/")
|
| 176 |
+
latest_metrics = json_list[-1] if json_list else {}
|
| 177 |
+
# Fallback to restore from progress.csv
|
| 178 |
+
elif _exists_at_fs_path(fs, progress_csv_file):
|
| 179 |
+
metrics_df = pd.read_csv(
|
| 180 |
+
io.StringIO(cls._read_file_as_str(fs, progress_csv_file))
|
| 181 |
+
)
|
| 182 |
+
latest_metrics = (
|
| 183 |
+
metrics_df.iloc[-1].to_dict() if not metrics_df.empty else {}
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
raise RuntimeError(
|
| 187 |
+
f"Failed to restore the Result object: Neither {EXPR_RESULT_FILE}"
|
| 188 |
+
f" nor {EXPR_PROGRESS_FILE} exists in the trial folder!"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Restore all checkpoints from the checkpoint folders
|
| 192 |
+
checkpoint_dir_names = sorted(
|
| 193 |
+
_list_at_fs_path(
|
| 194 |
+
fs,
|
| 195 |
+
fs_path,
|
| 196 |
+
file_filter=lambda file_info: file_info.type
|
| 197 |
+
== pyarrow.fs.FileType.Directory
|
| 198 |
+
and file_info.base_name.startswith("checkpoint_"),
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if checkpoint_dir_names:
|
| 203 |
+
checkpoints = [
|
| 204 |
+
Checkpoint(
|
| 205 |
+
path=Path(fs_path, checkpoint_dir_name).as_posix(), filesystem=fs
|
| 206 |
+
)
|
| 207 |
+
for checkpoint_dir_name in checkpoint_dir_names
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
metrics = []
|
| 211 |
+
for checkpoint_dir_name in checkpoint_dir_names:
|
| 212 |
+
metrics_corresponding_to_checkpoint = metrics_df[
|
| 213 |
+
metrics_df[CHECKPOINT_DIR_NAME] == checkpoint_dir_name
|
| 214 |
+
]
|
| 215 |
+
if metrics_corresponding_to_checkpoint.empty:
|
| 216 |
+
logger.warning(
|
| 217 |
+
"Could not find metrics corresponding to "
|
| 218 |
+
f"{checkpoint_dir_name}. These will default to an empty dict."
|
| 219 |
+
)
|
| 220 |
+
metrics.append(
|
| 221 |
+
{}
|
| 222 |
+
if metrics_corresponding_to_checkpoint.empty
|
| 223 |
+
else metrics_corresponding_to_checkpoint.iloc[-1].to_dict()
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
latest_checkpoint = checkpoints[-1]
|
| 227 |
+
# TODO(justinvyu): These are ordered by checkpoint index, since we don't
|
| 228 |
+
# know the metric to order these with.
|
| 229 |
+
best_checkpoints = list(zip(checkpoints, metrics))
|
| 230 |
+
else:
|
| 231 |
+
best_checkpoints = latest_checkpoint = None
|
| 232 |
+
|
| 233 |
+
# Restore the trial error if it exists
|
| 234 |
+
error = None
|
| 235 |
+
error_file_path = Path(fs_path, EXPR_ERROR_PICKLE_FILE).as_posix()
|
| 236 |
+
if _exists_at_fs_path(fs, error_file_path):
|
| 237 |
+
with fs.open_input_stream(error_file_path) as f:
|
| 238 |
+
error = ray.cloudpickle.load(f)
|
| 239 |
+
|
| 240 |
+
return Result(
|
| 241 |
+
metrics=latest_metrics,
|
| 242 |
+
checkpoint=latest_checkpoint,
|
| 243 |
+
path=fs_path,
|
| 244 |
+
_storage_filesystem=fs,
|
| 245 |
+
metrics_dataframe=metrics_df,
|
| 246 |
+
best_checkpoints=best_checkpoints,
|
| 247 |
+
error=error,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
@PublicAPI(stability="alpha")
|
| 251 |
+
def get_best_checkpoint(self, metric: str, mode: str) -> Optional["Checkpoint"]:
|
| 252 |
+
"""Get the best checkpoint from this trial based on a specific metric.
|
| 253 |
+
|
| 254 |
+
Any checkpoints without an associated metric value will be filtered out.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
metric: The key for checkpoints to order on.
|
| 258 |
+
mode: One of ["min", "max"].
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
:class:`Checkpoint <ray.train.Checkpoint>` object, or None if there is
|
| 262 |
+
no valid checkpoint associated with the metric.
|
| 263 |
+
"""
|
| 264 |
+
if not self.best_checkpoints:
|
| 265 |
+
raise RuntimeError("No checkpoint exists in the trial directory!")
|
| 266 |
+
|
| 267 |
+
if mode not in ["max", "min"]:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
f'Unsupported mode: {mode}. Please choose from ["min", "max"]!'
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
op = max if mode == "max" else min
|
| 273 |
+
valid_checkpoints = [
|
| 274 |
+
ckpt_info for ckpt_info in self.best_checkpoints if metric in ckpt_info[1]
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
if not valid_checkpoints:
|
| 278 |
+
raise RuntimeError(
|
| 279 |
+
f"Invalid metric name {metric}! "
|
| 280 |
+
f"You may choose from the following metrics: {self.metrics.keys()}."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return op(valid_checkpoints, key=lambda x: x[1][metric])[0]
|
.venv/lib/python3.11/site-packages/ray/air/session.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from ray.train._internal.session import * # noqa: F401,F403
|
.venv/lib/python3.11/site-packages/ray/air/util/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/air/util/tensor_extensions/pandas.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from
|
| 2 |
+
# https://github.com/CODAIT/text-extensions-for-pandas/blob/dc03278689fe1c5f131573658ae19815ba25f33e/text_extensions_for_pandas/array/tensor.py
|
| 3 |
+
# and
|
| 4 |
+
# https://github.com/CODAIT/text-extensions-for-pandas/blob/dc03278689fe1c5f131573658ae19815ba25f33e/text_extensions_for_pandas/array/arrow_conversion.py
|
| 5 |
+
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) 2020 IBM Corp.
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
# Modifications:
|
| 21 |
+
# - Added ArrowTensorType.to_pandas_type()
|
| 22 |
+
# - Added ArrowTensorArray.__getitem__()
|
| 23 |
+
# - Added ArrowTensorArray.__iter__()
|
| 24 |
+
# - Added support for column casts to extension types.
|
| 25 |
+
# - Fleshed out docstrings and examples.
|
| 26 |
+
# - Fixed TensorArray.isna() so it returns an appropriate ExtensionArray.
|
| 27 |
+
# - Added different (more vectorized) TensorArray.take() operation.
|
| 28 |
+
# - Added support for more reducers (agg funcs) to TensorArray.
|
| 29 |
+
# - Added support for logical operators to TensorArray(Element).
|
| 30 |
+
# - Added support for heterogeneously-shaped tensors.
|
| 31 |
+
# - Miscellaneous small bug fixes and optimizations.
|
| 32 |
+
|
| 33 |
+
import numbers
|
| 34 |
+
import os
|
| 35 |
+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import pandas as pd
|
| 39 |
+
import pyarrow as pa
|
| 40 |
+
from packaging.version import Version
|
| 41 |
+
from pandas._typing import Dtype
|
| 42 |
+
from pandas.compat import set_function_name
|
| 43 |
+
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
|
| 44 |
+
from pandas.core.indexers import check_array_indexer, validate_indices
|
| 45 |
+
|
| 46 |
+
from ray.air.util.tensor_extensions.utils import (
|
| 47 |
+
_create_possibly_ragged_ndarray,
|
| 48 |
+
_is_ndarray_variable_shaped_tensor,
|
| 49 |
+
)
|
| 50 |
+
from ray.util.annotations import PublicAPI
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from pandas.core.dtypes.generic import ABCIndex
|
| 54 |
+
except ImportError:
|
| 55 |
+
# ABCIndexClass changed to ABCIndex in Pandas 1.3
|
| 56 |
+
from pandas.core.dtypes.generic import ABCIndexClass as ABCIndex
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#############################################
|
| 60 |
+
# Begin patching of ExtensionArrayFormatter #
|
| 61 |
+
#############################################
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _format_strings_patched(self) -> List[str]:
|
| 65 |
+
from pandas.core.construction import extract_array
|
| 66 |
+
from pandas.io.formats.format import format_array
|
| 67 |
+
|
| 68 |
+
if not isinstance(self.values, TensorArray):
|
| 69 |
+
return self._format_strings_orig()
|
| 70 |
+
|
| 71 |
+
values = extract_array(self.values, extract_numpy=True)
|
| 72 |
+
array = np.asarray(values)
|
| 73 |
+
|
| 74 |
+
if array.ndim == 1:
|
| 75 |
+
return self._format_strings_orig()
|
| 76 |
+
|
| 77 |
+
def format_array_wrap(array_, formatter_):
|
| 78 |
+
fmt_values = format_array(
|
| 79 |
+
array_,
|
| 80 |
+
formatter_,
|
| 81 |
+
float_format=self.float_format,
|
| 82 |
+
na_rep=self.na_rep,
|
| 83 |
+
digits=self.digits,
|
| 84 |
+
space=self.space,
|
| 85 |
+
justify=self.justify,
|
| 86 |
+
decimal=self.decimal,
|
| 87 |
+
leading_space=self.leading_space,
|
| 88 |
+
quoting=self.quoting,
|
| 89 |
+
)
|
| 90 |
+
return fmt_values
|
| 91 |
+
|
| 92 |
+
flat_formatter = self.formatter
|
| 93 |
+
if flat_formatter is None:
|
| 94 |
+
flat_formatter = values._formatter(boxed=True)
|
| 95 |
+
|
| 96 |
+
# Flatten array, call function, reshape (use ravel_compat in v1.3.0)
|
| 97 |
+
flat_array = array.ravel("K")
|
| 98 |
+
fmt_flat_array = np.asarray(format_array_wrap(flat_array, flat_formatter))
|
| 99 |
+
order = "F" if array.flags.f_contiguous else "C"
|
| 100 |
+
fmt_array = fmt_flat_array.reshape(array.shape, order=order)
|
| 101 |
+
|
| 102 |
+
# Format the array of nested strings, use default formatter
|
| 103 |
+
return format_array_wrap(fmt_array, None)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _format_strings_patched_v1_0_0(self) -> List[str]:
|
| 107 |
+
from functools import partial
|
| 108 |
+
|
| 109 |
+
from pandas.core.construction import extract_array
|
| 110 |
+
from pandas.io.formats.format import format_array
|
| 111 |
+
from pandas.io.formats.printing import pprint_thing
|
| 112 |
+
|
| 113 |
+
if not isinstance(self.values, TensorArray):
|
| 114 |
+
return self._format_strings_orig()
|
| 115 |
+
|
| 116 |
+
values = extract_array(self.values, extract_numpy=True)
|
| 117 |
+
array = np.asarray(values)
|
| 118 |
+
|
| 119 |
+
if array.ndim == 1:
|
| 120 |
+
return self._format_strings_orig()
|
| 121 |
+
|
| 122 |
+
def format_array_wrap(array_, formatter_):
|
| 123 |
+
fmt_values = format_array(
|
| 124 |
+
array_,
|
| 125 |
+
formatter_,
|
| 126 |
+
float_format=self.float_format,
|
| 127 |
+
na_rep=self.na_rep,
|
| 128 |
+
digits=self.digits,
|
| 129 |
+
space=self.space,
|
| 130 |
+
justify=self.justify,
|
| 131 |
+
decimal=self.decimal,
|
| 132 |
+
leading_space=self.leading_space,
|
| 133 |
+
)
|
| 134 |
+
return fmt_values
|
| 135 |
+
|
| 136 |
+
flat_formatter = self.formatter
|
| 137 |
+
if flat_formatter is None:
|
| 138 |
+
flat_formatter = values._formatter(boxed=True)
|
| 139 |
+
|
| 140 |
+
# Flatten array, call function, reshape (use ravel_compat in v1.3.0)
|
| 141 |
+
flat_array = array.ravel("K")
|
| 142 |
+
fmt_flat_array = np.asarray(format_array_wrap(flat_array, flat_formatter))
|
| 143 |
+
order = "F" if array.flags.f_contiguous else "C"
|
| 144 |
+
fmt_array = fmt_flat_array.reshape(array.shape, order=order)
|
| 145 |
+
|
| 146 |
+
# Slimmed down version of GenericArrayFormatter due to:
|
| 147 |
+
# https://github.com/pandas-dev/pandas/issues/33770
|
| 148 |
+
def format_strings_slim(array_, leading_space):
|
| 149 |
+
formatter = partial(
|
| 150 |
+
pprint_thing,
|
| 151 |
+
escape_chars=("\t", "\r", "\n"),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def _format(x):
|
| 155 |
+
return str(formatter(x))
|
| 156 |
+
|
| 157 |
+
fmt_values = []
|
| 158 |
+
for v in array_:
|
| 159 |
+
tpl = "{v}" if leading_space is False else " {v}"
|
| 160 |
+
fmt_values.append(tpl.format(v=_format(v)))
|
| 161 |
+
return fmt_values
|
| 162 |
+
|
| 163 |
+
return format_strings_slim(fmt_array, self.leading_space)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
_FORMATTER_ENABLED_ENV_VAR = "TENSOR_COLUMN_EXTENSION_FORMATTER_ENABLED"
|
| 167 |
+
|
| 168 |
+
if os.getenv(_FORMATTER_ENABLED_ENV_VAR, "1") == "1":
|
| 169 |
+
if Version(pd.__version__) < Version("2.2.0"):
|
| 170 |
+
from pandas.io.formats.format import ExtensionArrayFormatter
|
| 171 |
+
|
| 172 |
+
formatter_cls = ExtensionArrayFormatter
|
| 173 |
+
else:
|
| 174 |
+
from pandas.io.formats.format import _ExtensionArrayFormatter
|
| 175 |
+
|
| 176 |
+
formatter_cls = _ExtensionArrayFormatter
|
| 177 |
+
formatter_cls._format_strings_orig = formatter_cls._format_strings
|
| 178 |
+
if Version("1.1.0") <= Version(pd.__version__) < Version("1.3.0"):
|
| 179 |
+
formatter_cls._format_strings = _format_strings_patched
|
| 180 |
+
else:
|
| 181 |
+
formatter_cls._format_strings = _format_strings_patched_v1_0_0
|
| 182 |
+
formatter_cls._patched_by_ray_datasets = True
|
| 183 |
+
|
| 184 |
+
###########################################
|
| 185 |
+
# End patching of ExtensionArrayFormatter #
|
| 186 |
+
###########################################
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@PublicAPI(stability="beta")
|
| 190 |
+
@pd.api.extensions.register_extension_dtype
|
| 191 |
+
class TensorDtype(pd.api.extensions.ExtensionDtype):
|
| 192 |
+
"""
|
| 193 |
+
Pandas extension type for a column of homogeneous-typed tensors.
|
| 194 |
+
|
| 195 |
+
This extension supports tensors in which the elements have different shapes.
|
| 196 |
+
However, each tensor element must be non-ragged, i.e. each tensor element must have
|
| 197 |
+
a well-defined, non-ragged shape.
|
| 198 |
+
|
| 199 |
+
See:
|
| 200 |
+
https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py
|
| 201 |
+
for up-to-date interface documentation and the subclassing contract. The
|
| 202 |
+
docstrings of the below properties and methods were copied from the base
|
| 203 |
+
ExtensionDtype.
|
| 204 |
+
|
| 205 |
+
Examples:
|
| 206 |
+
>>> # Create a DataFrame with a list of ndarrays as a column.
|
| 207 |
+
>>> import pandas as pd
|
| 208 |
+
>>> import numpy as np
|
| 209 |
+
>>> import ray
|
| 210 |
+
>>> df = pd.DataFrame({
|
| 211 |
+
... "one": [1, 2, 3],
|
| 212 |
+
... "two": list(np.arange(24).reshape((3, 2, 2, 2)))})
|
| 213 |
+
>>> # Note the opaque np.object dtype for this column.
|
| 214 |
+
>>> df.dtypes # doctest: +SKIP
|
| 215 |
+
one int64
|
| 216 |
+
two object
|
| 217 |
+
dtype: object
|
| 218 |
+
>>> # Cast column to our TensorDtype extension type.
|
| 219 |
+
>>> from ray.data.extensions import TensorDtype
|
| 220 |
+
>>> df["two"] = df["two"].astype(TensorDtype(np.int64, (3, 2, 2, 2)))
|
| 221 |
+
>>> # Note that the column dtype is now TensorDtype instead of
|
| 222 |
+
>>> # np.object.
|
| 223 |
+
>>> df.dtypes # doctest: +SKIP
|
| 224 |
+
one int64
|
| 225 |
+
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 226 |
+
dtype: object
|
| 227 |
+
>>> # Pandas is now aware of this tensor column, and we can do the
|
| 228 |
+
>>> # typical DataFrame operations on this column.
|
| 229 |
+
>>> col = 2 * (df["two"] + 10)
|
| 230 |
+
>>> # The ndarrays underlying the tensor column will be manipulated,
|
| 231 |
+
>>> # but the column itself will continue to be a Pandas type.
|
| 232 |
+
>>> type(col) # doctest: +SKIP
|
| 233 |
+
pandas.core.series.Series
|
| 234 |
+
>>> col # doctest: +SKIP
|
| 235 |
+
0 [[[ 2 4]
|
| 236 |
+
[ 6 8]]
|
| 237 |
+
[[10 12]
|
| 238 |
+
[14 16]]]
|
| 239 |
+
1 [[[18 20]
|
| 240 |
+
[22 24]]
|
| 241 |
+
[[26 28]
|
| 242 |
+
[30 32]]]
|
| 243 |
+
2 [[[34 36]
|
| 244 |
+
[38 40]]
|
| 245 |
+
[[42 44]
|
| 246 |
+
[46 48]]]
|
| 247 |
+
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 248 |
+
>>> # Once you do an aggregation on that column that returns a single
|
| 249 |
+
>>> # row's value, you get back our TensorArrayElement type.
|
| 250 |
+
>>> tensor = col.mean()
|
| 251 |
+
>>> type(tensor) # doctest: +SKIP
|
| 252 |
+
ray.data.extensions.tensor_extension.TensorArrayElement
|
| 253 |
+
>>> tensor # doctest: +SKIP
|
| 254 |
+
array([[[18., 20.],
|
| 255 |
+
[22., 24.]],
|
| 256 |
+
[[26., 28.],
|
| 257 |
+
[30., 32.]]])
|
| 258 |
+
>>> # This is a light wrapper around a NumPy ndarray, and can easily
|
| 259 |
+
>>> # be converted to an ndarray.
|
| 260 |
+
>>> type(tensor.to_numpy()) # doctest: +SKIP
|
| 261 |
+
numpy.ndarray
|
| 262 |
+
>>> # In addition to doing Pandas operations on the tensor column,
|
| 263 |
+
>>> # you can now put the DataFrame into a Dataset.
|
| 264 |
+
>>> ds = ray.data.from_pandas(df) # doctest: +SKIP
|
| 265 |
+
>>> # Internally, this column is represented the corresponding
|
| 266 |
+
>>> # Arrow tensor extension type.
|
| 267 |
+
>>> ds.schema() # doctest: +SKIP
|
| 268 |
+
one: int64
|
| 269 |
+
two: extension<arrow.py_extension_type<ArrowTensorType>>
|
| 270 |
+
>>> # You can write the dataset to Parquet.
|
| 271 |
+
>>> ds.write_parquet("/some/path") # doctest: +SKIP
|
| 272 |
+
>>> # And you can read it back.
|
| 273 |
+
>>> read_ds = ray.data.read_parquet("/some/path") # doctest: +SKIP
|
| 274 |
+
>>> read_ds.schema() # doctest: +SKIP
|
| 275 |
+
one: int64
|
| 276 |
+
two: extension<arrow.py_extension_type<ArrowTensorType>>
|
| 277 |
+
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
|
| 278 |
+
>>> read_df.dtypes # doctest: +SKIP
|
| 279 |
+
one int64
|
| 280 |
+
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 281 |
+
dtype: object
|
| 282 |
+
>>> # The tensor extension type is preserved along the
|
| 283 |
+
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
|
| 284 |
+
>>> # conversion chain.
|
| 285 |
+
>>> read_df.equals(df) # doctest: +SKIP
|
| 286 |
+
True
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
# NOTE(Clark): This is apparently required to prevent integer indexing
|
| 290 |
+
# errors, but is an undocumented ExtensionDtype attribute. See issue:
|
| 291 |
+
# https://github.com/CODAIT/text-extensions-for-pandas/issues/166
|
| 292 |
+
base = None
|
| 293 |
+
|
| 294 |
+
def __init__(self, shape: Tuple[Optional[int], ...], dtype: np.dtype):
|
| 295 |
+
self._shape = shape
|
| 296 |
+
self._dtype = dtype
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def type(self):
|
| 300 |
+
"""
|
| 301 |
+
The scalar type for the array, e.g. ``int``
|
| 302 |
+
It's expected ``ExtensionArray[item]`` returns an instance
|
| 303 |
+
of ``ExtensionDtype.type`` for scalar ``item``, assuming
|
| 304 |
+
that value is valid (not NA). NA values do not need to be
|
| 305 |
+
instances of `type`.
|
| 306 |
+
"""
|
| 307 |
+
return TensorArrayElement
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
def element_dtype(self):
|
| 311 |
+
"""
|
| 312 |
+
The dtype of the underlying tensor elements.
|
| 313 |
+
"""
|
| 314 |
+
return self._dtype
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def element_shape(self):
|
| 318 |
+
"""
|
| 319 |
+
The shape of the underlying tensor elements. This will be a tuple of Nones if
|
| 320 |
+
the corresponding TensorArray for this TensorDtype holds variable-shaped tensor
|
| 321 |
+
elements.
|
| 322 |
+
"""
|
| 323 |
+
return self._shape
|
| 324 |
+
|
| 325 |
+
@property
|
| 326 |
+
def is_variable_shaped(self):
|
| 327 |
+
"""
|
| 328 |
+
Whether the corresponding TensorArray for this TensorDtype holds variable-shaped
|
| 329 |
+
tensor elements.
|
| 330 |
+
"""
|
| 331 |
+
return all(dim_size is None for dim_size in self.shape)
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def name(self) -> str:
|
| 335 |
+
"""
|
| 336 |
+
A string identifying the data type.
|
| 337 |
+
Will be used for display in, e.g. ``Series.dtype``
|
| 338 |
+
"""
|
| 339 |
+
return f"numpy.ndarray(shape={self._shape}, dtype={self._dtype})"
|
| 340 |
+
|
| 341 |
+
@classmethod
|
| 342 |
+
def construct_from_string(cls, string: str):
|
| 343 |
+
r"""
|
| 344 |
+
Construct this type from a string.
|
| 345 |
+
|
| 346 |
+
This is useful mainly for data types that accept parameters.
|
| 347 |
+
For example, a period dtype accepts a frequency parameter that
|
| 348 |
+
can be set as ``period[H]`` (where H means hourly frequency).
|
| 349 |
+
|
| 350 |
+
By default, in the abstract class, just the name of the type is
|
| 351 |
+
expected. But subclasses can overwrite this method to accept
|
| 352 |
+
parameters.
|
| 353 |
+
|
| 354 |
+
Parameters
|
| 355 |
+
----------
|
| 356 |
+
string : str
|
| 357 |
+
The name of the type, for example ``category``.
|
| 358 |
+
|
| 359 |
+
Returns
|
| 360 |
+
-------
|
| 361 |
+
ExtensionDtype
|
| 362 |
+
Instance of the dtype.
|
| 363 |
+
|
| 364 |
+
Raises
|
| 365 |
+
------
|
| 366 |
+
TypeError
|
| 367 |
+
If a class cannot be constructed from this 'string'.
|
| 368 |
+
|
| 369 |
+
Examples
|
| 370 |
+
--------
|
| 371 |
+
For extension dtypes with arguments the following may be an
|
| 372 |
+
adequate implementation.
|
| 373 |
+
|
| 374 |
+
>>> import re
|
| 375 |
+
>>> @classmethod
|
| 376 |
+
... def construct_from_string(cls, string):
|
| 377 |
+
... pattern = re.compile(r"^my_type\[(?P<arg_name>.+)\]$")
|
| 378 |
+
... match = pattern.match(string)
|
| 379 |
+
... if match:
|
| 380 |
+
... return cls(**match.groupdict())
|
| 381 |
+
... else:
|
| 382 |
+
... raise TypeError(
|
| 383 |
+
... f"Cannot construct a '{cls.__name__}' from '{string}'"
|
| 384 |
+
... )
|
| 385 |
+
"""
|
| 386 |
+
import ast
|
| 387 |
+
import re
|
| 388 |
+
|
| 389 |
+
if not isinstance(string, str):
|
| 390 |
+
raise TypeError(
|
| 391 |
+
f"'construct_from_string' expects a string, got {type(string)}"
|
| 392 |
+
)
|
| 393 |
+
# Upstream code uses exceptions as part of its normal control flow and
|
| 394 |
+
# will pass this method bogus class names.
|
| 395 |
+
regex = (
|
| 396 |
+
r"^(TensorDtype|numpy.ndarray)"
|
| 397 |
+
r"\(shape=(\((?:(?:\d+|None),?\s?)*\)), dtype=(\w+)\)$"
|
| 398 |
+
)
|
| 399 |
+
m = re.search(regex, string)
|
| 400 |
+
err_msg = (
|
| 401 |
+
f"Cannot construct a '{cls.__name__}' from '{string}'; expected a string "
|
| 402 |
+
"like 'TensorDtype(shape=(1, 2, 3), dtype=int64)'."
|
| 403 |
+
)
|
| 404 |
+
if m is None:
|
| 405 |
+
raise TypeError(err_msg)
|
| 406 |
+
groups = m.groups()
|
| 407 |
+
if len(groups) != 3:
|
| 408 |
+
raise TypeError(err_msg)
|
| 409 |
+
_, shape, dtype = groups
|
| 410 |
+
shape = ast.literal_eval(shape)
|
| 411 |
+
dtype = np.dtype(dtype)
|
| 412 |
+
return cls(shape, dtype)
|
| 413 |
+
|
| 414 |
+
@classmethod
|
| 415 |
+
def construct_array_type(cls):
|
| 416 |
+
"""
|
| 417 |
+
Return the array type associated with this dtype.
|
| 418 |
+
|
| 419 |
+
Returns
|
| 420 |
+
-------
|
| 421 |
+
type
|
| 422 |
+
"""
|
| 423 |
+
return TensorArray
|
| 424 |
+
|
| 425 |
+
def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]):
|
| 426 |
+
"""
|
| 427 |
+
Convert a pyarrow (chunked) array to a TensorArray.
|
| 428 |
+
|
| 429 |
+
This and TensorArray.__arrow_array__ make up the
|
| 430 |
+
Pandas extension type + array <--> Arrow extension type + array
|
| 431 |
+
interoperability protocol. See
|
| 432 |
+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
|
| 433 |
+
for more information.
|
| 434 |
+
"""
|
| 435 |
+
if isinstance(array, pa.ChunkedArray):
|
| 436 |
+
if array.num_chunks > 1:
|
| 437 |
+
# TODO(Clark): Remove concat and construct from list with
|
| 438 |
+
# shape.
|
| 439 |
+
values = np.concatenate(
|
| 440 |
+
[chunk.to_numpy() for chunk in array.iterchunks()]
|
| 441 |
+
)
|
| 442 |
+
else:
|
| 443 |
+
values = array.chunk(0).to_numpy()
|
| 444 |
+
else:
|
| 445 |
+
values = array.to_numpy()
|
| 446 |
+
|
| 447 |
+
return TensorArray(values)
|
| 448 |
+
|
| 449 |
+
def __str__(self) -> str:
|
| 450 |
+
return self.name
|
| 451 |
+
|
| 452 |
+
def __repr__(self) -> str:
|
| 453 |
+
return str(self)
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def _is_boolean(self):
|
| 457 |
+
"""
|
| 458 |
+
Whether this extension array should be considered boolean.
|
| 459 |
+
|
| 460 |
+
By default, ExtensionArrays are assumed to be non-numeric.
|
| 461 |
+
Setting this to True will affect the behavior of several places,
|
| 462 |
+
e.g.
|
| 463 |
+
|
| 464 |
+
* is_bool
|
| 465 |
+
* boolean indexing
|
| 466 |
+
|
| 467 |
+
Returns
|
| 468 |
+
-------
|
| 469 |
+
bool
|
| 470 |
+
"""
|
| 471 |
+
# This is needed to support returning a TensorArray from .isnan().
|
| 472 |
+
from pandas.core.dtypes.common import is_bool_dtype
|
| 473 |
+
|
| 474 |
+
return is_bool_dtype(self._dtype)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class _TensorOpsMixin(pd.api.extensions.ExtensionScalarOpsMixin):
|
| 478 |
+
"""
|
| 479 |
+
Mixin for TensorArray operator support, applying operations on the
|
| 480 |
+
underlying ndarrays.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
@classmethod
|
| 484 |
+
def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
|
| 485 |
+
"""
|
| 486 |
+
Add support for binary operators by unwrapping, applying, and
|
| 487 |
+
rewrapping.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
# NOTE(Clark): This overrides, but coerce_to_dtype, result_dtype might
|
| 491 |
+
# not be needed
|
| 492 |
+
|
| 493 |
+
def _binop(self, other):
|
| 494 |
+
lvalues = self._tensor
|
| 495 |
+
|
| 496 |
+
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndex)):
|
| 497 |
+
# Rely on Pandas to unbox and dispatch to us.
|
| 498 |
+
return NotImplemented
|
| 499 |
+
|
| 500 |
+
# divmod returns a tuple
|
| 501 |
+
if op_name in ["__divmod__", "__rdivmod__"]:
|
| 502 |
+
# TODO(Clark): Add support for divmod and rdivmod.
|
| 503 |
+
# div, mod = result
|
| 504 |
+
raise NotImplementedError
|
| 505 |
+
|
| 506 |
+
if isinstance(other, (TensorArray, TensorArrayElement)):
|
| 507 |
+
rvalues = other._tensor
|
| 508 |
+
else:
|
| 509 |
+
rvalues = other
|
| 510 |
+
|
| 511 |
+
result = op(lvalues, rvalues)
|
| 512 |
+
|
| 513 |
+
# Force a TensorArray if rvalue is not a scalar.
|
| 514 |
+
if isinstance(self, TensorArrayElement) and (
|
| 515 |
+
not isinstance(other, TensorArrayElement) or not np.isscalar(other)
|
| 516 |
+
):
|
| 517 |
+
result_wrapped = TensorArray(result)
|
| 518 |
+
else:
|
| 519 |
+
result_wrapped = cls(result)
|
| 520 |
+
|
| 521 |
+
return result_wrapped
|
| 522 |
+
|
| 523 |
+
op_name = f"__{op.__name__}__"
|
| 524 |
+
return set_function_name(_binop, op_name, cls)
|
| 525 |
+
|
| 526 |
+
@classmethod
|
| 527 |
+
def _create_logical_method(cls, op):
|
| 528 |
+
return cls._create_method(op)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class _TensorScalarCastMixin:
|
| 532 |
+
"""
|
| 533 |
+
Mixin for casting scalar tensors to a particular numeric type.
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
def _scalarfunc(self, func: Callable[[Any], Any]):
|
| 537 |
+
return func(self._tensor)
|
| 538 |
+
|
| 539 |
+
def __complex__(self):
|
| 540 |
+
return self._scalarfunc(complex)
|
| 541 |
+
|
| 542 |
+
def __float__(self):
|
| 543 |
+
return self._scalarfunc(float)
|
| 544 |
+
|
| 545 |
+
def __int__(self):
|
| 546 |
+
return self._scalarfunc(int)
|
| 547 |
+
|
| 548 |
+
def __hex__(self):
|
| 549 |
+
return self._scalarfunc(hex)
|
| 550 |
+
|
| 551 |
+
def __oct__(self):
|
| 552 |
+
return self._scalarfunc(oct)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
@PublicAPI(stability="beta")
|
| 556 |
+
class TensorArrayElement(_TensorOpsMixin, _TensorScalarCastMixin):
|
| 557 |
+
"""
|
| 558 |
+
Single element of a TensorArray, wrapping an underlying ndarray.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def __init__(self, values: np.ndarray):
|
| 562 |
+
"""
|
| 563 |
+
Construct a TensorArrayElement from a NumPy ndarray.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
values: ndarray that underlies this TensorArray element.
|
| 567 |
+
"""
|
| 568 |
+
self._tensor = values
|
| 569 |
+
|
| 570 |
+
def __repr__(self):
|
| 571 |
+
return self._tensor.__repr__()
|
| 572 |
+
|
| 573 |
+
def __str__(self):
|
| 574 |
+
return self._tensor.__str__()
|
| 575 |
+
|
| 576 |
+
@property
|
| 577 |
+
def numpy_dtype(self):
|
| 578 |
+
"""
|
| 579 |
+
Get the dtype of the tensor.
|
| 580 |
+
:return: The numpy dtype of the backing ndarray
|
| 581 |
+
"""
|
| 582 |
+
return self._tensor.dtype
|
| 583 |
+
|
| 584 |
+
@property
|
| 585 |
+
def numpy_ndim(self):
|
| 586 |
+
"""
|
| 587 |
+
Get the number of tensor dimensions.
|
| 588 |
+
:return: integer for the number of dimensions
|
| 589 |
+
"""
|
| 590 |
+
return self._tensor.ndim
|
| 591 |
+
|
| 592 |
+
@property
|
| 593 |
+
def numpy_shape(self):
|
| 594 |
+
"""
|
| 595 |
+
Get the shape of the tensor.
|
| 596 |
+
:return: A tuple of integers for the numpy shape of the backing ndarray
|
| 597 |
+
"""
|
| 598 |
+
return self._tensor.shape
|
| 599 |
+
|
| 600 |
+
@property
|
| 601 |
+
def numpy_size(self):
|
| 602 |
+
"""
|
| 603 |
+
Get the size of the tensor.
|
| 604 |
+
:return: integer for the number of elements in the tensor
|
| 605 |
+
"""
|
| 606 |
+
return self._tensor.size
|
| 607 |
+
|
| 608 |
+
def to_numpy(self):
|
| 609 |
+
"""
|
| 610 |
+
Return the values of this element as a NumPy ndarray.
|
| 611 |
+
"""
|
| 612 |
+
return np.asarray(self._tensor)
|
| 613 |
+
|
| 614 |
+
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
|
| 615 |
+
return np.asarray(self._tensor, dtype=dtype, **kwargs)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@PublicAPI(stability="beta")
|
| 619 |
+
class TensorArray(
|
| 620 |
+
pd.api.extensions.ExtensionArray,
|
| 621 |
+
_TensorOpsMixin,
|
| 622 |
+
_TensorScalarCastMixin,
|
| 623 |
+
):
|
| 624 |
+
"""
|
| 625 |
+
Pandas `ExtensionArray` representing a tensor column, i.e. a column
|
| 626 |
+
consisting of ndarrays as elements.
|
| 627 |
+
|
| 628 |
+
This extension supports tensors in which the elements have different shapes.
|
| 629 |
+
However, each tensor element must be non-ragged, i.e. each tensor element must have
|
| 630 |
+
a well-defined, non-ragged shape.
|
| 631 |
+
|
| 632 |
+
Examples:
|
| 633 |
+
>>> # Create a DataFrame with a list of ndarrays as a column.
|
| 634 |
+
>>> import pandas as pd
|
| 635 |
+
>>> import numpy as np
|
| 636 |
+
>>> import ray
|
| 637 |
+
>>> from ray.data.extensions import TensorArray
|
| 638 |
+
>>> df = pd.DataFrame({
|
| 639 |
+
... "one": [1, 2, 3],
|
| 640 |
+
... "two": TensorArray(np.arange(24).reshape((3, 2, 2, 2)))})
|
| 641 |
+
>>> # Note that the column dtype is TensorDtype.
|
| 642 |
+
>>> df.dtypes # doctest: +SKIP
|
| 643 |
+
one int64
|
| 644 |
+
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 645 |
+
dtype: object
|
| 646 |
+
>>> # Pandas is aware of this tensor column, and we can do the
|
| 647 |
+
>>> # typical DataFrame operations on this column.
|
| 648 |
+
>>> col = 2 * (df["two"] + 10)
|
| 649 |
+
>>> # The ndarrays underlying the tensor column will be manipulated,
|
| 650 |
+
>>> # but the column itself will continue to be a Pandas type.
|
| 651 |
+
>>> type(col) # doctest: +SKIP
|
| 652 |
+
pandas.core.series.Series
|
| 653 |
+
>>> col # doctest: +SKIP
|
| 654 |
+
0 [[[ 2 4]
|
| 655 |
+
[ 6 8]]
|
| 656 |
+
[[10 12]
|
| 657 |
+
[14 16]]]
|
| 658 |
+
1 [[[18 20]
|
| 659 |
+
[22 24]]
|
| 660 |
+
[[26 28]
|
| 661 |
+
[30 32]]]
|
| 662 |
+
2 [[[34 36]
|
| 663 |
+
[38 40]]
|
| 664 |
+
[[42 44]
|
| 665 |
+
[46 48]]]
|
| 666 |
+
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 667 |
+
>>> # Once you do an aggregation on that column that returns a single
|
| 668 |
+
>>> # row's value, you get back our TensorArrayElement type.
|
| 669 |
+
>>> tensor = col.mean() # doctest: +SKIP
|
| 670 |
+
>>> type(tensor) # doctest: +SKIP
|
| 671 |
+
ray.data.extensions.tensor_extension.TensorArrayElement
|
| 672 |
+
>>> tensor # doctest: +SKIP
|
| 673 |
+
array([[[18., 20.],
|
| 674 |
+
[22., 24.]],
|
| 675 |
+
[[26., 28.],
|
| 676 |
+
[30., 32.]]])
|
| 677 |
+
>>> # This is a light wrapper around a NumPy ndarray, and can easily
|
| 678 |
+
>>> # be converted to an ndarray.
|
| 679 |
+
>>> type(tensor.to_numpy()) # doctest: +SKIP
|
| 680 |
+
numpy.ndarray
|
| 681 |
+
>>> # In addition to doing Pandas operations on the tensor column,
|
| 682 |
+
>>> # you can now put the DataFrame into a Dataset.
|
| 683 |
+
>>> ds = ray.data.from_pandas(df) # doctest: +SKIP
|
| 684 |
+
>>> # Internally, this column is represented the corresponding
|
| 685 |
+
>>> # Arrow tensor extension type.
|
| 686 |
+
>>> ds.schema() # doctest: +SKIP
|
| 687 |
+
one: int64
|
| 688 |
+
two: extension<arrow.py_extension_type<ArrowTensorType>>
|
| 689 |
+
>>> # You can write the dataset to Parquet.
|
| 690 |
+
>>> ds.write_parquet("/some/path") # doctest: +SKIP
|
| 691 |
+
>>> # And you can read it back.
|
| 692 |
+
>>> read_ds = ray.data.read_parquet("/some/path") # doctest: +SKIP
|
| 693 |
+
>>> read_ds.schema() # doctest: +SKIP
|
| 694 |
+
one: int64
|
| 695 |
+
two: extension<arrow.py_extension_type<ArrowTensorType>>
|
| 696 |
+
|
| 697 |
+
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
|
| 698 |
+
>>> read_df.dtypes # doctest: +SKIP
|
| 699 |
+
one int64
|
| 700 |
+
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
|
| 701 |
+
dtype: object
|
| 702 |
+
>>> # The tensor extension type is preserved along the
|
| 703 |
+
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
|
| 704 |
+
>>> # conversion chain.
|
| 705 |
+
>>> read_df.equals(df) # doctest: +SKIP
|
| 706 |
+
True
|
| 707 |
+
"""
|
| 708 |
+
|
| 709 |
+
SUPPORTED_REDUCERS = {
|
| 710 |
+
"sum": np.sum,
|
| 711 |
+
"all": np.all,
|
| 712 |
+
"any": np.any,
|
| 713 |
+
"min": np.min,
|
| 714 |
+
"max": np.max,
|
| 715 |
+
"mean": np.mean,
|
| 716 |
+
"median": np.median,
|
| 717 |
+
"prod": np.prod,
|
| 718 |
+
"std": np.std,
|
| 719 |
+
"var": np.var,
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
# See https://github.com/pandas-dev/pandas/blob/master/pandas/core/arrays/base.py
|
| 723 |
+
# for interface documentation and the subclassing contract.
|
| 724 |
+
def __init__(
|
| 725 |
+
self,
|
| 726 |
+
values: Union[
|
| 727 |
+
np.ndarray,
|
| 728 |
+
ABCSeries,
|
| 729 |
+
Sequence[Union[np.ndarray, TensorArrayElement]],
|
| 730 |
+
TensorArrayElement,
|
| 731 |
+
Any,
|
| 732 |
+
],
|
| 733 |
+
):
|
| 734 |
+
"""
|
| 735 |
+
Args:
|
| 736 |
+
values: A NumPy ndarray or sequence of NumPy ndarrays of equal
|
| 737 |
+
shape.
|
| 738 |
+
"""
|
| 739 |
+
# Try to convert some well-known objects to ndarrays before handing off to
|
| 740 |
+
# ndarray handling logic.
|
| 741 |
+
if isinstance(values, ABCSeries):
|
| 742 |
+
values = _create_possibly_ragged_ndarray(values)
|
| 743 |
+
elif isinstance(values, Sequence):
|
| 744 |
+
values = [
|
| 745 |
+
np.asarray(v) if isinstance(v, TensorArrayElement) else v
|
| 746 |
+
for v in values
|
| 747 |
+
]
|
| 748 |
+
values = _create_possibly_ragged_ndarray(values)
|
| 749 |
+
elif isinstance(values, TensorArrayElement):
|
| 750 |
+
values = np.array([np.asarray(values)], copy=False)
|
| 751 |
+
|
| 752 |
+
if isinstance(values, np.ndarray):
|
| 753 |
+
if values.dtype.type is np.object_:
|
| 754 |
+
if len(values) == 0:
|
| 755 |
+
# Tensor is empty, pass through to create empty TensorArray.
|
| 756 |
+
pass
|
| 757 |
+
elif all(
|
| 758 |
+
isinstance(v, (np.ndarray, TensorArrayElement, Sequence))
|
| 759 |
+
and not isinstance(v, str)
|
| 760 |
+
for v in values
|
| 761 |
+
):
|
| 762 |
+
values = [np.asarray(v) for v in values]
|
| 763 |
+
# Try to convert ndarrays of ndarrays/TensorArrayElements with an
|
| 764 |
+
# opaque object type to a properly typed ndarray of ndarrays.
|
| 765 |
+
values = _create_possibly_ragged_ndarray(values)
|
| 766 |
+
else:
|
| 767 |
+
raise TypeError(
|
| 768 |
+
"Expected a well-typed ndarray or an object-typed ndarray of "
|
| 769 |
+
"ndarray pointers, but got an object-typed ndarray whose "
|
| 770 |
+
f"subndarrays are of type {type(values[0])}."
|
| 771 |
+
)
|
| 772 |
+
elif isinstance(values, TensorArray):
|
| 773 |
+
raise TypeError("Use the copy() method to create a copy of a TensorArray.")
|
| 774 |
+
else:
|
| 775 |
+
raise TypeError(
|
| 776 |
+
"Expected a numpy.ndarray or sequence of numpy.ndarray, "
|
| 777 |
+
f"but received {values} of type {type(values).__name__} instead."
|
| 778 |
+
)
|
| 779 |
+
assert isinstance(values, np.ndarray)
|
| 780 |
+
self._tensor = values
|
| 781 |
+
self._is_variable_shaped = None
|
| 782 |
+
|
| 783 |
+
@classmethod
|
| 784 |
+
def _from_sequence(
|
| 785 |
+
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
|
| 786 |
+
):
|
| 787 |
+
"""
|
| 788 |
+
Construct a new ExtensionArray from a sequence of scalars.
|
| 789 |
+
|
| 790 |
+
Parameters
|
| 791 |
+
----------
|
| 792 |
+
scalars : Sequence
|
| 793 |
+
Each element will be an instance of the scalar type for this
|
| 794 |
+
array, ``cls.dtype.type`` or be converted into this type in this
|
| 795 |
+
method.
|
| 796 |
+
dtype : dtype, optional
|
| 797 |
+
Construct for this particular dtype. This should be a Dtype
|
| 798 |
+
compatible with the ExtensionArray.
|
| 799 |
+
copy : bool, default False
|
| 800 |
+
If True, copy the underlying data.
|
| 801 |
+
|
| 802 |
+
Returns
|
| 803 |
+
-------
|
| 804 |
+
ExtensionArray
|
| 805 |
+
"""
|
| 806 |
+
if copy and isinstance(scalars, np.ndarray):
|
| 807 |
+
scalars = scalars.copy()
|
| 808 |
+
elif isinstance(scalars, TensorArray):
|
| 809 |
+
scalars = scalars._tensor.copy() if copy else scalars._tensor
|
| 810 |
+
return TensorArray(scalars)
|
| 811 |
+
|
| 812 |
+
@classmethod
|
| 813 |
+
def _from_factorized(
|
| 814 |
+
cls, values: np.ndarray, original: pd.api.extensions.ExtensionArray
|
| 815 |
+
):
|
| 816 |
+
"""
|
| 817 |
+
Reconstruct an ExtensionArray after factorization.
|
| 818 |
+
|
| 819 |
+
Parameters
|
| 820 |
+
----------
|
| 821 |
+
values : ndarray
|
| 822 |
+
An integer ndarray with the factorized values.
|
| 823 |
+
original : ExtensionArray
|
| 824 |
+
The original ExtensionArray that factorize was called on.
|
| 825 |
+
|
| 826 |
+
See Also
|
| 827 |
+
--------
|
| 828 |
+
factorize : Top-level factorize method that dispatches here.
|
| 829 |
+
ExtensionArray.factorize : Encode the extension array as an enumerated
|
| 830 |
+
type.
|
| 831 |
+
"""
|
| 832 |
+
raise NotImplementedError
|
| 833 |
+
|
| 834 |
+
def __getitem__(
|
| 835 |
+
self, item: Union[int, slice, np.ndarray]
|
| 836 |
+
) -> Union["TensorArray", "TensorArrayElement"]:
|
| 837 |
+
"""
|
| 838 |
+
Select a subset of self.
|
| 839 |
+
|
| 840 |
+
Parameters
|
| 841 |
+
----------
|
| 842 |
+
item : int, slice, or ndarray
|
| 843 |
+
* int: The position in 'self' to get.
|
| 844 |
+
* slice: A slice object, where 'start', 'stop', and 'step' are
|
| 845 |
+
integers or None
|
| 846 |
+
* ndarray: A 1-d boolean NumPy ndarray the same length as 'self'
|
| 847 |
+
|
| 848 |
+
Returns
|
| 849 |
+
-------
|
| 850 |
+
item : scalar or ExtensionArray
|
| 851 |
+
|
| 852 |
+
Notes
|
| 853 |
+
-----
|
| 854 |
+
For scalar ``item``, return a scalar value suitable for the array's
|
| 855 |
+
type. This should be an instance of ``self.dtype.type``.
|
| 856 |
+
For slice ``key``, return an instance of ``ExtensionArray``, even
|
| 857 |
+
if the slice is length 0 or 1.
|
| 858 |
+
For a boolean mask, return an instance of ``ExtensionArray``, filtered
|
| 859 |
+
to the values where ``item`` is True.
|
| 860 |
+
"""
|
| 861 |
+
# Return scalar if single value is selected, a TensorArrayElement for
|
| 862 |
+
# single array element, or TensorArray for slice.
|
| 863 |
+
if isinstance(item, int):
|
| 864 |
+
value = self._tensor[item]
|
| 865 |
+
if np.isscalar(value):
|
| 866 |
+
return value
|
| 867 |
+
else:
|
| 868 |
+
return TensorArrayElement(value)
|
| 869 |
+
else:
|
| 870 |
+
# BEGIN workaround for Pandas issue #42430
|
| 871 |
+
if isinstance(item, tuple) and len(item) > 1 and item[0] == Ellipsis:
|
| 872 |
+
if len(item) > 2:
|
| 873 |
+
# Hopefully this case is not possible, but can't be sure
|
| 874 |
+
raise ValueError(
|
| 875 |
+
"Workaround Pandas issue #42430 not "
|
| 876 |
+
"implemented for tuple length > 2"
|
| 877 |
+
)
|
| 878 |
+
item = item[1]
|
| 879 |
+
# END workaround for issue #42430
|
| 880 |
+
if isinstance(item, TensorArray):
|
| 881 |
+
item = np.asarray(item)
|
| 882 |
+
item = check_array_indexer(self, item)
|
| 883 |
+
return TensorArray(self._tensor[item])
|
| 884 |
+
|
| 885 |
+
def __len__(self) -> int:
|
| 886 |
+
"""
|
| 887 |
+
Length of this array.
|
| 888 |
+
|
| 889 |
+
Returns
|
| 890 |
+
-------
|
| 891 |
+
length : int
|
| 892 |
+
"""
|
| 893 |
+
return len(self._tensor)
|
| 894 |
+
|
| 895 |
+
@property
|
| 896 |
+
def dtype(self) -> pd.api.extensions.ExtensionDtype:
|
| 897 |
+
"""
|
| 898 |
+
An instance of 'ExtensionDtype'.
|
| 899 |
+
"""
|
| 900 |
+
if self.is_variable_shaped:
|
| 901 |
+
# A tensor is only considered variable-shaped if it's non-empty, so no
|
| 902 |
+
# non-empty check is needed here.
|
| 903 |
+
dtype = self._tensor[0].dtype
|
| 904 |
+
shape = (None,) * self._tensor[0].ndim
|
| 905 |
+
else:
|
| 906 |
+
dtype = self.numpy_dtype
|
| 907 |
+
shape = self.numpy_shape[1:]
|
| 908 |
+
return TensorDtype(shape, dtype)
|
| 909 |
+
|
| 910 |
+
@property
|
| 911 |
+
def is_variable_shaped(self):
|
| 912 |
+
"""
|
| 913 |
+
Whether this TensorArray holds variable-shaped tensor elements.
|
| 914 |
+
"""
|
| 915 |
+
if self._is_variable_shaped is None:
|
| 916 |
+
self._is_variable_shaped = _is_ndarray_variable_shaped_tensor(self._tensor)
|
| 917 |
+
return self._is_variable_shaped
|
| 918 |
+
|
| 919 |
+
@property
|
| 920 |
+
def nbytes(self) -> int:
|
| 921 |
+
"""
|
| 922 |
+
The number of bytes needed to store this object in memory.
|
| 923 |
+
"""
|
| 924 |
+
return self._tensor.nbytes
|
| 925 |
+
|
| 926 |
+
def isna(self) -> "TensorArray":
|
| 927 |
+
"""
|
| 928 |
+
A 1-D array indicating if each value is missing.
|
| 929 |
+
|
| 930 |
+
Returns
|
| 931 |
+
-------
|
| 932 |
+
na_values : Union[np.ndarray, ExtensionArray]
|
| 933 |
+
In most cases, this should return a NumPy ndarray. For
|
| 934 |
+
exceptional cases like ``SparseArray``, where returning
|
| 935 |
+
an ndarray would be expensive, an ExtensionArray may be
|
| 936 |
+
returned.
|
| 937 |
+
|
| 938 |
+
Notes
|
| 939 |
+
-----
|
| 940 |
+
If returning an ExtensionArray, then
|
| 941 |
+
|
| 942 |
+
* ``na_values._is_boolean`` should be True
|
| 943 |
+
* `na_values` should implement :func:`ExtensionArray._reduce`
|
| 944 |
+
* ``na_values.any`` and ``na_values.all`` should be implemented
|
| 945 |
+
"""
|
| 946 |
+
if self._tensor.dtype.type is np.object_:
|
| 947 |
+
# Avoid comparing with __eq__ because the elements of the tensor
|
| 948 |
+
# may do something funny with that operation.
|
| 949 |
+
return np.array(
|
| 950 |
+
[self._tensor[i] is None for i in range(len(self))], dtype=bool
|
| 951 |
+
)
|
| 952 |
+
elif self._tensor.dtype.type is np.str_:
|
| 953 |
+
return np.all(self._tensor == "", axis=tuple(range(1, self._tensor.ndim)))
|
| 954 |
+
else:
|
| 955 |
+
return np.all(
|
| 956 |
+
np.isnan(self._tensor), axis=tuple(range(1, self._tensor.ndim))
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
def take(
|
| 960 |
+
self, indices: Sequence[int], allow_fill: bool = False, fill_value: Any = None
|
| 961 |
+
) -> "TensorArray":
|
| 962 |
+
"""
|
| 963 |
+
Take elements from an array.
|
| 964 |
+
|
| 965 |
+
Parameters
|
| 966 |
+
----------
|
| 967 |
+
indices : sequence of int
|
| 968 |
+
Indices to be taken.
|
| 969 |
+
allow_fill : bool, default False
|
| 970 |
+
How to handle negative values in `indices`.
|
| 971 |
+
|
| 972 |
+
* False: negative values in `indices` indicate positional indices
|
| 973 |
+
from the right (the default). This is similar to
|
| 974 |
+
:func:`numpy.take`.
|
| 975 |
+
|
| 976 |
+
* True: negative values in `indices` indicate
|
| 977 |
+
missing values. These values are set to `fill_value`. Any other
|
| 978 |
+
other negative values raise a ``ValueError``.
|
| 979 |
+
|
| 980 |
+
fill_value : any, optional
|
| 981 |
+
Fill value to use for NA-indices when `allow_fill` is True.
|
| 982 |
+
This may be ``None``, in which case the default NA value for
|
| 983 |
+
the type, ``self.dtype.na_value``, is used.
|
| 984 |
+
|
| 985 |
+
For many ExtensionArrays, there will be two representations of
|
| 986 |
+
`fill_value`: a user-facing "boxed" scalar, and a low-level
|
| 987 |
+
physical NA value. `fill_value` should be the user-facing version,
|
| 988 |
+
and the implementation should handle translating that to the
|
| 989 |
+
physical version for processing the take if necessary.
|
| 990 |
+
|
| 991 |
+
Returns
|
| 992 |
+
-------
|
| 993 |
+
ExtensionArray
|
| 994 |
+
|
| 995 |
+
Raises
|
| 996 |
+
------
|
| 997 |
+
IndexError
|
| 998 |
+
When the indices are out of bounds for the array.
|
| 999 |
+
ValueError
|
| 1000 |
+
When `indices` contains negative values other than ``-1``
|
| 1001 |
+
and `allow_fill` is True.
|
| 1002 |
+
|
| 1003 |
+
See Also
|
| 1004 |
+
--------
|
| 1005 |
+
numpy.take : Take elements from an array along an axis.
|
| 1006 |
+
api.extensions.take : Take elements from an array.
|
| 1007 |
+
|
| 1008 |
+
Notes
|
| 1009 |
+
-----
|
| 1010 |
+
ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
|
| 1011 |
+
``iloc``, when `indices` is a sequence of values. Additionally,
|
| 1012 |
+
it's called by :meth:`Series.reindex`, or any other method
|
| 1013 |
+
that causes realignment, with a `fill_value`.
|
| 1014 |
+
|
| 1015 |
+
Examples
|
| 1016 |
+
--------
|
| 1017 |
+
Here's an example implementation, which relies on casting the
|
| 1018 |
+
extension array to object dtype. This uses the helper method
|
| 1019 |
+
:func:`pandas.api.extensions.take`.
|
| 1020 |
+
|
| 1021 |
+
.. code-block:: python
|
| 1022 |
+
|
| 1023 |
+
def take(self, indices, allow_fill=False, fill_value=None):
|
| 1024 |
+
from pandas.core.algorithms import take
|
| 1025 |
+
|
| 1026 |
+
# If the ExtensionArray is backed by an ndarray, then
|
| 1027 |
+
# just pass that here instead of coercing to object.
|
| 1028 |
+
data = self.astype(object)
|
| 1029 |
+
|
| 1030 |
+
if allow_fill and fill_value is None:
|
| 1031 |
+
fill_value = self.dtype.na_value
|
| 1032 |
+
|
| 1033 |
+
# fill value should always be translated from the scalar
|
| 1034 |
+
# type for the array, to the physical storage type for
|
| 1035 |
+
# the data, before passing to take.
|
| 1036 |
+
|
| 1037 |
+
result = take(data, indices, fill_value=fill_value,
|
| 1038 |
+
allow_fill=allow_fill)
|
| 1039 |
+
return self._from_sequence(result, dtype=self.dtype)
|
| 1040 |
+
"""
|
| 1041 |
+
if allow_fill:
|
| 1042 |
+
# With allow_fill being True, negative values in `indices` indicate
|
| 1043 |
+
# missing values and should be set to `fill_value`.
|
| 1044 |
+
indices = np.asarray(indices, dtype=np.intp)
|
| 1045 |
+
validate_indices(indices, len(self._tensor))
|
| 1046 |
+
|
| 1047 |
+
# Check if there are missing indices to fill, otherwise we can
|
| 1048 |
+
# delegate to NumPy ndarray .take().
|
| 1049 |
+
has_missing = np.any(indices < 0)
|
| 1050 |
+
if has_missing:
|
| 1051 |
+
if fill_value is None:
|
| 1052 |
+
fill_value = np.nan
|
| 1053 |
+
|
| 1054 |
+
# Create an array populated with fill value.
|
| 1055 |
+
values = np.full((len(indices),) + self._tensor.shape[1:], fill_value)
|
| 1056 |
+
|
| 1057 |
+
# Put tensors at the given positive indices into array.
|
| 1058 |
+
is_nonneg = indices >= 0
|
| 1059 |
+
np.put(values, np.where(is_nonneg)[0], self._tensor[indices[is_nonneg]])
|
| 1060 |
+
|
| 1061 |
+
return TensorArray(values)
|
| 1062 |
+
|
| 1063 |
+
# Delegate take to NumPy array.
|
| 1064 |
+
values = self._tensor.take(indices, axis=0)
|
| 1065 |
+
|
| 1066 |
+
return TensorArray(values)
|
| 1067 |
+
|
| 1068 |
+
def copy(self) -> "TensorArray":
|
| 1069 |
+
"""
|
| 1070 |
+
Return a copy of the array.
|
| 1071 |
+
|
| 1072 |
+
Returns
|
| 1073 |
+
-------
|
| 1074 |
+
ExtensionArray
|
| 1075 |
+
"""
|
| 1076 |
+
# TODO(Clark): Copy cached properties.
|
| 1077 |
+
return TensorArray(self._tensor.copy())
|
| 1078 |
+
|
| 1079 |
+
@classmethod
|
| 1080 |
+
def _concat_same_type(cls, to_concat: Sequence["TensorArray"]) -> "TensorArray":
|
| 1081 |
+
"""
|
| 1082 |
+
Concatenate multiple array of this dtype.
|
| 1083 |
+
|
| 1084 |
+
Parameters
|
| 1085 |
+
----------
|
| 1086 |
+
to_concat : sequence of this type
|
| 1087 |
+
|
| 1088 |
+
Returns
|
| 1089 |
+
-------
|
| 1090 |
+
ExtensionArray
|
| 1091 |
+
"""
|
| 1092 |
+
should_flatten = False
|
| 1093 |
+
shape = None
|
| 1094 |
+
for a in to_concat:
|
| 1095 |
+
if shape is None:
|
| 1096 |
+
shape = a.dtype.element_shape
|
| 1097 |
+
if a.is_variable_shaped or a.dtype.element_shape != shape:
|
| 1098 |
+
should_flatten = True
|
| 1099 |
+
break
|
| 1100 |
+
if should_flatten:
|
| 1101 |
+
concated = TensorArray(
|
| 1102 |
+
np.array([e for a in to_concat for e in a._tensor], dtype=object)
|
| 1103 |
+
)
|
| 1104 |
+
else:
|
| 1105 |
+
concated = TensorArray(np.concatenate([a._tensor for a in to_concat]))
|
| 1106 |
+
return concated
|
| 1107 |
+
|
| 1108 |
+
def __setitem__(self, key: Union[int, np.ndarray], value: Any) -> None:
|
| 1109 |
+
"""
|
| 1110 |
+
Set one or more values inplace.
|
| 1111 |
+
|
| 1112 |
+
This method is not required to satisfy the pandas extension array
|
| 1113 |
+
interface.
|
| 1114 |
+
|
| 1115 |
+
Parameters
|
| 1116 |
+
----------
|
| 1117 |
+
key : int, ndarray, or slice
|
| 1118 |
+
When called from, e.g. ``Series.__setitem__``, ``key`` will be
|
| 1119 |
+
one of
|
| 1120 |
+
|
| 1121 |
+
* scalar int
|
| 1122 |
+
* ndarray of integers.
|
| 1123 |
+
* boolean ndarray
|
| 1124 |
+
* slice object
|
| 1125 |
+
|
| 1126 |
+
value : ExtensionDtype.type, Sequence[ExtensionDtype.type], or object
|
| 1127 |
+
value or values to be set of ``key``.
|
| 1128 |
+
|
| 1129 |
+
Returns
|
| 1130 |
+
-------
|
| 1131 |
+
None
|
| 1132 |
+
"""
|
| 1133 |
+
key = check_array_indexer(self, key)
|
| 1134 |
+
if isinstance(value, TensorArrayElement) or np.isscalar(value):
|
| 1135 |
+
value = np.asarray(value)
|
| 1136 |
+
if isinstance(value, list):
|
| 1137 |
+
value = [
|
| 1138 |
+
np.asarray(v) if isinstance(v, TensorArrayElement) else v for v in value
|
| 1139 |
+
]
|
| 1140 |
+
if isinstance(value, ABCSeries) and isinstance(value.dtype, TensorDtype):
|
| 1141 |
+
value = value.values
|
| 1142 |
+
if value is None or isinstance(value, Sequence) and len(value) == 0:
|
| 1143 |
+
self._tensor[key] = np.full_like(self._tensor[key], np.nan)
|
| 1144 |
+
elif isinstance(key, (int, slice, np.ndarray)):
|
| 1145 |
+
self._tensor[key] = value
|
| 1146 |
+
else:
|
| 1147 |
+
raise NotImplementedError(
|
| 1148 |
+
f"__setitem__ with key type '{type(key)}' not implemented"
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
def __contains__(self, item) -> bool:
|
| 1152 |
+
"""
|
| 1153 |
+
Return for `item in self`.
|
| 1154 |
+
"""
|
| 1155 |
+
if isinstance(item, TensorArrayElement):
|
| 1156 |
+
np_item = np.asarray(item)
|
| 1157 |
+
if np_item.size == 1 and np.isnan(np_item).all():
|
| 1158 |
+
return self.isna().any()
|
| 1159 |
+
return super().__contains__(item)
|
| 1160 |
+
|
| 1161 |
+
def __repr__(self):
|
| 1162 |
+
return self._tensor.__repr__()
|
| 1163 |
+
|
| 1164 |
+
def __str__(self):
|
| 1165 |
+
return self._tensor.__str__()
|
| 1166 |
+
|
| 1167 |
+
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
|
| 1168 |
+
# TODO(Clark): return self._tensor, np.nan
|
| 1169 |
+
raise NotImplementedError
|
| 1170 |
+
|
| 1171 |
+
def _reduce(self, name: str, skipna: bool = True, **kwargs):
|
| 1172 |
+
"""
|
| 1173 |
+
Return a scalar result of performing the reduction operation.
|
| 1174 |
+
|
| 1175 |
+
Parameters
|
| 1176 |
+
----------
|
| 1177 |
+
name : str
|
| 1178 |
+
Name of the function, supported values are:
|
| 1179 |
+
{ any, all, min, max, sum, mean, median, prod,
|
| 1180 |
+
std, var, sem, kurt, skew }.
|
| 1181 |
+
skipna : bool, default True
|
| 1182 |
+
If True, skip NaN values.
|
| 1183 |
+
**kwargs
|
| 1184 |
+
Additional keyword arguments passed to the reduction function.
|
| 1185 |
+
Currently, `ddof` is the only supported kwarg.
|
| 1186 |
+
|
| 1187 |
+
Returns
|
| 1188 |
+
-------
|
| 1189 |
+
scalar
|
| 1190 |
+
|
| 1191 |
+
Raises
|
| 1192 |
+
------
|
| 1193 |
+
TypeError : subclass does not define reductions
|
| 1194 |
+
"""
|
| 1195 |
+
supported_kwargs = ["ddof"]
|
| 1196 |
+
reducer_kwargs = {}
|
| 1197 |
+
for kw in supported_kwargs:
|
| 1198 |
+
try:
|
| 1199 |
+
reducer_kwargs[kw] = kwargs[kw]
|
| 1200 |
+
except KeyError:
|
| 1201 |
+
pass
|
| 1202 |
+
try:
|
| 1203 |
+
return TensorArrayElement(
|
| 1204 |
+
self.SUPPORTED_REDUCERS[name](self._tensor, axis=0, **reducer_kwargs)
|
| 1205 |
+
)
|
| 1206 |
+
except KeyError:
|
| 1207 |
+
raise NotImplementedError(f"'{name}' aggregate not implemented.") from None
|
| 1208 |
+
|
| 1209 |
+
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
|
| 1210 |
+
return np.asarray(self._tensor, dtype=dtype, **kwargs)
|
| 1211 |
+
|
| 1212 |
+
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs, **kwargs):
|
| 1213 |
+
"""
|
| 1214 |
+
Supports NumPy ufuncs without requiring sloppy coercion to an
|
| 1215 |
+
ndarray.
|
| 1216 |
+
"""
|
| 1217 |
+
out = kwargs.get("out", ())
|
| 1218 |
+
for x in inputs + out:
|
| 1219 |
+
if not isinstance(x, (TensorArray, np.ndarray, numbers.Number)):
|
| 1220 |
+
return NotImplemented
|
| 1221 |
+
|
| 1222 |
+
# Defer to the implementation of the ufunc on unwrapped values.
|
| 1223 |
+
inputs = tuple(x._tensor if isinstance(x, TensorArray) else x for x in inputs)
|
| 1224 |
+
if out:
|
| 1225 |
+
kwargs["out"] = tuple(
|
| 1226 |
+
x._tensor if isinstance(x, TensorArray) else x for x in out
|
| 1227 |
+
)
|
| 1228 |
+
result = getattr(ufunc, method)(*inputs, **kwargs)
|
| 1229 |
+
|
| 1230 |
+
if type(result) is tuple:
|
| 1231 |
+
# Multiple return values.
|
| 1232 |
+
return tuple(type(self)(x) for x in result)
|
| 1233 |
+
elif method == "at":
|
| 1234 |
+
# No return value.
|
| 1235 |
+
return None
|
| 1236 |
+
else:
|
| 1237 |
+
# One return value.
|
| 1238 |
+
return type(self)(result)
|
| 1239 |
+
|
| 1240 |
+
def to_numpy(
|
| 1241 |
+
self,
|
| 1242 |
+
dtype: np.dtype = None,
|
| 1243 |
+
copy: bool = False,
|
| 1244 |
+
na_value: Any = pd.api.extensions.no_default,
|
| 1245 |
+
):
|
| 1246 |
+
"""
|
| 1247 |
+
Convert to a NumPy ndarray.
|
| 1248 |
+
|
| 1249 |
+
.. versionadded:: 1.0.0
|
| 1250 |
+
|
| 1251 |
+
This is similar to :meth:`numpy.asarray`, but may provide additional
|
| 1252 |
+
control over how the conversion is done.
|
| 1253 |
+
|
| 1254 |
+
Parameters
|
| 1255 |
+
----------
|
| 1256 |
+
dtype : str or numpy.dtype, optional
|
| 1257 |
+
The dtype to pass to :meth:`numpy.asarray`.
|
| 1258 |
+
copy : bool, default False
|
| 1259 |
+
Whether to ensure that the returned value is a not a view on
|
| 1260 |
+
another array. Note that ``copy=False`` does not *ensure* that
|
| 1261 |
+
``to_numpy()`` is no-copy. Rather, ``copy=True`` ensure that
|
| 1262 |
+
a copy is made, even if not strictly necessary.
|
| 1263 |
+
na_value : Any, optional
|
| 1264 |
+
The value to use for missing values. The default value depends
|
| 1265 |
+
on `dtype` and the type of the array.
|
| 1266 |
+
|
| 1267 |
+
Returns
|
| 1268 |
+
-------
|
| 1269 |
+
numpy.ndarray
|
| 1270 |
+
"""
|
| 1271 |
+
if dtype is not None:
|
| 1272 |
+
dtype = pd.api.types.pandas_dtype(dtype)
|
| 1273 |
+
if copy:
|
| 1274 |
+
values = np.array(self._tensor, dtype=dtype, copy=True)
|
| 1275 |
+
else:
|
| 1276 |
+
values = self._tensor.astype(dtype)
|
| 1277 |
+
elif copy:
|
| 1278 |
+
values = self._tensor.copy()
|
| 1279 |
+
else:
|
| 1280 |
+
values = self._tensor
|
| 1281 |
+
return values
|
| 1282 |
+
|
| 1283 |
+
@property
|
| 1284 |
+
def numpy_dtype(self):
|
| 1285 |
+
"""
|
| 1286 |
+
Get the dtype of the tensor.
|
| 1287 |
+
:return: The numpy dtype of the backing ndarray
|
| 1288 |
+
"""
|
| 1289 |
+
return self._tensor.dtype
|
| 1290 |
+
|
| 1291 |
+
@property
|
| 1292 |
+
def numpy_ndim(self):
|
| 1293 |
+
"""
|
| 1294 |
+
Get the number of tensor dimensions.
|
| 1295 |
+
:return: integer for the number of dimensions
|
| 1296 |
+
"""
|
| 1297 |
+
return self._tensor.ndim
|
| 1298 |
+
|
| 1299 |
+
@property
|
| 1300 |
+
def numpy_shape(self):
|
| 1301 |
+
"""
|
| 1302 |
+
Get the shape of the tensor.
|
| 1303 |
+
:return: A tuple of integers for the numpy shape of the backing ndarray
|
| 1304 |
+
"""
|
| 1305 |
+
return self._tensor.shape
|
| 1306 |
+
|
| 1307 |
+
@property
|
| 1308 |
+
def numpy_size(self):
|
| 1309 |
+
"""
|
| 1310 |
+
Get the size of the tensor.
|
| 1311 |
+
:return: integer for the number of elements in the tensor
|
| 1312 |
+
"""
|
| 1313 |
+
return self._tensor.size
|
| 1314 |
+
|
| 1315 |
+
def astype(self, dtype, copy=True):
|
| 1316 |
+
"""
|
| 1317 |
+
Cast to a NumPy array with 'dtype'.
|
| 1318 |
+
|
| 1319 |
+
Parameters
|
| 1320 |
+
----------
|
| 1321 |
+
dtype : str or dtype
|
| 1322 |
+
Typecode or data-type to which the array is cast.
|
| 1323 |
+
copy : bool, default True
|
| 1324 |
+
Whether to copy the data, even if not necessary. If False,
|
| 1325 |
+
a copy is made only if the old dtype does not match the
|
| 1326 |
+
new dtype.
|
| 1327 |
+
|
| 1328 |
+
Returns
|
| 1329 |
+
-------
|
| 1330 |
+
array : ndarray
|
| 1331 |
+
NumPy ndarray with 'dtype' for its dtype.
|
| 1332 |
+
"""
|
| 1333 |
+
dtype = pd.api.types.pandas_dtype(dtype)
|
| 1334 |
+
|
| 1335 |
+
if isinstance(dtype, TensorDtype):
|
| 1336 |
+
values = TensorArray(self._tensor.copy()) if copy else self
|
| 1337 |
+
elif not (
|
| 1338 |
+
pd.api.types.is_object_dtype(dtype) and pd.api.types.is_string_dtype(dtype)
|
| 1339 |
+
):
|
| 1340 |
+
values = np.array([str(t) for t in self._tensor])
|
| 1341 |
+
if isinstance(dtype, pd.StringDtype):
|
| 1342 |
+
return dtype.construct_array_type()._from_sequence(values, copy=False)
|
| 1343 |
+
else:
|
| 1344 |
+
return values
|
| 1345 |
+
elif pd.api.types.is_object_dtype(dtype):
|
| 1346 |
+
# Interpret astype(object) as "cast to an array of numpy arrays"
|
| 1347 |
+
values = np.empty(len(self), dtype=object)
|
| 1348 |
+
for i in range(len(self)):
|
| 1349 |
+
values[i] = self._tensor[i]
|
| 1350 |
+
else:
|
| 1351 |
+
values = self._tensor.astype(dtype, copy=copy)
|
| 1352 |
+
return values
|
| 1353 |
+
|
| 1354 |
+
def any(self, axis=None, out=None, keepdims=False):
|
| 1355 |
+
"""
|
| 1356 |
+
Test whether any array element along a given axis evaluates to True.
|
| 1357 |
+
|
| 1358 |
+
See numpy.any() documentation for more information
|
| 1359 |
+
https://numpy.org/doc/stable/reference/generated/numpy.any.html#numpy.any
|
| 1360 |
+
|
| 1361 |
+
:param axis: Axis or axes along which a logical OR reduction is
|
| 1362 |
+
performed.
|
| 1363 |
+
:param out: Alternate output array in which to place the result.
|
| 1364 |
+
:param keepdims: If this is set to True, the axes which are reduced are
|
| 1365 |
+
left in the result as dimensions with size one.
|
| 1366 |
+
:return: single boolean unless axis is not None else TensorArray
|
| 1367 |
+
"""
|
| 1368 |
+
result = self._tensor.any(axis=axis, out=out, keepdims=keepdims)
|
| 1369 |
+
return result if axis is None else TensorArray(result)
|
| 1370 |
+
|
| 1371 |
+
def all(self, axis=None, out=None, keepdims=False):
|
| 1372 |
+
"""
|
| 1373 |
+
Test whether all array elements along a given axis evaluate to True.
|
| 1374 |
+
|
| 1375 |
+
:param axis: Axis or axes along which a logical AND reduction is
|
| 1376 |
+
performed.
|
| 1377 |
+
:param out: Alternate output array in which to place the result.
|
| 1378 |
+
:param keepdims: If this is set to True, the axes which are reduced are
|
| 1379 |
+
left in the result as dimensions with size one.
|
| 1380 |
+
:return: single boolean unless axis is not None else TensorArray
|
| 1381 |
+
"""
|
| 1382 |
+
result = self._tensor.all(axis=axis, out=out, keepdims=keepdims)
|
| 1383 |
+
return result if axis is None else TensorArray(result)
|
| 1384 |
+
|
| 1385 |
+
def __arrow_array__(self, type=None):
|
| 1386 |
+
"""
|
| 1387 |
+
Convert this TensorArray to an ArrowTensorArray extension array.
|
| 1388 |
+
|
| 1389 |
+
This and TensorDtype.__from_arrow__ make up the
|
| 1390 |
+
Pandas extension type + array <--> Arrow extension type + array
|
| 1391 |
+
interoperability protocol. See
|
| 1392 |
+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
|
| 1393 |
+
for more information.
|
| 1394 |
+
"""
|
| 1395 |
+
from ray.air.util.tensor_extensions.arrow import (
|
| 1396 |
+
ArrowTensorArray,
|
| 1397 |
+
ArrowVariableShapedTensorArray,
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
if self.is_variable_shaped:
|
| 1401 |
+
return ArrowVariableShapedTensorArray.from_numpy(self._tensor)
|
| 1402 |
+
else:
|
| 1403 |
+
return ArrowTensorArray.from_numpy(self._tensor)
|
| 1404 |
+
|
| 1405 |
+
@property
|
| 1406 |
+
def _is_boolean(self):
|
| 1407 |
+
"""
|
| 1408 |
+
Whether this extension array should be considered boolean.
|
| 1409 |
+
|
| 1410 |
+
By default, ExtensionArrays are assumed to be non-numeric.
|
| 1411 |
+
Setting this to True will affect the behavior of several places,
|
| 1412 |
+
e.g.
|
| 1413 |
+
|
| 1414 |
+
* is_bool
|
| 1415 |
+
* boolean indexing
|
| 1416 |
+
|
| 1417 |
+
Returns
|
| 1418 |
+
-------
|
| 1419 |
+
bool
|
| 1420 |
+
"""
|
| 1421 |
+
# This is needed to support returning a TensorArray from .isnan().
|
| 1422 |
+
return self.dtype._is_boolean()
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
# Add operators from the mixin to the TensorArrayElement and TensorArray
|
| 1426 |
+
# classes.
|
| 1427 |
+
TensorArrayElement._add_arithmetic_ops()
|
| 1428 |
+
TensorArrayElement._add_comparison_ops()
|
| 1429 |
+
TensorArrayElement._add_logical_ops()
|
| 1430 |
+
TensorArray._add_arithmetic_ops()
|
| 1431 |
+
TensorArray._add_comparison_ops()
|
| 1432 |
+
TensorArray._add_logical_ops()
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
@PublicAPI(stability="beta")
|
| 1436 |
+
def column_needs_tensor_extension(s: pd.Series) -> bool:
|
| 1437 |
+
"""Return whether the provided pandas Series column needs a tensor extension
|
| 1438 |
+
representation. This tensor extension representation provides more efficient slicing
|
| 1439 |
+
and interop with ML frameworks.
|
| 1440 |
+
|
| 1441 |
+
Args:
|
| 1442 |
+
s: The pandas Series column that may need to be represented using the tensor
|
| 1443 |
+
extension.
|
| 1444 |
+
|
| 1445 |
+
Returns:
|
| 1446 |
+
Whether the provided Series needs a tensor extension representation.
|
| 1447 |
+
"""
|
| 1448 |
+
# NOTE: This is an O(1) check.
|
| 1449 |
+
return (
|
| 1450 |
+
s.dtype.type is np.object_ and not s.empty and isinstance(s.iloc[0], np.ndarray)
|
| 1451 |
+
)
|
.venv/lib/python3.11/site-packages/ray/air/util/torch_dist.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file is modeled after ray/python/ray/train/torch/config.py
|
| 2 |
+
|
| 3 |
+
The logics are duplicated right now to allow maximum flexibility for
|
| 4 |
+
setting up PyTorch DDP process groups outside the context of Ray Train.
|
| 5 |
+
Eventually, these use cases should be consolidated.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from abc import ABC
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from datetime import timedelta
|
| 12 |
+
from typing import Callable, List, T
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
from ray.actor import ActorHandle
|
| 19 |
+
from ray.air._internal.torch_utils import get_devices
|
| 20 |
+
from ray.train._internal.utils import get_address_and_port
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TorchDistributedWorker(ABC):
|
| 24 |
+
"""Defines the interfaces required by the init_torch_dist_process_group().
|
| 25 |
+
|
| 26 |
+
This is modeled after RayTrainerWorker, which allows arbitrary functions
|
| 27 |
+
to be executed on a remote DDP worker.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
|
| 31 |
+
"""Executes the input function and returns the output.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
func: The function to execute.
|
| 35 |
+
args, kwargs: The arguments to pass into func.
|
| 36 |
+
"""
|
| 37 |
+
return func(*args, **kwargs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _init_torch_distributed(
|
| 41 |
+
init_method: str,
|
| 42 |
+
backend: str,
|
| 43 |
+
rank: int,
|
| 44 |
+
world_size: int,
|
| 45 |
+
local_rank: int,
|
| 46 |
+
local_world_size: int,
|
| 47 |
+
master_addr: str,
|
| 48 |
+
master_port: str,
|
| 49 |
+
gpu_ids: List[int],
|
| 50 |
+
**init_process_group_kwargs,
|
| 51 |
+
):
|
| 52 |
+
"""Initialize torch distributed backend"""
|
| 53 |
+
if init_method == "env":
|
| 54 |
+
os.environ["MASTER_ADDR"] = str(master_addr)
|
| 55 |
+
os.environ["MASTER_PORT"] = str(master_port)
|
| 56 |
+
url = "env://"
|
| 57 |
+
elif init_method == "tcp":
|
| 58 |
+
url = f"tcp://{master_addr}:{master_port}"
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"The provided init_method ("
|
| 62 |
+
f"{init_method}) is not supported. Must "
|
| 63 |
+
f"be either 'env' or 'tcp'."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if backend == "nccl":
|
| 67 |
+
# Same as in Ray Train
|
| 68 |
+
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
| 69 |
+
# All workers on a same node should share the same set of
|
| 70 |
+
# visible GPUs. Otherwise they can't talk among themselves.
|
| 71 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gid) for gid in gpu_ids)
|
| 72 |
+
|
| 73 |
+
init_process_group_kwargs.update(
|
| 74 |
+
dict(
|
| 75 |
+
backend=backend,
|
| 76 |
+
init_method=url,
|
| 77 |
+
rank=rank,
|
| 78 |
+
world_size=world_size,
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
init_process_group_kwargs.setdefault("timeout", timedelta(seconds=1800))
|
| 82 |
+
|
| 83 |
+
dist.init_process_group(**init_process_group_kwargs)
|
| 84 |
+
|
| 85 |
+
os.environ["RANK"] = str(rank)
|
| 86 |
+
os.environ["LOCAL_RANK"] = str(local_rank)
|
| 87 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 88 |
+
os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _get_node_and_gpu_ids():
|
| 92 |
+
"""Returns the node_id and gpu_ids for this worker."""
|
| 93 |
+
node_id = ray.get_runtime_context().get_node_id()
|
| 94 |
+
gpu_ids = ray.get_gpu_ids()
|
| 95 |
+
return node_id, gpu_ids
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def init_torch_dist_process_group(
|
| 99 |
+
workers: List[ActorHandle],
|
| 100 |
+
backend: str = "gloo",
|
| 101 |
+
init_method: str = "env",
|
| 102 |
+
**init_process_group_kwargs,
|
| 103 |
+
) -> List[int]:
|
| 104 |
+
"""Initialize a torch distributed process group.
|
| 105 |
+
|
| 106 |
+
Note: this util assumes that the order of the workers passed in
|
| 107 |
+
are their global ranks.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
workers: A list of TorchDistributedWorker actors.
|
| 111 |
+
backend: The torch distributed backend to use,
|
| 112 |
+
possible choices are "gloo" or "nccl".
|
| 113 |
+
init_method: The initialization method to use,
|
| 114 |
+
possible choices are "env" or "tcp".
|
| 115 |
+
init_process_group_kwargs: Additional kwargs to pass to the call to
|
| 116 |
+
:meth:`torch.distributed.init_process_group`.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Local ranks on their respective nodes for the list of workers.
|
| 120 |
+
"""
|
| 121 |
+
if not dist.is_available():
|
| 122 |
+
raise RuntimeError("Distributed torch is not available.")
|
| 123 |
+
|
| 124 |
+
# Build a map from node_id to workers on that node.
|
| 125 |
+
node_and_gpu_ids = ray.get(
|
| 126 |
+
[w.execute.remote(_get_node_and_gpu_ids) for w in workers]
|
| 127 |
+
)
|
| 128 |
+
# All the workers on a specific node.
|
| 129 |
+
node_to_workers = defaultdict(list)
|
| 130 |
+
# All the gpu ids visible to all the workers on a specific node.
|
| 131 |
+
node_to_gpu_ids = defaultdict(set)
|
| 132 |
+
for i, (node_id, gpu_ids) in enumerate(node_and_gpu_ids):
|
| 133 |
+
node_to_workers[node_id].append(i)
|
| 134 |
+
# Force list.
|
| 135 |
+
if not isinstance(gpu_ids, list):
|
| 136 |
+
gpu_ids = [gpu_ids]
|
| 137 |
+
# It is possible for a worker to have access to multiple GPUs.
|
| 138 |
+
for gpu_id in gpu_ids:
|
| 139 |
+
node_to_gpu_ids[node_id].add(gpu_id)
|
| 140 |
+
|
| 141 |
+
# Assume the first worker is the master.
|
| 142 |
+
master_addr, master_port = ray.get(workers[0].execute.remote(get_address_and_port))
|
| 143 |
+
|
| 144 |
+
setup_futures = []
|
| 145 |
+
world_size = len(workers)
|
| 146 |
+
local_ranks = []
|
| 147 |
+
for rank, worker in enumerate(workers):
|
| 148 |
+
node_id = node_and_gpu_ids[rank][0]
|
| 149 |
+
local_rank = node_to_workers[node_id].index(rank)
|
| 150 |
+
local_world_size = len(node_to_workers[node_id])
|
| 151 |
+
setup_futures.append(
|
| 152 |
+
worker.execute.remote(
|
| 153 |
+
_init_torch_distributed,
|
| 154 |
+
init_method=init_method,
|
| 155 |
+
backend=backend,
|
| 156 |
+
rank=rank,
|
| 157 |
+
world_size=world_size,
|
| 158 |
+
local_rank=local_rank,
|
| 159 |
+
local_world_size=local_world_size,
|
| 160 |
+
master_addr=master_addr,
|
| 161 |
+
master_port=master_port,
|
| 162 |
+
# list(set) will sort the gpu ids, so VISIBLE_CUDA_DEVICES
|
| 163 |
+
# is always sorted.
|
| 164 |
+
gpu_ids=list(node_to_gpu_ids[node_id]),
|
| 165 |
+
**init_process_group_kwargs,
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
local_ranks.append(local_rank)
|
| 169 |
+
|
| 170 |
+
# Wait for all workers to join the process group.
|
| 171 |
+
ray.get(setup_futures)
|
| 172 |
+
|
| 173 |
+
return local_ranks
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _shutdown_torch_distributed():
|
| 177 |
+
"""Shutdown torch distributed backend"""
|
| 178 |
+
dist.destroy_process_group()
|
| 179 |
+
|
| 180 |
+
if not torch.cuda.is_available():
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Clean up cuda memory.
|
| 184 |
+
devices = get_devices()
|
| 185 |
+
for device in devices:
|
| 186 |
+
with torch.cuda.device(device):
|
| 187 |
+
torch.cuda.empty_cache()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def shutdown_torch_dist_process_group(workers: List[ActorHandle]):
|
| 191 |
+
ray.get([w.execute.remote(_shutdown_torch_distributed) for w in workers])
|
.venv/lib/python3.11/site-packages/ray/air/util/transform_pyarrow.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import pyarrow
|
| 3 |
+
except ImportError:
|
| 4 |
+
pyarrow = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _is_column_extension_type(ca: "pyarrow.ChunkedArray") -> bool:
|
| 8 |
+
"""Whether the provided Arrow Table column is an extension array, using an Arrow
|
| 9 |
+
extension type.
|
| 10 |
+
"""
|
| 11 |
+
return isinstance(ca.type, pyarrow.ExtensionType)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _concatenate_extension_column(ca: "pyarrow.ChunkedArray") -> "pyarrow.Array":
|
| 15 |
+
"""Concatenate chunks of an extension column into a contiguous array.
|
| 16 |
+
|
| 17 |
+
This concatenation is required for creating copies and for .take() to work on
|
| 18 |
+
extension arrays.
|
| 19 |
+
See https://issues.apache.org/jira/browse/ARROW-16503.
|
| 20 |
+
"""
|
| 21 |
+
from ray.air.util.tensor_extensions.arrow import (
|
| 22 |
+
ArrowTensorArray,
|
| 23 |
+
get_arrow_extension_tensor_types,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if not _is_column_extension_type(ca):
|
| 27 |
+
raise ValueError("Chunked array isn't an extension array: {ca}")
|
| 28 |
+
|
| 29 |
+
tensor_extension_types = get_arrow_extension_tensor_types()
|
| 30 |
+
|
| 31 |
+
if ca.num_chunks == 0:
|
| 32 |
+
# Create empty storage array.
|
| 33 |
+
storage = pyarrow.array([], type=ca.type.storage_type)
|
| 34 |
+
elif isinstance(ca.type, tensor_extension_types):
|
| 35 |
+
return ArrowTensorArray._concat_same_type(ca.chunks)
|
| 36 |
+
else:
|
| 37 |
+
storage = pyarrow.concat_arrays([c.storage for c in ca.chunks])
|
| 38 |
+
|
| 39 |
+
return ca.type.__arrow_ext_class__().from_storage(ca.type, storage)
|
.venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fed1e2f11cb7a8c80b117de5ff1af60d479c6ccb4594d860948a52f971a6ec45
|
| 3 |
+
size 125314
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_noop_latency.cpython-311.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_throughput.cpython-311.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/http_noop_latency.cpython-311.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/microbenchmark.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/proxy_benchmark.cpython-311.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/common.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import inspect
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import string
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any, Callable, Coroutine, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import aiohttp
|
| 11 |
+
import aiohttp.client_exceptions
|
| 12 |
+
import grpc
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from starlette.responses import StreamingResponse
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from ray import serve
|
| 19 |
+
from ray.serve.generated import serve_pb2, serve_pb2_grpc
|
| 20 |
+
from ray.serve.handle import DeploymentHandle
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def run_latency_benchmark(
|
| 24 |
+
f: Callable, num_requests: int, *, num_warmup_requests: int = 100
|
| 25 |
+
) -> pd.Series:
|
| 26 |
+
if inspect.iscoroutinefunction(f):
|
| 27 |
+
to_call = f
|
| 28 |
+
else:
|
| 29 |
+
|
| 30 |
+
async def to_call():
|
| 31 |
+
f()
|
| 32 |
+
|
| 33 |
+
latencies = []
|
| 34 |
+
for i in tqdm(range(num_requests + num_warmup_requests)):
|
| 35 |
+
start = time.perf_counter()
|
| 36 |
+
await to_call()
|
| 37 |
+
end = time.perf_counter()
|
| 38 |
+
|
| 39 |
+
# Don't include warm-up requests.
|
| 40 |
+
if i >= num_warmup_requests:
|
| 41 |
+
latencies.append(1000 * (end - start))
|
| 42 |
+
|
| 43 |
+
return pd.Series(latencies)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def run_throughput_benchmark(
|
| 47 |
+
fn: Callable[[], List[float]],
|
| 48 |
+
multiplier: int = 1,
|
| 49 |
+
num_trials: int = 10,
|
| 50 |
+
trial_runtime: float = 1,
|
| 51 |
+
) -> Tuple[float, float, pd.Series]:
|
| 52 |
+
"""Benchmarks throughput of a function.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
fn: The function to benchmark. If this returns anything, it must
|
| 56 |
+
return a list of latencies.
|
| 57 |
+
multiplier: The number of requests or tokens (or whatever unit
|
| 58 |
+
is appropriate for this throughput benchmark) that is
|
| 59 |
+
completed in one call to `fn`.
|
| 60 |
+
num_trials: The number of trials to run.
|
| 61 |
+
trial_runtime: How long each trial should run for. During the
|
| 62 |
+
duration of one trial, `fn` will be repeatedly called.
|
| 63 |
+
|
| 64 |
+
Returns (mean, stddev, latencies).
|
| 65 |
+
"""
|
| 66 |
+
# Warmup
|
| 67 |
+
start = time.time()
|
| 68 |
+
while time.time() - start < 0.1:
|
| 69 |
+
await fn()
|
| 70 |
+
|
| 71 |
+
# Benchmark
|
| 72 |
+
stats = []
|
| 73 |
+
latencies = []
|
| 74 |
+
for _ in tqdm(range(num_trials)):
|
| 75 |
+
start = time.perf_counter()
|
| 76 |
+
count = 0
|
| 77 |
+
while time.perf_counter() - start < trial_runtime:
|
| 78 |
+
res = await fn()
|
| 79 |
+
if res:
|
| 80 |
+
latencies.extend(res)
|
| 81 |
+
|
| 82 |
+
count += 1
|
| 83 |
+
end = time.perf_counter()
|
| 84 |
+
stats.append(multiplier * count / (end - start))
|
| 85 |
+
|
| 86 |
+
return round(np.mean(stats), 2), round(np.std(stats), 2), pd.Series(latencies)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
async def do_single_http_batch(
|
| 90 |
+
*,
|
| 91 |
+
batch_size: int = 100,
|
| 92 |
+
url: str = "http://localhost:8000",
|
| 93 |
+
stream: bool = False,
|
| 94 |
+
) -> List[float]:
|
| 95 |
+
"""Sends a batch of http requests and returns e2e latencies."""
|
| 96 |
+
|
| 97 |
+
# By default, aiohttp limits the number of client connections to 100.
|
| 98 |
+
# We need to use TCPConnector to configure the limit if batch size
|
| 99 |
+
# is greater than 100.
|
| 100 |
+
connector = aiohttp.TCPConnector(limit=batch_size)
|
| 101 |
+
async with aiohttp.ClientSession(
|
| 102 |
+
connector=connector, raise_for_status=True
|
| 103 |
+
) as session:
|
| 104 |
+
|
| 105 |
+
async def do_query():
|
| 106 |
+
start = time.perf_counter()
|
| 107 |
+
try:
|
| 108 |
+
if stream:
|
| 109 |
+
async with session.get(url) as r:
|
| 110 |
+
async for chunk, _ in r.content.iter_chunks():
|
| 111 |
+
pass
|
| 112 |
+
else:
|
| 113 |
+
await session.get(url)
|
| 114 |
+
except aiohttp.client_exceptions.ClientConnectionError:
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
end = time.perf_counter()
|
| 118 |
+
return 1000 * (end - start)
|
| 119 |
+
|
| 120 |
+
return await asyncio.gather(*[do_query() for _ in range(batch_size)])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def do_single_grpc_batch(
|
| 124 |
+
*, batch_size: int = 100, target: str = "localhost:9000"
|
| 125 |
+
):
|
| 126 |
+
channel = grpc.aio.insecure_channel(target)
|
| 127 |
+
stub = serve_pb2_grpc.RayServeBenchmarkServiceStub(channel)
|
| 128 |
+
payload = serve_pb2.StringData(data="")
|
| 129 |
+
|
| 130 |
+
async def do_query():
|
| 131 |
+
start = time.perf_counter()
|
| 132 |
+
|
| 133 |
+
await stub.grpc_call(payload)
|
| 134 |
+
|
| 135 |
+
end = time.perf_counter()
|
| 136 |
+
return 1000 * (end - start)
|
| 137 |
+
|
| 138 |
+
return await asyncio.gather(*[do_query() for _ in range(batch_size)])
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
async def collect_profile_events(coro: Coroutine):
|
| 142 |
+
"""Collects profiling events using Viztracer"""
|
| 143 |
+
|
| 144 |
+
from viztracer import VizTracer
|
| 145 |
+
|
| 146 |
+
tracer = VizTracer()
|
| 147 |
+
tracer.start()
|
| 148 |
+
|
| 149 |
+
await coro
|
| 150 |
+
|
| 151 |
+
tracer.stop()
|
| 152 |
+
tracer.save()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def generate_payload(size: int = 100, chars=string.ascii_uppercase + string.digits):
|
| 156 |
+
return "".join(random.choice(chars) for _ in range(size))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Blackhole:
|
| 160 |
+
def sink(self, o):
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@serve.deployment
|
| 165 |
+
class Noop:
|
| 166 |
+
def __init__(self):
|
| 167 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 168 |
+
|
| 169 |
+
def __call__(self, *args, **kwargs):
|
| 170 |
+
return b""
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@serve.deployment
|
| 174 |
+
class Streamer:
|
| 175 |
+
def __init__(self, tokens_per_request: int, inter_token_delay_ms: int = 10):
|
| 176 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 177 |
+
self._tokens_per_request = tokens_per_request
|
| 178 |
+
self._inter_token_delay_s = inter_token_delay_ms / 1000
|
| 179 |
+
|
| 180 |
+
async def stream(self):
|
| 181 |
+
for _ in range(self._tokens_per_request):
|
| 182 |
+
await asyncio.sleep(self._inter_token_delay_s)
|
| 183 |
+
yield b"hi"
|
| 184 |
+
|
| 185 |
+
async def __call__(self):
|
| 186 |
+
return StreamingResponse(self.stream())
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@serve.deployment
|
| 190 |
+
class IntermediateRouter:
|
| 191 |
+
def __init__(self, handle: DeploymentHandle):
|
| 192 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 193 |
+
self._handle = handle.options(stream=True)
|
| 194 |
+
|
| 195 |
+
async def stream(self):
|
| 196 |
+
async for token in self._handle.stream.remote():
|
| 197 |
+
yield token
|
| 198 |
+
|
| 199 |
+
def __call__(self):
|
| 200 |
+
return StreamingResponse(self.stream())
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@serve.deployment
|
| 204 |
+
class Benchmarker:
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
handle: DeploymentHandle,
|
| 208 |
+
stream: bool = False,
|
| 209 |
+
):
|
| 210 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 211 |
+
self._handle = handle.options(stream=stream)
|
| 212 |
+
self._stream = stream
|
| 213 |
+
|
| 214 |
+
async def do_single_request(self, payload: Any = None) -> float:
|
| 215 |
+
"""Completes a single unary request. Returns e2e latency in ms."""
|
| 216 |
+
start = time.perf_counter()
|
| 217 |
+
|
| 218 |
+
if payload is None:
|
| 219 |
+
await self._handle.remote()
|
| 220 |
+
else:
|
| 221 |
+
await self._handle.remote(payload)
|
| 222 |
+
|
| 223 |
+
end = time.perf_counter()
|
| 224 |
+
return 1000 * (end - start)
|
| 225 |
+
|
| 226 |
+
async def _do_single_stream(self) -> float:
|
| 227 |
+
"""Consumes a single streaming request. Returns e2e latency in ms."""
|
| 228 |
+
start = time.perf_counter()
|
| 229 |
+
|
| 230 |
+
async for r in self._handle.stream.remote():
|
| 231 |
+
pass
|
| 232 |
+
|
| 233 |
+
end = time.perf_counter()
|
| 234 |
+
return 1000 * (end - start)
|
| 235 |
+
|
| 236 |
+
async def _do_single_batch(self, batch_size: int) -> List[float]:
|
| 237 |
+
if self._stream:
|
| 238 |
+
return await asyncio.gather(
|
| 239 |
+
*[self._do_single_stream() for _ in range(batch_size)]
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
return await asyncio.gather(
|
| 243 |
+
*[self.do_single_request() for _ in range(batch_size)]
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
async def run_latency_benchmark(
|
| 247 |
+
self, *, num_requests: int, payload: Any = None
|
| 248 |
+
) -> pd.Series:
|
| 249 |
+
async def f():
|
| 250 |
+
await self.do_single_request(payload)
|
| 251 |
+
|
| 252 |
+
return await run_latency_benchmark(f, num_requests=num_requests)
|
| 253 |
+
|
| 254 |
+
async def run_throughput_benchmark(
|
| 255 |
+
self,
|
| 256 |
+
*,
|
| 257 |
+
batch_size: int,
|
| 258 |
+
num_trials: int,
|
| 259 |
+
trial_runtime: float,
|
| 260 |
+
tokens_per_request: Optional[float] = None,
|
| 261 |
+
) -> Tuple[float, float]:
|
| 262 |
+
if self._stream:
|
| 263 |
+
assert tokens_per_request
|
| 264 |
+
multiplier = tokens_per_request * batch_size
|
| 265 |
+
else:
|
| 266 |
+
multiplier = batch_size
|
| 267 |
+
|
| 268 |
+
return await run_throughput_benchmark(
|
| 269 |
+
fn=partial(
|
| 270 |
+
self._do_single_batch,
|
| 271 |
+
batch_size=batch_size,
|
| 272 |
+
),
|
| 273 |
+
multiplier=multiplier,
|
| 274 |
+
num_trials=num_trials,
|
| 275 |
+
trial_runtime=trial_runtime,
|
| 276 |
+
)
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_noop_latency.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from ray import serve
|
| 7 |
+
from ray.serve._private.benchmarks.common import Benchmarker, Noop
|
| 8 |
+
from ray.serve.handle import DeploymentHandle
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@click.command(help="Benchmark no-op DeploymentHandle latency.")
|
| 12 |
+
@click.option("--num-replicas", type=int, default=1)
|
| 13 |
+
@click.option("--num-requests", type=int, default=100)
|
| 14 |
+
def main(num_replicas: int, num_requests: int):
|
| 15 |
+
h: DeploymentHandle = serve.run(
|
| 16 |
+
Benchmarker.bind(Noop.options(num_replicas=num_replicas).bind())
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
latencies: pd.Series = h.run_latency_benchmark.remote(
|
| 20 |
+
num_requests,
|
| 21 |
+
).result()
|
| 22 |
+
|
| 23 |
+
# Let the logs flush to avoid interwoven output.
|
| 24 |
+
time.sleep(1)
|
| 25 |
+
|
| 26 |
+
print(
|
| 27 |
+
"Latency (ms) for noop DeploymentHandle requests "
|
| 28 |
+
f"(num_replicas={num_replicas},num_requests={num_requests}):"
|
| 29 |
+
)
|
| 30 |
+
print(latencies.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
main()
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_throughput.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
|
| 3 |
+
from ray import serve
|
| 4 |
+
from ray.serve._private.benchmarks.common import Benchmarker, Hello
|
| 5 |
+
from ray.serve.handle import DeploymentHandle
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@click.command(help="Benchmark deployment handle throughput.")
|
| 9 |
+
@click.option(
|
| 10 |
+
"--batch-size",
|
| 11 |
+
type=int,
|
| 12 |
+
default=100,
|
| 13 |
+
help="Number of requests to send to downstream deployment in each trial.",
|
| 14 |
+
)
|
| 15 |
+
@click.option(
|
| 16 |
+
"--num-replicas",
|
| 17 |
+
type=int,
|
| 18 |
+
default=1,
|
| 19 |
+
help="Number of replicas in the downstream deployment.",
|
| 20 |
+
)
|
| 21 |
+
@click.option(
|
| 22 |
+
"--num-trials",
|
| 23 |
+
type=int,
|
| 24 |
+
default=5,
|
| 25 |
+
help="Number of trials of the benchmark to run.",
|
| 26 |
+
)
|
| 27 |
+
@click.option(
|
| 28 |
+
"--trial-runtime",
|
| 29 |
+
type=int,
|
| 30 |
+
default=1,
|
| 31 |
+
help="Duration to run each trial of the benchmark for (seconds).",
|
| 32 |
+
)
|
| 33 |
+
def main(
|
| 34 |
+
batch_size: int,
|
| 35 |
+
num_replicas: int,
|
| 36 |
+
num_trials: int,
|
| 37 |
+
trial_runtime: float,
|
| 38 |
+
):
|
| 39 |
+
app = Benchmarker.bind(
|
| 40 |
+
Hello.options(
|
| 41 |
+
num_replicas=num_replicas, ray_actor_options={"num_cpus": 0}
|
| 42 |
+
).bind(),
|
| 43 |
+
)
|
| 44 |
+
h: DeploymentHandle = serve.run(app)
|
| 45 |
+
|
| 46 |
+
mean, stddev = h.run_throughput_benchmark.remote(
|
| 47 |
+
batch_size=batch_size,
|
| 48 |
+
num_trials=num_trials,
|
| 49 |
+
trial_runtime=trial_runtime,
|
| 50 |
+
).result()
|
| 51 |
+
|
| 52 |
+
print(
|
| 53 |
+
"DeploymentHandle throughput {}: {} +- {} requests/s".format(
|
| 54 |
+
f"(num_replicas={num_replicas}, batch_size={batch_size})",
|
| 55 |
+
mean,
|
| 56 |
+
stddev,
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/http_noop_latency.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
from ray import serve
|
| 8 |
+
from ray.serve._private.benchmarks.common import Noop, run_latency_benchmark
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@click.command(help="Benchmark no-op HTTP latency.")
|
| 12 |
+
@click.option("--num-replicas", type=int, default=1)
|
| 13 |
+
@click.option("--num-requests", type=int, default=100)
|
| 14 |
+
def main(num_replicas: int, num_requests: int):
|
| 15 |
+
serve.run(Noop.options(num_replicas=num_replicas).bind())
|
| 16 |
+
|
| 17 |
+
latencies: pd.Series = asyncio.new_event_loop().run_until_complete(
|
| 18 |
+
run_latency_benchmark(
|
| 19 |
+
lambda: requests.get("http://localhost:8000"),
|
| 20 |
+
num_requests=num_requests,
|
| 21 |
+
)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
print(
|
| 25 |
+
"Latency (ms) for noop HTTP requests "
|
| 26 |
+
f"(num_replicas={num_replicas},num_requests={num_requests}):"
|
| 27 |
+
)
|
| 28 |
+
print(latencies.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/microbenchmark.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Runs several scenarios with varying max batch size, max concurrent queries,
|
| 2 |
+
# number of replicas, and with intermediate serve handles (to simulate ensemble
|
| 3 |
+
# models) either on or off.
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
from pprint import pprint
|
| 8 |
+
from typing import Dict, Union
|
| 9 |
+
|
| 10 |
+
import aiohttp
|
| 11 |
+
from starlette.requests import Request
|
| 12 |
+
|
| 13 |
+
import ray
|
| 14 |
+
from ray import serve
|
| 15 |
+
from ray.serve._private.benchmarks.common import run_throughput_benchmark
|
| 16 |
+
from ray.serve.handle import DeploymentHandle
|
| 17 |
+
|
| 18 |
+
NUM_CLIENTS = 8
|
| 19 |
+
CALLS_PER_BATCH = 100
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def fetch(session, data):
|
| 23 |
+
async with session.get("http://localhost:8000/", data=data) as response:
|
| 24 |
+
response = await response.text()
|
| 25 |
+
assert response == "ok", response
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@ray.remote
|
| 29 |
+
class Client:
|
| 30 |
+
def ready(self):
|
| 31 |
+
return "ok"
|
| 32 |
+
|
| 33 |
+
async def do_queries(self, num, data):
|
| 34 |
+
async with aiohttp.ClientSession() as session:
|
| 35 |
+
for _ in range(num):
|
| 36 |
+
await fetch(session, data)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_app(
|
| 40 |
+
intermediate_handles: bool,
|
| 41 |
+
num_replicas: int,
|
| 42 |
+
max_batch_size: int,
|
| 43 |
+
max_ongoing_requests: int,
|
| 44 |
+
):
|
| 45 |
+
@serve.deployment(max_ongoing_requests=1000)
|
| 46 |
+
class Upstream:
|
| 47 |
+
def __init__(self, handle: DeploymentHandle):
|
| 48 |
+
self._handle = handle
|
| 49 |
+
|
| 50 |
+
# Turn off access log.
|
| 51 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 52 |
+
|
| 53 |
+
async def __call__(self, req: Request):
|
| 54 |
+
return await self._handle.remote(await req.body())
|
| 55 |
+
|
| 56 |
+
@serve.deployment(
|
| 57 |
+
num_replicas=num_replicas,
|
| 58 |
+
max_ongoing_requests=max_ongoing_requests,
|
| 59 |
+
)
|
| 60 |
+
class Downstream:
|
| 61 |
+
def __init__(self):
|
| 62 |
+
# Turn off access log.
|
| 63 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 64 |
+
|
| 65 |
+
@serve.batch(max_batch_size=max_batch_size)
|
| 66 |
+
async def batch(self, reqs):
|
| 67 |
+
return [b"ok"] * len(reqs)
|
| 68 |
+
|
| 69 |
+
async def __call__(self, req: Union[bytes, Request]):
|
| 70 |
+
if max_batch_size > 1:
|
| 71 |
+
return await self.batch(req)
|
| 72 |
+
else:
|
| 73 |
+
return b"ok"
|
| 74 |
+
|
| 75 |
+
if intermediate_handles:
|
| 76 |
+
return Upstream.bind(Downstream.bind())
|
| 77 |
+
else:
|
| 78 |
+
return Downstream.bind()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
async def trial(
|
| 82 |
+
intermediate_handles: bool,
|
| 83 |
+
num_replicas: int,
|
| 84 |
+
max_batch_size: int,
|
| 85 |
+
max_ongoing_requests: int,
|
| 86 |
+
data_size: str,
|
| 87 |
+
) -> Dict[str, float]:
|
| 88 |
+
results = {}
|
| 89 |
+
|
| 90 |
+
trial_key_base = (
|
| 91 |
+
f"replica:{num_replicas}/batch_size:{max_batch_size}/"
|
| 92 |
+
f"concurrent_queries:{max_ongoing_requests}/"
|
| 93 |
+
f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
print(
|
| 97 |
+
f"intermediate_handles={intermediate_handles},"
|
| 98 |
+
f"num_replicas={num_replicas},"
|
| 99 |
+
f"max_batch_size={max_batch_size},"
|
| 100 |
+
f"max_ongoing_requests={max_ongoing_requests},"
|
| 101 |
+
f"data_size={data_size}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
app = build_app(
|
| 105 |
+
intermediate_handles, num_replicas, max_batch_size, max_ongoing_requests
|
| 106 |
+
)
|
| 107 |
+
serve.run(app)
|
| 108 |
+
|
| 109 |
+
if data_size == "small":
|
| 110 |
+
data = None
|
| 111 |
+
elif data_size == "large":
|
| 112 |
+
data = b"a" * 1024 * 1024
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("data_size should be 'small' or 'large'.")
|
| 115 |
+
|
| 116 |
+
async with aiohttp.ClientSession() as session:
|
| 117 |
+
|
| 118 |
+
async def single_client():
|
| 119 |
+
for _ in range(CALLS_PER_BATCH):
|
| 120 |
+
await fetch(session, data)
|
| 121 |
+
|
| 122 |
+
single_client_avg_tps, single_client_std_tps = await run_throughput_benchmark(
|
| 123 |
+
single_client,
|
| 124 |
+
multiplier=CALLS_PER_BATCH,
|
| 125 |
+
)
|
| 126 |
+
print(
|
| 127 |
+
"\t{} {} +- {} requests/s".format(
|
| 128 |
+
"single client {} data".format(data_size),
|
| 129 |
+
single_client_avg_tps,
|
| 130 |
+
single_client_std_tps,
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
key = f"num_client:1/{trial_key_base}"
|
| 134 |
+
results[key] = single_client_avg_tps
|
| 135 |
+
|
| 136 |
+
clients = [Client.remote() for _ in range(NUM_CLIENTS)]
|
| 137 |
+
ray.get([client.ready.remote() for client in clients])
|
| 138 |
+
|
| 139 |
+
async def many_clients():
|
| 140 |
+
ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])
|
| 141 |
+
|
| 142 |
+
multi_client_avg_tps, _ = await run_throughput_benchmark(
|
| 143 |
+
many_clients,
|
| 144 |
+
multiplier=CALLS_PER_BATCH * len(clients),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
results[f"num_client:{len(clients)}/{trial_key_base}"] = multi_client_avg_tps
|
| 148 |
+
return results
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
async def main():
|
| 152 |
+
results = {}
|
| 153 |
+
for intermediate_handles in [False, True]:
|
| 154 |
+
for num_replicas in [1, 8]:
|
| 155 |
+
for max_batch_size, max_ongoing_requests in [
|
| 156 |
+
(1, 1),
|
| 157 |
+
(1, 10000),
|
| 158 |
+
(10000, 10000),
|
| 159 |
+
]:
|
| 160 |
+
# TODO(edoakes): large data causes broken pipe errors.
|
| 161 |
+
for data_size in ["small"]:
|
| 162 |
+
results.update(
|
| 163 |
+
await trial(
|
| 164 |
+
intermediate_handles,
|
| 165 |
+
num_replicas,
|
| 166 |
+
max_batch_size,
|
| 167 |
+
max_ongoing_requests,
|
| 168 |
+
data_size,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print("Results from all conditions:")
|
| 173 |
+
pprint(results)
|
| 174 |
+
return results
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
ray.init()
|
| 179 |
+
serve.start()
|
| 180 |
+
loop = asyncio.new_event_loop()
|
| 181 |
+
asyncio.set_event_loop(loop)
|
| 182 |
+
loop.run_until_complete(main())
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/proxy_benchmark.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Runs some request ping to compare HTTP and gRPC performances in TPS and latency.
|
| 2 |
+
# Note: this takes around 1 hour to run.
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from random import random
|
| 9 |
+
from typing import Callable, Dict
|
| 10 |
+
|
| 11 |
+
import aiohttp
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from grpc import aio
|
| 15 |
+
from starlette.requests import Request
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
from ray import serve
|
| 19 |
+
from ray.serve._private.common import RequestProtocol
|
| 20 |
+
from ray.serve.config import gRPCOptions
|
| 21 |
+
from ray.serve.generated import serve_pb2, serve_pb2_grpc
|
| 22 |
+
from ray.serve.handle import DeploymentHandle
|
| 23 |
+
|
| 24 |
+
CALLS_PER_BATCH = 100
|
| 25 |
+
DELTA = 10**-7
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def get_query_tps(name: str, fn: Callable, multiplier: int = CALLS_PER_BATCH):
|
| 29 |
+
"""Get query TPS.
|
| 30 |
+
|
| 31 |
+
Run the function for 0.5 seconds 10 times to calculate how many requests can
|
| 32 |
+
be completed. And use those stats to calculate the mean and std of TPS.
|
| 33 |
+
"""
|
| 34 |
+
# warmup
|
| 35 |
+
start = time.time()
|
| 36 |
+
while time.time() - start < 0.1:
|
| 37 |
+
await fn()
|
| 38 |
+
# real run
|
| 39 |
+
stats = []
|
| 40 |
+
for _ in range(10):
|
| 41 |
+
count = 0
|
| 42 |
+
start = time.time()
|
| 43 |
+
while time.time() - start < 0.5:
|
| 44 |
+
await fn()
|
| 45 |
+
count += 1
|
| 46 |
+
end = time.time()
|
| 47 |
+
stats.append(multiplier * count / (end - start))
|
| 48 |
+
tps_mean = round(np.mean(stats), 2)
|
| 49 |
+
tps_std = round(np.std(stats), 2)
|
| 50 |
+
print(f"\t{name} {tps_mean} +- {tps_std} requests/s")
|
| 51 |
+
|
| 52 |
+
return tps_mean, tps_std
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
async def get_query_latencies(name: str, fn: Callable):
|
| 56 |
+
"""Get query latencies.
|
| 57 |
+
|
| 58 |
+
Take all the latencies from the function and calculate the mean and std.
|
| 59 |
+
"""
|
| 60 |
+
many_client_results = np.asarray(await fn())
|
| 61 |
+
many_client_results.flatten()
|
| 62 |
+
latency_ms_mean = round(np.mean(many_client_results) * 1000, 2)
|
| 63 |
+
latency_ms_std = round(np.std(many_client_results) * 1000, 2)
|
| 64 |
+
print(f"\t{name} {latency_ms_mean} +- {latency_ms_std} ms")
|
| 65 |
+
|
| 66 |
+
return latency_ms_mean, latency_ms_std
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def fetch_http(session, data):
|
| 70 |
+
data_json = {"nums": data}
|
| 71 |
+
response = await session.get("http://localhost:8000/", json=data_json)
|
| 72 |
+
response_text = await response.read()
|
| 73 |
+
float(response_text.decode())
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def fetch_grpc(stub, data):
|
| 77 |
+
result = await stub.grpc_call(serve_pb2.RawData(nums=data))
|
| 78 |
+
result.output
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@ray.remote
|
| 82 |
+
class HTTPClient:
|
| 83 |
+
def ready(self):
|
| 84 |
+
return "ok"
|
| 85 |
+
|
| 86 |
+
async def do_queries(self, num, data):
|
| 87 |
+
async with aiohttp.ClientSession() as session:
|
| 88 |
+
for _ in range(num):
|
| 89 |
+
await fetch_http(session, data)
|
| 90 |
+
|
| 91 |
+
async def time_queries(self, num, data):
|
| 92 |
+
stats = []
|
| 93 |
+
async with aiohttp.ClientSession() as session:
|
| 94 |
+
for _ in range(num):
|
| 95 |
+
start = time.time()
|
| 96 |
+
await fetch_http(session, data)
|
| 97 |
+
end = time.time()
|
| 98 |
+
stats.append(end - start)
|
| 99 |
+
|
| 100 |
+
return stats
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@ray.remote
|
| 104 |
+
class gRPCClient:
|
| 105 |
+
def __init__(self):
|
| 106 |
+
channel = aio.insecure_channel("localhost:9000")
|
| 107 |
+
self.stub = serve_pb2_grpc.RayServeBenchmarkServiceStub(channel)
|
| 108 |
+
|
| 109 |
+
def ready(self):
|
| 110 |
+
return "ok"
|
| 111 |
+
|
| 112 |
+
async def do_queries(self, num, data):
|
| 113 |
+
for _ in range(num):
|
| 114 |
+
await fetch_grpc(self.stub, data)
|
| 115 |
+
|
| 116 |
+
async def time_queries(self, num, data):
|
| 117 |
+
stats = []
|
| 118 |
+
for _ in range(num):
|
| 119 |
+
start = time.time()
|
| 120 |
+
await fetch_grpc(self.stub, data)
|
| 121 |
+
end = time.time()
|
| 122 |
+
stats.append(end - start)
|
| 123 |
+
return stats
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_app(
|
| 127 |
+
num_replicas: int,
|
| 128 |
+
max_ongoing_requests: int,
|
| 129 |
+
data_size: int,
|
| 130 |
+
):
|
| 131 |
+
@serve.deployment(max_ongoing_requests=1000)
|
| 132 |
+
class DataPreprocessing:
|
| 133 |
+
def __init__(self, handle: DeploymentHandle):
|
| 134 |
+
self._handle = handle
|
| 135 |
+
|
| 136 |
+
# Turn off access log.
|
| 137 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 138 |
+
|
| 139 |
+
def normalize(self, raw: np.ndarray) -> np.ndarray:
|
| 140 |
+
return (raw - np.min(raw)) / (np.max(raw) - np.min(raw) + DELTA)
|
| 141 |
+
|
| 142 |
+
async def __call__(self, req: Request):
|
| 143 |
+
"""HTTP entrypoint.
|
| 144 |
+
|
| 145 |
+
It parses the request, normalize the data, and send to model for inference.
|
| 146 |
+
"""
|
| 147 |
+
body = json.loads(await req.body())
|
| 148 |
+
raw = np.asarray(body["nums"])
|
| 149 |
+
processed = self.normalize(raw)
|
| 150 |
+
return await self._handle.remote(processed)
|
| 151 |
+
|
| 152 |
+
async def grpc_call(self, raq_data):
|
| 153 |
+
"""gRPC entrypoint.
|
| 154 |
+
|
| 155 |
+
It parses the request, normalize the data, and send to model for inference.
|
| 156 |
+
"""
|
| 157 |
+
raw = np.asarray(raq_data.nums)
|
| 158 |
+
processed = self.normalize(raw)
|
| 159 |
+
output = await self._handle.remote(processed)
|
| 160 |
+
return serve_pb2.ModelOutput(output=output)
|
| 161 |
+
|
| 162 |
+
async def call_with_string(self, raq_data):
|
| 163 |
+
"""gRPC entrypoint."""
|
| 164 |
+
return serve_pb2.ModelOutput(output=0)
|
| 165 |
+
|
| 166 |
+
@serve.deployment(
|
| 167 |
+
num_replicas=num_replicas,
|
| 168 |
+
max_ongoing_requests=max_ongoing_requests,
|
| 169 |
+
)
|
| 170 |
+
class ModelInference:
|
| 171 |
+
def __init__(self):
|
| 172 |
+
# Turn off access log.
|
| 173 |
+
logging.getLogger("ray.serve").setLevel(logging.WARNING)
|
| 174 |
+
self._model = np.random.randn(data_size, data_size)
|
| 175 |
+
|
| 176 |
+
async def __call__(self, processed: np.ndarray) -> float:
|
| 177 |
+
# Run a dot product with a random matrix to simulate a model inference.
|
| 178 |
+
model_output = np.dot(processed, self._model)
|
| 179 |
+
return sum(model_output)
|
| 180 |
+
|
| 181 |
+
return DataPreprocessing.bind(ModelInference.bind())
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def trial(
|
| 185 |
+
num_replicas: int,
|
| 186 |
+
max_ongoing_requests: int,
|
| 187 |
+
data_size: int,
|
| 188 |
+
num_clients: int,
|
| 189 |
+
proxy: RequestProtocol,
|
| 190 |
+
) -> Dict[str, float]:
|
| 191 |
+
# Generate input data as array of random floats.
|
| 192 |
+
data = [random() for _ in range(data_size)]
|
| 193 |
+
|
| 194 |
+
# Build and deploy the app.
|
| 195 |
+
app = build_app(
|
| 196 |
+
num_replicas=num_replicas,
|
| 197 |
+
max_ongoing_requests=max_ongoing_requests,
|
| 198 |
+
data_size=data_size,
|
| 199 |
+
)
|
| 200 |
+
serve.run(app)
|
| 201 |
+
|
| 202 |
+
# Start clients.
|
| 203 |
+
if proxy == RequestProtocol.GRPC:
|
| 204 |
+
clients = [gRPCClient.remote() for _ in range(num_clients)]
|
| 205 |
+
elif proxy == RequestProtocol.HTTP:
|
| 206 |
+
clients = [HTTPClient.remote() for _ in range(num_clients)]
|
| 207 |
+
ray.get([client.ready.remote() for client in clients])
|
| 208 |
+
|
| 209 |
+
async def client_time_queries():
|
| 210 |
+
return ray.get([a.time_queries.remote(CALLS_PER_BATCH, data) for a in clients])
|
| 211 |
+
|
| 212 |
+
async def client_do_queries():
|
| 213 |
+
ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])
|
| 214 |
+
|
| 215 |
+
trial_key_base = (
|
| 216 |
+
f"proxy:{proxy}/"
|
| 217 |
+
f"num_client:{num_clients}/"
|
| 218 |
+
f"replica:{num_replicas}/"
|
| 219 |
+
f"concurrent_queries:{max_ongoing_requests}/"
|
| 220 |
+
f"data_size:{data_size}"
|
| 221 |
+
)
|
| 222 |
+
tps_mean, tps_sdt = await get_query_tps(
|
| 223 |
+
trial_key_base,
|
| 224 |
+
client_do_queries,
|
| 225 |
+
)
|
| 226 |
+
latency_ms_mean, latency_ms_std = await get_query_latencies(
|
| 227 |
+
trial_key_base,
|
| 228 |
+
client_time_queries,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
results = {
|
| 232 |
+
"proxy": proxy.value,
|
| 233 |
+
"num_client": num_clients,
|
| 234 |
+
"replica": num_replicas,
|
| 235 |
+
"concurrent_queries": max_ongoing_requests,
|
| 236 |
+
"data_size": data_size,
|
| 237 |
+
"tps_mean": tps_mean,
|
| 238 |
+
"tps_sdt": tps_sdt,
|
| 239 |
+
"latency_ms_mean": latency_ms_mean,
|
| 240 |
+
"latency_ms_std": latency_ms_std,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
return results
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
async def main():
|
| 247 |
+
start_time = time.time()
|
| 248 |
+
results = []
|
| 249 |
+
for num_replicas in [1, 8]:
|
| 250 |
+
for max_ongoing_requests in [1, 10_000]:
|
| 251 |
+
for data_size in [1, 100, 10_000]:
|
| 252 |
+
for num_clients in [1, 8]:
|
| 253 |
+
for proxy in [RequestProtocol.GRPC, RequestProtocol.HTTP]:
|
| 254 |
+
results.append(
|
| 255 |
+
await trial(
|
| 256 |
+
num_replicas=num_replicas,
|
| 257 |
+
max_ongoing_requests=max_ongoing_requests,
|
| 258 |
+
data_size=data_size,
|
| 259 |
+
num_clients=num_clients,
|
| 260 |
+
proxy=proxy,
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
print(f"Total time: {time.time() - start_time}s")
|
| 265 |
+
print("results", results)
|
| 266 |
+
|
| 267 |
+
df = pd.DataFrame.from_dict(results)
|
| 268 |
+
df = df.sort_values(
|
| 269 |
+
by=["proxy", "num_client", "replica", "concurrent_queries", "data_size"]
|
| 270 |
+
)
|
| 271 |
+
print("Results from all conditions:")
|
| 272 |
+
# Print the results in with tab separated so we can copy into google sheets.
|
| 273 |
+
for i in range(len(df.index)):
|
| 274 |
+
row = list(df.iloc[i])
|
| 275 |
+
print("\t".join(map(str, row)))
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
ray.init()
|
| 280 |
+
|
| 281 |
+
grpc_port = 9000
|
| 282 |
+
grpc_servicer_functions = [
|
| 283 |
+
"ray.serve.generated.serve_pb2_grpc."
|
| 284 |
+
"add_RayServeBenchmarkServiceServicer_to_server",
|
| 285 |
+
]
|
| 286 |
+
serve.start(
|
| 287 |
+
grpc_options=gRPCOptions(
|
| 288 |
+
port=grpc_port,
|
| 289 |
+
grpc_servicer_functions=grpc_servicer_functions,
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
+
loop = asyncio.new_event_loop()
|
| 293 |
+
asyncio.set_event_loop(loop)
|
| 294 |
+
loop.run_until_complete(main())
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/serialization_benchmark.cpython-311.pyc
ADDED
|
Binary file (7.37 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/common.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
#
|
| 7 |
+
# NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
|
| 8 |
+
#
|
| 9 |
+
# Payloads in this module are purposefully extracted from benchmark file to force
|
| 10 |
+
# Ray's cloudpickle behavior when it does NOT serialize the class definition itself
|
| 11 |
+
# along with its payload (instead relying on it being imported)
|
| 12 |
+
#
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PayloadPydantic(BaseModel):
|
| 16 |
+
text: Optional[str] = None
|
| 17 |
+
floats: Optional[List[float]] = None
|
| 18 |
+
ints: Optional[List[int]] = None
|
| 19 |
+
ts: Optional[float] = None
|
| 20 |
+
reason: Optional[str] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class PayloadDataclass:
|
| 25 |
+
text: Optional[str] = None
|
| 26 |
+
floats: Optional[List[float]] = None
|
| 27 |
+
ints: Optional[List[int]] = None
|
| 28 |
+
ts: Optional[float] = None
|
| 29 |
+
reason: Optional[str] = None
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/serialization_benchmark.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import enum
|
| 3 |
+
import pickle
|
| 4 |
+
import time
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
import click
|
| 8 |
+
import msgpack
|
| 9 |
+
|
| 10 |
+
from ray._private.serialization import SerializationContext
|
| 11 |
+
from ray.cloudpickle import cloudpickle_fast
|
| 12 |
+
from ray.serve._private.benchmarks.common import (
|
| 13 |
+
collect_profile_events,
|
| 14 |
+
run_latency_benchmark,
|
| 15 |
+
)
|
| 16 |
+
from ray.serve._private.benchmarks.serialization.common import (
|
| 17 |
+
PayloadDataclass,
|
| 18 |
+
PayloadPydantic,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PayloadType(enum.Enum):
|
| 23 |
+
PYDANTIC = "pydantic"
|
| 24 |
+
DATACLASS = "dataclass"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SerializerType(enum.Enum):
|
| 28 |
+
RAY = "ray"
|
| 29 |
+
PICKLE = "pickle"
|
| 30 |
+
CLOUDPICKLE = "cloudpickle"
|
| 31 |
+
MSGPACK = "msgpack"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_PERCENTILES = [0.5, 0.99]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
sc = SerializationContext(None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _create_model(cls):
|
| 41 |
+
return cls(
|
| 42 |
+
text="Test output",
|
| 43 |
+
floats=[float(f) for f in range(1, 100)],
|
| 44 |
+
ints=list(range(1, 100)),
|
| 45 |
+
ts=time.time(),
|
| 46 |
+
reason="Success!",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _blackhole(o):
|
| 51 |
+
"""Placeholder to be used in the benchmark to make sure runtime
|
| 52 |
+
doesn't optimize out unused results"""
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
async def run_serializer_benchmark(
|
| 57 |
+
model, serializer: Callable[[Any], bytes], iterations: int
|
| 58 |
+
):
|
| 59 |
+
def _serde_loop():
|
| 60 |
+
bs = serializer(model)
|
| 61 |
+
_blackhole(bs)
|
| 62 |
+
|
| 63 |
+
pd = await run_latency_benchmark(_serde_loop, iterations)
|
| 64 |
+
|
| 65 |
+
print("Latencies (ms):\n", pd.describe(percentiles=_PERCENTILES))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@click.command(help="Benchmark serialization latency")
|
| 69 |
+
@click.option(
|
| 70 |
+
"--trials",
|
| 71 |
+
type=int,
|
| 72 |
+
default=1000,
|
| 73 |
+
help="Total number of trials to run in a single benchmark run",
|
| 74 |
+
)
|
| 75 |
+
@click.option(
|
| 76 |
+
"--batch-size",
|
| 77 |
+
type=int,
|
| 78 |
+
default=10,
|
| 79 |
+
help="Controls how many objects are contained in a serialized batch",
|
| 80 |
+
)
|
| 81 |
+
@click.option(
|
| 82 |
+
"--payload-type",
|
| 83 |
+
type=PayloadType,
|
| 84 |
+
help="Target type of the payload to be benchmarked (supported: pydantic, "
|
| 85 |
+
"dataclass)",
|
| 86 |
+
)
|
| 87 |
+
@click.option(
|
| 88 |
+
"--serializer",
|
| 89 |
+
type=SerializerType,
|
| 90 |
+
help="Target type of the serializer to be benchmarked (supported: ray, pickle, "
|
| 91 |
+
"cloudpickle, msgpack)",
|
| 92 |
+
)
|
| 93 |
+
@click.option(
|
| 94 |
+
"--profile-events",
|
| 95 |
+
type=bool,
|
| 96 |
+
default=False,
|
| 97 |
+
)
|
| 98 |
+
def main(
|
| 99 |
+
trials: int,
|
| 100 |
+
batch_size: int,
|
| 101 |
+
payload_type: PayloadType,
|
| 102 |
+
serializer: SerializerType,
|
| 103 |
+
profile_events: bool,
|
| 104 |
+
):
|
| 105 |
+
if serializer == SerializerType.RAY:
|
| 106 |
+
|
| 107 |
+
def _serialize(obj):
|
| 108 |
+
so = sc.serialize(obj)
|
| 109 |
+
bs = so.to_bytes()
|
| 110 |
+
return bs
|
| 111 |
+
|
| 112 |
+
elif serializer == SerializerType.CLOUDPICKLE:
|
| 113 |
+
|
| 114 |
+
def _serialize(obj):
|
| 115 |
+
bs = cloudpickle_fast.dumps(obj)
|
| 116 |
+
return bs
|
| 117 |
+
|
| 118 |
+
elif serializer == SerializerType.PICKLE:
|
| 119 |
+
|
| 120 |
+
def _serialize(obj):
|
| 121 |
+
bs = pickle.dumps(obj)
|
| 122 |
+
return bs
|
| 123 |
+
|
| 124 |
+
elif serializer == SerializerType.MSGPACK:
|
| 125 |
+
|
| 126 |
+
def _dumps(obj):
|
| 127 |
+
bs = msgpack.dumps(obj.__dict__)
|
| 128 |
+
# print(f"Bytes ({len(bs)}): ", bs)
|
| 129 |
+
return bs
|
| 130 |
+
|
| 131 |
+
def _loads(bs):
|
| 132 |
+
dict = msgpack.loads(bs)
|
| 133 |
+
return PayloadPydantic(**dict)
|
| 134 |
+
|
| 135 |
+
sc._register_cloudpickle_serializer(PayloadPydantic, _dumps, _loads)
|
| 136 |
+
|
| 137 |
+
def _serialize(obj):
|
| 138 |
+
so = sc.serialize(obj)
|
| 139 |
+
bs = so.to_bytes()
|
| 140 |
+
return bs
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
raise NotImplementedError(serializer)
|
| 144 |
+
|
| 145 |
+
if payload_type == PayloadType.PYDANTIC:
|
| 146 |
+
model = _create_model(PayloadPydantic)
|
| 147 |
+
elif payload_type == PayloadType.DATACLASS:
|
| 148 |
+
model = _create_model(PayloadDataclass)
|
| 149 |
+
else:
|
| 150 |
+
raise NotImplementedError(f"Not supported ({payload_type})")
|
| 151 |
+
|
| 152 |
+
payload = [model.copy(deep=True) for _ in range(batch_size)]
|
| 153 |
+
|
| 154 |
+
routine = run_serializer_benchmark(payload, _serialize, trials)
|
| 155 |
+
|
| 156 |
+
if profile_events:
|
| 157 |
+
routine = collect_profile_events(routine)
|
| 158 |
+
|
| 159 |
+
asyncio.run(routine)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (7.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_core_throughput.cpython-311.pyc
ADDED
|
Binary file (4.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_grpc_throughput.cpython-311.pyc
ADDED
|
Binary file (9.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_handle_throughput.cpython-311.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_http_throughput.cpython-311.pyc
ADDED
|
Binary file (7.7 kB). View file
|
|
|