Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/core/libjemalloc.so +3 -0
- .venv/lib/python3.11/site-packages/ray/dag/__init__.py +46 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/base.py +8 -0
- .venv/lib/python3.11/site-packages/ray/dag/class_node.py +321 -0
- .venv/lib/python3.11/site-packages/ray/dag/collective_node.py +191 -0
- .venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/conftest.py +16 -0
- .venv/lib/python3.11/site-packages/ray/dag/constants.py +33 -0
- .venv/lib/python3.11/site-packages/ray/dag/context.py +101 -0
- .venv/lib/python3.11/site-packages/ray/dag/dag_node.py +622 -0
- .venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py +789 -0
- .venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py +95 -0
- .venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dag/format_utils.py +155 -0
- .venv/lib/python3.11/site-packages/ray/dag/function_node.py +60 -0
- .venv/lib/python3.11/site-packages/ray/dag/input_node.py +321 -0
- .venv/lib/python3.11/site-packages/ray/dag/output_node.py +45 -0
- .venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py +105 -0
- .venv/lib/python3.11/site-packages/ray/dag/utils.py +66 -0
- .venv/lib/python3.11/site-packages/ray/dag/vis_utils.py +115 -0
- .venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py +39 -0
.gitattributes
CHANGED
|
@@ -158,3 +158,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 158 |
.venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 159 |
.venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 160 |
.venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 158 |
.venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 159 |
.venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 160 |
.venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text
|
| 161 |
+
.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (8.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc
ADDED
|
Binary file (378 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc
ADDED
|
Binary file (70.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc
ADDED
|
Binary file (5.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc
ADDED
|
Binary file (41.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (636 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0284919db23f95e692026039838aef89b5964b5cfec4a88acb9b3a9f4a226fd5
|
| 3 |
+
size 885296
|
.venv/lib/python3.11/site-packages/ray/dag/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.dag.dag_node import DAGNode
|
| 2 |
+
from ray.dag.function_node import FunctionNode
|
| 3 |
+
from ray.dag.class_node import (
|
| 4 |
+
ClassNode,
|
| 5 |
+
ClassMethodNode,
|
| 6 |
+
)
|
| 7 |
+
from ray.dag.collective_node import CollectiveOutputNode
|
| 8 |
+
from ray.dag.input_node import (
|
| 9 |
+
InputNode,
|
| 10 |
+
InputAttributeNode,
|
| 11 |
+
DAGInputData,
|
| 12 |
+
)
|
| 13 |
+
from ray.dag.output_node import MultiOutputNode
|
| 14 |
+
from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture
|
| 15 |
+
from ray.dag.constants import (
|
| 16 |
+
PARENT_CLASS_NODE_KEY,
|
| 17 |
+
PREV_CLASS_METHOD_CALL_KEY,
|
| 18 |
+
BIND_INDEX_KEY,
|
| 19 |
+
IS_CLASS_METHOD_OUTPUT_KEY,
|
| 20 |
+
COLLECTIVE_OPERATION_KEY,
|
| 21 |
+
DAGNODE_TYPE_KEY,
|
| 22 |
+
)
|
| 23 |
+
from ray.dag.vis_utils import plot
|
| 24 |
+
from ray.dag.context import DAGContext
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ClassNode",
|
| 28 |
+
"ClassMethodNode",
|
| 29 |
+
"CollectiveOutputNode",
|
| 30 |
+
"DAGNode",
|
| 31 |
+
"DAGOperationFuture",
|
| 32 |
+
"FunctionNode",
|
| 33 |
+
"GPUFuture",
|
| 34 |
+
"InputNode",
|
| 35 |
+
"InputAttributeNode",
|
| 36 |
+
"DAGInputData",
|
| 37 |
+
"PARENT_CLASS_NODE_KEY",
|
| 38 |
+
"PREV_CLASS_METHOD_CALL_KEY",
|
| 39 |
+
"BIND_INDEX_KEY",
|
| 40 |
+
"IS_CLASS_METHOD_OUTPUT_KEY",
|
| 41 |
+
"COLLECTIVE_OPERATION_KEY",
|
| 42 |
+
"DAGNODE_TYPE_KEY",
|
| 43 |
+
"plot",
|
| 44 |
+
"MultiOutputNode",
|
| 45 |
+
"DAGContext",
|
| 46 |
+
]
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (683 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc
ADDED
|
Binary file (9.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc
ADDED
|
Binary file (794 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (1.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc
ADDED
|
Binary file (26.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc
ADDED
|
Binary file (34.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc
ADDED
|
Binary file (7.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc
ADDED
|
Binary file (5.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc
ADDED
|
Binary file (4.73 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/base.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module defines the base class for object scanning and gets rid of
|
| 2 |
+
reference cycles."""
|
| 3 |
+
from ray.util.annotations import DeveloperAPI
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@DeveloperAPI
|
| 7 |
+
class DAGNodeBase:
|
| 8 |
+
"""Common base class for a node in a Ray task graph."""
|
.venv/lib/python3.11/site-packages/ray/dag/class_node.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from weakref import ReferenceType
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray.dag.dag_node import DAGNode
|
| 5 |
+
from ray.dag.input_node import InputNode
|
| 6 |
+
from ray.dag.format_utils import get_dag_node_str
|
| 7 |
+
from ray.dag.constants import (
|
| 8 |
+
PARENT_CLASS_NODE_KEY,
|
| 9 |
+
PREV_CLASS_METHOD_CALL_KEY,
|
| 10 |
+
BIND_INDEX_KEY,
|
| 11 |
+
IS_CLASS_METHOD_OUTPUT_KEY,
|
| 12 |
+
)
|
| 13 |
+
from ray.util.annotations import DeveloperAPI
|
| 14 |
+
|
| 15 |
+
from typing import Any, Dict, List, Union, Tuple, Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@DeveloperAPI
|
| 19 |
+
class ClassNode(DAGNode):
|
| 20 |
+
"""Represents an actor creation in a Ray task DAG."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
cls,
|
| 25 |
+
cls_args,
|
| 26 |
+
cls_kwargs,
|
| 27 |
+
cls_options,
|
| 28 |
+
other_args_to_resolve=None,
|
| 29 |
+
):
|
| 30 |
+
self._body = cls
|
| 31 |
+
self._last_call: Optional["ClassMethodNode"] = None
|
| 32 |
+
super().__init__(
|
| 33 |
+
cls_args,
|
| 34 |
+
cls_kwargs,
|
| 35 |
+
cls_options,
|
| 36 |
+
other_args_to_resolve=other_args_to_resolve,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if self._contains_input_node():
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"InputNode handles user dynamic input the the DAG, and "
|
| 42 |
+
"cannot be used as args, kwargs, or other_args_to_resolve "
|
| 43 |
+
"in ClassNode constructor because it is not available at "
|
| 44 |
+
"class construction or binding time."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def _copy_impl(
|
| 48 |
+
self,
|
| 49 |
+
new_args: List[Any],
|
| 50 |
+
new_kwargs: Dict[str, Any],
|
| 51 |
+
new_options: Dict[str, Any],
|
| 52 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 53 |
+
):
|
| 54 |
+
return ClassNode(
|
| 55 |
+
self._body,
|
| 56 |
+
new_args,
|
| 57 |
+
new_kwargs,
|
| 58 |
+
new_options,
|
| 59 |
+
other_args_to_resolve=new_other_args_to_resolve,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def _execute_impl(self, *args, **kwargs):
|
| 63 |
+
"""Executor of ClassNode by ray.remote()
|
| 64 |
+
|
| 65 |
+
Args and kwargs are to match base class signature, but not in the
|
| 66 |
+
implementation. All args and kwargs should be resolved and replaced
|
| 67 |
+
with value in bound_args and bound_kwargs via bottom-up recursion when
|
| 68 |
+
current node is executed.
|
| 69 |
+
"""
|
| 70 |
+
return (
|
| 71 |
+
ray.remote(self._body)
|
| 72 |
+
.options(**self._bound_options)
|
| 73 |
+
.remote(*self._bound_args, **self._bound_kwargs)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _contains_input_node(self) -> bool:
|
| 77 |
+
"""Check if InputNode is used in children DAGNodes with current node
|
| 78 |
+
as the root.
|
| 79 |
+
"""
|
| 80 |
+
children_dag_nodes = self._get_all_child_nodes()
|
| 81 |
+
for child in children_dag_nodes:
|
| 82 |
+
if isinstance(child, InputNode):
|
| 83 |
+
return True
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
def __getattr__(self, method_name: str):
|
| 87 |
+
# User trying to call .bind() without a bind class method
|
| 88 |
+
if method_name == "bind" and "bind" not in dir(self._body):
|
| 89 |
+
raise AttributeError(f".bind() cannot be used again on {type(self)} ")
|
| 90 |
+
# Raise an error if the method is invalid.
|
| 91 |
+
getattr(self._body, method_name)
|
| 92 |
+
call_node = _UnboundClassMethodNode(self, method_name, {})
|
| 93 |
+
return call_node
|
| 94 |
+
|
| 95 |
+
def __str__(self) -> str:
|
| 96 |
+
return get_dag_node_str(self, str(self._body))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class _UnboundClassMethodNode(object):
|
| 100 |
+
def __init__(self, actor: ClassNode, method_name: str, options: dict):
|
| 101 |
+
# TODO(sang): Theoretically, We should use weakref cuz it is
|
| 102 |
+
# a circular dependency but when I used weakref, it fails
|
| 103 |
+
# because we cannot serialize the weakref.
|
| 104 |
+
self._actor = actor
|
| 105 |
+
self._method_name = method_name
|
| 106 |
+
self._options = options
|
| 107 |
+
|
| 108 |
+
def bind(self, *args, **kwargs):
|
| 109 |
+
other_args_to_resolve = {
|
| 110 |
+
PARENT_CLASS_NODE_KEY: self._actor,
|
| 111 |
+
PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
node = ClassMethodNode(
|
| 115 |
+
self._method_name,
|
| 116 |
+
args,
|
| 117 |
+
kwargs,
|
| 118 |
+
self._options,
|
| 119 |
+
other_args_to_resolve=other_args_to_resolve,
|
| 120 |
+
)
|
| 121 |
+
self._actor._last_call = node
|
| 122 |
+
return node
|
| 123 |
+
|
| 124 |
+
def __getattr__(self, attr: str):
|
| 125 |
+
if attr == "remote":
|
| 126 |
+
raise AttributeError(
|
| 127 |
+
".remote() cannot be used on ClassMethodNodes. Use .bind() instead "
|
| 128 |
+
"to express an symbolic actor call."
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
return self.__getattribute__(attr)
|
| 132 |
+
|
| 133 |
+
def options(self, **options):
|
| 134 |
+
self._options = options
|
| 135 |
+
return self
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class _ClassMethodOutput:
|
| 139 |
+
"""Represents a class method output in a Ray function DAG."""
|
| 140 |
+
|
| 141 |
+
def __init__(self, class_method_call: "ClassMethodNode", output_idx: int):
|
| 142 |
+
# The upstream class method call that returns multiple values.
|
| 143 |
+
self._class_method_call = class_method_call
|
| 144 |
+
# The output index of the return value from the upstream class method call.
|
| 145 |
+
self._output_idx = output_idx
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def class_method_call(self) -> "ClassMethodNode":
|
| 149 |
+
return self._class_method_call
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def output_idx(self) -> int:
|
| 153 |
+
return self._output_idx
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@DeveloperAPI
|
| 157 |
+
class ClassMethodNode(DAGNode):
|
| 158 |
+
"""Represents an actor method invocation in a Ray function DAG."""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
method_name: str,
|
| 163 |
+
method_args: Tuple[Any],
|
| 164 |
+
method_kwargs: Dict[str, Any],
|
| 165 |
+
method_options: Dict[str, Any],
|
| 166 |
+
other_args_to_resolve: Dict[str, Any],
|
| 167 |
+
):
|
| 168 |
+
self._bound_args = method_args or []
|
| 169 |
+
self._bound_kwargs = method_kwargs or {}
|
| 170 |
+
self._bound_options = method_options or {}
|
| 171 |
+
self._method_name: str = method_name
|
| 172 |
+
# Parse other_args_to_resolve and assign to variables
|
| 173 |
+
self._parent_class_node: Union[
|
| 174 |
+
ClassNode, ReferenceType["ray._private.actor.ActorHandle"]
|
| 175 |
+
] = other_args_to_resolve.get(PARENT_CLASS_NODE_KEY)
|
| 176 |
+
# Used to track lineage of ClassMethodCall to preserve deterministic
|
| 177 |
+
# submission and execution order.
|
| 178 |
+
self._prev_class_method_call: Optional[
|
| 179 |
+
ClassMethodNode
|
| 180 |
+
] = other_args_to_resolve.get(PREV_CLASS_METHOD_CALL_KEY, None)
|
| 181 |
+
# The index/order when bind() is called on this class method
|
| 182 |
+
self._bind_index: Optional[int] = other_args_to_resolve.get(
|
| 183 |
+
BIND_INDEX_KEY, None
|
| 184 |
+
)
|
| 185 |
+
# Represent if the ClassMethodNode is a class method output. If True,
|
| 186 |
+
# the node is a placeholder for a return value from the ClassMethodNode
|
| 187 |
+
# that returns multiple values. If False, the node is a class method call.
|
| 188 |
+
self._is_class_method_output: bool = other_args_to_resolve.get(
|
| 189 |
+
IS_CLASS_METHOD_OUTPUT_KEY, False
|
| 190 |
+
)
|
| 191 |
+
# Represents the return value from the upstream ClassMethodNode that
|
| 192 |
+
# returns multiple values. If the node is a class method call, this is None.
|
| 193 |
+
self._class_method_output: Optional[_ClassMethodOutput] = None
|
| 194 |
+
if self._is_class_method_output:
|
| 195 |
+
# Set the upstream ClassMethodNode and the output index of the return
|
| 196 |
+
# value from `method_args`.
|
| 197 |
+
self._class_method_output = _ClassMethodOutput(
|
| 198 |
+
method_args[0], method_args[1]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# The actor creation task dependency is encoded as the first argument,
|
| 202 |
+
# and the ordering dependency as the second, which ensures they are
|
| 203 |
+
# executed prior to this node.
|
| 204 |
+
super().__init__(
|
| 205 |
+
method_args,
|
| 206 |
+
method_kwargs,
|
| 207 |
+
method_options,
|
| 208 |
+
other_args_to_resolve=other_args_to_resolve,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def _copy_impl(
|
| 212 |
+
self,
|
| 213 |
+
new_args: List[Any],
|
| 214 |
+
new_kwargs: Dict[str, Any],
|
| 215 |
+
new_options: Dict[str, Any],
|
| 216 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 217 |
+
):
|
| 218 |
+
return ClassMethodNode(
|
| 219 |
+
self._method_name,
|
| 220 |
+
new_args,
|
| 221 |
+
new_kwargs,
|
| 222 |
+
new_options,
|
| 223 |
+
other_args_to_resolve=new_other_args_to_resolve,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def _execute_impl(self, *args, **kwargs):
|
| 227 |
+
"""Executor of ClassMethodNode by ray.remote()
|
| 228 |
+
|
| 229 |
+
Args and kwargs are to match base class signature, but not in the
|
| 230 |
+
implementation. All args and kwargs should be resolved and replaced
|
| 231 |
+
with value in bound_args and bound_kwargs via bottom-up recursion when
|
| 232 |
+
current node is executed.
|
| 233 |
+
"""
|
| 234 |
+
if self.is_class_method_call:
|
| 235 |
+
method_body = getattr(self._parent_class_node, self._method_name)
|
| 236 |
+
# Execute with bound args.
|
| 237 |
+
return method_body.options(**self._bound_options).remote(
|
| 238 |
+
*self._bound_args,
|
| 239 |
+
**self._bound_kwargs,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
assert self._class_method_output is not None
|
| 243 |
+
return self._bound_args[0][self._class_method_output.output_idx]
|
| 244 |
+
|
| 245 |
+
def __str__(self) -> str:
|
| 246 |
+
return get_dag_node_str(self, f"{self._method_name}()")
|
| 247 |
+
|
| 248 |
+
def __repr__(self) -> str:
|
| 249 |
+
return self.__str__()
|
| 250 |
+
|
| 251 |
+
def get_method_name(self) -> str:
|
| 252 |
+
return self._method_name
|
| 253 |
+
|
| 254 |
+
def _get_bind_index(self) -> int:
|
| 255 |
+
return self._bind_index
|
| 256 |
+
|
| 257 |
+
def _get_remote_method(self, method_name):
|
| 258 |
+
method_body = getattr(self._parent_class_node, method_name)
|
| 259 |
+
return method_body
|
| 260 |
+
|
| 261 |
+
def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]:
|
| 262 |
+
if not isinstance(self._parent_class_node, ray.actor.ActorHandle):
|
| 263 |
+
return None
|
| 264 |
+
return self._parent_class_node
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def num_returns(self) -> int:
|
| 268 |
+
"""
|
| 269 |
+
Return the number of return values from the class method call. If the
|
| 270 |
+
node is a class method output, return the number of return values from
|
| 271 |
+
the upstream class method call.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
if self.is_class_method_call:
|
| 275 |
+
num_returns = self._bound_options.get("num_returns", None)
|
| 276 |
+
if num_returns is None:
|
| 277 |
+
method = self._get_remote_method(self._method_name)
|
| 278 |
+
num_returns = method.__getstate__()["num_returns"]
|
| 279 |
+
return num_returns
|
| 280 |
+
else:
|
| 281 |
+
assert self._class_method_output is not None
|
| 282 |
+
return self._class_method_output.class_method_call.num_returns
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def is_class_method_call(self) -> bool:
|
| 286 |
+
"""
|
| 287 |
+
Return True if the node is a class method call, False if the node is a
|
| 288 |
+
class method output.
|
| 289 |
+
"""
|
| 290 |
+
return not self._is_class_method_output
|
| 291 |
+
|
| 292 |
+
@property
|
| 293 |
+
def is_class_method_output(self) -> bool:
|
| 294 |
+
"""
|
| 295 |
+
Return True if the node is a class method output, False if the node is a
|
| 296 |
+
class method call.
|
| 297 |
+
"""
|
| 298 |
+
return self._is_class_method_output
|
| 299 |
+
|
| 300 |
+
@property
|
| 301 |
+
def class_method_call(self) -> Optional["ClassMethodNode"]:
|
| 302 |
+
"""
|
| 303 |
+
Return the upstream class method call that returns multiple values. If
|
| 304 |
+
the node is a class method output, return None.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
if self._class_method_output is None:
|
| 308 |
+
return None
|
| 309 |
+
return self._class_method_output.class_method_call
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def output_idx(self) -> Optional[int]:
|
| 313 |
+
"""
|
| 314 |
+
Return the output index of the return value from the upstream class
|
| 315 |
+
method call that returns multiple values. If the node is a class method
|
| 316 |
+
call, return None.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
if self._class_method_output is None:
|
| 320 |
+
return None
|
| 321 |
+
return self._class_method_output.output_idx
|
.venv/lib/python3.11/site-packages/ray/dag/collective_node.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Union, Tuple, Optional, TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
if TYPE_CHECKING:
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.dag import (
|
| 8 |
+
DAGNode,
|
| 9 |
+
ClassMethodNode,
|
| 10 |
+
)
|
| 11 |
+
from ray.dag.constants import COLLECTIVE_OPERATION_KEY
|
| 12 |
+
from ray.experimental.channel import ChannelContext
|
| 13 |
+
from ray.experimental.channel.torch_tensor_nccl_channel import _init_communicator
|
| 14 |
+
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
|
| 15 |
+
from ray.experimental.util.types import _CollectiveOp, ReduceOp
|
| 16 |
+
from ray.util.annotations import DeveloperAPI
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _CollectiveOperation:
|
| 20 |
+
"""
|
| 21 |
+
Represent metadata for a NCCL collective operation.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
input_nodes: A list of input nodes to the collective operation.
|
| 25 |
+
op: The collective operation to perform.
|
| 26 |
+
transport: The transport to use for the collective operation.
|
| 27 |
+
|
| 28 |
+
Requirements:
|
| 29 |
+
1. Input nodes are unique.
|
| 30 |
+
2. Actor handles are unique.
|
| 31 |
+
3. Actor handles match the custom NCCL group if specified.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
input_nodes: List[DAGNode],
|
| 37 |
+
op: _CollectiveOp,
|
| 38 |
+
transport: Optional[Union[str, Communicator]] = None,
|
| 39 |
+
):
|
| 40 |
+
if len(input_nodes) == 0:
|
| 41 |
+
raise ValueError("Expected input nodes for a collective operation")
|
| 42 |
+
if len(set(input_nodes)) != len(input_nodes):
|
| 43 |
+
raise ValueError("Expected unique input nodes for a collective operation")
|
| 44 |
+
|
| 45 |
+
self._actor_handles: List["ray.actor.ActorHandle"] = []
|
| 46 |
+
for input_node in input_nodes:
|
| 47 |
+
actor_handle = input_node._get_actor_handle()
|
| 48 |
+
if actor_handle is None:
|
| 49 |
+
raise ValueError("Expected an actor handle from the input node")
|
| 50 |
+
self._actor_handles.append(actor_handle)
|
| 51 |
+
if len(set(self._actor_handles)) != len(self._actor_handles):
|
| 52 |
+
invalid_input_nodes = [
|
| 53 |
+
input_node
|
| 54 |
+
for input_node in input_nodes
|
| 55 |
+
if self._actor_handles.count(input_node._get_actor_handle()) > 1
|
| 56 |
+
]
|
| 57 |
+
raise ValueError(
|
| 58 |
+
"Expected unique actor handles for a collective operation, "
|
| 59 |
+
"but found duplicate actor handles from input nodes: "
|
| 60 |
+
f"{invalid_input_nodes}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self._op = op
|
| 64 |
+
if not isinstance(self._op, ReduceOp):
|
| 65 |
+
raise NotImplementedError("Only ReduceOp is implemented")
|
| 66 |
+
if transport is None:
|
| 67 |
+
transport = TorchTensorType.NCCL
|
| 68 |
+
self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
|
| 69 |
+
if isinstance(transport, Communicator):
|
| 70 |
+
if set(transport.get_actor_handles()) != set(self._actor_handles):
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Expected actor handles to match the custom NCCL group"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def __str__(self) -> str:
|
| 76 |
+
return (
|
| 77 |
+
f"CollectiveGroup("
|
| 78 |
+
f"_actor_handles={self._actor_handles}, "
|
| 79 |
+
f"_op={self._op}, "
|
| 80 |
+
f"_type_hint={self._type_hint})"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def actor_handles(self) -> List["ray.actor.ActorHandle"]:
|
| 85 |
+
return self._actor_handles
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def type_hint(self) -> TorchTensorType:
|
| 89 |
+
return self._type_hint
|
| 90 |
+
|
| 91 |
+
def init_communicator(self, communicator_id: Optional[str] = None) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Initialize the communicator if it has not been initialized yet. If
|
| 94 |
+
`communicator_id` is provided, it means the communicator has already
|
| 95 |
+
been initialized.
|
| 96 |
+
"""
|
| 97 |
+
type_hint = self._type_hint
|
| 98 |
+
if type_hint.communicator_id is not None:
|
| 99 |
+
return type_hint.communicator_id
|
| 100 |
+
if communicator_id is None:
|
| 101 |
+
communicator_id = _init_communicator(
|
| 102 |
+
self._actor_handles, type_hint.get_custom_communicator()
|
| 103 |
+
)
|
| 104 |
+
type_hint.set_communicator_id(communicator_id)
|
| 105 |
+
return communicator_id
|
| 106 |
+
|
| 107 |
+
def get_communicator(self) -> Communicator:
|
| 108 |
+
if self._type_hint.communicator_id is not None:
|
| 109 |
+
ctx = ChannelContext.get_current()
|
| 110 |
+
communicator = ctx.communicators[self._type_hint.communicator_id]
|
| 111 |
+
elif self._type_hint.get_custom_communicator() is not None:
|
| 112 |
+
communicator = self._type_hint.get_custom_communicator()
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("Expected a NCCL group")
|
| 115 |
+
return communicator
|
| 116 |
+
|
| 117 |
+
def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
|
| 118 |
+
"""
|
| 119 |
+
Call the collective operation on the input tensor. An output tensor is
|
| 120 |
+
allocated and returned.
|
| 121 |
+
"""
|
| 122 |
+
import torch
|
| 123 |
+
|
| 124 |
+
if not isinstance(send_buf, torch.Tensor):
|
| 125 |
+
raise ValueError("Expected a torch tensor")
|
| 126 |
+
communicator = self.get_communicator()
|
| 127 |
+
recv_buf = torch.empty_like(send_buf)
|
| 128 |
+
communicator.allreduce(send_buf, recv_buf, self._op)
|
| 129 |
+
return recv_buf
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@DeveloperAPI
|
| 133 |
+
class CollectiveOutputNode(ClassMethodNode):
|
| 134 |
+
"""Represent an output node from a NCCL collective operation in a Ray DAG."""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
method_name: str,
|
| 139 |
+
method_args: Tuple[
|
| 140 |
+
DAGNode,
|
| 141 |
+
],
|
| 142 |
+
method_kwargs: Dict[str, Any],
|
| 143 |
+
method_options: Dict[str, Any],
|
| 144 |
+
other_args_to_resolve: Dict[str, Any],
|
| 145 |
+
):
|
| 146 |
+
# Parse the input node.
|
| 147 |
+
if not (
|
| 148 |
+
isinstance(method_args, tuple)
|
| 149 |
+
and len(method_args) == 1
|
| 150 |
+
and isinstance(method_args[0], DAGNode)
|
| 151 |
+
):
|
| 152 |
+
raise ValueError("Expected a single input node")
|
| 153 |
+
self._input_node = method_args[0]
|
| 154 |
+
# Parse the collective operation.
|
| 155 |
+
self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
|
| 156 |
+
COLLECTIVE_OPERATION_KEY, None
|
| 157 |
+
)
|
| 158 |
+
if self._collective_op is None:
|
| 159 |
+
raise ValueError("Expected a collective operation")
|
| 160 |
+
|
| 161 |
+
super().__init__(
|
| 162 |
+
method_name,
|
| 163 |
+
method_args,
|
| 164 |
+
method_kwargs,
|
| 165 |
+
method_options,
|
| 166 |
+
other_args_to_resolve,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def _copy_impl(
|
| 170 |
+
self,
|
| 171 |
+
new_args: List[Any],
|
| 172 |
+
new_kwargs: Dict[str, Any],
|
| 173 |
+
new_options: Dict[str, Any],
|
| 174 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 175 |
+
):
|
| 176 |
+
return CollectiveOutputNode(
|
| 177 |
+
self._method_name,
|
| 178 |
+
new_args,
|
| 179 |
+
new_kwargs,
|
| 180 |
+
new_options,
|
| 181 |
+
other_args_to_resolve=new_other_args_to_resolve,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def _execute_impl(self, *args, **kwargs):
|
| 185 |
+
raise NotImplementedError(
|
| 186 |
+
"CollectiveOutputNode is only supported with dag.experimental_compile()"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def collective_op(self) -> _CollectiveOperation:
|
| 191 |
+
return self._collective_op
|
.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/conftest.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
|
| 6 |
+
TEST_NAMESPACE = "ray_dag_test_namespace"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture(scope="session")
|
| 10 |
+
def shared_ray_instance():
|
| 11 |
+
# Remove ray address for test ray cluster in case we have
|
| 12 |
+
# lingering RAY_ADDRESS="http://127.0.0.1:8265" from previous local job
|
| 13 |
+
# submissions.
|
| 14 |
+
if "RAY_ADDRESS" in os.environ:
|
| 15 |
+
del os.environ["RAY_ADDRESS"]
|
| 16 |
+
yield ray.init(num_cpus=16, namespace=TEST_NAMESPACE, log_to_driver=True)
|
.venv/lib/python3.11/site-packages/ray/dag/constants.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Reserved keys used to handle ClassMethodNode in Ray DAG building.
|
| 4 |
+
PARENT_CLASS_NODE_KEY = "parent_class_node"
|
| 5 |
+
PREV_CLASS_METHOD_CALL_KEY = "prev_class_method_call"
|
| 6 |
+
BIND_INDEX_KEY = "bind_index"
|
| 7 |
+
IS_CLASS_METHOD_OUTPUT_KEY = "is_class_method_output"
|
| 8 |
+
|
| 9 |
+
# Reserved keys used to handle CollectiveOutputNode in Ray DAG building.
|
| 10 |
+
COLLECTIVE_OPERATION_KEY = "collective_operation"
|
| 11 |
+
|
| 12 |
+
# Reserved key to distinguish DAGNode type and avoid collision with user dict.
|
| 13 |
+
DAGNODE_TYPE_KEY = "__dag_node_type__"
|
| 14 |
+
|
| 15 |
+
# Feature flag to turn off the deadlock detection.
|
| 16 |
+
RAY_CGRAPH_ENABLE_DETECT_DEADLOCK = (
|
| 17 |
+
os.environ.get("RAY_CGRAPH_ENABLE_DETECT_DEADLOCK", "1") == "1"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Feature flag to turn on profiling.
|
| 21 |
+
RAY_CGRAPH_ENABLE_PROFILING = os.environ.get("RAY_CGRAPH_ENABLE_PROFILING", "0") == "1"
|
| 22 |
+
|
| 23 |
+
# Feature flag to turn on NVTX (NVIDIA Tools Extension Library) profiling.
|
| 24 |
+
# With this flag, Compiled Graph uses nvtx to automatically annotate and profile
|
| 25 |
+
# function calls during each actor's execution loop.
|
| 26 |
+
RAY_CGRAPH_ENABLE_NVTX_PROFILING = (
|
| 27 |
+
os.environ.get("RAY_CGRAPH_ENABLE_NVTX_PROFILING", "0") == "1"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Feature flag to turn on visualization of the execution schedule.
|
| 31 |
+
RAY_CGRAPH_VISUALIZE_SCHEDULE = (
|
| 32 |
+
os.environ.get("RAY_CGRAPH_VISUALIZE_SCHEDULE", "0") == "1"
|
| 33 |
+
)
|
.venv/lib/python3.11/site-packages/ray/dag/context.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import os
|
| 3 |
+
import threading
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from ray.util.annotations import DeveloperAPI
|
| 6 |
+
|
| 7 |
+
# The context singleton on this process.
|
| 8 |
+
_default_context: "Optional[DAGContext]" = None
|
| 9 |
+
_context_lock = threading.Lock()
|
| 10 |
+
|
| 11 |
+
DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10))
|
| 12 |
+
DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10))
|
| 13 |
+
DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30))
|
| 14 |
+
DEFAULT_READ_ITERATION_TIMEOUT_S = float(
|
| 15 |
+
os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1)
|
| 16 |
+
)
|
| 17 |
+
# Default buffer size is 1MB.
|
| 18 |
+
DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6))
|
| 19 |
+
# The default number of in-flight executions that can be submitted before consuming the
|
| 20 |
+
# output.
|
| 21 |
+
DEFAULT_MAX_INFLIGHT_EXECUTIONS = int(
|
| 22 |
+
os.environ.get("RAY_CGRAPH_max_inflight_executions", 10)
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
|
| 26 |
+
os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@DeveloperAPI
|
| 31 |
+
@dataclass
|
| 32 |
+
class DAGContext:
|
| 33 |
+
"""Global settings for Ray DAG.
|
| 34 |
+
|
| 35 |
+
You can configure parameters in the DAGContext by setting the environment
|
| 36 |
+
variables, `RAY_CGRAPH_<param>` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python.
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
>>> from ray.dag import DAGContext
|
| 40 |
+
>>> DAGContext.get_current().buffer_size_bytes
|
| 41 |
+
1000000
|
| 42 |
+
>>> DAGContext.get_current().buffer_size_bytes = 500
|
| 43 |
+
>>> DAGContext.get_current().buffer_size_bytes
|
| 44 |
+
500
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
submit_timeout: The maximum time in seconds to wait for execute()
|
| 48 |
+
calls.
|
| 49 |
+
get_timeout: The maximum time in seconds to wait when retrieving
|
| 50 |
+
a result from the DAG during `ray.get`. This should be set to a
|
| 51 |
+
value higher than the expected time to execute the entire DAG.
|
| 52 |
+
teardown_timeout: The maximum time in seconds to wait for the DAG to
|
| 53 |
+
cleanly shut down.
|
| 54 |
+
read_iteration_timeout: The timeout in seconds for each read iteration
|
| 55 |
+
that reads one of the input channels. If the timeout is reached, the
|
| 56 |
+
read operation will be interrupted and will try to read the next
|
| 57 |
+
input channel. It must be less than or equal to `get_timeout`.
|
| 58 |
+
buffer_size_bytes: The initial buffer size in bytes for messages
|
| 59 |
+
that can be passed between tasks in the DAG. The buffers will
|
| 60 |
+
be automatically resized if larger messages are written to the
|
| 61 |
+
channel.
|
| 62 |
+
max_inflight_executions: The maximum number of in-flight executions that
|
| 63 |
+
can be submitted via `execute` or `execute_async` before consuming
|
| 64 |
+
the output using `ray.get()`. If the caller submits more executions,
|
| 65 |
+
`RayCgraphCapacityExceeded` is raised.
|
| 66 |
+
overlap_gpu_communication: (experimental) Whether to overlap GPU
|
| 67 |
+
communication with computation during DAG execution. If True, the
|
| 68 |
+
communication and computation can be overlapped, which can improve
|
| 69 |
+
the performance of the DAG execution.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S
|
| 73 |
+
get_timeout: int = DEFAULT_GET_TIMEOUT_S
|
| 74 |
+
teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S
|
| 75 |
+
read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S
|
| 76 |
+
buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES
|
| 77 |
+
max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
|
| 78 |
+
overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION
|
| 79 |
+
|
| 80 |
+
def __post_init__(self):
|
| 81 |
+
if self.read_iteration_timeout > self.get_timeout:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"RAY_CGRAPH_read_iteration_timeout_s "
|
| 84 |
+
f"({self.read_iteration_timeout}) must be less than or equal to "
|
| 85 |
+
f"RAY_CGRAPH_get_timeout ({self.get_timeout})"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def get_current() -> "DAGContext":
|
| 90 |
+
"""Get or create a singleton context.
|
| 91 |
+
|
| 92 |
+
If the context has not yet been created in this process, it will be
|
| 93 |
+
initialized with default settings.
|
| 94 |
+
"""
|
| 95 |
+
global _default_context
|
| 96 |
+
|
| 97 |
+
with _context_lock:
|
| 98 |
+
if _default_context is None:
|
| 99 |
+
_default_context = DAGContext()
|
| 100 |
+
|
| 101 |
+
return _default_context
|
.venv/lib/python3.11/site-packages/ray/dag/dag_node.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from ray.experimental.channel.auto_transport_type import AutoTransportType
|
| 3 |
+
from ray.experimental.channel.torch_tensor_type import TorchTensorType
|
| 4 |
+
import ray
|
| 5 |
+
from ray.dag.base import DAGNodeBase
|
| 6 |
+
from ray.dag.py_obj_scanner import _PyObjScanner
|
| 7 |
+
from ray.util.annotations import DeveloperAPI
|
| 8 |
+
|
| 9 |
+
from itertools import chain
|
| 10 |
+
|
| 11 |
+
from typing import (
|
| 12 |
+
Optional,
|
| 13 |
+
Union,
|
| 14 |
+
List,
|
| 15 |
+
Tuple,
|
| 16 |
+
Dict,
|
| 17 |
+
Any,
|
| 18 |
+
TypeVar,
|
| 19 |
+
Callable,
|
| 20 |
+
)
|
| 21 |
+
import uuid
|
| 22 |
+
import asyncio
|
| 23 |
+
|
| 24 |
+
from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
|
| 25 |
+
from ray.experimental.channel import ChannelOutputType
|
| 26 |
+
from ray.experimental.channel.communicator import Communicator
|
| 27 |
+
|
| 28 |
+
T = TypeVar("T")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@DeveloperAPI
|
| 32 |
+
class DAGNode(DAGNodeBase):
|
| 33 |
+
"""Abstract class for a node in a Ray task graph.
|
| 34 |
+
|
| 35 |
+
A node has a type (e.g., FunctionNode), data (e.g., function options and
|
| 36 |
+
body), arguments (Python values, DAGNodes, and DAGNodes nested within Python
|
| 37 |
+
argument values) and options (Ray API .options() used for function, class
|
| 38 |
+
or class method)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
args: Tuple[Any],
|
| 44 |
+
kwargs: Dict[str, Any],
|
| 45 |
+
options: Dict[str, Any],
|
| 46 |
+
other_args_to_resolve: Dict[str, Any],
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
args:
|
| 50 |
+
args (Tuple[Any]): Bound node arguments.
|
| 51 |
+
ex: func_or_class.bind(1)
|
| 52 |
+
kwargs (Dict[str, Any]): Bound node keyword arguments.
|
| 53 |
+
ex: func_or_class.bind(a=1)
|
| 54 |
+
options (Dict[str, Any]): Bound node options arguments.
|
| 55 |
+
ex: func_or_class.options(num_cpus=2)
|
| 56 |
+
other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve
|
| 57 |
+
that's specific to subclass implementation without exposing
|
| 58 |
+
as args in base class, example: ClassMethodNode
|
| 59 |
+
"""
|
| 60 |
+
self._bound_args: Tuple[Any] = args or []
|
| 61 |
+
self._bound_kwargs: Dict[str, Any] = kwargs or {}
|
| 62 |
+
self._bound_options: Dict[str, Any] = options or {}
|
| 63 |
+
self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = (
|
| 64 |
+
other_args_to_resolve or {}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# The list of nodes that use this DAG node as an argument.
|
| 68 |
+
self._downstream_nodes: List["DAGNode"] = []
|
| 69 |
+
|
| 70 |
+
# UUID that is not changed over copies of this node.
|
| 71 |
+
self._stable_uuid = uuid.uuid4().hex
|
| 72 |
+
|
| 73 |
+
# Indicates whether this DAG node contains nested DAG nodes.
|
| 74 |
+
# Nested DAG nodes are allowed in traditional DAGs but not
|
| 75 |
+
# in Ray Compiled Graphs, except for MultiOutputNode.
|
| 76 |
+
self._args_contain_nested_dag_node = False
|
| 77 |
+
|
| 78 |
+
# The list of nodes that this DAG node uses as an argument.
|
| 79 |
+
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()
|
| 80 |
+
|
| 81 |
+
# Cached values from last call to execute()
|
| 82 |
+
self.cache_from_last_execute = {}
|
| 83 |
+
|
| 84 |
+
self._type_hint: ChannelOutputType = ChannelOutputType()
|
| 85 |
+
|
| 86 |
+
# If the original type hint is an AutoTransportType, we make a copy
|
| 87 |
+
# here when it is resolved to the actual type, as additional debugging
|
| 88 |
+
# information. Otherwise, it is None.
|
| 89 |
+
self._original_type_hint: Optional[ChannelOutputType] = None
|
| 90 |
+
|
| 91 |
+
# Whether this node calls `experimental_compile`.
|
| 92 |
+
self.is_cgraph_output_node = False
|
| 93 |
+
|
| 94 |
+
def _collect_upstream_nodes(self) -> List["DAGNode"]:
|
| 95 |
+
"""
|
| 96 |
+
Retrieve upstream nodes and update their downstream dependencies.
|
| 97 |
+
|
| 98 |
+
Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
|
| 99 |
+
`other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
|
| 100 |
+
builds the upstream/downstream relationship based only on args. Be cautious
|
| 101 |
+
when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.
|
| 102 |
+
|
| 103 |
+
TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
|
| 104 |
+
circular references. Therefore, it relies on the garbage collector to clean
|
| 105 |
+
them up instead of reference counting. We should consider using weak references
|
| 106 |
+
to avoid circular references.
|
| 107 |
+
"""
|
| 108 |
+
upstream_nodes: List["DAGNode"] = []
|
| 109 |
+
|
| 110 |
+
# Ray Compiled Graphs do not allow nested DAG nodes in arguments.
|
| 111 |
+
# Specifically, a DAGNode should not be placed inside any type of
|
| 112 |
+
# container. However, we only know if this is a compiled graph
|
| 113 |
+
# when calling `experimental_compile`. Therefore, we need to check
|
| 114 |
+
# in advance if the arguments contain nested DAG nodes and raise
|
| 115 |
+
# an error after compilation.
|
| 116 |
+
assert hasattr(self._bound_args, "__iter__")
|
| 117 |
+
for arg in self._bound_args:
|
| 118 |
+
if isinstance(arg, DAGNode):
|
| 119 |
+
upstream_nodes.append(arg)
|
| 120 |
+
else:
|
| 121 |
+
scanner = _PyObjScanner()
|
| 122 |
+
dag_nodes = scanner.find_nodes(arg)
|
| 123 |
+
upstream_nodes.extend(dag_nodes)
|
| 124 |
+
scanner.clear()
|
| 125 |
+
self._args_contain_nested_dag_node = len(dag_nodes) > 0
|
| 126 |
+
|
| 127 |
+
scanner = _PyObjScanner()
|
| 128 |
+
other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
|
| 129 |
+
[
|
| 130 |
+
self._bound_kwargs,
|
| 131 |
+
self._bound_other_args_to_resolve,
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
upstream_nodes.extend(other_upstream_nodes)
|
| 135 |
+
scanner.clear()
|
| 136 |
+
# Update dependencies.
|
| 137 |
+
for upstream_node in upstream_nodes:
|
| 138 |
+
upstream_node._downstream_nodes.append(self)
|
| 139 |
+
return upstream_nodes
|
| 140 |
+
|
| 141 |
+
def with_tensor_transport(
|
| 142 |
+
self,
|
| 143 |
+
transport: Optional[Union[str, Communicator]] = "auto",
|
| 144 |
+
_static_shape: bool = False,
|
| 145 |
+
_direct_return: bool = False,
|
| 146 |
+
):
|
| 147 |
+
if transport == "auto":
|
| 148 |
+
self._type_hint = AutoTransportType(
|
| 149 |
+
_static_shape=_static_shape,
|
| 150 |
+
_direct_return=_direct_return,
|
| 151 |
+
)
|
| 152 |
+
elif transport == "nccl":
|
| 153 |
+
self._type_hint = TorchTensorType(
|
| 154 |
+
transport=transport,
|
| 155 |
+
_static_shape=_static_shape,
|
| 156 |
+
_direct_return=_direct_return,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
if not isinstance(transport, Communicator):
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"transport must be 'auto', 'nccl' or a Communicator type"
|
| 162 |
+
)
|
| 163 |
+
self._type_hint = TorchTensorType(
|
| 164 |
+
transport=transport,
|
| 165 |
+
_static_shape=_static_shape,
|
| 166 |
+
_direct_return=_direct_return,
|
| 167 |
+
)
|
| 168 |
+
return self
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def type_hint(self) -> ChannelOutputType:
|
| 172 |
+
return self._type_hint
|
| 173 |
+
|
| 174 |
+
@type_hint.setter
|
| 175 |
+
def type_hint(self, type_hint: ChannelOutputType) -> None:
|
| 176 |
+
if isinstance(self._type_hint, AutoTransportType):
|
| 177 |
+
self._original_type_hint = self._type_hint
|
| 178 |
+
self._type_hint = type_hint
|
| 179 |
+
|
| 180 |
+
def get_args(self) -> Tuple[Any]:
|
| 181 |
+
"""Return the tuple of arguments for this node."""
|
| 182 |
+
|
| 183 |
+
return self._bound_args
|
| 184 |
+
|
| 185 |
+
def get_kwargs(self) -> Dict[str, Any]:
|
| 186 |
+
"""Return the dict of keyword arguments for this node."""
|
| 187 |
+
|
| 188 |
+
return self._bound_kwargs.copy()
|
| 189 |
+
|
| 190 |
+
def get_options(self) -> Dict[str, Any]:
|
| 191 |
+
"""Return the dict of options arguments for this node."""
|
| 192 |
+
|
| 193 |
+
return self._bound_options.copy()
|
| 194 |
+
|
| 195 |
+
def get_other_args_to_resolve(self) -> Dict[str, Any]:
|
| 196 |
+
"""Return the dict of other args to resolve arguments for this node."""
|
| 197 |
+
return self._bound_other_args_to_resolve.copy()
|
| 198 |
+
|
| 199 |
+
def get_stable_uuid(self) -> str:
|
| 200 |
+
"""Return stable uuid for this node.
|
| 201 |
+
1) Generated only once at first instance creation
|
| 202 |
+
2) Stable across pickling, replacement and JSON serialization.
|
| 203 |
+
"""
|
| 204 |
+
return self._stable_uuid
|
| 205 |
+
|
| 206 |
+
async def get_object_refs_from_last_execute(self) -> Dict[str, Any]:
|
| 207 |
+
"""Gets cached object refs from the last call to execute().
|
| 208 |
+
|
| 209 |
+
After this DAG is executed through execute(), retrieves a map between node
|
| 210 |
+
UUID to a reference to the return value of the default executor on that node.
|
| 211 |
+
"""
|
| 212 |
+
cache = {}
|
| 213 |
+
for node_uuid, value in self.cache_from_last_execute.items():
|
| 214 |
+
if isinstance(value, asyncio.Task):
|
| 215 |
+
cache[node_uuid] = await value
|
| 216 |
+
else:
|
| 217 |
+
cache[node_uuid] = value
|
| 218 |
+
|
| 219 |
+
return cache
|
| 220 |
+
|
| 221 |
+
def clear_cache(self):
|
| 222 |
+
self.cache_from_last_execute = {}
|
| 223 |
+
|
| 224 |
+
def experimental_compile(
|
| 225 |
+
self,
|
| 226 |
+
_submit_timeout: Optional[float] = None,
|
| 227 |
+
_buffer_size_bytes: Optional[int] = None,
|
| 228 |
+
enable_asyncio: bool = False,
|
| 229 |
+
_max_inflight_executions: Optional[int] = None,
|
| 230 |
+
_overlap_gpu_communication: Optional[bool] = None,
|
| 231 |
+
) -> "ray.dag.CompiledDAG":
|
| 232 |
+
"""Compile an accelerated execution path for this DAG.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
_submit_timeout: The maximum time in seconds to wait for execute() calls.
|
| 236 |
+
None means using default timeout, 0 means immediate timeout
|
| 237 |
+
(immediate success or timeout without blocking), -1 means
|
| 238 |
+
infinite timeout (block indefinitely).
|
| 239 |
+
_buffer_size_bytes: The initial buffer size in bytes for messages
|
| 240 |
+
that can be passed between tasks in the DAG. The buffers will
|
| 241 |
+
be automatically resized if larger messages are written to the
|
| 242 |
+
channel.
|
| 243 |
+
enable_asyncio: Whether to enable asyncio for this DAG.
|
| 244 |
+
_max_inflight_executions: The maximum number of in-flight executions that
|
| 245 |
+
can be submitted via `execute` or `execute_async` before consuming
|
| 246 |
+
the output using `ray.get()`. If the caller submits more executions,
|
| 247 |
+
`RayCgraphCapacityExceeded` is raised.
|
| 248 |
+
_overlap_gpu_communication: (experimental) Whether to overlap GPU
|
| 249 |
+
communication with computation during DAG execution. If True, the
|
| 250 |
+
communication and computation can be overlapped, which can improve
|
| 251 |
+
the performance of the DAG execution. If None, the default value
|
| 252 |
+
will be used.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
A compiled DAG.
|
| 256 |
+
"""
|
| 257 |
+
from ray.dag import DAGContext
|
| 258 |
+
|
| 259 |
+
ctx = DAGContext.get_current()
|
| 260 |
+
if _buffer_size_bytes is None:
|
| 261 |
+
_buffer_size_bytes = ctx.buffer_size_bytes
|
| 262 |
+
|
| 263 |
+
# Validate whether this DAG node has already been compiled.
|
| 264 |
+
if self.is_cgraph_output_node:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
"It is not allowed to call `experimental_compile` on the same DAG "
|
| 267 |
+
"object multiple times no matter whether `teardown` is called or not. "
|
| 268 |
+
"Please reuse the existing compiled DAG or create a new one."
|
| 269 |
+
)
|
| 270 |
+
# Whether this node is an output node in the DAG. We cannot determine
|
| 271 |
+
# this in the constructor because the output node is determined when
|
| 272 |
+
# `experimental_compile` is called.
|
| 273 |
+
self.is_cgraph_output_node = True
|
| 274 |
+
return build_compiled_dag_from_ray_dag(
|
| 275 |
+
self,
|
| 276 |
+
_submit_timeout,
|
| 277 |
+
_buffer_size_bytes,
|
| 278 |
+
enable_asyncio,
|
| 279 |
+
_max_inflight_executions,
|
| 280 |
+
_overlap_gpu_communication,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def execute(
|
| 284 |
+
self, *args, _ray_cache_refs: bool = False, **kwargs
|
| 285 |
+
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
|
| 286 |
+
"""Execute this DAG using the Ray default executor _execute_impl().
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
_ray_cache_refs: If true, stores the the default executor's return values
|
| 290 |
+
on each node in this DAG in a cache. These should be a mix of:
|
| 291 |
+
- ray.ObjectRefs pointing to the outputs of method and function nodes
|
| 292 |
+
- Serve handles for class nodes
|
| 293 |
+
- resolved values representing user input at runtime
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def executor(node):
|
| 297 |
+
return node._execute_impl(*args, **kwargs)
|
| 298 |
+
|
| 299 |
+
result = self.apply_recursive(executor)
|
| 300 |
+
if _ray_cache_refs:
|
| 301 |
+
self.cache_from_last_execute = executor.cache
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
def _get_toplevel_child_nodes(self) -> List["DAGNode"]:
|
| 305 |
+
"""Return the list of nodes specified as top-level args.
|
| 306 |
+
|
| 307 |
+
For example, in `f.remote(a, [b])`, only `a` is a top-level arg.
|
| 308 |
+
|
| 309 |
+
This list of nodes are those that are typically resolved prior to
|
| 310 |
+
task execution in Ray. This does not include nodes nested within args.
|
| 311 |
+
For that, use ``_get_all_child_nodes()``.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
# we use List instead of Set here because the hash key of the node
|
| 315 |
+
# object changes each time we create it. So if using Set here, the
|
| 316 |
+
# order of returned children can be different if we create the same
|
| 317 |
+
# nodes and dag one more time.
|
| 318 |
+
children = []
|
| 319 |
+
for a in self.get_args():
|
| 320 |
+
if isinstance(a, DAGNode):
|
| 321 |
+
if a not in children:
|
| 322 |
+
children.append(a)
|
| 323 |
+
for a in self.get_kwargs().values():
|
| 324 |
+
if isinstance(a, DAGNode):
|
| 325 |
+
if a not in children:
|
| 326 |
+
children.append(a)
|
| 327 |
+
for a in self.get_other_args_to_resolve().values():
|
| 328 |
+
if isinstance(a, DAGNode):
|
| 329 |
+
if a not in children:
|
| 330 |
+
children.append(a)
|
| 331 |
+
return children
|
| 332 |
+
|
| 333 |
+
def _get_all_child_nodes(self) -> List["DAGNode"]:
|
| 334 |
+
"""Return the list of nodes referenced by the args, kwargs, and
|
| 335 |
+
args_to_resolve in current node, even they're deeply nested.
|
| 336 |
+
|
| 337 |
+
Examples:
|
| 338 |
+
f.remote(a, [b]) -> [a, b]
|
| 339 |
+
f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
scanner = _PyObjScanner()
|
| 343 |
+
# we use List instead of Set here, reason explained
|
| 344 |
+
# in `_get_toplevel_child_nodes`.
|
| 345 |
+
children = []
|
| 346 |
+
for n in scanner.find_nodes(
|
| 347 |
+
[
|
| 348 |
+
self._bound_args,
|
| 349 |
+
self._bound_kwargs,
|
| 350 |
+
self._bound_other_args_to_resolve,
|
| 351 |
+
]
|
| 352 |
+
):
|
| 353 |
+
if n not in children:
|
| 354 |
+
children.append(n)
|
| 355 |
+
scanner.clear()
|
| 356 |
+
return children
|
| 357 |
+
|
| 358 |
+
def _apply_and_replace_all_child_nodes(
|
| 359 |
+
self, fn: "Callable[[DAGNode], T]"
|
| 360 |
+
) -> "DAGNode":
|
| 361 |
+
"""Apply and replace all immediate child nodes using a given function.
|
| 362 |
+
|
| 363 |
+
This is a shallow replacement only. To recursively transform nodes in
|
| 364 |
+
the DAG, use ``apply_recursive()``.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
fn: Callable that will be applied once to each child of this node.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
New DAGNode after replacing all child nodes.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
replace_table = {}
|
| 374 |
+
# CloudPickler scanner object for current layer of DAGNode. Same
|
| 375 |
+
# scanner should be use for a full find & replace cycle.
|
| 376 |
+
scanner = _PyObjScanner()
|
| 377 |
+
# Find all first-level nested DAGNode children in args.
|
| 378 |
+
# Update replacement table and execute the replace.
|
| 379 |
+
for node in scanner.find_nodes(
|
| 380 |
+
[
|
| 381 |
+
self._bound_args,
|
| 382 |
+
self._bound_kwargs,
|
| 383 |
+
self._bound_other_args_to_resolve,
|
| 384 |
+
]
|
| 385 |
+
):
|
| 386 |
+
if node not in replace_table:
|
| 387 |
+
replace_table[node] = fn(node)
|
| 388 |
+
new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes(
|
| 389 |
+
replace_table
|
| 390 |
+
)
|
| 391 |
+
scanner.clear()
|
| 392 |
+
|
| 393 |
+
# Return updated copy of self.
|
| 394 |
+
return self._copy(
|
| 395 |
+
new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
|
| 399 |
+
"""Apply callable on each node in this DAG in a bottom-up tree walk.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
fn: Callable that will be applied once to each node in the
|
| 403 |
+
DAG. It will be applied recursively bottom-up, so nodes can
|
| 404 |
+
assume the fn has been applied to their args already.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Return type of the fn after application to the tree.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
if not type(fn).__name__ == "_CachingFn":
|
| 411 |
+
|
| 412 |
+
class _CachingFn:
|
| 413 |
+
def __init__(self, fn):
|
| 414 |
+
self.cache = {}
|
| 415 |
+
self.fn = fn
|
| 416 |
+
self.fn.cache = self.cache
|
| 417 |
+
self.input_node_uuid = None
|
| 418 |
+
|
| 419 |
+
def __call__(self, node: "DAGNode"):
|
| 420 |
+
from ray.dag.input_node import InputNode
|
| 421 |
+
|
| 422 |
+
if node._stable_uuid not in self.cache:
|
| 423 |
+
self.cache[node._stable_uuid] = self.fn(node)
|
| 424 |
+
if isinstance(node, InputNode):
|
| 425 |
+
if not self.input_node_uuid:
|
| 426 |
+
self.input_node_uuid = node._stable_uuid
|
| 427 |
+
elif self.input_node_uuid != node._stable_uuid:
|
| 428 |
+
raise AssertionError(
|
| 429 |
+
"Each DAG should only have one unique InputNode."
|
| 430 |
+
)
|
| 431 |
+
return self.cache[node._stable_uuid]
|
| 432 |
+
|
| 433 |
+
fn = _CachingFn(fn)
|
| 434 |
+
else:
|
| 435 |
+
if self._stable_uuid in fn.cache:
|
| 436 |
+
return fn.cache[self._stable_uuid]
|
| 437 |
+
|
| 438 |
+
return fn(
|
| 439 |
+
self._apply_and_replace_all_child_nodes(
|
| 440 |
+
lambda node: node.apply_recursive(fn)
|
| 441 |
+
)
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
|
| 445 |
+
"""
|
| 446 |
+
Traverse all nodes in the connected component of the DAG that contains
|
| 447 |
+
the `self` node, and apply the given function to each node.
|
| 448 |
+
"""
|
| 449 |
+
visited = set()
|
| 450 |
+
queue = [self]
|
| 451 |
+
cgraph_output_node: Optional[DAGNode] = None
|
| 452 |
+
|
| 453 |
+
while queue:
|
| 454 |
+
node = queue.pop(0)
|
| 455 |
+
if node._args_contain_nested_dag_node:
|
| 456 |
+
self._raise_nested_dag_node_error(node._bound_args)
|
| 457 |
+
|
| 458 |
+
if node not in visited:
|
| 459 |
+
if node.is_cgraph_output_node:
|
| 460 |
+
# Validate whether there are multiple nodes that call
|
| 461 |
+
# `experimental_compile`.
|
| 462 |
+
if cgraph_output_node is not None:
|
| 463 |
+
raise ValueError(
|
| 464 |
+
"The DAG was compiled more than once. The following two "
|
| 465 |
+
"nodes call `experimental_compile`: "
|
| 466 |
+
f"(1) {cgraph_output_node}, (2) {node}"
|
| 467 |
+
)
|
| 468 |
+
cgraph_output_node = node
|
| 469 |
+
fn(node)
|
| 470 |
+
visited.add(node)
|
| 471 |
+
"""
|
| 472 |
+
Add all unseen downstream and upstream nodes to the queue.
|
| 473 |
+
This function should be called by the root of the DAG. However,
|
| 474 |
+
in some invalid cases, some nodes may not be descendants of the
|
| 475 |
+
root. Therefore, we also add upstream nodes to the queue so that
|
| 476 |
+
a meaningful error message can be raised when the DAG is compiled.
|
| 477 |
+
|
| 478 |
+
```
|
| 479 |
+
with InputNode() as inp:
|
| 480 |
+
dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)])
|
| 481 |
+
```
|
| 482 |
+
|
| 483 |
+
In the above example, `a2.inc` is not a descendant of inp. If we only
|
| 484 |
+
add downstream nodes to the queue, the `a2.inc` node will not be visited
|
| 485 |
+
, and the error message will be hard to understand, such as a key error
|
| 486 |
+
in the compiled DAG.
|
| 487 |
+
"""
|
| 488 |
+
for neighbor in chain.from_iterable(
|
| 489 |
+
[node._downstream_nodes, node._upstream_nodes]
|
| 490 |
+
):
|
| 491 |
+
if neighbor not in visited:
|
| 492 |
+
queue.append(neighbor)
|
| 493 |
+
|
| 494 |
+
def _raise_nested_dag_node_error(self, args):
|
| 495 |
+
"""
|
| 496 |
+
Raise an error for nested DAGNodes in Ray Compiled Graphs.
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
args: The arguments of the DAGNode.
|
| 500 |
+
"""
|
| 501 |
+
for arg in args:
|
| 502 |
+
if isinstance(arg, DAGNode):
|
| 503 |
+
continue
|
| 504 |
+
else:
|
| 505 |
+
scanner = _PyObjScanner()
|
| 506 |
+
dag_nodes = scanner.find_nodes([arg])
|
| 507 |
+
scanner.clear()
|
| 508 |
+
if len(dag_nodes) > 0:
|
| 509 |
+
raise ValueError(
|
| 510 |
+
f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
|
| 511 |
+
f"in {self}. Please ensure that the argument is a "
|
| 512 |
+
"single DAGNode and that a DAGNode is not allowed to "
|
| 513 |
+
"be placed inside any type of container."
|
| 514 |
+
)
|
| 515 |
+
raise AssertionError(
|
| 516 |
+
"A DAGNode's args should contain nested DAGNodes as args, "
|
| 517 |
+
"but none were found during the compilation process. This is a "
|
| 518 |
+
"Ray internal error. Please report this issue to the Ray team."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def _find_root(self) -> "DAGNode":
|
| 522 |
+
"""
|
| 523 |
+
Return the root node of the DAG. The root node must be an InputNode.
|
| 524 |
+
"""
|
| 525 |
+
from ray.dag.input_node import InputNode
|
| 526 |
+
|
| 527 |
+
node = self
|
| 528 |
+
while not isinstance(node, InputNode):
|
| 529 |
+
if len(node._upstream_nodes) == 0:
|
| 530 |
+
raise ValueError(
|
| 531 |
+
"No InputNode found in the DAG: when traversing upwards, "
|
| 532 |
+
f"no upstream node was found for {node}."
|
| 533 |
+
)
|
| 534 |
+
node = node._upstream_nodes[0]
|
| 535 |
+
return node
|
| 536 |
+
|
| 537 |
+
def apply_functional(
|
| 538 |
+
self,
|
| 539 |
+
source_input_list: Any,
|
| 540 |
+
predictate_fn: Callable,
|
| 541 |
+
apply_fn: Callable,
|
| 542 |
+
):
|
| 543 |
+
"""
|
| 544 |
+
Apply a given function to DAGNodes in source_input_list, and return
|
| 545 |
+
the replaced inputs without mutating or coping any DAGNode.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
source_input_list: Source inputs to extract and apply function on
|
| 549 |
+
all children DAGNode instances.
|
| 550 |
+
predictate_fn: Applied on each DAGNode instance found and determine
|
| 551 |
+
if we should apply function to it. Can be used to filter node
|
| 552 |
+
types.
|
| 553 |
+
apply_fn: Function to appy on the node on bound attributes. Example:
|
| 554 |
+
apply_fn = lambda node: node._get_serve_deployment_handle(
|
| 555 |
+
node._deployment, node._bound_other_args_to_resolve
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
replaced_inputs: Outputs of apply_fn on DAGNodes in
|
| 560 |
+
source_input_list that passes predictate_fn.
|
| 561 |
+
"""
|
| 562 |
+
replace_table = {}
|
| 563 |
+
scanner = _PyObjScanner()
|
| 564 |
+
for node in scanner.find_nodes(source_input_list):
|
| 565 |
+
if predictate_fn(node) and node not in replace_table:
|
| 566 |
+
replace_table[node] = apply_fn(node)
|
| 567 |
+
|
| 568 |
+
replaced_inputs = scanner.replace_nodes(replace_table)
|
| 569 |
+
scanner.clear()
|
| 570 |
+
|
| 571 |
+
return replaced_inputs
|
| 572 |
+
|
| 573 |
+
def _execute_impl(
|
| 574 |
+
self, *args, **kwargs
|
| 575 |
+
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
|
| 576 |
+
"""Execute this node, assuming args have been transformed already."""
|
| 577 |
+
raise NotImplementedError
|
| 578 |
+
|
| 579 |
+
def _copy_impl(
|
| 580 |
+
self,
|
| 581 |
+
new_args: List[Any],
|
| 582 |
+
new_kwargs: Dict[str, Any],
|
| 583 |
+
new_options: Dict[str, Any],
|
| 584 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 585 |
+
) -> "DAGNode":
|
| 586 |
+
"""Return a copy of this node with the given new args."""
|
| 587 |
+
raise NotImplementedError
|
| 588 |
+
|
| 589 |
+
def _copy(
|
| 590 |
+
self,
|
| 591 |
+
new_args: List[Any],
|
| 592 |
+
new_kwargs: Dict[str, Any],
|
| 593 |
+
new_options: Dict[str, Any],
|
| 594 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 595 |
+
) -> "DAGNode":
|
| 596 |
+
"""Return a copy of this node with the given new args."""
|
| 597 |
+
instance = self._copy_impl(
|
| 598 |
+
new_args, new_kwargs, new_options, new_other_args_to_resolve
|
| 599 |
+
)
|
| 600 |
+
instance._stable_uuid = self._stable_uuid
|
| 601 |
+
instance._type_hint = copy.deepcopy(self._type_hint)
|
| 602 |
+
instance._original_type_hint = copy.deepcopy(self._original_type_hint)
|
| 603 |
+
return instance
|
| 604 |
+
|
| 605 |
+
def __getstate__(self):
|
| 606 |
+
"""Required due to overriding `__getattr__` else pickling fails."""
|
| 607 |
+
return self.__dict__
|
| 608 |
+
|
| 609 |
+
def __setstate__(self, d: Dict[str, Any]):
|
| 610 |
+
"""Required due to overriding `__getattr__` else pickling fails."""
|
| 611 |
+
self.__dict__.update(d)
|
| 612 |
+
|
| 613 |
+
def __getattr__(self, attr: str):
|
| 614 |
+
if attr == "bind":
|
| 615 |
+
raise AttributeError(f".bind() cannot be used again on {type(self)} ")
|
| 616 |
+
elif attr == "remote":
|
| 617 |
+
raise AttributeError(
|
| 618 |
+
f".remote() cannot be used on {type(self)}. To execute the task "
|
| 619 |
+
"graph for this node, use .execute()."
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
return self.__getattribute__(attr)
|
.venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import total_ordering
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Set, Tuple, List, Dict, Optional
|
| 4 |
+
import copy
|
| 5 |
+
import logging
|
| 6 |
+
import ray
|
| 7 |
+
import heapq
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class _DAGNodeOperationType(Enum):
|
| 15 |
+
"""
|
| 16 |
+
There are three types of operations that a DAG node can perform:
|
| 17 |
+
1. READ: Read from an input channel.
|
| 18 |
+
2. COMPUTE: Execute the method corresponding to the node.
|
| 19 |
+
3. WRITE: Write to an output channel.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
READ = "READ"
|
| 23 |
+
COMPUTE = "COMPUTE"
|
| 24 |
+
WRITE = "WRITE"
|
| 25 |
+
|
| 26 |
+
def viz_str(self):
|
| 27 |
+
"""
|
| 28 |
+
A string representation of the operation type to be used in visualization.
|
| 29 |
+
|
| 30 |
+
The result string is a single character because conciseness is preferred.
|
| 31 |
+
"""
|
| 32 |
+
if self == _DAGNodeOperationType.READ:
|
| 33 |
+
return "R"
|
| 34 |
+
elif self == _DAGNodeOperationType.COMPUTE:
|
| 35 |
+
return "C"
|
| 36 |
+
elif self == _DAGNodeOperationType.WRITE:
|
| 37 |
+
return "W"
|
| 38 |
+
assert False, f"Unknown operation type: {self}"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class _DAGNodeOperation:
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
exec_task_idx: int,
|
| 45 |
+
operation_type: _DAGNodeOperationType,
|
| 46 |
+
method_name: Optional[str] = None,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
exec_task_idx: The index of the task that this operation belongs to
|
| 51 |
+
in the actor's ExecutableTask list. The index is not the same
|
| 52 |
+
as bind_index because there may be more tasks bound to an actor
|
| 53 |
+
than tasks that appear in the current compiled DAG.
|
| 54 |
+
operation_type: The type of operation to perform.
|
| 55 |
+
method_name: The name of the method that this operation originates
|
| 56 |
+
from. This is only for visualization and debugging purposes.
|
| 57 |
+
"""
|
| 58 |
+
self.exec_task_idx = exec_task_idx
|
| 59 |
+
self.type = operation_type
|
| 60 |
+
self.method_name = method_name
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return (
|
| 64 |
+
f"_DAGNodeOperation("
|
| 65 |
+
f"exec_task_idx: {self.exec_task_idx}, "
|
| 66 |
+
f" type: {self.type})"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def viz_str(self):
|
| 70 |
+
"""
|
| 71 |
+
A string representation of the node to be used in visualization.
|
| 72 |
+
"""
|
| 73 |
+
return f"[{self.exec_task_idx}] {self.method_name} {self.type.viz_str()}"
|
| 74 |
+
|
| 75 |
+
def __hash__(self):
|
| 76 |
+
return hash((self.exec_task_idx, self.type))
|
| 77 |
+
|
| 78 |
+
def __eq__(self, other):
|
| 79 |
+
# An operation is uniquely identified by its `exec_task_idx` and type.
|
| 80 |
+
# `method_name` is only for debugging purposes.
|
| 81 |
+
return self.exec_task_idx == other.exec_task_idx and self.type == other.type
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@total_ordering
|
| 85 |
+
class _DAGOperationGraphNode:
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
operation: _DAGNodeOperation,
|
| 89 |
+
task_idx: int,
|
| 90 |
+
actor_handle: "ray.actor.ActorHandle",
|
| 91 |
+
requires_nccl: bool,
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
_DAGOperationGraphNode represents a node in the DAG operation graph.
|
| 95 |
+
It contains information about the node's in-degree, out-degree, edges,
|
| 96 |
+
and the operation it performs.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
operation: The operation that this node performs. The operation
|
| 100 |
+
can be a READ, COMPUTE, or WRITE operation.
|
| 101 |
+
task_idx: A unique index which can be used to index into
|
| 102 |
+
`CompiledDAG.idx_to_task` to get the corresponding task.
|
| 103 |
+
actor_handle: The actor handle to which this operation belongs.
|
| 104 |
+
requires_nccl: Whether this operation requires NCCL.
|
| 105 |
+
"""
|
| 106 |
+
self.operation = operation
|
| 107 |
+
self.task_idx = task_idx
|
| 108 |
+
self.actor_handle = actor_handle
|
| 109 |
+
self.requires_nccl = requires_nccl
|
| 110 |
+
# The in_edges and out_edges are dicts of tuples to strings.
|
| 111 |
+
# Each tuple (the key) contains an integer `task_idx`, which can be
|
| 112 |
+
# used to index into `idx_to_task` to get the corresponding task,
|
| 113 |
+
# and a `_DAGNodeOperationType`, which can be READ, COMPUTE, or WRITE.
|
| 114 |
+
# The string (the value) is the visualization information of the edge,
|
| 115 |
+
# it is a tuple of a label of the edge and a boolean indicating whether
|
| 116 |
+
# the edge is a control dependency.
|
| 117 |
+
self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {}
|
| 118 |
+
self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {}
|
| 119 |
+
# The collective nodes are the nodes that belong to the same collective
|
| 120 |
+
# operation. Each node is represented by a tuple of its task idx and type.
|
| 121 |
+
self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set()
|
| 122 |
+
# The ready collective nodes are the nodes that are ready to be executed,
|
| 123 |
+
# i.e., their in-degrees are zero. When a collective node is ready, it
|
| 124 |
+
# will be added to the ready collective nodes of all the nodes in its
|
| 125 |
+
# collective operation.
|
| 126 |
+
self.ready_collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set()
|
| 127 |
+
|
| 128 |
+
def __repr__(self):
|
| 129 |
+
return (
|
| 130 |
+
f"_DAGOperationGraphNode("
|
| 131 |
+
f"operation: {self.operation}, "
|
| 132 |
+
f"task_idx: {self.task_idx}, "
|
| 133 |
+
f"actor_handle: {self.actor_handle}, "
|
| 134 |
+
f"requires_nccl: {self.requires_nccl})"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def __lt__(self, other: "_DAGOperationGraphNode"):
|
| 138 |
+
"""
|
| 139 |
+
This function defines the order of the nodes in the priority queue used in
|
| 140 |
+
`_select_next_nodes`. The priority queue is a min-heap, so the node with
|
| 141 |
+
higher priority is considered "less than" the other node.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def compare(lhs: "_DAGOperationGraphNode", rhs: "_DAGOperationGraphNode"):
|
| 145 |
+
# If both nodes belong to the same actor, the node with the smaller
|
| 146 |
+
# `exec_task_idx` is prioritized. If two nodes belong to different
|
| 147 |
+
# actors, it approximates balancing the scheduled tasks across actors,
|
| 148 |
+
# by prioritizing the node with the smaller `exec_task_idx`. The tie
|
| 149 |
+
# is broken by the `task_idx`.
|
| 150 |
+
if lhs.operation.exec_task_idx != rhs.operation.exec_task_idx:
|
| 151 |
+
return lhs.operation.exec_task_idx < rhs.operation.exec_task_idx
|
| 152 |
+
return lhs.task_idx < rhs.task_idx
|
| 153 |
+
|
| 154 |
+
if self.actor_handle == other.actor_handle:
|
| 155 |
+
# When both nodes belong to the same actor, use the default comparison.
|
| 156 |
+
return compare(self, other)
|
| 157 |
+
elif self.is_nccl_op != other.is_nccl_op:
|
| 158 |
+
# When one node is a NCCL operation and the other is not, prioritize
|
| 159 |
+
# the non-NCCL operation.
|
| 160 |
+
return not self.is_nccl_op
|
| 161 |
+
else:
|
| 162 |
+
# When either both nodes are NCCL operations or both nodes are not
|
| 163 |
+
# NCCL operations, use the default comparison.
|
| 164 |
+
return compare(self, other)
|
| 165 |
+
|
| 166 |
+
def __eq__(self, other: "_DAGOperationGraphNode"):
|
| 167 |
+
"""
|
| 168 |
+
Two operations are equal only when they have the same `exec_task_idx` and `type`
|
| 169 |
+
and belong to the same actor.
|
| 170 |
+
"""
|
| 171 |
+
return (
|
| 172 |
+
self.actor_handle == other.actor_handle
|
| 173 |
+
and self.operation.exec_task_idx == other.operation.exec_task_idx
|
| 174 |
+
and self.operation.type == other.operation.type
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def __hash__(self):
|
| 178 |
+
"""
|
| 179 |
+
An operation is uniquely identified by its `task_idx` and type.
|
| 180 |
+
"""
|
| 181 |
+
return hash((self.operation, self.task_idx))
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def in_degree(self) -> int:
|
| 185 |
+
return len(self.in_edges)
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def is_ready(self) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
If a node is not a NCCL collective, it is ready when it has a zero
|
| 191 |
+
in-degree. If it is a NCCL collective, it is ready when all the nodes
|
| 192 |
+
in its collective operation have zero in-degrees.
|
| 193 |
+
"""
|
| 194 |
+
return self.in_degree == 0 and (
|
| 195 |
+
len(self.ready_collective_idxs) == len(self.collective_idxs)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def is_read(self) -> bool:
|
| 200 |
+
return self.operation.type == _DAGNodeOperationType.READ
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def is_nccl_collective(self) -> bool:
|
| 204 |
+
"""
|
| 205 |
+
A node is a NCCL collective if it is a compute node and requires NCCL.
|
| 206 |
+
"""
|
| 207 |
+
return (
|
| 208 |
+
self.operation.type == _DAGNodeOperationType.COMPUTE and self.requires_nccl
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def is_nccl_write(self) -> bool:
|
| 213 |
+
"""
|
| 214 |
+
A node is a NCCL write if it is a write node and requires NCCL.
|
| 215 |
+
"""
|
| 216 |
+
return self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def is_nccl_op(self) -> bool:
|
| 220 |
+
return self.is_nccl_collective or self.is_nccl_write
|
| 221 |
+
|
| 222 |
+
def viz_str(self):
|
| 223 |
+
"""
|
| 224 |
+
A string representation of the node to be used in visualization.
|
| 225 |
+
"""
|
| 226 |
+
return self.operation.viz_str()
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def _actor_id(self):
|
| 230 |
+
return self.actor_handle._ray_actor_id.hex()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _add_edge(
|
| 234 |
+
from_node: _DAGOperationGraphNode,
|
| 235 |
+
to_node: _DAGOperationGraphNode,
|
| 236 |
+
label: str = "",
|
| 237 |
+
control_dependency: bool = False,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Add an edge from `from_node` to `to_node`.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
from_node: The node from which the edge originates.
|
| 244 |
+
to_node: The node to which the edge points.
|
| 245 |
+
label: The label of the edge. This will be used to annotate the edge
|
| 246 |
+
in the visualization of the execution schedule.
|
| 247 |
+
"""
|
| 248 |
+
from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = (
|
| 249 |
+
label,
|
| 250 |
+
control_dependency,
|
| 251 |
+
)
|
| 252 |
+
to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = (
|
| 253 |
+
label,
|
| 254 |
+
control_dependency,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _push_candidate_node_if_ready(
|
| 259 |
+
actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]],
|
| 260 |
+
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
|
| 261 |
+
node: _DAGOperationGraphNode,
|
| 262 |
+
) -> None:
|
| 263 |
+
# Collective operations are ready when all the collective nodes have zero
|
| 264 |
+
# in-degrees. Only one node per collective will be added as ready.
|
| 265 |
+
if node.is_nccl_collective:
|
| 266 |
+
for collective_node_metadata in node.collective_idxs:
|
| 267 |
+
task_idx, op_type = collective_node_metadata
|
| 268 |
+
collective_node = graph[task_idx][op_type]
|
| 269 |
+
collective_node.ready_collective_idxs.add(
|
| 270 |
+
(node.task_idx, node.operation.type)
|
| 271 |
+
)
|
| 272 |
+
if node.is_ready:
|
| 273 |
+
heapq.heappush(
|
| 274 |
+
actor_to_candidates[node.actor_handle._actor_id],
|
| 275 |
+
node,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _select_next_nodes(
|
| 280 |
+
actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]],
|
| 281 |
+
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
|
| 282 |
+
) -> Optional[List[_DAGOperationGraphNode]]:
|
| 283 |
+
"""
|
| 284 |
+
This function selects the next nodes for the topological sort to generate
|
| 285 |
+
execution schedule. If there are multiple candidate _DAGOperationGraphNodes,
|
| 286 |
+
select the node with the top priority. The priority is defined in
|
| 287 |
+
`_DAGOperationGraphNode.__lt__`.
|
| 288 |
+
|
| 289 |
+
For the implementation details, we maintain a priority queue for each actor,
|
| 290 |
+
where the head of the priority queue is the node with the smallest `exec_task_idx`.
|
| 291 |
+
When a node has a zero in-degree, it is added to the corresponding actor's
|
| 292 |
+
priority queue. For a node other than a NCCL collective node, it is ready to be
|
| 293 |
+
executed if it has a zero in-degree. For a NCCL collective node, it is ready to
|
| 294 |
+
be executed when all the nodes in its collective operation have zero in-degrees.
|
| 295 |
+
|
| 296 |
+
If a node is a NCCL collective node, it updates the `ready_collective_nodes` of
|
| 297 |
+
all the nodes in its collective operation. Unless all the nodes in its collective
|
| 298 |
+
group have zero in-degrees, this node is removed from the candidate list.
|
| 299 |
+
Eventually, exactly one NCCL collective node from its collective operation is
|
| 300 |
+
selected from the candidate list.
|
| 301 |
+
|
| 302 |
+
If the selected node is a NCCL write node, select all the downstream NCCL
|
| 303 |
+
read nodes. If the selected node is a NCCL collective node, select all the NCCL
|
| 304 |
+
compute nodes in its collective operation.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
actor_to_candidates: A dictionary mapping an actor id to a list of
|
| 308 |
+
candidate nodes. The list is maintained as a priority queue, so
|
| 309 |
+
the head of the queue, i.e., `candidates[0]`, is the node with
|
| 310 |
+
the smallest `bind_index`.
|
| 311 |
+
graph: A dictionary mapping the index of a task to a dictionary of its
|
| 312 |
+
_DAGOperationGraphNodes for different operations.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
A list of _DAGOperationGraphNodes to be placed into the corresponding
|
| 316 |
+
execution schedules.
|
| 317 |
+
"""
|
| 318 |
+
top_priority_node = None
|
| 319 |
+
for _, candidates in actor_to_candidates.items():
|
| 320 |
+
if len(candidates) == 0:
|
| 321 |
+
continue
|
| 322 |
+
if top_priority_node is None or candidates[0] < top_priority_node:
|
| 323 |
+
top_priority_node = candidates[0]
|
| 324 |
+
|
| 325 |
+
if top_priority_node is None:
|
| 326 |
+
return None
|
| 327 |
+
next_nodes = [
|
| 328 |
+
heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id])
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
if not top_priority_node.is_nccl_op:
|
| 332 |
+
# A non-NCCL operation node is picked.
|
| 333 |
+
assert len(next_nodes) == 1
|
| 334 |
+
elif top_priority_node.is_nccl_write:
|
| 335 |
+
# a NCCL write node is picked. NCCL is a blocking operation, so we need
|
| 336 |
+
# to pick all the corresponding NCCL read nodes to avoid a deadlock.
|
| 337 |
+
for downstream_node_metadata in top_priority_node.out_edges:
|
| 338 |
+
task_idx, op_type = downstream_node_metadata
|
| 339 |
+
downstream_node = graph[task_idx][op_type]
|
| 340 |
+
assert downstream_node.is_read
|
| 341 |
+
next_nodes.append(downstream_node)
|
| 342 |
+
assert len(next_nodes) == 1 + len(top_priority_node.out_edges)
|
| 343 |
+
elif top_priority_node.is_nccl_collective:
|
| 344 |
+
# a NCCL collective node is picked. NCCL is a blocking operation, so we need
|
| 345 |
+
# to pick all the corresponding NCCL collective nodes in its collective
|
| 346 |
+
# operation to avoid a deadlock.
|
| 347 |
+
for collective_node_metadata in top_priority_node.collective_idxs:
|
| 348 |
+
task_idx, op_type = collective_node_metadata
|
| 349 |
+
collective_node = graph[task_idx][op_type]
|
| 350 |
+
assert collective_node.is_nccl_collective and collective_node.is_ready
|
| 351 |
+
if collective_node != top_priority_node:
|
| 352 |
+
next_nodes.append(collective_node)
|
| 353 |
+
assert len(next_nodes) == len(top_priority_node.collective_idxs)
|
| 354 |
+
|
| 355 |
+
return next_nodes
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _build_dag_node_operation_graph(
|
| 359 |
+
idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"],
|
| 360 |
+
actor_to_operation_nodes: Dict[
|
| 361 |
+
"ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]
|
| 362 |
+
],
|
| 363 |
+
) -> Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]:
|
| 364 |
+
"""
|
| 365 |
+
Generate a DAG node operation graph by adding edges based on the
|
| 366 |
+
following rules:
|
| 367 |
+
|
| 368 |
+
#1 Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which
|
| 369 |
+
belong to the same task.
|
| 370 |
+
#2 Add an edge from COMPUTE with bind_index i to COMPUTE with bind_index
|
| 371 |
+
i+1 if they belong to the same actor.
|
| 372 |
+
#3 Add an edge from WRITE of the writer task to READ of the reader task.
|
| 373 |
+
|
| 374 |
+
This is the step one of building an execution schedule for each actor.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
idx_to_task: A dictionary that maps the `task_idx` to the `CompiledTask`.
|
| 378 |
+
`CompiledTask` contains information about a DAGNode and its downstream
|
| 379 |
+
nodes.
|
| 380 |
+
|
| 381 |
+
actor_to_operation_nodes: A dictionary that maps an actor handle to
|
| 382 |
+
a list of lists of _DAGOperationGraphNode. For the same actor, the
|
| 383 |
+
index of the outer list corresponds to the index of the ExecutableTask
|
| 384 |
+
in the list of `executable_tasks` in `actor_to_executable_tasks`. In
|
| 385 |
+
the inner list, the order of operations is READ, COMPUTE, and WRITE.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
A graph where each node is a _DAGOperationGraphNode. The key is `task_idx`,
|
| 389 |
+
the index to retrieve its task from `idx_to_task`, and the value is a
|
| 390 |
+
dictionary that maps the _DAGNodeOperationType (READ, COMPUTE, or WRITE)
|
| 391 |
+
to the corresponding _DAGOperationGraphNode
|
| 392 |
+
"""
|
| 393 |
+
assert idx_to_task
|
| 394 |
+
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] = {}
|
| 395 |
+
|
| 396 |
+
for _, operation_nodes_list in actor_to_operation_nodes.items():
|
| 397 |
+
prev_compute_node = None
|
| 398 |
+
for operation_nodes in operation_nodes_list:
|
| 399 |
+
task_idx = operation_nodes[0].task_idx
|
| 400 |
+
read_node, compute_node, write_node = (
|
| 401 |
+
operation_nodes[0],
|
| 402 |
+
operation_nodes[1],
|
| 403 |
+
operation_nodes[2],
|
| 404 |
+
)
|
| 405 |
+
# Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which
|
| 406 |
+
# belong to the same task.
|
| 407 |
+
_add_edge(read_node, compute_node)
|
| 408 |
+
_add_edge(compute_node, write_node)
|
| 409 |
+
# Add an edge from COMPUTE with `bind_index` i to COMPUTE with
|
| 410 |
+
# `bind_index` i+1 if they belong to the same actor.
|
| 411 |
+
if prev_compute_node is not None:
|
| 412 |
+
_add_edge(prev_compute_node, compute_node, "", True)
|
| 413 |
+
prev_compute_node = compute_node
|
| 414 |
+
assert task_idx not in graph
|
| 415 |
+
graph[task_idx] = {
|
| 416 |
+
_DAGNodeOperationType.READ: read_node,
|
| 417 |
+
_DAGNodeOperationType.COMPUTE: compute_node,
|
| 418 |
+
_DAGNodeOperationType.WRITE: write_node,
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# Import `ray.dag` here to avoid circular import.
|
| 422 |
+
from ray.dag import ClassMethodNode, CollectiveOutputNode, MultiOutputNode
|
| 423 |
+
|
| 424 |
+
# Add an edge from WRITE of the writer task to READ of the reader task.
|
| 425 |
+
for task_idx, task in idx_to_task.items():
|
| 426 |
+
if not (
|
| 427 |
+
isinstance(task.dag_node, ClassMethodNode)
|
| 428 |
+
or isinstance(task.dag_node, CollectiveOutputNode)
|
| 429 |
+
):
|
| 430 |
+
# The graph is used to generate an execution schedule for each actor.
|
| 431 |
+
# The edge from the InputNode has no impact on the final execution
|
| 432 |
+
# schedule.
|
| 433 |
+
continue
|
| 434 |
+
if (
|
| 435 |
+
isinstance(task.dag_node, ClassMethodNode)
|
| 436 |
+
and task.dag_node.is_class_method_output
|
| 437 |
+
):
|
| 438 |
+
# Class method output node dependencies are handled at its upstream:
|
| 439 |
+
# i.e., class method node
|
| 440 |
+
continue
|
| 441 |
+
for downstream_task_idx in task.downstream_task_idxs:
|
| 442 |
+
downstream_dag_node = idx_to_task[downstream_task_idx].dag_node
|
| 443 |
+
if isinstance(downstream_dag_node, MultiOutputNode):
|
| 444 |
+
continue
|
| 445 |
+
if (
|
| 446 |
+
isinstance(downstream_dag_node, ClassMethodNode)
|
| 447 |
+
and downstream_dag_node.is_class_method_output
|
| 448 |
+
):
|
| 449 |
+
consumer_idxs = idx_to_task[downstream_task_idx].downstream_task_idxs
|
| 450 |
+
for consumer_idx in consumer_idxs:
|
| 451 |
+
if consumer_idx in graph:
|
| 452 |
+
_add_edge(
|
| 453 |
+
graph[task_idx][_DAGNodeOperationType.WRITE],
|
| 454 |
+
graph[consumer_idx][_DAGNodeOperationType.READ],
|
| 455 |
+
"nccl"
|
| 456 |
+
if graph[task_idx][
|
| 457 |
+
_DAGNodeOperationType.WRITE
|
| 458 |
+
].requires_nccl
|
| 459 |
+
else "shm",
|
| 460 |
+
)
|
| 461 |
+
continue
|
| 462 |
+
_add_edge(
|
| 463 |
+
graph[task_idx][_DAGNodeOperationType.WRITE],
|
| 464 |
+
graph[downstream_task_idx][_DAGNodeOperationType.READ],
|
| 465 |
+
"nccl"
|
| 466 |
+
if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl
|
| 467 |
+
else "shm",
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
return graph
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _actor_viz_label(actor: "ray.actor.ActorHandle"):
|
| 474 |
+
"""
|
| 475 |
+
Returns the label of an actor in the visualization of the execution schedule.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
actor: The actor to be represented.
|
| 479 |
+
"""
|
| 480 |
+
class_name = actor._ray_actor_creation_function_descriptor.class_name
|
| 481 |
+
actor_id = actor._ray_actor_id.hex()
|
| 482 |
+
return f"Actor class name: {class_name}\nActor ID: {actor_id}"
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _node_viz_id_and_label(
|
| 486 |
+
node: _DAGOperationGraphNode, idx: int, optimized_index: int
|
| 487 |
+
):
|
| 488 |
+
"""
|
| 489 |
+
Returns the visualization id and label of a node. The visualization id is unique
|
| 490 |
+
across all nodes.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
node: The node to be represented.
|
| 494 |
+
idx: The index of the node in the execution schedule.
|
| 495 |
+
optimized_index: The index of the node in the optimized execution schedule.
|
| 496 |
+
"""
|
| 497 |
+
node_viz_label = node.viz_str() + f" {idx},{optimized_index}"
|
| 498 |
+
node_viz_id = f"{node._actor_id}_{node_viz_label}"
|
| 499 |
+
return node_viz_id, node_viz_label
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _visualize_execution_schedule(
|
| 503 |
+
actor_to_execution_schedule: Dict[
|
| 504 |
+
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
|
| 505 |
+
],
|
| 506 |
+
actor_to_overlapped_schedule: Optional[
|
| 507 |
+
Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]
|
| 508 |
+
],
|
| 509 |
+
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
|
| 510 |
+
):
|
| 511 |
+
"""
|
| 512 |
+
Visualize the execution schedule for each actor.
|
| 513 |
+
|
| 514 |
+
The visualization will be saved as a PNG file named `compiled_graph_schedule.png`.
|
| 515 |
+
Details of the visualization: # noqa
|
| 516 |
+
|
| 517 |
+
Node description format:
|
| 518 |
+
[<task_index>] <method_name> <operation> <orig_index>, <overlap_index>
|
| 519 |
+
|
| 520 |
+
Node description fields:
|
| 521 |
+
operation: is R(READ), C(COMPUTE), or W(WRITE)
|
| 522 |
+
orig_index: the index in the original execution schedule
|
| 523 |
+
overlap_index: the index in the overlap-communication optimized execution schedule
|
| 524 |
+
If this is different from orig_index, the node is highlighted in red color
|
| 525 |
+
|
| 526 |
+
Node grouping:
|
| 527 |
+
The nodes belonging to the same actor are grouped in the same rectangle
|
| 528 |
+
The actor class name and the actor id are shown in the rectangle
|
| 529 |
+
|
| 530 |
+
Edges:
|
| 531 |
+
black color (without label): data dependency
|
| 532 |
+
black color (annotated with "shm"): shared memory channel
|
| 533 |
+
blue color (annotated with "nccl): NCCL channel
|
| 534 |
+
dashed edge: control dependency between compute operations
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
actor_to_execution_schedule: A dictionary that maps an actor handle to
|
| 538 |
+
the execution schedule which is a list of operation nodes.
|
| 539 |
+
actor_to_overlapped_schedule: A dictionary that maps an actor handle to the
|
| 540 |
+
optimized execution schedule which is a list of operation nodes.
|
| 541 |
+
graph: A graph where each node is a _DAGOperationGraphNode. The key is
|
| 542 |
+
`task_idx`, the index to retrieve its task from `idx_to_task`, and
|
| 543 |
+
the value is a dictionary that maps the _DAGNodeOperationType (READ,
|
| 544 |
+
COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is
|
| 545 |
+
generated by `_build_dag_node_operation_graph`.
|
| 546 |
+
"""
|
| 547 |
+
try:
|
| 548 |
+
import graphviz
|
| 549 |
+
except ImportError:
|
| 550 |
+
raise ImportError(
|
| 551 |
+
"Please install graphviz to visualize the execution schedule. "
|
| 552 |
+
"You can install it by running `pip install graphviz`."
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
dot = graphviz.Digraph(comment="DAG")
|
| 556 |
+
# A dictionary that maps a node to its visualization id
|
| 557 |
+
node_to_viz_id: Dict[_DAGOperationGraphNode, str] = {}
|
| 558 |
+
|
| 559 |
+
if actor_to_overlapped_schedule is None:
|
| 560 |
+
# TODO(rui): make the visualization more concise by only displaying
|
| 561 |
+
# the original schedule
|
| 562 |
+
actor_to_overlapped_schedule = actor_to_execution_schedule
|
| 563 |
+
for actor, execution_nodes in actor_to_execution_schedule.items():
|
| 564 |
+
overlapped_schedule = actor_to_overlapped_schedule[actor]
|
| 565 |
+
node_to_optimized_index = {
|
| 566 |
+
node: i for i, node in enumerate(overlapped_schedule)
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
actor_id = actor._ray_actor_id.hex()
|
| 570 |
+
with dot.subgraph(name=f"cluster_{actor_id}") as subgraph:
|
| 571 |
+
subgraph.attr(rank=actor_id, label=_actor_viz_label(actor))
|
| 572 |
+
for i, node in enumerate(execution_nodes):
|
| 573 |
+
optimized_index = node_to_optimized_index.get(node)
|
| 574 |
+
node_viz_id, node_viz_label = _node_viz_id_and_label(
|
| 575 |
+
node, i, optimized_index
|
| 576 |
+
)
|
| 577 |
+
color = "red" if optimized_index != i else "black"
|
| 578 |
+
subgraph.node(node_viz_id, node_viz_label, color=color)
|
| 579 |
+
node_to_viz_id[node] = node_viz_id
|
| 580 |
+
|
| 581 |
+
for actor, execution_nodes in actor_to_execution_schedule.items():
|
| 582 |
+
for i, node in enumerate(execution_nodes):
|
| 583 |
+
node_viz_id = node_to_viz_id[node]
|
| 584 |
+
for out_edge, viz_info in node.out_edges.items():
|
| 585 |
+
label, control_dependency = viz_info
|
| 586 |
+
out_task_idx, out_op_type = out_edge
|
| 587 |
+
out_node = graph[out_task_idx][out_op_type]
|
| 588 |
+
out_node_viz_id = node_to_viz_id[out_node]
|
| 589 |
+
color = "blue" if label == "nccl" else "black"
|
| 590 |
+
style = "dashed" if control_dependency else "solid"
|
| 591 |
+
dot.edge(
|
| 592 |
+
node_viz_id, out_node_viz_id, label=label, color=color, style=style
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Add legend
|
| 596 |
+
with dot.subgraph(name="cluster_legend") as legend:
|
| 597 |
+
legend.attr(label="Legend", labelloc="t", fontsize="20", bgcolor="lightgrey")
|
| 598 |
+
|
| 599 |
+
# Single node and its explanation
|
| 600 |
+
legend.node("example_node", "[0] bwd C 10,10\n")
|
| 601 |
+
explanation = (
|
| 602 |
+
'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">' # noqa
|
| 603 |
+
'<TR><TD ALIGN="LEFT"><B>Node description format:</B></TD></TR>'
|
| 604 |
+
'<TR><TD ALIGN="LEFT">[<task_index>] <method_name> <operation> <orig_index>, <overlap_index></TD></TR>' # noqa
|
| 605 |
+
"<TR><TD></TD></TR>"
|
| 606 |
+
'<TR><TD ALIGN="LEFT"><B>Node description fields:</B></TD></TR>'
|
| 607 |
+
'<TR><TD ALIGN="LEFT">operation: is R(READ), C(COMPUTE), or W(WRITE)</TD></TR>' # noqa
|
| 608 |
+
'<TR><TD ALIGN="LEFT">orig_index: the index in the original execution schedule</TD></TR>' # noqa
|
| 609 |
+
'<TR><TD ALIGN="LEFT">overlap_index: the index in the overlap-communication optimized execution schedule</TD></TR>' # noqa
|
| 610 |
+
'<TR><TD ALIGN="LEFT">If this is different from orig_index, the node is highlighted in <FONT COLOR="red">red color</FONT></TD></TR>' # noqa
|
| 611 |
+
"<TR><TD></TD></TR>"
|
| 612 |
+
'<TR><TD ALIGN="LEFT"><B>Node grouping:</B></TD></TR>'
|
| 613 |
+
'<TR><TD ALIGN="LEFT">The nodes belonging to the same actor are grouped in the same rectangle</TD></TR>' # noqa
|
| 614 |
+
'<TR><TD ALIGN="LEFT">The actor class name and the actor id are shown in the rectangle</TD></TR>' # noqa
|
| 615 |
+
"<TR><TD></TD></TR>"
|
| 616 |
+
'<TR><TD ALIGN="LEFT"><B>Edges:</B></TD></TR>'
|
| 617 |
+
'<TR><TD ALIGN="LEFT">black color (without label): data dependency</TD></TR>' # noqa
|
| 618 |
+
'<TR><TD ALIGN="LEFT">black color (annotated with "shm"): shared memory channel</TD></TR>' # noqa
|
| 619 |
+
'<TR><TD ALIGN="LEFT"><FONT COLOR="blue">blue color</FONT> (annotated with "nccl): NCCL channel</TD></TR>' # noqa
|
| 620 |
+
'<TR><TD ALIGN="LEFT">dashed edge: control dependency between compute operations</TD></TR>' # noqa
|
| 621 |
+
"</TABLE>>"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
legend.node("example_explanation", explanation, shape="plaintext")
|
| 625 |
+
legend.edge("example_node", "example_explanation", style="invis")
|
| 626 |
+
|
| 627 |
+
logger.info(
|
| 628 |
+
"Writing compiled graph schedule visualization "
|
| 629 |
+
"to compiled_graph_schedule.png"
|
| 630 |
+
)
|
| 631 |
+
dot.render("compiled_graph_schedule", format="png", view=False)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _generate_actor_to_execution_schedule(
|
| 635 |
+
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]
|
| 636 |
+
) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]:
|
| 637 |
+
"""
|
| 638 |
+
Generate an execution schedule for each actor. The schedule is a list of
|
| 639 |
+
operation nodes to be executed. The function uses a topological sort
|
| 640 |
+
algorithm to generate the schedule.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
graph: A graph where each node is a _DAGOperationGraphNode. The key is
|
| 644 |
+
`task_idx`, the index to retrieve its task from `idx_to_task`, and
|
| 645 |
+
the value is a dictionary that maps the _DAGNodeOperationType (READ,
|
| 646 |
+
COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is
|
| 647 |
+
generated by `_build_dag_node_operation_graph`.
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
actor_to_execution_schedule: A dictionary that maps an actor handle to
|
| 651 |
+
the execution schedule which is a list of operation nodes to be
|
| 652 |
+
executed.
|
| 653 |
+
"""
|
| 654 |
+
|
| 655 |
+
# Mapping from the actor handle to the execution schedule which is a list
|
| 656 |
+
# of operations to be executed.
|
| 657 |
+
actor_to_execution_schedule: Dict[
|
| 658 |
+
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
|
| 659 |
+
] = defaultdict(list)
|
| 660 |
+
|
| 661 |
+
# A dictionary mapping an actor id to a list of candidate nodes. The list
|
| 662 |
+
# is maintained as a priority queue, so the head of the queue, i.e.,
|
| 663 |
+
# `candidates[0]`, is the node with the smallest `bind_index`.
|
| 664 |
+
actor_to_candidates: Dict[
|
| 665 |
+
"ray._raylet.ActorID", List[_DAGOperationGraphNode]
|
| 666 |
+
] = defaultdict(list)
|
| 667 |
+
for _, node_dict in graph.items():
|
| 668 |
+
for _, node in node_dict.items():
|
| 669 |
+
# A node with a zero in-degree edge means all of its dependencies
|
| 670 |
+
# have been satisfied, including both data and control dependencies.
|
| 671 |
+
# Therefore, it is a candidate for execution.
|
| 672 |
+
if node.in_degree == 0:
|
| 673 |
+
_push_candidate_node_if_ready(actor_to_candidates, graph, node)
|
| 674 |
+
|
| 675 |
+
visited_nodes = set()
|
| 676 |
+
|
| 677 |
+
# Use topological sort algorithm to generate the execution schedule.
|
| 678 |
+
while True:
|
| 679 |
+
# Select a list of nodes to be executed. There are three cases:
|
| 680 |
+
# 1. If a selected node is not a NCCL operation, only itself is returned.
|
| 681 |
+
# 2. If a selected node is a NCCL write operation, the corresponding NCCL
|
| 682 |
+
# read operations are also returned.
|
| 683 |
+
# 3. If a selected node is a NCCL collective operation, all the nodes in
|
| 684 |
+
# its collective operation are returned.
|
| 685 |
+
# In cases 1 and 3, all the selected nodes are ready. In case 2, the NCCL
|
| 686 |
+
# write node is ready, while the NCCL read nodes are not ready until their
|
| 687 |
+
# in-degrees are updated.
|
| 688 |
+
nodes = _select_next_nodes(actor_to_candidates, graph)
|
| 689 |
+
if nodes is None:
|
| 690 |
+
break
|
| 691 |
+
# Filter out the visited nodes.
|
| 692 |
+
nodes = [node for node in nodes if node not in visited_nodes]
|
| 693 |
+
# Add the selected nodes to the execution schedule.
|
| 694 |
+
for node in nodes:
|
| 695 |
+
actor_to_execution_schedule[node.actor_handle].append(node)
|
| 696 |
+
visited_nodes.add(node)
|
| 697 |
+
# Update the in-degree of the downstream nodes.
|
| 698 |
+
for node in nodes:
|
| 699 |
+
for out_node_task_idx, out_node_type in node.out_edges:
|
| 700 |
+
out_node = graph[out_node_task_idx][out_node_type]
|
| 701 |
+
out_node.in_edges.pop((node.task_idx, node.operation.type))
|
| 702 |
+
if out_node.in_degree == 0 and out_node not in visited_nodes:
|
| 703 |
+
# If the downstream node is already visited, it has been added
|
| 704 |
+
# to the execution schedule. They are the NCCL read nodes in
|
| 705 |
+
# case 2.
|
| 706 |
+
_push_candidate_node_if_ready(actor_to_candidates, graph, out_node)
|
| 707 |
+
assert len(visited_nodes) == len(graph) * 3, "Expected all nodes to be visited"
|
| 708 |
+
for node in visited_nodes:
|
| 709 |
+
assert node.is_ready, f"Expected {node} to be ready"
|
| 710 |
+
for _, candidates in actor_to_candidates.items():
|
| 711 |
+
assert len(candidates) == 0, "Expected all candidates to be empty"
|
| 712 |
+
|
| 713 |
+
return actor_to_execution_schedule
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def _generate_overlapped_execution_schedule(
|
| 717 |
+
actor_to_execution_schedule: Dict[
|
| 718 |
+
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
|
| 719 |
+
],
|
| 720 |
+
) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]:
|
| 721 |
+
"""
|
| 722 |
+
From an existing execution schedule, generate a new schedule by overlapping
|
| 723 |
+
computation and communication.
|
| 724 |
+
|
| 725 |
+
Currently, the algorithm generates a new schedule for each actor as follows:
|
| 726 |
+
For each NCCL read operation (i.e., recv), scan backwards to find the nearest
|
| 727 |
+
compute node to swap with so that the NCCL read operation can be overlapped
|
| 728 |
+
with computation.
|
| 729 |
+
|
| 730 |
+
Collective operations are not yet supported.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
actor_to_execution_schedule: A dictionary that maps an actor handle to
|
| 734 |
+
the existing execution schedule for the actor. The schedule is a list
|
| 735 |
+
is a list of operations to be executed.
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
A dictionary that maps an actor handle to the overlapped execution schedule
|
| 739 |
+
for the actor.
|
| 740 |
+
"""
|
| 741 |
+
|
| 742 |
+
actor_to_overlapped_schedule: Dict[
|
| 743 |
+
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
|
| 744 |
+
] = copy.deepcopy(actor_to_execution_schedule)
|
| 745 |
+
for overlapped_schedule in actor_to_overlapped_schedule.values():
|
| 746 |
+
for i in range(len(overlapped_schedule)):
|
| 747 |
+
if (
|
| 748 |
+
overlapped_schedule[i].operation.type == _DAGNodeOperationType.READ
|
| 749 |
+
and overlapped_schedule[i].requires_nccl
|
| 750 |
+
):
|
| 751 |
+
# For each NCCL read operation (i.e., recv), scan backwards
|
| 752 |
+
# to find the nearest compute node to swap with so that
|
| 753 |
+
# the NCCL read operation can be overlapped with computation.
|
| 754 |
+
for j in range(i - 1, -1, -1):
|
| 755 |
+
if (
|
| 756 |
+
overlapped_schedule[j].operation.type
|
| 757 |
+
== _DAGNodeOperationType.COMPUTE
|
| 758 |
+
):
|
| 759 |
+
# Found a desired compute operation, make the swap
|
| 760 |
+
nccl_read_op = overlapped_schedule[i]
|
| 761 |
+
prev_ops = overlapped_schedule[j:i]
|
| 762 |
+
overlapped_schedule[j + 1 : i + 1] = prev_ops
|
| 763 |
+
overlapped_schedule[j] = nccl_read_op
|
| 764 |
+
break
|
| 765 |
+
if (
|
| 766 |
+
overlapped_schedule[j].operation.type
|
| 767 |
+
== _DAGNodeOperationType.READ
|
| 768 |
+
or overlapped_schedule[j].operation.type
|
| 769 |
+
== _DAGNodeOperationType.WRITE
|
| 770 |
+
) and overlapped_schedule[j].requires_nccl:
|
| 771 |
+
# Found a NCCL read/write operation, skip the overlap
|
| 772 |
+
# optimization to keep relative order of NCCL operations
|
| 773 |
+
break
|
| 774 |
+
return actor_to_overlapped_schedule
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def _extract_execution_schedule(
|
| 778 |
+
actor_to_execution_schedule: Dict[
|
| 779 |
+
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
|
| 780 |
+
]
|
| 781 |
+
) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]:
|
| 782 |
+
"""
|
| 783 |
+
Extract _DAGNodeOperation from _DAGOperationGraphNode in the schedule
|
| 784 |
+
and discard unnecessary information.
|
| 785 |
+
"""
|
| 786 |
+
return {
|
| 787 |
+
actor: [node.operation for node in nodes]
|
| 788 |
+
for actor, nodes in actor_to_execution_schedule.items()
|
| 789 |
+
}
|
.venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
| 3 |
+
from ray.util.annotations import DeveloperAPI
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
import cupy as cp
|
| 8 |
+
|
| 9 |
+
T = TypeVar("T")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@DeveloperAPI
|
| 13 |
+
class DAGOperationFuture(ABC, Generic[T]):
|
| 14 |
+
"""
|
| 15 |
+
A future representing the result of a DAG operation.
|
| 16 |
+
|
| 17 |
+
This is an abstraction that is internal to each actor,
|
| 18 |
+
and is not exposed to the DAG caller.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def wait(self):
|
| 23 |
+
"""
|
| 24 |
+
Wait for the future and return the result of the operation.
|
| 25 |
+
"""
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@DeveloperAPI
|
| 30 |
+
class ResolvedFuture(DAGOperationFuture):
|
| 31 |
+
"""
|
| 32 |
+
A future that is already resolved. Calling `wait()` on this will
|
| 33 |
+
immediately return the result without blocking.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, result):
|
| 37 |
+
"""
|
| 38 |
+
Initialize a resolved future.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
result: The result of the future.
|
| 42 |
+
"""
|
| 43 |
+
self._result = result
|
| 44 |
+
|
| 45 |
+
def wait(self):
|
| 46 |
+
"""
|
| 47 |
+
Wait and immediately return the result. This operation will not block.
|
| 48 |
+
"""
|
| 49 |
+
return self._result
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@DeveloperAPI
|
| 53 |
+
class GPUFuture(DAGOperationFuture[Any]):
|
| 54 |
+
"""
|
| 55 |
+
A future for a GPU event on a CUDA stream.
|
| 56 |
+
|
| 57 |
+
This future wraps a buffer, and records an event on the given stream
|
| 58 |
+
when it is created. When the future is waited on, it makes the current
|
| 59 |
+
CUDA stream wait on the event, then returns the buffer.
|
| 60 |
+
|
| 61 |
+
The buffer must be a GPU tensor produced by an earlier operation launched
|
| 62 |
+
on the given stream, or it could be CPU data. Then the future guarantees
|
| 63 |
+
that when the wait() returns, the buffer is ready on the current stream.
|
| 64 |
+
|
| 65 |
+
The `wait()` does not block CPU.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, buf: Any, stream: Optional["cp.cuda.Stream"] = None):
|
| 69 |
+
"""
|
| 70 |
+
Initialize a GPU future on the given stream.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
buf: The buffer to return when the future is resolved.
|
| 74 |
+
stream: The CUDA stream to record the event on, this event is waited
|
| 75 |
+
on when the future is resolved. If None, the current stream is used.
|
| 76 |
+
"""
|
| 77 |
+
import cupy as cp
|
| 78 |
+
|
| 79 |
+
if stream is None:
|
| 80 |
+
stream = cp.cuda.get_current_stream()
|
| 81 |
+
|
| 82 |
+
self._buf = buf
|
| 83 |
+
self._event = cp.cuda.Event()
|
| 84 |
+
self._event.record(stream)
|
| 85 |
+
|
| 86 |
+
def wait(self) -> Any:
|
| 87 |
+
"""
|
| 88 |
+
Wait for the future on the current CUDA stream and return the result from
|
| 89 |
+
the GPU operation. This operation does not block CPU.
|
| 90 |
+
"""
|
| 91 |
+
import cupy as cp
|
| 92 |
+
|
| 93 |
+
current_stream = cp.cuda.get_current_stream()
|
| 94 |
+
current_stream.wait_event(self._event)
|
| 95 |
+
return self._buf
|
.venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dag/format_utils.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.dag import DAGNode
|
| 2 |
+
from ray.util.annotations import DeveloperAPI
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@DeveloperAPI
|
| 6 |
+
def get_dag_node_str(
|
| 7 |
+
dag_node: DAGNode,
|
| 8 |
+
body_line,
|
| 9 |
+
):
|
| 10 |
+
indent = _get_indentation()
|
| 11 |
+
other_args_to_resolve_lines = _get_other_args_to_resolve_lines(
|
| 12 |
+
dag_node._bound_other_args_to_resolve
|
| 13 |
+
)
|
| 14 |
+
return (
|
| 15 |
+
f"({dag_node.__class__.__name__}, {dag_node._stable_uuid})(\n"
|
| 16 |
+
f"{indent}body={body_line}\n"
|
| 17 |
+
f"{indent}args={_get_args_lines(dag_node._bound_args)}\n"
|
| 18 |
+
f"{indent}kwargs={_get_kwargs_lines(dag_node._bound_kwargs)}\n"
|
| 19 |
+
f"{indent}options={_get_options_lines(dag_node._bound_options)}\n"
|
| 20 |
+
f"{indent}other_args_to_resolve={other_args_to_resolve_lines}\n"
|
| 21 |
+
f")"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_indentation(num_spaces=4):
|
| 26 |
+
return " " * num_spaces
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _get_args_lines(bound_args):
|
| 30 |
+
"""Pretty prints bounded args of a DAGNode, and recursively handle
|
| 31 |
+
DAGNode in list / dict containers.
|
| 32 |
+
"""
|
| 33 |
+
indent = _get_indentation()
|
| 34 |
+
lines = []
|
| 35 |
+
for arg in bound_args:
|
| 36 |
+
if isinstance(arg, DAGNode):
|
| 37 |
+
node_repr_lines = str(arg).split("\n")
|
| 38 |
+
for node_repr_line in node_repr_lines:
|
| 39 |
+
lines.append(f"{indent}" + node_repr_line)
|
| 40 |
+
elif isinstance(arg, list):
|
| 41 |
+
for ele in arg:
|
| 42 |
+
node_repr_lines = str(ele).split("\n")
|
| 43 |
+
for node_repr_line in node_repr_lines:
|
| 44 |
+
lines.append(f"{indent}" + node_repr_line)
|
| 45 |
+
elif isinstance(arg, dict):
|
| 46 |
+
for _, val in arg.items():
|
| 47 |
+
node_repr_lines = str(val).split("\n")
|
| 48 |
+
for node_repr_line in node_repr_lines:
|
| 49 |
+
lines.append(f"{indent}" + node_repr_line)
|
| 50 |
+
# TODO: (jiaodong) Handle nested containers and other obj types
|
| 51 |
+
else:
|
| 52 |
+
lines.append(f"{indent}" + str(arg) + ", ")
|
| 53 |
+
|
| 54 |
+
if len(lines) == 0:
|
| 55 |
+
args_line = "[]"
|
| 56 |
+
else:
|
| 57 |
+
args_line = "["
|
| 58 |
+
for args in lines:
|
| 59 |
+
args_line += f"\n{indent}{args}"
|
| 60 |
+
args_line += f"\n{indent}]"
|
| 61 |
+
|
| 62 |
+
return args_line
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_kwargs_lines(bound_kwargs):
|
| 66 |
+
"""Pretty prints bounded kwargs of a DAGNode, and recursively handle
|
| 67 |
+
DAGNode in list / dict containers.
|
| 68 |
+
"""
|
| 69 |
+
# TODO: (jiaodong) Nits, we're missing keys and indentation was a bit off.
|
| 70 |
+
if not bound_kwargs:
|
| 71 |
+
return "{}"
|
| 72 |
+
indent = _get_indentation()
|
| 73 |
+
kwargs_lines = []
|
| 74 |
+
for key, val in bound_kwargs.items():
|
| 75 |
+
if isinstance(val, DAGNode):
|
| 76 |
+
node_repr_lines = str(val).split("\n")
|
| 77 |
+
for index, node_repr_line in enumerate(node_repr_lines):
|
| 78 |
+
if index == 0:
|
| 79 |
+
kwargs_lines.append(
|
| 80 |
+
f"{indent}{key}:" + f"{indent}" + node_repr_line
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
kwargs_lines.append(f"{indent}{indent}" + node_repr_line)
|
| 84 |
+
|
| 85 |
+
elif isinstance(val, list):
|
| 86 |
+
for ele in val:
|
| 87 |
+
node_repr_lines = str(ele).split("\n")
|
| 88 |
+
for node_repr_line in node_repr_lines:
|
| 89 |
+
kwargs_lines.append(f"{indent}" + node_repr_line)
|
| 90 |
+
elif isinstance(val, dict):
|
| 91 |
+
for _, inner_val in val.items():
|
| 92 |
+
node_repr_lines = str(inner_val).split("\n")
|
| 93 |
+
for node_repr_line in node_repr_lines:
|
| 94 |
+
kwargs_lines.append(f"{indent}" + node_repr_line)
|
| 95 |
+
# TODO: (jiaodong) Handle nested containers and other obj types
|
| 96 |
+
else:
|
| 97 |
+
kwargs_lines.append(val)
|
| 98 |
+
|
| 99 |
+
if len(kwargs_lines) > 0:
|
| 100 |
+
kwargs_line = "{"
|
| 101 |
+
for line in kwargs_lines:
|
| 102 |
+
kwargs_line += f"\n{indent}{line}"
|
| 103 |
+
kwargs_line += f"\n{indent}}}"
|
| 104 |
+
else:
|
| 105 |
+
kwargs_line = "{}"
|
| 106 |
+
|
| 107 |
+
return kwargs_line
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _get_options_lines(bound_options):
|
| 111 |
+
"""Pretty prints .options() in DAGNode. Only prints non-empty values."""
|
| 112 |
+
if not bound_options:
|
| 113 |
+
return "{}"
|
| 114 |
+
indent = _get_indentation()
|
| 115 |
+
options_lines = []
|
| 116 |
+
for key, val in bound_options.items():
|
| 117 |
+
if val:
|
| 118 |
+
options_lines.append(f"{indent}{key}: " + str(val))
|
| 119 |
+
|
| 120 |
+
options_line = "{"
|
| 121 |
+
for line in options_lines:
|
| 122 |
+
options_line += f"\n{indent}{line}"
|
| 123 |
+
options_line += f"\n{indent}}}"
|
| 124 |
+
return options_line
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _get_other_args_to_resolve_lines(other_args_to_resolve):
|
| 128 |
+
if not other_args_to_resolve:
|
| 129 |
+
return "{}"
|
| 130 |
+
indent = _get_indentation()
|
| 131 |
+
other_args_to_resolve_lines = []
|
| 132 |
+
for key, val in other_args_to_resolve.items():
|
| 133 |
+
if isinstance(val, DAGNode):
|
| 134 |
+
node_repr_lines = str(val).split("\n")
|
| 135 |
+
for index, node_repr_line in enumerate(node_repr_lines):
|
| 136 |
+
if index == 0:
|
| 137 |
+
other_args_to_resolve_lines.append(
|
| 138 |
+
f"{indent}{key}:"
|
| 139 |
+
+ f"{indent}"
|
| 140 |
+
+ "\n"
|
| 141 |
+
+ f"{indent}{indent}{indent}"
|
| 142 |
+
+ node_repr_line
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
other_args_to_resolve_lines.append(
|
| 146 |
+
f"{indent}{indent}" + node_repr_line
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
other_args_to_resolve_lines.append(f"{indent}{key}: " + str(val))
|
| 150 |
+
|
| 151 |
+
other_args_to_resolve_line = "{"
|
| 152 |
+
for line in other_args_to_resolve_lines:
|
| 153 |
+
other_args_to_resolve_line += f"\n{indent}{line}"
|
| 154 |
+
other_args_to_resolve_line += f"\n{indent}}}"
|
| 155 |
+
return other_args_to_resolve_line
|
.venv/lib/python3.11/site-packages/ray/dag/function_node.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
from ray.dag.dag_node import DAGNode
|
| 6 |
+
from ray.dag.format_utils import get_dag_node_str
|
| 7 |
+
from ray.util.annotations import DeveloperAPI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DeveloperAPI
|
| 11 |
+
class FunctionNode(DAGNode):
|
| 12 |
+
"""Represents a bound task node in a Ray task DAG."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
func_body,
|
| 17 |
+
func_args,
|
| 18 |
+
func_kwargs,
|
| 19 |
+
func_options,
|
| 20 |
+
other_args_to_resolve=None,
|
| 21 |
+
):
|
| 22 |
+
self._body = func_body
|
| 23 |
+
super().__init__(
|
| 24 |
+
func_args,
|
| 25 |
+
func_kwargs,
|
| 26 |
+
func_options,
|
| 27 |
+
other_args_to_resolve=other_args_to_resolve,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def _copy_impl(
|
| 31 |
+
self,
|
| 32 |
+
new_args: List[Any],
|
| 33 |
+
new_kwargs: Dict[str, Any],
|
| 34 |
+
new_options: Dict[str, Any],
|
| 35 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 36 |
+
):
|
| 37 |
+
return FunctionNode(
|
| 38 |
+
self._body,
|
| 39 |
+
new_args,
|
| 40 |
+
new_kwargs,
|
| 41 |
+
new_options,
|
| 42 |
+
other_args_to_resolve=new_other_args_to_resolve,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def _execute_impl(self, *args, **kwargs):
|
| 46 |
+
"""Executor of FunctionNode by ray.remote().
|
| 47 |
+
|
| 48 |
+
Args and kwargs are to match base class signature, but not in the
|
| 49 |
+
implementation. All args and kwargs should be resolved and replaced
|
| 50 |
+
with value in bound_args and bound_kwargs via bottom-up recursion when
|
| 51 |
+
current node is executed.
|
| 52 |
+
"""
|
| 53 |
+
return (
|
| 54 |
+
ray.remote(self._body)
|
| 55 |
+
.options(**self._bound_options)
|
| 56 |
+
.remote(*self._bound_args, **self._bound_kwargs)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def __str__(self) -> str:
|
| 60 |
+
return get_dag_node_str(self, str(self._body))
|
.venv/lib/python3.11/site-packages/ray/dag/input_node.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Union, Optional
|
| 2 |
+
|
| 3 |
+
from ray.dag import DAGNode
|
| 4 |
+
from ray.dag.format_utils import get_dag_node_str
|
| 5 |
+
from ray.experimental.gradio_utils import type_to_string
|
| 6 |
+
from ray.util.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
IN_CONTEXT_MANAGER = "__in_context_manager__"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DeveloperAPI
|
| 12 |
+
class InputNode(DAGNode):
|
| 13 |
+
r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.
|
| 14 |
+
|
| 15 |
+
Should only be function or class method. A DAG can have multiple
|
| 16 |
+
entrypoints, but only one instance of InputNode exists per DAG, shared
|
| 17 |
+
among all DAGNodes.
|
| 18 |
+
|
| 19 |
+
Example:
|
| 20 |
+
|
| 21 |
+
.. code-block::
|
| 22 |
+
|
| 23 |
+
m1.forward
|
| 24 |
+
/ \
|
| 25 |
+
dag_input ensemble -> dag_output
|
| 26 |
+
\ /
|
| 27 |
+
m2.forward
|
| 28 |
+
|
| 29 |
+
In this pipeline, each user input is broadcasted to both m1.forward and
|
| 30 |
+
m2.forward as first stop of the DAG, and authored like
|
| 31 |
+
|
| 32 |
+
.. code-block:: python
|
| 33 |
+
|
| 34 |
+
import ray
|
| 35 |
+
|
| 36 |
+
@ray.remote
|
| 37 |
+
class Model:
|
| 38 |
+
def __init__(self, val):
|
| 39 |
+
self.val = val
|
| 40 |
+
def forward(self, input):
|
| 41 |
+
return self.val * input
|
| 42 |
+
|
| 43 |
+
@ray.remote
|
| 44 |
+
def combine(a, b):
|
| 45 |
+
return a + b
|
| 46 |
+
|
| 47 |
+
with InputNode() as dag_input:
|
| 48 |
+
m1 = Model.bind(1)
|
| 49 |
+
m2 = Model.bind(2)
|
| 50 |
+
m1_output = m1.forward.bind(dag_input[0])
|
| 51 |
+
m2_output = m2.forward.bind(dag_input.x)
|
| 52 |
+
ray_dag = combine.bind(m1_output, m2_output)
|
| 53 |
+
|
| 54 |
+
# Pass mix of args and kwargs as input.
|
| 55 |
+
ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2
|
| 56 |
+
|
| 57 |
+
# Alternatively user can also pass single data object, list or dict
|
| 58 |
+
# and access them via list index, object attribute or dict key str.
|
| 59 |
+
ray_dag.execute(UserDataObject(m1=1, m2=2))
|
| 60 |
+
# dag_input.m1, dag_input.m2
|
| 61 |
+
ray_dag.execute([1, 2])
|
| 62 |
+
# dag_input[0], dag_input[1]
|
| 63 |
+
ray_dag.execute({"m1": 1, "m2": 2})
|
| 64 |
+
# dag_input["m1"], dag_input["m2"]
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
*args,
|
| 70 |
+
input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
|
| 71 |
+
_other_args_to_resolve=None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
"""InputNode should only take attributes of validating and converting
|
| 75 |
+
input data rather than the input data itself. User input should be
|
| 76 |
+
provided via `ray_dag.execute(user_input)`.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
input_type: Describes the data type of inputs user will be giving.
|
| 80 |
+
- if given through singular InputNode: type of InputNode
|
| 81 |
+
- if given through InputAttributeNodes: map of key -> type
|
| 82 |
+
Used when deciding what Gradio block to represent the input nodes with.
|
| 83 |
+
_other_args_to_resolve: Internal only to keep InputNode's execution
|
| 84 |
+
context throughput pickling, replacement and serialization.
|
| 85 |
+
User should not use or pass this field.
|
| 86 |
+
"""
|
| 87 |
+
if len(args) != 0 or len(kwargs) != 0:
|
| 88 |
+
raise ValueError("InputNode should not take any args or kwargs.")
|
| 89 |
+
|
| 90 |
+
self.input_attribute_nodes = {}
|
| 91 |
+
|
| 92 |
+
self.input_type = input_type
|
| 93 |
+
if input_type is not None and isinstance(input_type, type):
|
| 94 |
+
if _other_args_to_resolve is None:
|
| 95 |
+
_other_args_to_resolve = {}
|
| 96 |
+
_other_args_to_resolve["result_type_string"] = type_to_string(input_type)
|
| 97 |
+
|
| 98 |
+
super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)
|
| 99 |
+
|
| 100 |
+
def _copy_impl(
|
| 101 |
+
self,
|
| 102 |
+
new_args: List[Any],
|
| 103 |
+
new_kwargs: Dict[str, Any],
|
| 104 |
+
new_options: Dict[str, Any],
|
| 105 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 106 |
+
):
|
| 107 |
+
return InputNode(_other_args_to_resolve=new_other_args_to_resolve)
|
| 108 |
+
|
| 109 |
+
def _execute_impl(self, *args, **kwargs):
|
| 110 |
+
"""Executor of InputNode."""
|
| 111 |
+
# Catch and assert singleton context at dag execution time.
|
| 112 |
+
assert self._in_context_manager(), (
|
| 113 |
+
"InputNode is a singleton instance that should be only used in "
|
| 114 |
+
"context manager for dag building and execution. See the docstring "
|
| 115 |
+
"of class InputNode for examples."
|
| 116 |
+
)
|
| 117 |
+
# If user only passed in one value, for simplicity we just return it.
|
| 118 |
+
if len(args) == 1 and len(kwargs) == 0:
|
| 119 |
+
return args[0]
|
| 120 |
+
|
| 121 |
+
return DAGInputData(*args, **kwargs)
|
| 122 |
+
|
| 123 |
+
def _in_context_manager(self) -> bool:
|
| 124 |
+
"""Return if InputNode is created in context manager."""
|
| 125 |
+
if (
|
| 126 |
+
not self._bound_other_args_to_resolve
|
| 127 |
+
or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
|
| 128 |
+
):
|
| 129 |
+
return False
|
| 130 |
+
else:
|
| 131 |
+
return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]
|
| 132 |
+
|
| 133 |
+
def set_context(self, key: str, val: Any):
|
| 134 |
+
"""Set field in parent DAGNode attribute that can be resolved in both
|
| 135 |
+
pickle and JSON serialization
|
| 136 |
+
"""
|
| 137 |
+
self._bound_other_args_to_resolve[key] = val
|
| 138 |
+
|
| 139 |
+
def __str__(self) -> str:
|
| 140 |
+
return get_dag_node_str(self, "__InputNode__")
|
| 141 |
+
|
| 142 |
+
def __getattr__(self, key: str):
|
| 143 |
+
assert isinstance(
|
| 144 |
+
key, str
|
| 145 |
+
), "Please only access dag input attributes with str key."
|
| 146 |
+
if key not in self.input_attribute_nodes:
|
| 147 |
+
self.input_attribute_nodes[key] = InputAttributeNode(
|
| 148 |
+
self, key, "__getattr__"
|
| 149 |
+
)
|
| 150 |
+
return self.input_attribute_nodes[key]
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, key: Union[int, str]) -> Any:
|
| 153 |
+
assert isinstance(key, (str, int)), (
|
| 154 |
+
"Please only use int index or str as first-level key to "
|
| 155 |
+
"access fields of dag input."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
input_type = None
|
| 159 |
+
if self.input_type is not None and key in self.input_type:
|
| 160 |
+
input_type = type_to_string(self.input_type[key])
|
| 161 |
+
|
| 162 |
+
if key not in self.input_attribute_nodes:
|
| 163 |
+
self.input_attribute_nodes[key] = InputAttributeNode(
|
| 164 |
+
self, key, "__getitem__", input_type
|
| 165 |
+
)
|
| 166 |
+
return self.input_attribute_nodes[key]
|
| 167 |
+
|
| 168 |
+
def __enter__(self):
|
| 169 |
+
self.set_context(IN_CONTEXT_MANAGER, True)
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def __exit__(self, *args):
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
def get_result_type(self) -> str:
|
| 176 |
+
"""Get type of the output of this DAGNode.
|
| 177 |
+
|
| 178 |
+
Generated by ray.experimental.gradio_utils.type_to_string().
|
| 179 |
+
"""
|
| 180 |
+
if "result_type_string" in self._bound_other_args_to_resolve:
|
| 181 |
+
return self._bound_other_args_to_resolve["result_type_string"]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@DeveloperAPI
|
| 185 |
+
class InputAttributeNode(DAGNode):
|
| 186 |
+
"""Represents partial access of user input based on an index (int),
|
| 187 |
+
object attribute or dict key (str).
|
| 188 |
+
|
| 189 |
+
Examples:
|
| 190 |
+
|
| 191 |
+
.. code-block:: python
|
| 192 |
+
|
| 193 |
+
with InputNode() as dag_input:
|
| 194 |
+
a = dag_input[0]
|
| 195 |
+
b = dag_input.x
|
| 196 |
+
ray_dag = add.bind(a, b)
|
| 197 |
+
|
| 198 |
+
# This makes a = 1 and b = 2
|
| 199 |
+
ray_dag.execute(1, x=2)
|
| 200 |
+
|
| 201 |
+
with InputNode() as dag_input:
|
| 202 |
+
a = dag_input[0]
|
| 203 |
+
b = dag_input[1]
|
| 204 |
+
ray_dag = add.bind(a, b)
|
| 205 |
+
|
| 206 |
+
# This makes a = 2 and b = 3
|
| 207 |
+
ray_dag.execute(2, 3)
|
| 208 |
+
|
| 209 |
+
# Alternatively, you can input a single object
|
| 210 |
+
# and the inputs are automatically indexed from the object:
|
| 211 |
+
# This makes a = 2 and b = 3
|
| 212 |
+
ray_dag.execute([2, 3])
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
dag_input_node: InputNode,
|
| 218 |
+
key: Union[int, str],
|
| 219 |
+
accessor_method: str,
|
| 220 |
+
input_type: str = None,
|
| 221 |
+
):
|
| 222 |
+
self._dag_input_node = dag_input_node
|
| 223 |
+
self._key = key
|
| 224 |
+
self._accessor_method = accessor_method
|
| 225 |
+
super().__init__(
|
| 226 |
+
[],
|
| 227 |
+
{},
|
| 228 |
+
{},
|
| 229 |
+
{
|
| 230 |
+
"dag_input_node": dag_input_node,
|
| 231 |
+
"key": key,
|
| 232 |
+
"accessor_method": accessor_method,
|
| 233 |
+
# Type of the input tied to this node. Used by
|
| 234 |
+
# gradio_visualize_graph.GraphVisualizer to determine which Gradio
|
| 235 |
+
# component should be used for this node.
|
| 236 |
+
"result_type_string": input_type,
|
| 237 |
+
},
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def _copy_impl(
|
| 241 |
+
self,
|
| 242 |
+
new_args: List[Any],
|
| 243 |
+
new_kwargs: Dict[str, Any],
|
| 244 |
+
new_options: Dict[str, Any],
|
| 245 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 246 |
+
):
|
| 247 |
+
return InputAttributeNode(
|
| 248 |
+
new_other_args_to_resolve["dag_input_node"],
|
| 249 |
+
new_other_args_to_resolve["key"],
|
| 250 |
+
new_other_args_to_resolve["accessor_method"],
|
| 251 |
+
new_other_args_to_resolve["result_type_string"],
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def _execute_impl(self, *args, **kwargs):
|
| 255 |
+
"""Executor of InputAttributeNode.
|
| 256 |
+
|
| 257 |
+
Args and kwargs are to match base class signature, but not in the
|
| 258 |
+
implementation. All args and kwargs should be resolved and replaced
|
| 259 |
+
with value in bound_args and bound_kwargs via bottom-up recursion when
|
| 260 |
+
current node is executed.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
if isinstance(self._dag_input_node, DAGInputData):
|
| 264 |
+
return self._dag_input_node[self._key]
|
| 265 |
+
else:
|
| 266 |
+
# dag.execute() is called with only one arg, thus when an
|
| 267 |
+
# InputAttributeNode is executed, its dependent InputNode is
|
| 268 |
+
# resolved with original user input python object.
|
| 269 |
+
user_input_python_object = self._dag_input_node
|
| 270 |
+
if isinstance(self._key, str):
|
| 271 |
+
if self._accessor_method == "__getitem__":
|
| 272 |
+
return user_input_python_object[self._key]
|
| 273 |
+
elif self._accessor_method == "__getattr__":
|
| 274 |
+
return getattr(user_input_python_object, self._key)
|
| 275 |
+
elif isinstance(self._key, int):
|
| 276 |
+
return user_input_python_object[self._key]
|
| 277 |
+
else:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"Please only use int index or str as first-level key to "
|
| 280 |
+
"access fields of dag input."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def __str__(self) -> str:
|
| 284 |
+
return get_dag_node_str(self, f'["{self._key}"]')
|
| 285 |
+
|
| 286 |
+
def get_result_type(self) -> str:
|
| 287 |
+
"""Get type of the output of this DAGNode.
|
| 288 |
+
|
| 289 |
+
Generated by ray.experimental.gradio_utils.type_to_string().
|
| 290 |
+
"""
|
| 291 |
+
if "result_type_string" in self._bound_other_args_to_resolve:
|
| 292 |
+
return self._bound_other_args_to_resolve["result_type_string"]
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def key(self) -> Union[int, str]:
|
| 296 |
+
return self._key
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@DeveloperAPI
|
| 300 |
+
class DAGInputData:
|
| 301 |
+
"""If user passed multiple args and kwargs directly to dag.execute(), we
|
| 302 |
+
generate this wrapper for all user inputs as one object, accessible via
|
| 303 |
+
list index or object attribute key.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(self, *args, **kwargs):
|
| 307 |
+
self._args = list(args)
|
| 308 |
+
self._kwargs = kwargs
|
| 309 |
+
|
| 310 |
+
def __getitem__(self, key: Union[int, str]) -> Any:
|
| 311 |
+
if isinstance(key, int):
|
| 312 |
+
# Access list args by index.
|
| 313 |
+
return self._args[key]
|
| 314 |
+
elif isinstance(key, str):
|
| 315 |
+
# Access kwarg by key.
|
| 316 |
+
return self._kwargs[key]
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
"Please only use int index or str as first-level key to "
|
| 320 |
+
"access fields of dag input."
|
| 321 |
+
)
|
.venv/lib/python3.11/site-packages/ray/dag/output_node.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ray
|
| 2 |
+
from typing import Any, Dict, List, Union, Tuple
|
| 3 |
+
|
| 4 |
+
from ray.dag import DAGNode
|
| 5 |
+
from ray.dag.format_utils import get_dag_node_str
|
| 6 |
+
from ray.util.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@DeveloperAPI
|
| 10 |
+
class MultiOutputNode(DAGNode):
|
| 11 |
+
"""Ray dag node used in DAG building API to mark the endpoint of DAG"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
args: Union[List[DAGNode], Tuple[DAGNode]],
|
| 16 |
+
other_args_to_resolve: Dict[str, Any] = None,
|
| 17 |
+
):
|
| 18 |
+
if isinstance(args, tuple):
|
| 19 |
+
args = list(args)
|
| 20 |
+
if not isinstance(args, list):
|
| 21 |
+
raise ValueError(f"Invalid input type for `args`, {type(args)}.")
|
| 22 |
+
super().__init__(
|
| 23 |
+
args,
|
| 24 |
+
{},
|
| 25 |
+
{},
|
| 26 |
+
other_args_to_resolve=other_args_to_resolve or {},
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def _execute_impl(
|
| 30 |
+
self, *args, **kwargs
|
| 31 |
+
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
|
| 32 |
+
return self._bound_args
|
| 33 |
+
|
| 34 |
+
def _copy_impl(
|
| 35 |
+
self,
|
| 36 |
+
new_args: List[Any],
|
| 37 |
+
new_kwargs: Dict[str, Any],
|
| 38 |
+
new_options: Dict[str, Any],
|
| 39 |
+
new_other_args_to_resolve: Dict[str, Any],
|
| 40 |
+
) -> "DAGNode":
|
| 41 |
+
"""Return a copy of this node with the given new args."""
|
| 42 |
+
return MultiOutputNode(new_args, new_other_args_to_resolve)
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return get_dag_node_str(self, "__MultiOutputNode__")
|
.venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union
|
| 3 |
+
|
| 4 |
+
import pickle # noqa: F401
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.dag.base import DAGNodeBase
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Used in deserialization hooks to reference scanner instances.
|
| 11 |
+
_instances: Dict[int, "_PyObjScanner"] = {}
|
| 12 |
+
|
| 13 |
+
# Generic types for the scanner to transform from and to.
|
| 14 |
+
SourceType = TypeVar("SourceType")
|
| 15 |
+
TransformedType = TypeVar("TransformedType")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_node(instance_id: int, node_index: int) -> SourceType:
|
| 19 |
+
"""Get the node instance.
|
| 20 |
+
|
| 21 |
+
Note: This function should be static and globally importable,
|
| 22 |
+
otherwise the serialization overhead would be very significant.
|
| 23 |
+
"""
|
| 24 |
+
return _instances[instance_id]._replace_index(node_index)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
|
| 28 |
+
"""Utility to find and replace the `source_type` in Python objects.
|
| 29 |
+
|
| 30 |
+
`source_type` can either be a single type or a tuple of multiple types.
|
| 31 |
+
|
| 32 |
+
The caller must first call `find_nodes()`, then compute a replacement table and
|
| 33 |
+
pass it to `replace_nodes`.
|
| 34 |
+
|
| 35 |
+
This uses cloudpickle under the hood, so all sub-objects that are not `source_type`
|
| 36 |
+
must be serializable.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
source_type: the type(s) of object to find and replace. Default to DAGNodeBase.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase):
|
| 43 |
+
self.source_type = source_type
|
| 44 |
+
# Buffer to keep intermediate serialized state.
|
| 45 |
+
self._buf = io.BytesIO()
|
| 46 |
+
# List of top-level SourceType found during the serialization pass.
|
| 47 |
+
self._found = None
|
| 48 |
+
# List of other objects found during the serialization pass.
|
| 49 |
+
# This is used to store references to objects so they won't be
|
| 50 |
+
# serialized by cloudpickle.
|
| 51 |
+
self._objects = []
|
| 52 |
+
# Replacement table to consult during deserialization.
|
| 53 |
+
self._replace_table: Dict[SourceType, TransformedType] = None
|
| 54 |
+
_instances[id(self)] = self
|
| 55 |
+
super().__init__(self._buf)
|
| 56 |
+
|
| 57 |
+
def reducer_override(self, obj):
|
| 58 |
+
"""Hook for reducing objects.
|
| 59 |
+
|
| 60 |
+
Objects of `self.source_type` are saved to `self._found` and a global map so
|
| 61 |
+
they can later be replaced.
|
| 62 |
+
|
| 63 |
+
All other objects fall back to the default `CloudPickler` serialization.
|
| 64 |
+
"""
|
| 65 |
+
if isinstance(obj, self.source_type):
|
| 66 |
+
index = len(self._found)
|
| 67 |
+
self._found.append(obj)
|
| 68 |
+
return _get_node, (id(self), index)
|
| 69 |
+
|
| 70 |
+
return super().reducer_override(obj)
|
| 71 |
+
|
| 72 |
+
def find_nodes(self, obj: Any) -> List[SourceType]:
|
| 73 |
+
"""
|
| 74 |
+
Serialize `obj` and store all instances of `source_type` found in `_found`.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
obj: The object to scan for `source_type`.
|
| 78 |
+
Returns:
|
| 79 |
+
A list of all instances of `source_type` found in `obj`.
|
| 80 |
+
"""
|
| 81 |
+
assert (
|
| 82 |
+
self._found is None
|
| 83 |
+
), "find_nodes cannot be called twice on the same PyObjScanner instance."
|
| 84 |
+
self._found = []
|
| 85 |
+
self._objects = []
|
| 86 |
+
self.dump(obj)
|
| 87 |
+
return self._found
|
| 88 |
+
|
| 89 |
+
def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
|
| 90 |
+
"""Replace previously found DAGNodes per the given table."""
|
| 91 |
+
assert self._found is not None, "find_nodes must be called first"
|
| 92 |
+
self._replace_table = table
|
| 93 |
+
self._buf.seek(0)
|
| 94 |
+
return pickle.load(self._buf)
|
| 95 |
+
|
| 96 |
+
def _replace_index(self, i: int) -> SourceType:
|
| 97 |
+
return self._replace_table[self._found[i]]
|
| 98 |
+
|
| 99 |
+
def clear(self):
|
| 100 |
+
"""Clear the scanner from the _instances"""
|
| 101 |
+
if id(self) in _instances:
|
| 102 |
+
del _instances[id(self)]
|
| 103 |
+
|
| 104 |
+
def __del__(self):
|
| 105 |
+
self.clear()
|
.venv/lib/python3.11/site-packages/ray/dag/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ray.dag import (
|
| 4 |
+
DAGNode,
|
| 5 |
+
InputNode,
|
| 6 |
+
InputAttributeNode,
|
| 7 |
+
FunctionNode,
|
| 8 |
+
ClassNode,
|
| 9 |
+
ClassMethodNode,
|
| 10 |
+
MultiOutputNode,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class _DAGNodeNameGenerator(object):
|
| 15 |
+
"""
|
| 16 |
+
Generate unique suffix for each given Node in the DAG.
|
| 17 |
+
Apply monotonic increasing id suffix for duplicated names.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.name_to_suffix: Dict[str, int] = dict()
|
| 22 |
+
|
| 23 |
+
def get_node_name(self, node: DAGNode):
|
| 24 |
+
# InputNode should be unique.
|
| 25 |
+
if isinstance(node, InputNode):
|
| 26 |
+
return "INPUT_NODE"
|
| 27 |
+
if isinstance(node, MultiOutputNode):
|
| 28 |
+
return "MultiOutputNode"
|
| 29 |
+
# InputAttributeNode suffixes should match the user-defined key.
|
| 30 |
+
elif isinstance(node, InputAttributeNode):
|
| 31 |
+
return f"INPUT_ATTRIBUTE_NODE_{node._key}"
|
| 32 |
+
|
| 33 |
+
# As class, method, and function nodes may have duplicated names,
|
| 34 |
+
# generate unique suffixes for such nodes.
|
| 35 |
+
if isinstance(node, ClassMethodNode):
|
| 36 |
+
node_name = node.get_options().get("name", None) or node._method_name
|
| 37 |
+
elif isinstance(node, (ClassNode, FunctionNode)):
|
| 38 |
+
node_name = node.get_options().get("name", None) or node._body.__name__
|
| 39 |
+
# we use instance class name check here to avoid importing ServeNodes as
|
| 40 |
+
# serve components are not included in Ray Core.
|
| 41 |
+
elif type(node).__name__ in ("DeploymentNode", "DeploymentFunctionNode"):
|
| 42 |
+
node_name = node.get_deployment_name()
|
| 43 |
+
elif type(node).__name__ == "DeploymentFunctionExecutorNode":
|
| 44 |
+
node_name = node._deployment_function_handle.deployment_name
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"get_node_name() should only be called on DAGNode instances."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if node_name not in self.name_to_suffix:
|
| 51 |
+
self.name_to_suffix[node_name] = 0
|
| 52 |
+
return node_name
|
| 53 |
+
else:
|
| 54 |
+
self.name_to_suffix[node_name] += 1
|
| 55 |
+
suffix_num = self.name_to_suffix[node_name]
|
| 56 |
+
|
| 57 |
+
return f"{node_name}_{suffix_num}"
|
| 58 |
+
|
| 59 |
+
def reset(self):
|
| 60 |
+
self.name_to_suffix = dict()
|
| 61 |
+
|
| 62 |
+
def __enter__(self):
|
| 63 |
+
return self
|
| 64 |
+
|
| 65 |
+
def __exit__(self, *args):
|
| 66 |
+
self.reset()
|
.venv/lib/python3.11/site-packages/ray/dag/vis_utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.dag import DAGNode
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
from ray.dag.utils import _DAGNodeNameGenerator
|
| 7 |
+
from ray.util.annotations import DeveloperAPI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DeveloperAPI
|
| 11 |
+
def plot(dag: DAGNode, to_file=None):
|
| 12 |
+
if to_file is None:
|
| 13 |
+
tmp_file = tempfile.NamedTemporaryFile(suffix=".png")
|
| 14 |
+
to_file = tmp_file.name
|
| 15 |
+
extension = "png"
|
| 16 |
+
else:
|
| 17 |
+
_, extension = os.path.splitext(to_file)
|
| 18 |
+
if not extension:
|
| 19 |
+
extension = "png"
|
| 20 |
+
else:
|
| 21 |
+
extension = extension[1:]
|
| 22 |
+
|
| 23 |
+
graph = _dag_to_dot(dag)
|
| 24 |
+
graph.write(to_file, format=extension)
|
| 25 |
+
|
| 26 |
+
# Render the image directly if running inside a Jupyter notebook
|
| 27 |
+
try:
|
| 28 |
+
from IPython import display
|
| 29 |
+
|
| 30 |
+
return display.Image(filename=to_file)
|
| 31 |
+
except ImportError:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
# close temp file if needed
|
| 35 |
+
try:
|
| 36 |
+
tmp_file.close()
|
| 37 |
+
except NameError:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _check_pydot_and_graphviz():
|
| 42 |
+
"""Check if pydot and graphviz are installed.
|
| 43 |
+
|
| 44 |
+
pydot and graphviz are required for plotting. We check this
|
| 45 |
+
during runtime rather than adding them to Ray dependencies.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
import pydot
|
| 50 |
+
except ImportError:
|
| 51 |
+
raise ImportError(
|
| 52 |
+
"pydot is required to plot DAG, " "install it with `pip install pydot`."
|
| 53 |
+
)
|
| 54 |
+
try:
|
| 55 |
+
pydot.Dot.create(pydot.Dot())
|
| 56 |
+
except (OSError, pydot.InvocationException):
|
| 57 |
+
raise ImportError(
|
| 58 |
+
"graphviz is required to plot DAG, "
|
| 59 |
+
"download it from https://graphviz.gitlab.io/download/"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _get_nodes_and_edges(dag: DAGNode):
|
| 64 |
+
"""Get all unique nodes and edges in the DAG.
|
| 65 |
+
|
| 66 |
+
A basic dfs with memorization to get all unique nodes
|
| 67 |
+
and edges in the DAG.
|
| 68 |
+
Unique nodes will be used to generate unique names,
|
| 69 |
+
while edges will be used to construct the graph.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
edges = []
|
| 73 |
+
nodes = []
|
| 74 |
+
|
| 75 |
+
def _dfs(node):
|
| 76 |
+
nodes.append(node)
|
| 77 |
+
for child_node in node._get_all_child_nodes():
|
| 78 |
+
edges.append((child_node, node))
|
| 79 |
+
return node
|
| 80 |
+
|
| 81 |
+
dag.apply_recursive(_dfs)
|
| 82 |
+
return nodes, edges
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _dag_to_dot(dag: DAGNode):
|
| 86 |
+
"""Create a Dot graph from dag.
|
| 87 |
+
|
| 88 |
+
TODO(lchu):
|
| 89 |
+
1. add more Dot configs in kwargs,
|
| 90 |
+
e.g. rankdir, alignment, etc.
|
| 91 |
+
2. add more contents to graph,
|
| 92 |
+
e.g. args, kwargs and options of each node
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
# Step 0: check dependencies and init graph
|
| 96 |
+
_check_pydot_and_graphviz()
|
| 97 |
+
import pydot
|
| 98 |
+
|
| 99 |
+
graph = pydot.Dot(rankdir="LR")
|
| 100 |
+
|
| 101 |
+
# Step 1: generate unique name for each node in dag
|
| 102 |
+
nodes, edges = _get_nodes_and_edges(dag)
|
| 103 |
+
name_generator = _DAGNodeNameGenerator()
|
| 104 |
+
node_names = {}
|
| 105 |
+
for node in nodes:
|
| 106 |
+
node_names[node] = name_generator.get_node_name(node)
|
| 107 |
+
|
| 108 |
+
# Step 2: create graph with all the edges
|
| 109 |
+
for edge in edges:
|
| 110 |
+
graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]]))
|
| 111 |
+
# if there is only one node
|
| 112 |
+
if len(nodes) == 1 and len(edges) == 0:
|
| 113 |
+
graph.add_node(pydot.Node(node_names[nodes[0]]))
|
| 114 |
+
|
| 115 |
+
return graph
|
.venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.experimental.channel.cached_channel import CachedChannel
|
| 2 |
+
from ray.experimental.channel.common import ( # noqa: F401
|
| 3 |
+
AwaitableBackgroundReader,
|
| 4 |
+
AwaitableBackgroundWriter,
|
| 5 |
+
ChannelContext,
|
| 6 |
+
ChannelInterface,
|
| 7 |
+
ChannelOutputType,
|
| 8 |
+
CompiledDAGArgs,
|
| 9 |
+
ReaderInterface,
|
| 10 |
+
SynchronousReader,
|
| 11 |
+
SynchronousWriter,
|
| 12 |
+
WriterInterface,
|
| 13 |
+
)
|
| 14 |
+
from ray.experimental.channel.communicator import Communicator
|
| 15 |
+
from ray.experimental.channel.intra_process_channel import IntraProcessChannel
|
| 16 |
+
from ray.experimental.channel.shared_memory_channel import (
|
| 17 |
+
BufferedSharedMemoryChannel,
|
| 18 |
+
Channel,
|
| 19 |
+
CompositeChannel,
|
| 20 |
+
)
|
| 21 |
+
from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"AwaitableBackgroundReader",
|
| 25 |
+
"AwaitableBackgroundWriter",
|
| 26 |
+
"CachedChannel",
|
| 27 |
+
"Channel",
|
| 28 |
+
"Communicator",
|
| 29 |
+
"ReaderInterface",
|
| 30 |
+
"SynchronousReader",
|
| 31 |
+
"SynchronousWriter",
|
| 32 |
+
"WriterInterface",
|
| 33 |
+
"ChannelContext",
|
| 34 |
+
"TorchTensorNcclChannel",
|
| 35 |
+
"IntraProcessChannel",
|
| 36 |
+
"CompositeChannel",
|
| 37 |
+
"BufferedSharedMemoryChannel",
|
| 38 |
+
"CompiledDAGArgs",
|
| 39 |
+
]
|