koichi12 commited on
Commit
de7cd93
·
verified ·
1 Parent(s): c9870ab

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/core/libjemalloc.so +3 -0
  15. .venv/lib/python3.11/site-packages/ray/dag/__init__.py +46 -0
  16. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/dag/base.py +8 -0
  32. .venv/lib/python3.11/site-packages/ray/dag/class_node.py +321 -0
  33. .venv/lib/python3.11/site-packages/ray/dag/collective_node.py +191 -0
  34. .venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py +0 -0
  35. .venv/lib/python3.11/site-packages/ray/dag/conftest.py +16 -0
  36. .venv/lib/python3.11/site-packages/ray/dag/constants.py +33 -0
  37. .venv/lib/python3.11/site-packages/ray/dag/context.py +101 -0
  38. .venv/lib/python3.11/site-packages/ray/dag/dag_node.py +622 -0
  39. .venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py +789 -0
  40. .venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py +95 -0
  41. .venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py +0 -0
  42. .venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/dag/format_utils.py +155 -0
  44. .venv/lib/python3.11/site-packages/ray/dag/function_node.py +60 -0
  45. .venv/lib/python3.11/site-packages/ray/dag/input_node.py +321 -0
  46. .venv/lib/python3.11/site-packages/ray/dag/output_node.py +45 -0
  47. .venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py +105 -0
  48. .venv/lib/python3.11/site-packages/ray/dag/utils.py +66 -0
  49. .venv/lib/python3.11/site-packages/ray/dag/vis_utils.py +115 -0
  50. .venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py +39 -0
.gitattributes CHANGED
@@ -158,3 +158,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
158
  .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
159
  .venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
160
  .venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text
 
 
158
  .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
159
  .venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
160
  .venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text
161
+ .venv/lib/python3.11/site-packages/ray/core/libjemalloc.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (8.47 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc ADDED
Binary file (378 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc ADDED
Binary file (70.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc ADDED
Binary file (5.14 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc ADDED
Binary file (41.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc ADDED
Binary file (7.78 kB). View file
 
.venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc ADDED
Binary file (636 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0284919db23f95e692026039838aef89b5964b5cfec4a88acb9b3a9f4a226fd5
3
+ size 885296
.venv/lib/python3.11/site-packages/ray/dag/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.dag.dag_node import DAGNode
2
+ from ray.dag.function_node import FunctionNode
3
+ from ray.dag.class_node import (
4
+ ClassNode,
5
+ ClassMethodNode,
6
+ )
7
+ from ray.dag.collective_node import CollectiveOutputNode
8
+ from ray.dag.input_node import (
9
+ InputNode,
10
+ InputAttributeNode,
11
+ DAGInputData,
12
+ )
13
+ from ray.dag.output_node import MultiOutputNode
14
+ from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture
15
+ from ray.dag.constants import (
16
+ PARENT_CLASS_NODE_KEY,
17
+ PREV_CLASS_METHOD_CALL_KEY,
18
+ BIND_INDEX_KEY,
19
+ IS_CLASS_METHOD_OUTPUT_KEY,
20
+ COLLECTIVE_OPERATION_KEY,
21
+ DAGNODE_TYPE_KEY,
22
+ )
23
+ from ray.dag.vis_utils import plot
24
+ from ray.dag.context import DAGContext
25
+
26
+ __all__ = [
27
+ "ClassNode",
28
+ "ClassMethodNode",
29
+ "CollectiveOutputNode",
30
+ "DAGNode",
31
+ "DAGOperationFuture",
32
+ "FunctionNode",
33
+ "GPUFuture",
34
+ "InputNode",
35
+ "InputAttributeNode",
36
+ "DAGInputData",
37
+ "PARENT_CLASS_NODE_KEY",
38
+ "PREV_CLASS_METHOD_CALL_KEY",
39
+ "BIND_INDEX_KEY",
40
+ "IS_CLASS_METHOD_OUTPUT_KEY",
41
+ "COLLECTIVE_OPERATION_KEY",
42
+ "DAGNODE_TYPE_KEY",
43
+ "plot",
44
+ "MultiOutputNode",
45
+ "DAGContext",
46
+ ]
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.42 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc ADDED
Binary file (683 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc ADDED
Binary file (9.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc ADDED
Binary file (794 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc ADDED
Binary file (1.07 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc ADDED
Binary file (5.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc ADDED
Binary file (26.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc ADDED
Binary file (34.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc ADDED
Binary file (7.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc ADDED
Binary file (2.82 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc ADDED
Binary file (2.82 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc ADDED
Binary file (5.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.43 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc ADDED
Binary file (4.73 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dag/base.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """This module defines the base class for object scanning and gets rid of
2
+ reference cycles."""
3
+ from ray.util.annotations import DeveloperAPI
4
+
5
+
6
+ @DeveloperAPI
7
+ class DAGNodeBase:
8
+ """Common base class for a node in a Ray task graph."""
.venv/lib/python3.11/site-packages/ray/dag/class_node.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from weakref import ReferenceType
2
+
3
+ import ray
4
+ from ray.dag.dag_node import DAGNode
5
+ from ray.dag.input_node import InputNode
6
+ from ray.dag.format_utils import get_dag_node_str
7
+ from ray.dag.constants import (
8
+ PARENT_CLASS_NODE_KEY,
9
+ PREV_CLASS_METHOD_CALL_KEY,
10
+ BIND_INDEX_KEY,
11
+ IS_CLASS_METHOD_OUTPUT_KEY,
12
+ )
13
+ from ray.util.annotations import DeveloperAPI
14
+
15
+ from typing import Any, Dict, List, Union, Tuple, Optional
16
+
17
+
18
+ @DeveloperAPI
19
+ class ClassNode(DAGNode):
20
+ """Represents an actor creation in a Ray task DAG."""
21
+
22
+ def __init__(
23
+ self,
24
+ cls,
25
+ cls_args,
26
+ cls_kwargs,
27
+ cls_options,
28
+ other_args_to_resolve=None,
29
+ ):
30
+ self._body = cls
31
+ self._last_call: Optional["ClassMethodNode"] = None
32
+ super().__init__(
33
+ cls_args,
34
+ cls_kwargs,
35
+ cls_options,
36
+ other_args_to_resolve=other_args_to_resolve,
37
+ )
38
+
39
+ if self._contains_input_node():
40
+ raise ValueError(
41
+ "InputNode handles user dynamic input the the DAG, and "
42
+ "cannot be used as args, kwargs, or other_args_to_resolve "
43
+ "in ClassNode constructor because it is not available at "
44
+ "class construction or binding time."
45
+ )
46
+
47
+ def _copy_impl(
48
+ self,
49
+ new_args: List[Any],
50
+ new_kwargs: Dict[str, Any],
51
+ new_options: Dict[str, Any],
52
+ new_other_args_to_resolve: Dict[str, Any],
53
+ ):
54
+ return ClassNode(
55
+ self._body,
56
+ new_args,
57
+ new_kwargs,
58
+ new_options,
59
+ other_args_to_resolve=new_other_args_to_resolve,
60
+ )
61
+
62
+ def _execute_impl(self, *args, **kwargs):
63
+ """Executor of ClassNode by ray.remote()
64
+
65
+ Args and kwargs are to match base class signature, but not in the
66
+ implementation. All args and kwargs should be resolved and replaced
67
+ with value in bound_args and bound_kwargs via bottom-up recursion when
68
+ current node is executed.
69
+ """
70
+ return (
71
+ ray.remote(self._body)
72
+ .options(**self._bound_options)
73
+ .remote(*self._bound_args, **self._bound_kwargs)
74
+ )
75
+
76
+ def _contains_input_node(self) -> bool:
77
+ """Check if InputNode is used in children DAGNodes with current node
78
+ as the root.
79
+ """
80
+ children_dag_nodes = self._get_all_child_nodes()
81
+ for child in children_dag_nodes:
82
+ if isinstance(child, InputNode):
83
+ return True
84
+ return False
85
+
86
+ def __getattr__(self, method_name: str):
87
+ # User trying to call .bind() without a bind class method
88
+ if method_name == "bind" and "bind" not in dir(self._body):
89
+ raise AttributeError(f".bind() cannot be used again on {type(self)} ")
90
+ # Raise an error if the method is invalid.
91
+ getattr(self._body, method_name)
92
+ call_node = _UnboundClassMethodNode(self, method_name, {})
93
+ return call_node
94
+
95
+ def __str__(self) -> str:
96
+ return get_dag_node_str(self, str(self._body))
97
+
98
+
99
+ class _UnboundClassMethodNode(object):
100
+ def __init__(self, actor: ClassNode, method_name: str, options: dict):
101
+ # TODO(sang): Theoretically, We should use weakref cuz it is
102
+ # a circular dependency but when I used weakref, it fails
103
+ # because we cannot serialize the weakref.
104
+ self._actor = actor
105
+ self._method_name = method_name
106
+ self._options = options
107
+
108
+ def bind(self, *args, **kwargs):
109
+ other_args_to_resolve = {
110
+ PARENT_CLASS_NODE_KEY: self._actor,
111
+ PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
112
+ }
113
+
114
+ node = ClassMethodNode(
115
+ self._method_name,
116
+ args,
117
+ kwargs,
118
+ self._options,
119
+ other_args_to_resolve=other_args_to_resolve,
120
+ )
121
+ self._actor._last_call = node
122
+ return node
123
+
124
+ def __getattr__(self, attr: str):
125
+ if attr == "remote":
126
+ raise AttributeError(
127
+ ".remote() cannot be used on ClassMethodNodes. Use .bind() instead "
128
+ "to express an symbolic actor call."
129
+ )
130
+ else:
131
+ return self.__getattribute__(attr)
132
+
133
+ def options(self, **options):
134
+ self._options = options
135
+ return self
136
+
137
+
138
+ class _ClassMethodOutput:
139
+ """Represents a class method output in a Ray function DAG."""
140
+
141
+ def __init__(self, class_method_call: "ClassMethodNode", output_idx: int):
142
+ # The upstream class method call that returns multiple values.
143
+ self._class_method_call = class_method_call
144
+ # The output index of the return value from the upstream class method call.
145
+ self._output_idx = output_idx
146
+
147
+ @property
148
+ def class_method_call(self) -> "ClassMethodNode":
149
+ return self._class_method_call
150
+
151
+ @property
152
+ def output_idx(self) -> int:
153
+ return self._output_idx
154
+
155
+
156
+ @DeveloperAPI
157
+ class ClassMethodNode(DAGNode):
158
+ """Represents an actor method invocation in a Ray function DAG."""
159
+
160
+ def __init__(
161
+ self,
162
+ method_name: str,
163
+ method_args: Tuple[Any],
164
+ method_kwargs: Dict[str, Any],
165
+ method_options: Dict[str, Any],
166
+ other_args_to_resolve: Dict[str, Any],
167
+ ):
168
+ self._bound_args = method_args or []
169
+ self._bound_kwargs = method_kwargs or {}
170
+ self._bound_options = method_options or {}
171
+ self._method_name: str = method_name
172
+ # Parse other_args_to_resolve and assign to variables
173
+ self._parent_class_node: Union[
174
+ ClassNode, ReferenceType["ray._private.actor.ActorHandle"]
175
+ ] = other_args_to_resolve.get(PARENT_CLASS_NODE_KEY)
176
+ # Used to track lineage of ClassMethodCall to preserve deterministic
177
+ # submission and execution order.
178
+ self._prev_class_method_call: Optional[
179
+ ClassMethodNode
180
+ ] = other_args_to_resolve.get(PREV_CLASS_METHOD_CALL_KEY, None)
181
+ # The index/order when bind() is called on this class method
182
+ self._bind_index: Optional[int] = other_args_to_resolve.get(
183
+ BIND_INDEX_KEY, None
184
+ )
185
+ # Represent if the ClassMethodNode is a class method output. If True,
186
+ # the node is a placeholder for a return value from the ClassMethodNode
187
+ # that returns multiple values. If False, the node is a class method call.
188
+ self._is_class_method_output: bool = other_args_to_resolve.get(
189
+ IS_CLASS_METHOD_OUTPUT_KEY, False
190
+ )
191
+ # Represents the return value from the upstream ClassMethodNode that
192
+ # returns multiple values. If the node is a class method call, this is None.
193
+ self._class_method_output: Optional[_ClassMethodOutput] = None
194
+ if self._is_class_method_output:
195
+ # Set the upstream ClassMethodNode and the output index of the return
196
+ # value from `method_args`.
197
+ self._class_method_output = _ClassMethodOutput(
198
+ method_args[0], method_args[1]
199
+ )
200
+
201
+ # The actor creation task dependency is encoded as the first argument,
202
+ # and the ordering dependency as the second, which ensures they are
203
+ # executed prior to this node.
204
+ super().__init__(
205
+ method_args,
206
+ method_kwargs,
207
+ method_options,
208
+ other_args_to_resolve=other_args_to_resolve,
209
+ )
210
+
211
+ def _copy_impl(
212
+ self,
213
+ new_args: List[Any],
214
+ new_kwargs: Dict[str, Any],
215
+ new_options: Dict[str, Any],
216
+ new_other_args_to_resolve: Dict[str, Any],
217
+ ):
218
+ return ClassMethodNode(
219
+ self._method_name,
220
+ new_args,
221
+ new_kwargs,
222
+ new_options,
223
+ other_args_to_resolve=new_other_args_to_resolve,
224
+ )
225
+
226
+ def _execute_impl(self, *args, **kwargs):
227
+ """Executor of ClassMethodNode by ray.remote()
228
+
229
+ Args and kwargs are to match base class signature, but not in the
230
+ implementation. All args and kwargs should be resolved and replaced
231
+ with value in bound_args and bound_kwargs via bottom-up recursion when
232
+ current node is executed.
233
+ """
234
+ if self.is_class_method_call:
235
+ method_body = getattr(self._parent_class_node, self._method_name)
236
+ # Execute with bound args.
237
+ return method_body.options(**self._bound_options).remote(
238
+ *self._bound_args,
239
+ **self._bound_kwargs,
240
+ )
241
+ else:
242
+ assert self._class_method_output is not None
243
+ return self._bound_args[0][self._class_method_output.output_idx]
244
+
245
+ def __str__(self) -> str:
246
+ return get_dag_node_str(self, f"{self._method_name}()")
247
+
248
+ def __repr__(self) -> str:
249
+ return self.__str__()
250
+
251
+ def get_method_name(self) -> str:
252
+ return self._method_name
253
+
254
+ def _get_bind_index(self) -> int:
255
+ return self._bind_index
256
+
257
+ def _get_remote_method(self, method_name):
258
+ method_body = getattr(self._parent_class_node, method_name)
259
+ return method_body
260
+
261
+ def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]:
262
+ if not isinstance(self._parent_class_node, ray.actor.ActorHandle):
263
+ return None
264
+ return self._parent_class_node
265
+
266
+ @property
267
+ def num_returns(self) -> int:
268
+ """
269
+ Return the number of return values from the class method call. If the
270
+ node is a class method output, return the number of return values from
271
+ the upstream class method call.
272
+ """
273
+
274
+ if self.is_class_method_call:
275
+ num_returns = self._bound_options.get("num_returns", None)
276
+ if num_returns is None:
277
+ method = self._get_remote_method(self._method_name)
278
+ num_returns = method.__getstate__()["num_returns"]
279
+ return num_returns
280
+ else:
281
+ assert self._class_method_output is not None
282
+ return self._class_method_output.class_method_call.num_returns
283
+
284
+ @property
285
+ def is_class_method_call(self) -> bool:
286
+ """
287
+ Return True if the node is a class method call, False if the node is a
288
+ class method output.
289
+ """
290
+ return not self._is_class_method_output
291
+
292
+ @property
293
+ def is_class_method_output(self) -> bool:
294
+ """
295
+ Return True if the node is a class method output, False if the node is a
296
+ class method call.
297
+ """
298
+ return self._is_class_method_output
299
+
300
+ @property
301
+ def class_method_call(self) -> Optional["ClassMethodNode"]:
302
+ """
303
+ Return the upstream class method call that returns multiple values. If
304
+ the node is a class method output, return None.
305
+ """
306
+
307
+ if self._class_method_output is None:
308
+ return None
309
+ return self._class_method_output.class_method_call
310
+
311
+ @property
312
+ def output_idx(self) -> Optional[int]:
313
+ """
314
+ Return the output index of the return value from the upstream class
315
+ method call that returns multiple values. If the node is a class method
316
+ call, return None.
317
+ """
318
+
319
+ if self._class_method_output is None:
320
+ return None
321
+ return self._class_method_output.output_idx
.venv/lib/python3.11/site-packages/ray/dag/collective_node.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Tuple, Optional, TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ import torch
5
+
6
+ import ray
7
+ from ray.dag import (
8
+ DAGNode,
9
+ ClassMethodNode,
10
+ )
11
+ from ray.dag.constants import COLLECTIVE_OPERATION_KEY
12
+ from ray.experimental.channel import ChannelContext
13
+ from ray.experimental.channel.torch_tensor_nccl_channel import _init_communicator
14
+ from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
15
+ from ray.experimental.util.types import _CollectiveOp, ReduceOp
16
+ from ray.util.annotations import DeveloperAPI
17
+
18
+
19
+ class _CollectiveOperation:
20
+ """
21
+ Represent metadata for a NCCL collective operation.
22
+
23
+ Args:
24
+ input_nodes: A list of input nodes to the collective operation.
25
+ op: The collective operation to perform.
26
+ transport: The transport to use for the collective operation.
27
+
28
+ Requirements:
29
+ 1. Input nodes are unique.
30
+ 2. Actor handles are unique.
31
+ 3. Actor handles match the custom NCCL group if specified.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_nodes: List[DAGNode],
37
+ op: _CollectiveOp,
38
+ transport: Optional[Union[str, Communicator]] = None,
39
+ ):
40
+ if len(input_nodes) == 0:
41
+ raise ValueError("Expected input nodes for a collective operation")
42
+ if len(set(input_nodes)) != len(input_nodes):
43
+ raise ValueError("Expected unique input nodes for a collective operation")
44
+
45
+ self._actor_handles: List["ray.actor.ActorHandle"] = []
46
+ for input_node in input_nodes:
47
+ actor_handle = input_node._get_actor_handle()
48
+ if actor_handle is None:
49
+ raise ValueError("Expected an actor handle from the input node")
50
+ self._actor_handles.append(actor_handle)
51
+ if len(set(self._actor_handles)) != len(self._actor_handles):
52
+ invalid_input_nodes = [
53
+ input_node
54
+ for input_node in input_nodes
55
+ if self._actor_handles.count(input_node._get_actor_handle()) > 1
56
+ ]
57
+ raise ValueError(
58
+ "Expected unique actor handles for a collective operation, "
59
+ "but found duplicate actor handles from input nodes: "
60
+ f"{invalid_input_nodes}"
61
+ )
62
+
63
+ self._op = op
64
+ if not isinstance(self._op, ReduceOp):
65
+ raise NotImplementedError("Only ReduceOp is implemented")
66
+ if transport is None:
67
+ transport = TorchTensorType.NCCL
68
+ self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
69
+ if isinstance(transport, Communicator):
70
+ if set(transport.get_actor_handles()) != set(self._actor_handles):
71
+ raise ValueError(
72
+ "Expected actor handles to match the custom NCCL group"
73
+ )
74
+
75
+ def __str__(self) -> str:
76
+ return (
77
+ f"CollectiveGroup("
78
+ f"_actor_handles={self._actor_handles}, "
79
+ f"_op={self._op}, "
80
+ f"_type_hint={self._type_hint})"
81
+ )
82
+
83
+ @property
84
+ def actor_handles(self) -> List["ray.actor.ActorHandle"]:
85
+ return self._actor_handles
86
+
87
+ @property
88
+ def type_hint(self) -> TorchTensorType:
89
+ return self._type_hint
90
+
91
+ def init_communicator(self, communicator_id: Optional[str] = None) -> str:
92
+ """
93
+ Initialize the communicator if it has not been initialized yet. If
94
+ `communicator_id` is provided, it means the communicator has already
95
+ been initialized.
96
+ """
97
+ type_hint = self._type_hint
98
+ if type_hint.communicator_id is not None:
99
+ return type_hint.communicator_id
100
+ if communicator_id is None:
101
+ communicator_id = _init_communicator(
102
+ self._actor_handles, type_hint.get_custom_communicator()
103
+ )
104
+ type_hint.set_communicator_id(communicator_id)
105
+ return communicator_id
106
+
107
+ def get_communicator(self) -> Communicator:
108
+ if self._type_hint.communicator_id is not None:
109
+ ctx = ChannelContext.get_current()
110
+ communicator = ctx.communicators[self._type_hint.communicator_id]
111
+ elif self._type_hint.get_custom_communicator() is not None:
112
+ communicator = self._type_hint.get_custom_communicator()
113
+ else:
114
+ raise ValueError("Expected a NCCL group")
115
+ return communicator
116
+
117
+ def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
118
+ """
119
+ Call the collective operation on the input tensor. An output tensor is
120
+ allocated and returned.
121
+ """
122
+ import torch
123
+
124
+ if not isinstance(send_buf, torch.Tensor):
125
+ raise ValueError("Expected a torch tensor")
126
+ communicator = self.get_communicator()
127
+ recv_buf = torch.empty_like(send_buf)
128
+ communicator.allreduce(send_buf, recv_buf, self._op)
129
+ return recv_buf
130
+
131
+
132
+ @DeveloperAPI
133
+ class CollectiveOutputNode(ClassMethodNode):
134
+ """Represent an output node from a NCCL collective operation in a Ray DAG."""
135
+
136
+ def __init__(
137
+ self,
138
+ method_name: str,
139
+ method_args: Tuple[
140
+ DAGNode,
141
+ ],
142
+ method_kwargs: Dict[str, Any],
143
+ method_options: Dict[str, Any],
144
+ other_args_to_resolve: Dict[str, Any],
145
+ ):
146
+ # Parse the input node.
147
+ if not (
148
+ isinstance(method_args, tuple)
149
+ and len(method_args) == 1
150
+ and isinstance(method_args[0], DAGNode)
151
+ ):
152
+ raise ValueError("Expected a single input node")
153
+ self._input_node = method_args[0]
154
+ # Parse the collective operation.
155
+ self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
156
+ COLLECTIVE_OPERATION_KEY, None
157
+ )
158
+ if self._collective_op is None:
159
+ raise ValueError("Expected a collective operation")
160
+
161
+ super().__init__(
162
+ method_name,
163
+ method_args,
164
+ method_kwargs,
165
+ method_options,
166
+ other_args_to_resolve,
167
+ )
168
+
169
+ def _copy_impl(
170
+ self,
171
+ new_args: List[Any],
172
+ new_kwargs: Dict[str, Any],
173
+ new_options: Dict[str, Any],
174
+ new_other_args_to_resolve: Dict[str, Any],
175
+ ):
176
+ return CollectiveOutputNode(
177
+ self._method_name,
178
+ new_args,
179
+ new_kwargs,
180
+ new_options,
181
+ other_args_to_resolve=new_other_args_to_resolve,
182
+ )
183
+
184
+ def _execute_impl(self, *args, **kwargs):
185
+ raise NotImplementedError(
186
+ "CollectiveOutputNode is only supported with dag.experimental_compile()"
187
+ )
188
+
189
+ @property
190
+ def collective_op(self) -> _CollectiveOperation:
191
+ return self._collective_op
.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/dag/conftest.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ import ray
5
+
6
+ TEST_NAMESPACE = "ray_dag_test_namespace"
7
+
8
+
9
+ @pytest.fixture(scope="session")
10
+ def shared_ray_instance():
11
+ # Remove ray address for test ray cluster in case we have
12
+ # lingering RAY_ADDRESS="http://127.0.0.1:8265" from previous local job
13
+ # submissions.
14
+ if "RAY_ADDRESS" in os.environ:
15
+ del os.environ["RAY_ADDRESS"]
16
+ yield ray.init(num_cpus=16, namespace=TEST_NAMESPACE, log_to_driver=True)
.venv/lib/python3.11/site-packages/ray/dag/constants.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Reserved keys used to handle ClassMethodNode in Ray DAG building.
4
+ PARENT_CLASS_NODE_KEY = "parent_class_node"
5
+ PREV_CLASS_METHOD_CALL_KEY = "prev_class_method_call"
6
+ BIND_INDEX_KEY = "bind_index"
7
+ IS_CLASS_METHOD_OUTPUT_KEY = "is_class_method_output"
8
+
9
+ # Reserved keys used to handle CollectiveOutputNode in Ray DAG building.
10
+ COLLECTIVE_OPERATION_KEY = "collective_operation"
11
+
12
+ # Reserved key to distinguish DAGNode type and avoid collision with user dict.
13
+ DAGNODE_TYPE_KEY = "__dag_node_type__"
14
+
15
+ # Feature flag to turn off the deadlock detection.
16
+ RAY_CGRAPH_ENABLE_DETECT_DEADLOCK = (
17
+ os.environ.get("RAY_CGRAPH_ENABLE_DETECT_DEADLOCK", "1") == "1"
18
+ )
19
+
20
+ # Feature flag to turn on profiling.
21
+ RAY_CGRAPH_ENABLE_PROFILING = os.environ.get("RAY_CGRAPH_ENABLE_PROFILING", "0") == "1"
22
+
23
+ # Feature flag to turn on NVTX (NVIDIA Tools Extension Library) profiling.
24
+ # With this flag, Compiled Graph uses nvtx to automatically annotate and profile
25
+ # function calls during each actor's execution loop.
26
+ RAY_CGRAPH_ENABLE_NVTX_PROFILING = (
27
+ os.environ.get("RAY_CGRAPH_ENABLE_NVTX_PROFILING", "0") == "1"
28
+ )
29
+
30
+ # Feature flag to turn on visualization of the execution schedule.
31
+ RAY_CGRAPH_VISUALIZE_SCHEDULE = (
32
+ os.environ.get("RAY_CGRAPH_VISUALIZE_SCHEDULE", "0") == "1"
33
+ )
.venv/lib/python3.11/site-packages/ray/dag/context.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import os
3
+ import threading
4
+ from typing import Optional
5
+ from ray.util.annotations import DeveloperAPI
6
+
7
+ # The context singleton on this process.
8
+ _default_context: "Optional[DAGContext]" = None
9
+ _context_lock = threading.Lock()
10
+
11
+ DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10))
12
+ DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10))
13
+ DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30))
14
+ DEFAULT_READ_ITERATION_TIMEOUT_S = float(
15
+ os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1)
16
+ )
17
+ # Default buffer size is 1MB.
18
+ DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6))
19
+ # The default number of in-flight executions that can be submitted before consuming the
20
+ # output.
21
+ DEFAULT_MAX_INFLIGHT_EXECUTIONS = int(
22
+ os.environ.get("RAY_CGRAPH_max_inflight_executions", 10)
23
+ )
24
+
25
+ DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
26
+ os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0)
27
+ )
28
+
29
+
30
+ @DeveloperAPI
31
+ @dataclass
32
+ class DAGContext:
33
+ """Global settings for Ray DAG.
34
+
35
+ You can configure parameters in the DAGContext by setting the environment
36
+ variables, `RAY_CGRAPH_<param>` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python.
37
+
38
+ Examples:
39
+ >>> from ray.dag import DAGContext
40
+ >>> DAGContext.get_current().buffer_size_bytes
41
+ 1000000
42
+ >>> DAGContext.get_current().buffer_size_bytes = 500
43
+ >>> DAGContext.get_current().buffer_size_bytes
44
+ 500
45
+
46
+ Args:
47
+ submit_timeout: The maximum time in seconds to wait for execute()
48
+ calls.
49
+ get_timeout: The maximum time in seconds to wait when retrieving
50
+ a result from the DAG during `ray.get`. This should be set to a
51
+ value higher than the expected time to execute the entire DAG.
52
+ teardown_timeout: The maximum time in seconds to wait for the DAG to
53
+ cleanly shut down.
54
+ read_iteration_timeout: The timeout in seconds for each read iteration
55
+ that reads one of the input channels. If the timeout is reached, the
56
+ read operation will be interrupted and will try to read the next
57
+ input channel. It must be less than or equal to `get_timeout`.
58
+ buffer_size_bytes: The initial buffer size in bytes for messages
59
+ that can be passed between tasks in the DAG. The buffers will
60
+ be automatically resized if larger messages are written to the
61
+ channel.
62
+ max_inflight_executions: The maximum number of in-flight executions that
63
+ can be submitted via `execute` or `execute_async` before consuming
64
+ the output using `ray.get()`. If the caller submits more executions,
65
+ `RayCgraphCapacityExceeded` is raised.
66
+ overlap_gpu_communication: (experimental) Whether to overlap GPU
67
+ communication with computation during DAG execution. If True, the
68
+ communication and computation can be overlapped, which can improve
69
+ the performance of the DAG execution.
70
+ """
71
+
72
+ submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S
73
+ get_timeout: int = DEFAULT_GET_TIMEOUT_S
74
+ teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S
75
+ read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S
76
+ buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES
77
+ max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
78
+ overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION
79
+
80
+ def __post_init__(self):
81
+ if self.read_iteration_timeout > self.get_timeout:
82
+ raise ValueError(
83
+ "RAY_CGRAPH_read_iteration_timeout_s "
84
+ f"({self.read_iteration_timeout}) must be less than or equal to "
85
+ f"RAY_CGRAPH_get_timeout ({self.get_timeout})"
86
+ )
87
+
88
+ @staticmethod
89
+ def get_current() -> "DAGContext":
90
+ """Get or create a singleton context.
91
+
92
+ If the context has not yet been created in this process, it will be
93
+ initialized with default settings.
94
+ """
95
+ global _default_context
96
+
97
+ with _context_lock:
98
+ if _default_context is None:
99
+ _default_context = DAGContext()
100
+
101
+ return _default_context
.venv/lib/python3.11/site-packages/ray/dag/dag_node.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from ray.experimental.channel.auto_transport_type import AutoTransportType
3
+ from ray.experimental.channel.torch_tensor_type import TorchTensorType
4
+ import ray
5
+ from ray.dag.base import DAGNodeBase
6
+ from ray.dag.py_obj_scanner import _PyObjScanner
7
+ from ray.util.annotations import DeveloperAPI
8
+
9
+ from itertools import chain
10
+
11
+ from typing import (
12
+ Optional,
13
+ Union,
14
+ List,
15
+ Tuple,
16
+ Dict,
17
+ Any,
18
+ TypeVar,
19
+ Callable,
20
+ )
21
+ import uuid
22
+ import asyncio
23
+
24
+ from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
25
+ from ray.experimental.channel import ChannelOutputType
26
+ from ray.experimental.channel.communicator import Communicator
27
+
28
+ T = TypeVar("T")
29
+
30
+
31
+ @DeveloperAPI
32
+ class DAGNode(DAGNodeBase):
33
+ """Abstract class for a node in a Ray task graph.
34
+
35
+ A node has a type (e.g., FunctionNode), data (e.g., function options and
36
+ body), arguments (Python values, DAGNodes, and DAGNodes nested within Python
37
+ argument values) and options (Ray API .options() used for function, class
38
+ or class method)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ args: Tuple[Any],
44
+ kwargs: Dict[str, Any],
45
+ options: Dict[str, Any],
46
+ other_args_to_resolve: Dict[str, Any],
47
+ ):
48
+ """
49
+ args:
50
+ args (Tuple[Any]): Bound node arguments.
51
+ ex: func_or_class.bind(1)
52
+ kwargs (Dict[str, Any]): Bound node keyword arguments.
53
+ ex: func_or_class.bind(a=1)
54
+ options (Dict[str, Any]): Bound node options arguments.
55
+ ex: func_or_class.options(num_cpus=2)
56
+ other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve
57
+ that's specific to subclass implementation without exposing
58
+ as args in base class, example: ClassMethodNode
59
+ """
60
+ self._bound_args: Tuple[Any] = args or []
61
+ self._bound_kwargs: Dict[str, Any] = kwargs or {}
62
+ self._bound_options: Dict[str, Any] = options or {}
63
+ self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = (
64
+ other_args_to_resolve or {}
65
+ )
66
+
67
+ # The list of nodes that use this DAG node as an argument.
68
+ self._downstream_nodes: List["DAGNode"] = []
69
+
70
+ # UUID that is not changed over copies of this node.
71
+ self._stable_uuid = uuid.uuid4().hex
72
+
73
+ # Indicates whether this DAG node contains nested DAG nodes.
74
+ # Nested DAG nodes are allowed in traditional DAGs but not
75
+ # in Ray Compiled Graphs, except for MultiOutputNode.
76
+ self._args_contain_nested_dag_node = False
77
+
78
+ # The list of nodes that this DAG node uses as an argument.
79
+ self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()
80
+
81
+ # Cached values from last call to execute()
82
+ self.cache_from_last_execute = {}
83
+
84
+ self._type_hint: ChannelOutputType = ChannelOutputType()
85
+
86
+ # If the original type hint is an AutoTransportType, we make a copy
87
+ # here when it is resolved to the actual type, as additional debugging
88
+ # information. Otherwise, it is None.
89
+ self._original_type_hint: Optional[ChannelOutputType] = None
90
+
91
+ # Whether this node calls `experimental_compile`.
92
+ self.is_cgraph_output_node = False
93
+
94
+ def _collect_upstream_nodes(self) -> List["DAGNode"]:
95
+ """
96
+ Retrieve upstream nodes and update their downstream dependencies.
97
+
98
+ Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
99
+ `other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
100
+ builds the upstream/downstream relationship based only on args. Be cautious
101
+ when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.
102
+
103
+ TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
104
+ circular references. Therefore, it relies on the garbage collector to clean
105
+ them up instead of reference counting. We should consider using weak references
106
+ to avoid circular references.
107
+ """
108
+ upstream_nodes: List["DAGNode"] = []
109
+
110
+ # Ray Compiled Graphs do not allow nested DAG nodes in arguments.
111
+ # Specifically, a DAGNode should not be placed inside any type of
112
+ # container. However, we only know if this is a compiled graph
113
+ # when calling `experimental_compile`. Therefore, we need to check
114
+ # in advance if the arguments contain nested DAG nodes and raise
115
+ # an error after compilation.
116
+ assert hasattr(self._bound_args, "__iter__")
117
+ for arg in self._bound_args:
118
+ if isinstance(arg, DAGNode):
119
+ upstream_nodes.append(arg)
120
+ else:
121
+ scanner = _PyObjScanner()
122
+ dag_nodes = scanner.find_nodes(arg)
123
+ upstream_nodes.extend(dag_nodes)
124
+ scanner.clear()
125
+ self._args_contain_nested_dag_node = len(dag_nodes) > 0
126
+
127
+ scanner = _PyObjScanner()
128
+ other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
129
+ [
130
+ self._bound_kwargs,
131
+ self._bound_other_args_to_resolve,
132
+ ]
133
+ )
134
+ upstream_nodes.extend(other_upstream_nodes)
135
+ scanner.clear()
136
+ # Update dependencies.
137
+ for upstream_node in upstream_nodes:
138
+ upstream_node._downstream_nodes.append(self)
139
+ return upstream_nodes
140
+
141
+ def with_tensor_transport(
142
+ self,
143
+ transport: Optional[Union[str, Communicator]] = "auto",
144
+ _static_shape: bool = False,
145
+ _direct_return: bool = False,
146
+ ):
147
+ if transport == "auto":
148
+ self._type_hint = AutoTransportType(
149
+ _static_shape=_static_shape,
150
+ _direct_return=_direct_return,
151
+ )
152
+ elif transport == "nccl":
153
+ self._type_hint = TorchTensorType(
154
+ transport=transport,
155
+ _static_shape=_static_shape,
156
+ _direct_return=_direct_return,
157
+ )
158
+ else:
159
+ if not isinstance(transport, Communicator):
160
+ raise ValueError(
161
+ "transport must be 'auto', 'nccl' or a Communicator type"
162
+ )
163
+ self._type_hint = TorchTensorType(
164
+ transport=transport,
165
+ _static_shape=_static_shape,
166
+ _direct_return=_direct_return,
167
+ )
168
+ return self
169
+
170
+ @property
171
+ def type_hint(self) -> ChannelOutputType:
172
+ return self._type_hint
173
+
174
+ @type_hint.setter
175
+ def type_hint(self, type_hint: ChannelOutputType) -> None:
176
+ if isinstance(self._type_hint, AutoTransportType):
177
+ self._original_type_hint = self._type_hint
178
+ self._type_hint = type_hint
179
+
180
+ def get_args(self) -> Tuple[Any]:
181
+ """Return the tuple of arguments for this node."""
182
+
183
+ return self._bound_args
184
+
185
+ def get_kwargs(self) -> Dict[str, Any]:
186
+ """Return the dict of keyword arguments for this node."""
187
+
188
+ return self._bound_kwargs.copy()
189
+
190
+ def get_options(self) -> Dict[str, Any]:
191
+ """Return the dict of options arguments for this node."""
192
+
193
+ return self._bound_options.copy()
194
+
195
+ def get_other_args_to_resolve(self) -> Dict[str, Any]:
196
+ """Return the dict of other args to resolve arguments for this node."""
197
+ return self._bound_other_args_to_resolve.copy()
198
+
199
+ def get_stable_uuid(self) -> str:
200
+ """Return stable uuid for this node.
201
+ 1) Generated only once at first instance creation
202
+ 2) Stable across pickling, replacement and JSON serialization.
203
+ """
204
+ return self._stable_uuid
205
+
206
+ async def get_object_refs_from_last_execute(self) -> Dict[str, Any]:
207
+ """Gets cached object refs from the last call to execute().
208
+
209
+ After this DAG is executed through execute(), retrieves a map between node
210
+ UUID to a reference to the return value of the default executor on that node.
211
+ """
212
+ cache = {}
213
+ for node_uuid, value in self.cache_from_last_execute.items():
214
+ if isinstance(value, asyncio.Task):
215
+ cache[node_uuid] = await value
216
+ else:
217
+ cache[node_uuid] = value
218
+
219
+ return cache
220
+
221
+ def clear_cache(self):
222
+ self.cache_from_last_execute = {}
223
+
224
+ def experimental_compile(
225
+ self,
226
+ _submit_timeout: Optional[float] = None,
227
+ _buffer_size_bytes: Optional[int] = None,
228
+ enable_asyncio: bool = False,
229
+ _max_inflight_executions: Optional[int] = None,
230
+ _overlap_gpu_communication: Optional[bool] = None,
231
+ ) -> "ray.dag.CompiledDAG":
232
+ """Compile an accelerated execution path for this DAG.
233
+
234
+ Args:
235
+ _submit_timeout: The maximum time in seconds to wait for execute() calls.
236
+ None means using default timeout, 0 means immediate timeout
237
+ (immediate success or timeout without blocking), -1 means
238
+ infinite timeout (block indefinitely).
239
+ _buffer_size_bytes: The initial buffer size in bytes for messages
240
+ that can be passed between tasks in the DAG. The buffers will
241
+ be automatically resized if larger messages are written to the
242
+ channel.
243
+ enable_asyncio: Whether to enable asyncio for this DAG.
244
+ _max_inflight_executions: The maximum number of in-flight executions that
245
+ can be submitted via `execute` or `execute_async` before consuming
246
+ the output using `ray.get()`. If the caller submits more executions,
247
+ `RayCgraphCapacityExceeded` is raised.
248
+ _overlap_gpu_communication: (experimental) Whether to overlap GPU
249
+ communication with computation during DAG execution. If True, the
250
+ communication and computation can be overlapped, which can improve
251
+ the performance of the DAG execution. If None, the default value
252
+ will be used.
253
+
254
+ Returns:
255
+ A compiled DAG.
256
+ """
257
+ from ray.dag import DAGContext
258
+
259
+ ctx = DAGContext.get_current()
260
+ if _buffer_size_bytes is None:
261
+ _buffer_size_bytes = ctx.buffer_size_bytes
262
+
263
+ # Validate whether this DAG node has already been compiled.
264
+ if self.is_cgraph_output_node:
265
+ raise ValueError(
266
+ "It is not allowed to call `experimental_compile` on the same DAG "
267
+ "object multiple times no matter whether `teardown` is called or not. "
268
+ "Please reuse the existing compiled DAG or create a new one."
269
+ )
270
+ # Whether this node is an output node in the DAG. We cannot determine
271
+ # this in the constructor because the output node is determined when
272
+ # `experimental_compile` is called.
273
+ self.is_cgraph_output_node = True
274
+ return build_compiled_dag_from_ray_dag(
275
+ self,
276
+ _submit_timeout,
277
+ _buffer_size_bytes,
278
+ enable_asyncio,
279
+ _max_inflight_executions,
280
+ _overlap_gpu_communication,
281
+ )
282
+
283
+ def execute(
284
+ self, *args, _ray_cache_refs: bool = False, **kwargs
285
+ ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
286
+ """Execute this DAG using the Ray default executor _execute_impl().
287
+
288
+ Args:
289
+ _ray_cache_refs: If true, stores the the default executor's return values
290
+ on each node in this DAG in a cache. These should be a mix of:
291
+ - ray.ObjectRefs pointing to the outputs of method and function nodes
292
+ - Serve handles for class nodes
293
+ - resolved values representing user input at runtime
294
+ """
295
+
296
+ def executor(node):
297
+ return node._execute_impl(*args, **kwargs)
298
+
299
+ result = self.apply_recursive(executor)
300
+ if _ray_cache_refs:
301
+ self.cache_from_last_execute = executor.cache
302
+ return result
303
+
304
+ def _get_toplevel_child_nodes(self) -> List["DAGNode"]:
305
+ """Return the list of nodes specified as top-level args.
306
+
307
+ For example, in `f.remote(a, [b])`, only `a` is a top-level arg.
308
+
309
+ This list of nodes are those that are typically resolved prior to
310
+ task execution in Ray. This does not include nodes nested within args.
311
+ For that, use ``_get_all_child_nodes()``.
312
+ """
313
+
314
+ # we use List instead of Set here because the hash key of the node
315
+ # object changes each time we create it. So if using Set here, the
316
+ # order of returned children can be different if we create the same
317
+ # nodes and dag one more time.
318
+ children = []
319
+ for a in self.get_args():
320
+ if isinstance(a, DAGNode):
321
+ if a not in children:
322
+ children.append(a)
323
+ for a in self.get_kwargs().values():
324
+ if isinstance(a, DAGNode):
325
+ if a not in children:
326
+ children.append(a)
327
+ for a in self.get_other_args_to_resolve().values():
328
+ if isinstance(a, DAGNode):
329
+ if a not in children:
330
+ children.append(a)
331
+ return children
332
+
333
+ def _get_all_child_nodes(self) -> List["DAGNode"]:
334
+ """Return the list of nodes referenced by the args, kwargs, and
335
+ args_to_resolve in current node, even they're deeply nested.
336
+
337
+ Examples:
338
+ f.remote(a, [b]) -> [a, b]
339
+ f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
340
+ """
341
+
342
+ scanner = _PyObjScanner()
343
+ # we use List instead of Set here, reason explained
344
+ # in `_get_toplevel_child_nodes`.
345
+ children = []
346
+ for n in scanner.find_nodes(
347
+ [
348
+ self._bound_args,
349
+ self._bound_kwargs,
350
+ self._bound_other_args_to_resolve,
351
+ ]
352
+ ):
353
+ if n not in children:
354
+ children.append(n)
355
+ scanner.clear()
356
+ return children
357
+
358
+ def _apply_and_replace_all_child_nodes(
359
+ self, fn: "Callable[[DAGNode], T]"
360
+ ) -> "DAGNode":
361
+ """Apply and replace all immediate child nodes using a given function.
362
+
363
+ This is a shallow replacement only. To recursively transform nodes in
364
+ the DAG, use ``apply_recursive()``.
365
+
366
+ Args:
367
+ fn: Callable that will be applied once to each child of this node.
368
+
369
+ Returns:
370
+ New DAGNode after replacing all child nodes.
371
+ """
372
+
373
+ replace_table = {}
374
+ # CloudPickler scanner object for current layer of DAGNode. Same
375
+ # scanner should be use for a full find & replace cycle.
376
+ scanner = _PyObjScanner()
377
+ # Find all first-level nested DAGNode children in args.
378
+ # Update replacement table and execute the replace.
379
+ for node in scanner.find_nodes(
380
+ [
381
+ self._bound_args,
382
+ self._bound_kwargs,
383
+ self._bound_other_args_to_resolve,
384
+ ]
385
+ ):
386
+ if node not in replace_table:
387
+ replace_table[node] = fn(node)
388
+ new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes(
389
+ replace_table
390
+ )
391
+ scanner.clear()
392
+
393
+ # Return updated copy of self.
394
+ return self._copy(
395
+ new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
396
+ )
397
+
398
+ def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
399
+ """Apply callable on each node in this DAG in a bottom-up tree walk.
400
+
401
+ Args:
402
+ fn: Callable that will be applied once to each node in the
403
+ DAG. It will be applied recursively bottom-up, so nodes can
404
+ assume the fn has been applied to their args already.
405
+
406
+ Returns:
407
+ Return type of the fn after application to the tree.
408
+ """
409
+
410
+ if not type(fn).__name__ == "_CachingFn":
411
+
412
+ class _CachingFn:
413
+ def __init__(self, fn):
414
+ self.cache = {}
415
+ self.fn = fn
416
+ self.fn.cache = self.cache
417
+ self.input_node_uuid = None
418
+
419
+ def __call__(self, node: "DAGNode"):
420
+ from ray.dag.input_node import InputNode
421
+
422
+ if node._stable_uuid not in self.cache:
423
+ self.cache[node._stable_uuid] = self.fn(node)
424
+ if isinstance(node, InputNode):
425
+ if not self.input_node_uuid:
426
+ self.input_node_uuid = node._stable_uuid
427
+ elif self.input_node_uuid != node._stable_uuid:
428
+ raise AssertionError(
429
+ "Each DAG should only have one unique InputNode."
430
+ )
431
+ return self.cache[node._stable_uuid]
432
+
433
+ fn = _CachingFn(fn)
434
+ else:
435
+ if self._stable_uuid in fn.cache:
436
+ return fn.cache[self._stable_uuid]
437
+
438
+ return fn(
439
+ self._apply_and_replace_all_child_nodes(
440
+ lambda node: node.apply_recursive(fn)
441
+ )
442
+ )
443
+
444
+ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
445
+ """
446
+ Traverse all nodes in the connected component of the DAG that contains
447
+ the `self` node, and apply the given function to each node.
448
+ """
449
+ visited = set()
450
+ queue = [self]
451
+ cgraph_output_node: Optional[DAGNode] = None
452
+
453
+ while queue:
454
+ node = queue.pop(0)
455
+ if node._args_contain_nested_dag_node:
456
+ self._raise_nested_dag_node_error(node._bound_args)
457
+
458
+ if node not in visited:
459
+ if node.is_cgraph_output_node:
460
+ # Validate whether there are multiple nodes that call
461
+ # `experimental_compile`.
462
+ if cgraph_output_node is not None:
463
+ raise ValueError(
464
+ "The DAG was compiled more than once. The following two "
465
+ "nodes call `experimental_compile`: "
466
+ f"(1) {cgraph_output_node}, (2) {node}"
467
+ )
468
+ cgraph_output_node = node
469
+ fn(node)
470
+ visited.add(node)
471
+ """
472
+ Add all unseen downstream and upstream nodes to the queue.
473
+ This function should be called by the root of the DAG. However,
474
+ in some invalid cases, some nodes may not be descendants of the
475
+ root. Therefore, we also add upstream nodes to the queue so that
476
+ a meaningful error message can be raised when the DAG is compiled.
477
+
478
+ ```
479
+ with InputNode() as inp:
480
+ dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)])
481
+ ```
482
+
483
+ In the above example, `a2.inc` is not a descendant of inp. If we only
484
+ add downstream nodes to the queue, the `a2.inc` node will not be visited
485
+ , and the error message will be hard to understand, such as a key error
486
+ in the compiled DAG.
487
+ """
488
+ for neighbor in chain.from_iterable(
489
+ [node._downstream_nodes, node._upstream_nodes]
490
+ ):
491
+ if neighbor not in visited:
492
+ queue.append(neighbor)
493
+
494
+ def _raise_nested_dag_node_error(self, args):
495
+ """
496
+ Raise an error for nested DAGNodes in Ray Compiled Graphs.
497
+
498
+ Args:
499
+ args: The arguments of the DAGNode.
500
+ """
501
+ for arg in args:
502
+ if isinstance(arg, DAGNode):
503
+ continue
504
+ else:
505
+ scanner = _PyObjScanner()
506
+ dag_nodes = scanner.find_nodes([arg])
507
+ scanner.clear()
508
+ if len(dag_nodes) > 0:
509
+ raise ValueError(
510
+ f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
511
+ f"in {self}. Please ensure that the argument is a "
512
+ "single DAGNode and that a DAGNode is not allowed to "
513
+ "be placed inside any type of container."
514
+ )
515
+ raise AssertionError(
516
+ "A DAGNode's args should contain nested DAGNodes as args, "
517
+ "but none were found during the compilation process. This is a "
518
+ "Ray internal error. Please report this issue to the Ray team."
519
+ )
520
+
521
+ def _find_root(self) -> "DAGNode":
522
+ """
523
+ Return the root node of the DAG. The root node must be an InputNode.
524
+ """
525
+ from ray.dag.input_node import InputNode
526
+
527
+ node = self
528
+ while not isinstance(node, InputNode):
529
+ if len(node._upstream_nodes) == 0:
530
+ raise ValueError(
531
+ "No InputNode found in the DAG: when traversing upwards, "
532
+ f"no upstream node was found for {node}."
533
+ )
534
+ node = node._upstream_nodes[0]
535
+ return node
536
+
537
+ def apply_functional(
538
+ self,
539
+ source_input_list: Any,
540
+ predictate_fn: Callable,
541
+ apply_fn: Callable,
542
+ ):
543
+ """
544
+ Apply a given function to DAGNodes in source_input_list, and return
545
+ the replaced inputs without mutating or coping any DAGNode.
546
+
547
+ Args:
548
+ source_input_list: Source inputs to extract and apply function on
549
+ all children DAGNode instances.
550
+ predictate_fn: Applied on each DAGNode instance found and determine
551
+ if we should apply function to it. Can be used to filter node
552
+ types.
553
+ apply_fn: Function to appy on the node on bound attributes. Example:
554
+ apply_fn = lambda node: node._get_serve_deployment_handle(
555
+ node._deployment, node._bound_other_args_to_resolve
556
+ )
557
+
558
+ Returns:
559
+ replaced_inputs: Outputs of apply_fn on DAGNodes in
560
+ source_input_list that passes predictate_fn.
561
+ """
562
+ replace_table = {}
563
+ scanner = _PyObjScanner()
564
+ for node in scanner.find_nodes(source_input_list):
565
+ if predictate_fn(node) and node not in replace_table:
566
+ replace_table[node] = apply_fn(node)
567
+
568
+ replaced_inputs = scanner.replace_nodes(replace_table)
569
+ scanner.clear()
570
+
571
+ return replaced_inputs
572
+
573
+ def _execute_impl(
574
+ self, *args, **kwargs
575
+ ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
576
+ """Execute this node, assuming args have been transformed already."""
577
+ raise NotImplementedError
578
+
579
+ def _copy_impl(
580
+ self,
581
+ new_args: List[Any],
582
+ new_kwargs: Dict[str, Any],
583
+ new_options: Dict[str, Any],
584
+ new_other_args_to_resolve: Dict[str, Any],
585
+ ) -> "DAGNode":
586
+ """Return a copy of this node with the given new args."""
587
+ raise NotImplementedError
588
+
589
+ def _copy(
590
+ self,
591
+ new_args: List[Any],
592
+ new_kwargs: Dict[str, Any],
593
+ new_options: Dict[str, Any],
594
+ new_other_args_to_resolve: Dict[str, Any],
595
+ ) -> "DAGNode":
596
+ """Return a copy of this node with the given new args."""
597
+ instance = self._copy_impl(
598
+ new_args, new_kwargs, new_options, new_other_args_to_resolve
599
+ )
600
+ instance._stable_uuid = self._stable_uuid
601
+ instance._type_hint = copy.deepcopy(self._type_hint)
602
+ instance._original_type_hint = copy.deepcopy(self._original_type_hint)
603
+ return instance
604
+
605
+ def __getstate__(self):
606
+ """Required due to overriding `__getattr__` else pickling fails."""
607
+ return self.__dict__
608
+
609
+ def __setstate__(self, d: Dict[str, Any]):
610
+ """Required due to overriding `__getattr__` else pickling fails."""
611
+ self.__dict__.update(d)
612
+
613
+ def __getattr__(self, attr: str):
614
+ if attr == "bind":
615
+ raise AttributeError(f".bind() cannot be used again on {type(self)} ")
616
+ elif attr == "remote":
617
+ raise AttributeError(
618
+ f".remote() cannot be used on {type(self)}. To execute the task "
619
+ "graph for this node, use .execute()."
620
+ )
621
+ else:
622
+ return self.__getattribute__(attr)
.venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import total_ordering
2
+ from enum import Enum
3
+ from typing import Set, Tuple, List, Dict, Optional
4
+ import copy
5
+ import logging
6
+ import ray
7
+ import heapq
8
+ from collections import defaultdict
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class _DAGNodeOperationType(Enum):
15
+ """
16
+ There are three types of operations that a DAG node can perform:
17
+ 1. READ: Read from an input channel.
18
+ 2. COMPUTE: Execute the method corresponding to the node.
19
+ 3. WRITE: Write to an output channel.
20
+ """
21
+
22
+ READ = "READ"
23
+ COMPUTE = "COMPUTE"
24
+ WRITE = "WRITE"
25
+
26
+ def viz_str(self):
27
+ """
28
+ A string representation of the operation type to be used in visualization.
29
+
30
+ The result string is a single character because conciseness is preferred.
31
+ """
32
+ if self == _DAGNodeOperationType.READ:
33
+ return "R"
34
+ elif self == _DAGNodeOperationType.COMPUTE:
35
+ return "C"
36
+ elif self == _DAGNodeOperationType.WRITE:
37
+ return "W"
38
+ assert False, f"Unknown operation type: {self}"
39
+
40
+
41
+ class _DAGNodeOperation:
42
+ def __init__(
43
+ self,
44
+ exec_task_idx: int,
45
+ operation_type: _DAGNodeOperationType,
46
+ method_name: Optional[str] = None,
47
+ ):
48
+ """
49
+ Args:
50
+ exec_task_idx: The index of the task that this operation belongs to
51
+ in the actor's ExecutableTask list. The index is not the same
52
+ as bind_index because there may be more tasks bound to an actor
53
+ than tasks that appear in the current compiled DAG.
54
+ operation_type: The type of operation to perform.
55
+ method_name: The name of the method that this operation originates
56
+ from. This is only for visualization and debugging purposes.
57
+ """
58
+ self.exec_task_idx = exec_task_idx
59
+ self.type = operation_type
60
+ self.method_name = method_name
61
+
62
+ def __repr__(self):
63
+ return (
64
+ f"_DAGNodeOperation("
65
+ f"exec_task_idx: {self.exec_task_idx}, "
66
+ f" type: {self.type})"
67
+ )
68
+
69
+ def viz_str(self):
70
+ """
71
+ A string representation of the node to be used in visualization.
72
+ """
73
+ return f"[{self.exec_task_idx}] {self.method_name} {self.type.viz_str()}"
74
+
75
+ def __hash__(self):
76
+ return hash((self.exec_task_idx, self.type))
77
+
78
+ def __eq__(self, other):
79
+ # An operation is uniquely identified by its `exec_task_idx` and type.
80
+ # `method_name` is only for debugging purposes.
81
+ return self.exec_task_idx == other.exec_task_idx and self.type == other.type
82
+
83
+
84
+ @total_ordering
85
+ class _DAGOperationGraphNode:
86
+ def __init__(
87
+ self,
88
+ operation: _DAGNodeOperation,
89
+ task_idx: int,
90
+ actor_handle: "ray.actor.ActorHandle",
91
+ requires_nccl: bool,
92
+ ):
93
+ """
94
+ _DAGOperationGraphNode represents a node in the DAG operation graph.
95
+ It contains information about the node's in-degree, out-degree, edges,
96
+ and the operation it performs.
97
+
98
+ Args:
99
+ operation: The operation that this node performs. The operation
100
+ can be a READ, COMPUTE, or WRITE operation.
101
+ task_idx: A unique index which can be used to index into
102
+ `CompiledDAG.idx_to_task` to get the corresponding task.
103
+ actor_handle: The actor handle to which this operation belongs.
104
+ requires_nccl: Whether this operation requires NCCL.
105
+ """
106
+ self.operation = operation
107
+ self.task_idx = task_idx
108
+ self.actor_handle = actor_handle
109
+ self.requires_nccl = requires_nccl
110
+ # The in_edges and out_edges are dicts of tuples to strings.
111
+ # Each tuple (the key) contains an integer `task_idx`, which can be
112
+ # used to index into `idx_to_task` to get the corresponding task,
113
+ # and a `_DAGNodeOperationType`, which can be READ, COMPUTE, or WRITE.
114
+ # The string (the value) is the visualization information of the edge,
115
+ # it is a tuple of a label of the edge and a boolean indicating whether
116
+ # the edge is a control dependency.
117
+ self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {}
118
+ self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {}
119
+ # The collective nodes are the nodes that belong to the same collective
120
+ # operation. Each node is represented by a tuple of its task idx and type.
121
+ self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set()
122
+ # The ready collective nodes are the nodes that are ready to be executed,
123
+ # i.e., their in-degrees are zero. When a collective node is ready, it
124
+ # will be added to the ready collective nodes of all the nodes in its
125
+ # collective operation.
126
+ self.ready_collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set()
127
+
128
+ def __repr__(self):
129
+ return (
130
+ f"_DAGOperationGraphNode("
131
+ f"operation: {self.operation}, "
132
+ f"task_idx: {self.task_idx}, "
133
+ f"actor_handle: {self.actor_handle}, "
134
+ f"requires_nccl: {self.requires_nccl})"
135
+ )
136
+
137
+ def __lt__(self, other: "_DAGOperationGraphNode"):
138
+ """
139
+ This function defines the order of the nodes in the priority queue used in
140
+ `_select_next_nodes`. The priority queue is a min-heap, so the node with
141
+ higher priority is considered "less than" the other node.
142
+ """
143
+
144
+ def compare(lhs: "_DAGOperationGraphNode", rhs: "_DAGOperationGraphNode"):
145
+ # If both nodes belong to the same actor, the node with the smaller
146
+ # `exec_task_idx` is prioritized. If two nodes belong to different
147
+ # actors, it approximates balancing the scheduled tasks across actors,
148
+ # by prioritizing the node with the smaller `exec_task_idx`. The tie
149
+ # is broken by the `task_idx`.
150
+ if lhs.operation.exec_task_idx != rhs.operation.exec_task_idx:
151
+ return lhs.operation.exec_task_idx < rhs.operation.exec_task_idx
152
+ return lhs.task_idx < rhs.task_idx
153
+
154
+ if self.actor_handle == other.actor_handle:
155
+ # When both nodes belong to the same actor, use the default comparison.
156
+ return compare(self, other)
157
+ elif self.is_nccl_op != other.is_nccl_op:
158
+ # When one node is a NCCL operation and the other is not, prioritize
159
+ # the non-NCCL operation.
160
+ return not self.is_nccl_op
161
+ else:
162
+ # When either both nodes are NCCL operations or both nodes are not
163
+ # NCCL operations, use the default comparison.
164
+ return compare(self, other)
165
+
166
+ def __eq__(self, other: "_DAGOperationGraphNode"):
167
+ """
168
+ Two operations are equal only when they have the same `exec_task_idx` and `type`
169
+ and belong to the same actor.
170
+ """
171
+ return (
172
+ self.actor_handle == other.actor_handle
173
+ and self.operation.exec_task_idx == other.operation.exec_task_idx
174
+ and self.operation.type == other.operation.type
175
+ )
176
+
177
+ def __hash__(self):
178
+ """
179
+ An operation is uniquely identified by its `task_idx` and type.
180
+ """
181
+ return hash((self.operation, self.task_idx))
182
+
183
+ @property
184
+ def in_degree(self) -> int:
185
+ return len(self.in_edges)
186
+
187
+ @property
188
+ def is_ready(self) -> bool:
189
+ """
190
+ If a node is not a NCCL collective, it is ready when it has a zero
191
+ in-degree. If it is a NCCL collective, it is ready when all the nodes
192
+ in its collective operation have zero in-degrees.
193
+ """
194
+ return self.in_degree == 0 and (
195
+ len(self.ready_collective_idxs) == len(self.collective_idxs)
196
+ )
197
+
198
+ @property
199
+ def is_read(self) -> bool:
200
+ return self.operation.type == _DAGNodeOperationType.READ
201
+
202
+ @property
203
+ def is_nccl_collective(self) -> bool:
204
+ """
205
+ A node is a NCCL collective if it is a compute node and requires NCCL.
206
+ """
207
+ return (
208
+ self.operation.type == _DAGNodeOperationType.COMPUTE and self.requires_nccl
209
+ )
210
+
211
+ @property
212
+ def is_nccl_write(self) -> bool:
213
+ """
214
+ A node is a NCCL write if it is a write node and requires NCCL.
215
+ """
216
+ return self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl
217
+
218
+ @property
219
+ def is_nccl_op(self) -> bool:
220
+ return self.is_nccl_collective or self.is_nccl_write
221
+
222
+ def viz_str(self):
223
+ """
224
+ A string representation of the node to be used in visualization.
225
+ """
226
+ return self.operation.viz_str()
227
+
228
+ @property
229
+ def _actor_id(self):
230
+ return self.actor_handle._ray_actor_id.hex()
231
+
232
+
233
+ def _add_edge(
234
+ from_node: _DAGOperationGraphNode,
235
+ to_node: _DAGOperationGraphNode,
236
+ label: str = "",
237
+ control_dependency: bool = False,
238
+ ):
239
+ """
240
+ Add an edge from `from_node` to `to_node`.
241
+
242
+ Args:
243
+ from_node: The node from which the edge originates.
244
+ to_node: The node to which the edge points.
245
+ label: The label of the edge. This will be used to annotate the edge
246
+ in the visualization of the execution schedule.
247
+ """
248
+ from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = (
249
+ label,
250
+ control_dependency,
251
+ )
252
+ to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = (
253
+ label,
254
+ control_dependency,
255
+ )
256
+
257
+
258
+ def _push_candidate_node_if_ready(
259
+ actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]],
260
+ graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
261
+ node: _DAGOperationGraphNode,
262
+ ) -> None:
263
+ # Collective operations are ready when all the collective nodes have zero
264
+ # in-degrees. Only one node per collective will be added as ready.
265
+ if node.is_nccl_collective:
266
+ for collective_node_metadata in node.collective_idxs:
267
+ task_idx, op_type = collective_node_metadata
268
+ collective_node = graph[task_idx][op_type]
269
+ collective_node.ready_collective_idxs.add(
270
+ (node.task_idx, node.operation.type)
271
+ )
272
+ if node.is_ready:
273
+ heapq.heappush(
274
+ actor_to_candidates[node.actor_handle._actor_id],
275
+ node,
276
+ )
277
+
278
+
279
+ def _select_next_nodes(
280
+ actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]],
281
+ graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
282
+ ) -> Optional[List[_DAGOperationGraphNode]]:
283
+ """
284
+ This function selects the next nodes for the topological sort to generate
285
+ execution schedule. If there are multiple candidate _DAGOperationGraphNodes,
286
+ select the node with the top priority. The priority is defined in
287
+ `_DAGOperationGraphNode.__lt__`.
288
+
289
+ For the implementation details, we maintain a priority queue for each actor,
290
+ where the head of the priority queue is the node with the smallest `exec_task_idx`.
291
+ When a node has a zero in-degree, it is added to the corresponding actor's
292
+ priority queue. For a node other than a NCCL collective node, it is ready to be
293
+ executed if it has a zero in-degree. For a NCCL collective node, it is ready to
294
+ be executed when all the nodes in its collective operation have zero in-degrees.
295
+
296
+ If a node is a NCCL collective node, it updates the `ready_collective_nodes` of
297
+ all the nodes in its collective operation. Unless all the nodes in its collective
298
+ group have zero in-degrees, this node is removed from the candidate list.
299
+ Eventually, exactly one NCCL collective node from its collective operation is
300
+ selected from the candidate list.
301
+
302
+ If the selected node is a NCCL write node, select all the downstream NCCL
303
+ read nodes. If the selected node is a NCCL collective node, select all the NCCL
304
+ compute nodes in its collective operation.
305
+
306
+ Args:
307
+ actor_to_candidates: A dictionary mapping an actor id to a list of
308
+ candidate nodes. The list is maintained as a priority queue, so
309
+ the head of the queue, i.e., `candidates[0]`, is the node with
310
+ the smallest `bind_index`.
311
+ graph: A dictionary mapping the index of a task to a dictionary of its
312
+ _DAGOperationGraphNodes for different operations.
313
+
314
+ Returns:
315
+ A list of _DAGOperationGraphNodes to be placed into the corresponding
316
+ execution schedules.
317
+ """
318
+ top_priority_node = None
319
+ for _, candidates in actor_to_candidates.items():
320
+ if len(candidates) == 0:
321
+ continue
322
+ if top_priority_node is None or candidates[0] < top_priority_node:
323
+ top_priority_node = candidates[0]
324
+
325
+ if top_priority_node is None:
326
+ return None
327
+ next_nodes = [
328
+ heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id])
329
+ ]
330
+
331
+ if not top_priority_node.is_nccl_op:
332
+ # A non-NCCL operation node is picked.
333
+ assert len(next_nodes) == 1
334
+ elif top_priority_node.is_nccl_write:
335
+ # a NCCL write node is picked. NCCL is a blocking operation, so we need
336
+ # to pick all the corresponding NCCL read nodes to avoid a deadlock.
337
+ for downstream_node_metadata in top_priority_node.out_edges:
338
+ task_idx, op_type = downstream_node_metadata
339
+ downstream_node = graph[task_idx][op_type]
340
+ assert downstream_node.is_read
341
+ next_nodes.append(downstream_node)
342
+ assert len(next_nodes) == 1 + len(top_priority_node.out_edges)
343
+ elif top_priority_node.is_nccl_collective:
344
+ # a NCCL collective node is picked. NCCL is a blocking operation, so we need
345
+ # to pick all the corresponding NCCL collective nodes in its collective
346
+ # operation to avoid a deadlock.
347
+ for collective_node_metadata in top_priority_node.collective_idxs:
348
+ task_idx, op_type = collective_node_metadata
349
+ collective_node = graph[task_idx][op_type]
350
+ assert collective_node.is_nccl_collective and collective_node.is_ready
351
+ if collective_node != top_priority_node:
352
+ next_nodes.append(collective_node)
353
+ assert len(next_nodes) == len(top_priority_node.collective_idxs)
354
+
355
+ return next_nodes
356
+
357
+
358
+ def _build_dag_node_operation_graph(
359
+ idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"],
360
+ actor_to_operation_nodes: Dict[
361
+ "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]
362
+ ],
363
+ ) -> Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]:
364
+ """
365
+ Generate a DAG node operation graph by adding edges based on the
366
+ following rules:
367
+
368
+ #1 Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which
369
+ belong to the same task.
370
+ #2 Add an edge from COMPUTE with bind_index i to COMPUTE with bind_index
371
+ i+1 if they belong to the same actor.
372
+ #3 Add an edge from WRITE of the writer task to READ of the reader task.
373
+
374
+ This is the step one of building an execution schedule for each actor.
375
+
376
+ Args:
377
+ idx_to_task: A dictionary that maps the `task_idx` to the `CompiledTask`.
378
+ `CompiledTask` contains information about a DAGNode and its downstream
379
+ nodes.
380
+
381
+ actor_to_operation_nodes: A dictionary that maps an actor handle to
382
+ a list of lists of _DAGOperationGraphNode. For the same actor, the
383
+ index of the outer list corresponds to the index of the ExecutableTask
384
+ in the list of `executable_tasks` in `actor_to_executable_tasks`. In
385
+ the inner list, the order of operations is READ, COMPUTE, and WRITE.
386
+
387
+ Returns:
388
+ A graph where each node is a _DAGOperationGraphNode. The key is `task_idx`,
389
+ the index to retrieve its task from `idx_to_task`, and the value is a
390
+ dictionary that maps the _DAGNodeOperationType (READ, COMPUTE, or WRITE)
391
+ to the corresponding _DAGOperationGraphNode
392
+ """
393
+ assert idx_to_task
394
+ graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] = {}
395
+
396
+ for _, operation_nodes_list in actor_to_operation_nodes.items():
397
+ prev_compute_node = None
398
+ for operation_nodes in operation_nodes_list:
399
+ task_idx = operation_nodes[0].task_idx
400
+ read_node, compute_node, write_node = (
401
+ operation_nodes[0],
402
+ operation_nodes[1],
403
+ operation_nodes[2],
404
+ )
405
+ # Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which
406
+ # belong to the same task.
407
+ _add_edge(read_node, compute_node)
408
+ _add_edge(compute_node, write_node)
409
+ # Add an edge from COMPUTE with `bind_index` i to COMPUTE with
410
+ # `bind_index` i+1 if they belong to the same actor.
411
+ if prev_compute_node is not None:
412
+ _add_edge(prev_compute_node, compute_node, "", True)
413
+ prev_compute_node = compute_node
414
+ assert task_idx not in graph
415
+ graph[task_idx] = {
416
+ _DAGNodeOperationType.READ: read_node,
417
+ _DAGNodeOperationType.COMPUTE: compute_node,
418
+ _DAGNodeOperationType.WRITE: write_node,
419
+ }
420
+
421
+ # Import `ray.dag` here to avoid circular import.
422
+ from ray.dag import ClassMethodNode, CollectiveOutputNode, MultiOutputNode
423
+
424
+ # Add an edge from WRITE of the writer task to READ of the reader task.
425
+ for task_idx, task in idx_to_task.items():
426
+ if not (
427
+ isinstance(task.dag_node, ClassMethodNode)
428
+ or isinstance(task.dag_node, CollectiveOutputNode)
429
+ ):
430
+ # The graph is used to generate an execution schedule for each actor.
431
+ # The edge from the InputNode has no impact on the final execution
432
+ # schedule.
433
+ continue
434
+ if (
435
+ isinstance(task.dag_node, ClassMethodNode)
436
+ and task.dag_node.is_class_method_output
437
+ ):
438
+ # Class method output node dependencies are handled at its upstream:
439
+ # i.e., class method node
440
+ continue
441
+ for downstream_task_idx in task.downstream_task_idxs:
442
+ downstream_dag_node = idx_to_task[downstream_task_idx].dag_node
443
+ if isinstance(downstream_dag_node, MultiOutputNode):
444
+ continue
445
+ if (
446
+ isinstance(downstream_dag_node, ClassMethodNode)
447
+ and downstream_dag_node.is_class_method_output
448
+ ):
449
+ consumer_idxs = idx_to_task[downstream_task_idx].downstream_task_idxs
450
+ for consumer_idx in consumer_idxs:
451
+ if consumer_idx in graph:
452
+ _add_edge(
453
+ graph[task_idx][_DAGNodeOperationType.WRITE],
454
+ graph[consumer_idx][_DAGNodeOperationType.READ],
455
+ "nccl"
456
+ if graph[task_idx][
457
+ _DAGNodeOperationType.WRITE
458
+ ].requires_nccl
459
+ else "shm",
460
+ )
461
+ continue
462
+ _add_edge(
463
+ graph[task_idx][_DAGNodeOperationType.WRITE],
464
+ graph[downstream_task_idx][_DAGNodeOperationType.READ],
465
+ "nccl"
466
+ if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl
467
+ else "shm",
468
+ )
469
+
470
+ return graph
471
+
472
+
473
+ def _actor_viz_label(actor: "ray.actor.ActorHandle"):
474
+ """
475
+ Returns the label of an actor in the visualization of the execution schedule.
476
+
477
+ Args:
478
+ actor: The actor to be represented.
479
+ """
480
+ class_name = actor._ray_actor_creation_function_descriptor.class_name
481
+ actor_id = actor._ray_actor_id.hex()
482
+ return f"Actor class name: {class_name}\nActor ID: {actor_id}"
483
+
484
+
485
+ def _node_viz_id_and_label(
486
+ node: _DAGOperationGraphNode, idx: int, optimized_index: int
487
+ ):
488
+ """
489
+ Returns the visualization id and label of a node. The visualization id is unique
490
+ across all nodes.
491
+
492
+ Args:
493
+ node: The node to be represented.
494
+ idx: The index of the node in the execution schedule.
495
+ optimized_index: The index of the node in the optimized execution schedule.
496
+ """
497
+ node_viz_label = node.viz_str() + f" {idx},{optimized_index}"
498
+ node_viz_id = f"{node._actor_id}_{node_viz_label}"
499
+ return node_viz_id, node_viz_label
500
+
501
+
502
+ def _visualize_execution_schedule(
503
+ actor_to_execution_schedule: Dict[
504
+ "ray.actor.ActorHandle", List[_DAGOperationGraphNode]
505
+ ],
506
+ actor_to_overlapped_schedule: Optional[
507
+ Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]
508
+ ],
509
+ graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
510
+ ):
511
+ """
512
+ Visualize the execution schedule for each actor.
513
+
514
+ The visualization will be saved as a PNG file named `compiled_graph_schedule.png`.
515
+ Details of the visualization: # noqa
516
+
517
+ Node description format:
518
+ [<task_index>] <method_name> <operation> <orig_index>, <overlap_index>
519
+
520
+ Node description fields:
521
+ operation: is R(READ), C(COMPUTE), or W(WRITE)
522
+ orig_index: the index in the original execution schedule
523
+ overlap_index: the index in the overlap-communication optimized execution schedule
524
+ If this is different from orig_index, the node is highlighted in red color
525
+
526
+ Node grouping:
527
+ The nodes belonging to the same actor are grouped in the same rectangle
528
+ The actor class name and the actor id are shown in the rectangle
529
+
530
+ Edges:
531
+ black color (without label): data dependency
532
+ black color (annotated with "shm"): shared memory channel
533
+ blue color (annotated with "nccl): NCCL channel
534
+ dashed edge: control dependency between compute operations
535
+
536
+ Args:
537
+ actor_to_execution_schedule: A dictionary that maps an actor handle to
538
+ the execution schedule which is a list of operation nodes.
539
+ actor_to_overlapped_schedule: A dictionary that maps an actor handle to the
540
+ optimized execution schedule which is a list of operation nodes.
541
+ graph: A graph where each node is a _DAGOperationGraphNode. The key is
542
+ `task_idx`, the index to retrieve its task from `idx_to_task`, and
543
+ the value is a dictionary that maps the _DAGNodeOperationType (READ,
544
+ COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is
545
+ generated by `_build_dag_node_operation_graph`.
546
+ """
547
+ try:
548
+ import graphviz
549
+ except ImportError:
550
+ raise ImportError(
551
+ "Please install graphviz to visualize the execution schedule. "
552
+ "You can install it by running `pip install graphviz`."
553
+ )
554
+
555
+ dot = graphviz.Digraph(comment="DAG")
556
+ # A dictionary that maps a node to its visualization id
557
+ node_to_viz_id: Dict[_DAGOperationGraphNode, str] = {}
558
+
559
+ if actor_to_overlapped_schedule is None:
560
+ # TODO(rui): make the visualization more concise by only displaying
561
+ # the original schedule
562
+ actor_to_overlapped_schedule = actor_to_execution_schedule
563
+ for actor, execution_nodes in actor_to_execution_schedule.items():
564
+ overlapped_schedule = actor_to_overlapped_schedule[actor]
565
+ node_to_optimized_index = {
566
+ node: i for i, node in enumerate(overlapped_schedule)
567
+ }
568
+
569
+ actor_id = actor._ray_actor_id.hex()
570
+ with dot.subgraph(name=f"cluster_{actor_id}") as subgraph:
571
+ subgraph.attr(rank=actor_id, label=_actor_viz_label(actor))
572
+ for i, node in enumerate(execution_nodes):
573
+ optimized_index = node_to_optimized_index.get(node)
574
+ node_viz_id, node_viz_label = _node_viz_id_and_label(
575
+ node, i, optimized_index
576
+ )
577
+ color = "red" if optimized_index != i else "black"
578
+ subgraph.node(node_viz_id, node_viz_label, color=color)
579
+ node_to_viz_id[node] = node_viz_id
580
+
581
+ for actor, execution_nodes in actor_to_execution_schedule.items():
582
+ for i, node in enumerate(execution_nodes):
583
+ node_viz_id = node_to_viz_id[node]
584
+ for out_edge, viz_info in node.out_edges.items():
585
+ label, control_dependency = viz_info
586
+ out_task_idx, out_op_type = out_edge
587
+ out_node = graph[out_task_idx][out_op_type]
588
+ out_node_viz_id = node_to_viz_id[out_node]
589
+ color = "blue" if label == "nccl" else "black"
590
+ style = "dashed" if control_dependency else "solid"
591
+ dot.edge(
592
+ node_viz_id, out_node_viz_id, label=label, color=color, style=style
593
+ )
594
+
595
+ # Add legend
596
+ with dot.subgraph(name="cluster_legend") as legend:
597
+ legend.attr(label="Legend", labelloc="t", fontsize="20", bgcolor="lightgrey")
598
+
599
+ # Single node and its explanation
600
+ legend.node("example_node", "[0] bwd C 10,10\n")
601
+ explanation = (
602
+ '<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">' # noqa
603
+ '<TR><TD ALIGN="LEFT"><B>Node description format:</B></TD></TR>'
604
+ '<TR><TD ALIGN="LEFT">[&lt;task_index&gt;] &lt;method_name&gt; &lt;operation&gt; &lt;orig_index&gt;, &lt;overlap_index&gt;</TD></TR>' # noqa
605
+ "<TR><TD></TD></TR>"
606
+ '<TR><TD ALIGN="LEFT"><B>Node description fields:</B></TD></TR>'
607
+ '<TR><TD ALIGN="LEFT">operation: is R(READ), C(COMPUTE), or W(WRITE)</TD></TR>' # noqa
608
+ '<TR><TD ALIGN="LEFT">orig_index: the index in the original execution schedule</TD></TR>' # noqa
609
+ '<TR><TD ALIGN="LEFT">overlap_index: the index in the overlap-communication optimized execution schedule</TD></TR>' # noqa
610
+ '<TR><TD ALIGN="LEFT">If this is different from orig_index, the node is highlighted in <FONT COLOR="red">red color</FONT></TD></TR>' # noqa
611
+ "<TR><TD></TD></TR>"
612
+ '<TR><TD ALIGN="LEFT"><B>Node grouping:</B></TD></TR>'
613
+ '<TR><TD ALIGN="LEFT">The nodes belonging to the same actor are grouped in the same rectangle</TD></TR>' # noqa
614
+ '<TR><TD ALIGN="LEFT">The actor class name and the actor id are shown in the rectangle</TD></TR>' # noqa
615
+ "<TR><TD></TD></TR>"
616
+ '<TR><TD ALIGN="LEFT"><B>Edges:</B></TD></TR>'
617
+ '<TR><TD ALIGN="LEFT">black color (without label): data dependency</TD></TR>' # noqa
618
+ '<TR><TD ALIGN="LEFT">black color (annotated with "shm"): shared memory channel</TD></TR>' # noqa
619
+ '<TR><TD ALIGN="LEFT"><FONT COLOR="blue">blue color</FONT> (annotated with "nccl): NCCL channel</TD></TR>' # noqa
620
+ '<TR><TD ALIGN="LEFT">dashed edge: control dependency between compute operations</TD></TR>' # noqa
621
+ "</TABLE>>"
622
+ )
623
+
624
+ legend.node("example_explanation", explanation, shape="plaintext")
625
+ legend.edge("example_node", "example_explanation", style="invis")
626
+
627
+ logger.info(
628
+ "Writing compiled graph schedule visualization "
629
+ "to compiled_graph_schedule.png"
630
+ )
631
+ dot.render("compiled_graph_schedule", format="png", view=False)
632
+
633
+
634
+ def _generate_actor_to_execution_schedule(
635
+ graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]
636
+ ) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]:
637
+ """
638
+ Generate an execution schedule for each actor. The schedule is a list of
639
+ operation nodes to be executed. The function uses a topological sort
640
+ algorithm to generate the schedule.
641
+
642
+ Args:
643
+ graph: A graph where each node is a _DAGOperationGraphNode. The key is
644
+ `task_idx`, the index to retrieve its task from `idx_to_task`, and
645
+ the value is a dictionary that maps the _DAGNodeOperationType (READ,
646
+ COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is
647
+ generated by `_build_dag_node_operation_graph`.
648
+
649
+ Returns:
650
+ actor_to_execution_schedule: A dictionary that maps an actor handle to
651
+ the execution schedule which is a list of operation nodes to be
652
+ executed.
653
+ """
654
+
655
+ # Mapping from the actor handle to the execution schedule which is a list
656
+ # of operations to be executed.
657
+ actor_to_execution_schedule: Dict[
658
+ "ray.actor.ActorHandle", List[_DAGOperationGraphNode]
659
+ ] = defaultdict(list)
660
+
661
+ # A dictionary mapping an actor id to a list of candidate nodes. The list
662
+ # is maintained as a priority queue, so the head of the queue, i.e.,
663
+ # `candidates[0]`, is the node with the smallest `bind_index`.
664
+ actor_to_candidates: Dict[
665
+ "ray._raylet.ActorID", List[_DAGOperationGraphNode]
666
+ ] = defaultdict(list)
667
+ for _, node_dict in graph.items():
668
+ for _, node in node_dict.items():
669
+ # A node with a zero in-degree edge means all of its dependencies
670
+ # have been satisfied, including both data and control dependencies.
671
+ # Therefore, it is a candidate for execution.
672
+ if node.in_degree == 0:
673
+ _push_candidate_node_if_ready(actor_to_candidates, graph, node)
674
+
675
+ visited_nodes = set()
676
+
677
+ # Use topological sort algorithm to generate the execution schedule.
678
+ while True:
679
+ # Select a list of nodes to be executed. There are three cases:
680
+ # 1. If a selected node is not a NCCL operation, only itself is returned.
681
+ # 2. If a selected node is a NCCL write operation, the corresponding NCCL
682
+ # read operations are also returned.
683
+ # 3. If a selected node is a NCCL collective operation, all the nodes in
684
+ # its collective operation are returned.
685
+ # In cases 1 and 3, all the selected nodes are ready. In case 2, the NCCL
686
+ # write node is ready, while the NCCL read nodes are not ready until their
687
+ # in-degrees are updated.
688
+ nodes = _select_next_nodes(actor_to_candidates, graph)
689
+ if nodes is None:
690
+ break
691
+ # Filter out the visited nodes.
692
+ nodes = [node for node in nodes if node not in visited_nodes]
693
+ # Add the selected nodes to the execution schedule.
694
+ for node in nodes:
695
+ actor_to_execution_schedule[node.actor_handle].append(node)
696
+ visited_nodes.add(node)
697
+ # Update the in-degree of the downstream nodes.
698
+ for node in nodes:
699
+ for out_node_task_idx, out_node_type in node.out_edges:
700
+ out_node = graph[out_node_task_idx][out_node_type]
701
+ out_node.in_edges.pop((node.task_idx, node.operation.type))
702
+ if out_node.in_degree == 0 and out_node not in visited_nodes:
703
+ # If the downstream node is already visited, it has been added
704
+ # to the execution schedule. They are the NCCL read nodes in
705
+ # case 2.
706
+ _push_candidate_node_if_ready(actor_to_candidates, graph, out_node)
707
+ assert len(visited_nodes) == len(graph) * 3, "Expected all nodes to be visited"
708
+ for node in visited_nodes:
709
+ assert node.is_ready, f"Expected {node} to be ready"
710
+ for _, candidates in actor_to_candidates.items():
711
+ assert len(candidates) == 0, "Expected all candidates to be empty"
712
+
713
+ return actor_to_execution_schedule
714
+
715
+
716
+ def _generate_overlapped_execution_schedule(
717
+ actor_to_execution_schedule: Dict[
718
+ "ray.actor.ActorHandle", List[_DAGOperationGraphNode]
719
+ ],
720
+ ) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]:
721
+ """
722
+ From an existing execution schedule, generate a new schedule by overlapping
723
+ computation and communication.
724
+
725
+ Currently, the algorithm generates a new schedule for each actor as follows:
726
+ For each NCCL read operation (i.e., recv), scan backwards to find the nearest
727
+ compute node to swap with so that the NCCL read operation can be overlapped
728
+ with computation.
729
+
730
+ Collective operations are not yet supported.
731
+
732
+ Args:
733
+ actor_to_execution_schedule: A dictionary that maps an actor handle to
734
+ the existing execution schedule for the actor. The schedule is a list
735
+ is a list of operations to be executed.
736
+
737
+ Returns:
738
+ A dictionary that maps an actor handle to the overlapped execution schedule
739
+ for the actor.
740
+ """
741
+
742
+ actor_to_overlapped_schedule: Dict[
743
+ "ray.actor.ActorHandle", List[_DAGOperationGraphNode]
744
+ ] = copy.deepcopy(actor_to_execution_schedule)
745
+ for overlapped_schedule in actor_to_overlapped_schedule.values():
746
+ for i in range(len(overlapped_schedule)):
747
+ if (
748
+ overlapped_schedule[i].operation.type == _DAGNodeOperationType.READ
749
+ and overlapped_schedule[i].requires_nccl
750
+ ):
751
+ # For each NCCL read operation (i.e., recv), scan backwards
752
+ # to find the nearest compute node to swap with so that
753
+ # the NCCL read operation can be overlapped with computation.
754
+ for j in range(i - 1, -1, -1):
755
+ if (
756
+ overlapped_schedule[j].operation.type
757
+ == _DAGNodeOperationType.COMPUTE
758
+ ):
759
+ # Found a desired compute operation, make the swap
760
+ nccl_read_op = overlapped_schedule[i]
761
+ prev_ops = overlapped_schedule[j:i]
762
+ overlapped_schedule[j + 1 : i + 1] = prev_ops
763
+ overlapped_schedule[j] = nccl_read_op
764
+ break
765
+ if (
766
+ overlapped_schedule[j].operation.type
767
+ == _DAGNodeOperationType.READ
768
+ or overlapped_schedule[j].operation.type
769
+ == _DAGNodeOperationType.WRITE
770
+ ) and overlapped_schedule[j].requires_nccl:
771
+ # Found a NCCL read/write operation, skip the overlap
772
+ # optimization to keep relative order of NCCL operations
773
+ break
774
+ return actor_to_overlapped_schedule
775
+
776
+
777
+ def _extract_execution_schedule(
778
+ actor_to_execution_schedule: Dict[
779
+ "ray.actor.ActorHandle", List[_DAGOperationGraphNode]
780
+ ]
781
+ ) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]:
782
+ """
783
+ Extract _DAGNodeOperation from _DAGOperationGraphNode in the schedule
784
+ and discard unnecessary information.
785
+ """
786
+ return {
787
+ actor: [node.operation for node in nodes]
788
+ for actor, nodes in actor_to_execution_schedule.items()
789
+ }
.venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
3
+ from ray.util.annotations import DeveloperAPI
4
+
5
+
6
+ if TYPE_CHECKING:
7
+ import cupy as cp
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ @DeveloperAPI
13
+ class DAGOperationFuture(ABC, Generic[T]):
14
+ """
15
+ A future representing the result of a DAG operation.
16
+
17
+ This is an abstraction that is internal to each actor,
18
+ and is not exposed to the DAG caller.
19
+ """
20
+
21
+ @abstractmethod
22
+ def wait(self):
23
+ """
24
+ Wait for the future and return the result of the operation.
25
+ """
26
+ raise NotImplementedError
27
+
28
+
29
+ @DeveloperAPI
30
+ class ResolvedFuture(DAGOperationFuture):
31
+ """
32
+ A future that is already resolved. Calling `wait()` on this will
33
+ immediately return the result without blocking.
34
+ """
35
+
36
+ def __init__(self, result):
37
+ """
38
+ Initialize a resolved future.
39
+
40
+ Args:
41
+ result: The result of the future.
42
+ """
43
+ self._result = result
44
+
45
+ def wait(self):
46
+ """
47
+ Wait and immediately return the result. This operation will not block.
48
+ """
49
+ return self._result
50
+
51
+
52
+ @DeveloperAPI
53
+ class GPUFuture(DAGOperationFuture[Any]):
54
+ """
55
+ A future for a GPU event on a CUDA stream.
56
+
57
+ This future wraps a buffer, and records an event on the given stream
58
+ when it is created. When the future is waited on, it makes the current
59
+ CUDA stream wait on the event, then returns the buffer.
60
+
61
+ The buffer must be a GPU tensor produced by an earlier operation launched
62
+ on the given stream, or it could be CPU data. Then the future guarantees
63
+ that when the wait() returns, the buffer is ready on the current stream.
64
+
65
+ The `wait()` does not block CPU.
66
+ """
67
+
68
+ def __init__(self, buf: Any, stream: Optional["cp.cuda.Stream"] = None):
69
+ """
70
+ Initialize a GPU future on the given stream.
71
+
72
+ Args:
73
+ buf: The buffer to return when the future is resolved.
74
+ stream: The CUDA stream to record the event on, this event is waited
75
+ on when the future is resolved. If None, the current stream is used.
76
+ """
77
+ import cupy as cp
78
+
79
+ if stream is None:
80
+ stream = cp.cuda.get_current_stream()
81
+
82
+ self._buf = buf
83
+ self._event = cp.cuda.Event()
84
+ self._event.record(stream)
85
+
86
+ def wait(self) -> Any:
87
+ """
88
+ Wait for the future on the current CUDA stream and return the result from
89
+ the GPU operation. This operation does not block CPU.
90
+ """
91
+ import cupy as cp
92
+
93
+ current_stream = cp.cuda.get_current_stream()
94
+ current_stream.wait_event(self._event)
95
+ return self._buf
.venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dag/format_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.dag import DAGNode
2
+ from ray.util.annotations import DeveloperAPI
3
+
4
+
5
+ @DeveloperAPI
6
+ def get_dag_node_str(
7
+ dag_node: DAGNode,
8
+ body_line,
9
+ ):
10
+ indent = _get_indentation()
11
+ other_args_to_resolve_lines = _get_other_args_to_resolve_lines(
12
+ dag_node._bound_other_args_to_resolve
13
+ )
14
+ return (
15
+ f"({dag_node.__class__.__name__}, {dag_node._stable_uuid})(\n"
16
+ f"{indent}body={body_line}\n"
17
+ f"{indent}args={_get_args_lines(dag_node._bound_args)}\n"
18
+ f"{indent}kwargs={_get_kwargs_lines(dag_node._bound_kwargs)}\n"
19
+ f"{indent}options={_get_options_lines(dag_node._bound_options)}\n"
20
+ f"{indent}other_args_to_resolve={other_args_to_resolve_lines}\n"
21
+ f")"
22
+ )
23
+
24
+
25
+ def _get_indentation(num_spaces=4):
26
+ return " " * num_spaces
27
+
28
+
29
+ def _get_args_lines(bound_args):
30
+ """Pretty prints bounded args of a DAGNode, and recursively handle
31
+ DAGNode in list / dict containers.
32
+ """
33
+ indent = _get_indentation()
34
+ lines = []
35
+ for arg in bound_args:
36
+ if isinstance(arg, DAGNode):
37
+ node_repr_lines = str(arg).split("\n")
38
+ for node_repr_line in node_repr_lines:
39
+ lines.append(f"{indent}" + node_repr_line)
40
+ elif isinstance(arg, list):
41
+ for ele in arg:
42
+ node_repr_lines = str(ele).split("\n")
43
+ for node_repr_line in node_repr_lines:
44
+ lines.append(f"{indent}" + node_repr_line)
45
+ elif isinstance(arg, dict):
46
+ for _, val in arg.items():
47
+ node_repr_lines = str(val).split("\n")
48
+ for node_repr_line in node_repr_lines:
49
+ lines.append(f"{indent}" + node_repr_line)
50
+ # TODO: (jiaodong) Handle nested containers and other obj types
51
+ else:
52
+ lines.append(f"{indent}" + str(arg) + ", ")
53
+
54
+ if len(lines) == 0:
55
+ args_line = "[]"
56
+ else:
57
+ args_line = "["
58
+ for args in lines:
59
+ args_line += f"\n{indent}{args}"
60
+ args_line += f"\n{indent}]"
61
+
62
+ return args_line
63
+
64
+
65
+ def _get_kwargs_lines(bound_kwargs):
66
+ """Pretty prints bounded kwargs of a DAGNode, and recursively handle
67
+ DAGNode in list / dict containers.
68
+ """
69
+ # TODO: (jiaodong) Nits, we're missing keys and indentation was a bit off.
70
+ if not bound_kwargs:
71
+ return "{}"
72
+ indent = _get_indentation()
73
+ kwargs_lines = []
74
+ for key, val in bound_kwargs.items():
75
+ if isinstance(val, DAGNode):
76
+ node_repr_lines = str(val).split("\n")
77
+ for index, node_repr_line in enumerate(node_repr_lines):
78
+ if index == 0:
79
+ kwargs_lines.append(
80
+ f"{indent}{key}:" + f"{indent}" + node_repr_line
81
+ )
82
+ else:
83
+ kwargs_lines.append(f"{indent}{indent}" + node_repr_line)
84
+
85
+ elif isinstance(val, list):
86
+ for ele in val:
87
+ node_repr_lines = str(ele).split("\n")
88
+ for node_repr_line in node_repr_lines:
89
+ kwargs_lines.append(f"{indent}" + node_repr_line)
90
+ elif isinstance(val, dict):
91
+ for _, inner_val in val.items():
92
+ node_repr_lines = str(inner_val).split("\n")
93
+ for node_repr_line in node_repr_lines:
94
+ kwargs_lines.append(f"{indent}" + node_repr_line)
95
+ # TODO: (jiaodong) Handle nested containers and other obj types
96
+ else:
97
+ kwargs_lines.append(val)
98
+
99
+ if len(kwargs_lines) > 0:
100
+ kwargs_line = "{"
101
+ for line in kwargs_lines:
102
+ kwargs_line += f"\n{indent}{line}"
103
+ kwargs_line += f"\n{indent}}}"
104
+ else:
105
+ kwargs_line = "{}"
106
+
107
+ return kwargs_line
108
+
109
+
110
+ def _get_options_lines(bound_options):
111
+ """Pretty prints .options() in DAGNode. Only prints non-empty values."""
112
+ if not bound_options:
113
+ return "{}"
114
+ indent = _get_indentation()
115
+ options_lines = []
116
+ for key, val in bound_options.items():
117
+ if val:
118
+ options_lines.append(f"{indent}{key}: " + str(val))
119
+
120
+ options_line = "{"
121
+ for line in options_lines:
122
+ options_line += f"\n{indent}{line}"
123
+ options_line += f"\n{indent}}}"
124
+ return options_line
125
+
126
+
127
+ def _get_other_args_to_resolve_lines(other_args_to_resolve):
128
+ if not other_args_to_resolve:
129
+ return "{}"
130
+ indent = _get_indentation()
131
+ other_args_to_resolve_lines = []
132
+ for key, val in other_args_to_resolve.items():
133
+ if isinstance(val, DAGNode):
134
+ node_repr_lines = str(val).split("\n")
135
+ for index, node_repr_line in enumerate(node_repr_lines):
136
+ if index == 0:
137
+ other_args_to_resolve_lines.append(
138
+ f"{indent}{key}:"
139
+ + f"{indent}"
140
+ + "\n"
141
+ + f"{indent}{indent}{indent}"
142
+ + node_repr_line
143
+ )
144
+ else:
145
+ other_args_to_resolve_lines.append(
146
+ f"{indent}{indent}" + node_repr_line
147
+ )
148
+ else:
149
+ other_args_to_resolve_lines.append(f"{indent}{key}: " + str(val))
150
+
151
+ other_args_to_resolve_line = "{"
152
+ for line in other_args_to_resolve_lines:
153
+ other_args_to_resolve_line += f"\n{indent}{line}"
154
+ other_args_to_resolve_line += f"\n{indent}}}"
155
+ return other_args_to_resolve_line
.venv/lib/python3.11/site-packages/ray/dag/function_node.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+
4
+ import ray
5
+ from ray.dag.dag_node import DAGNode
6
+ from ray.dag.format_utils import get_dag_node_str
7
+ from ray.util.annotations import DeveloperAPI
8
+
9
+
10
+ @DeveloperAPI
11
+ class FunctionNode(DAGNode):
12
+ """Represents a bound task node in a Ray task DAG."""
13
+
14
+ def __init__(
15
+ self,
16
+ func_body,
17
+ func_args,
18
+ func_kwargs,
19
+ func_options,
20
+ other_args_to_resolve=None,
21
+ ):
22
+ self._body = func_body
23
+ super().__init__(
24
+ func_args,
25
+ func_kwargs,
26
+ func_options,
27
+ other_args_to_resolve=other_args_to_resolve,
28
+ )
29
+
30
+ def _copy_impl(
31
+ self,
32
+ new_args: List[Any],
33
+ new_kwargs: Dict[str, Any],
34
+ new_options: Dict[str, Any],
35
+ new_other_args_to_resolve: Dict[str, Any],
36
+ ):
37
+ return FunctionNode(
38
+ self._body,
39
+ new_args,
40
+ new_kwargs,
41
+ new_options,
42
+ other_args_to_resolve=new_other_args_to_resolve,
43
+ )
44
+
45
+ def _execute_impl(self, *args, **kwargs):
46
+ """Executor of FunctionNode by ray.remote().
47
+
48
+ Args and kwargs are to match base class signature, but not in the
49
+ implementation. All args and kwargs should be resolved and replaced
50
+ with value in bound_args and bound_kwargs via bottom-up recursion when
51
+ current node is executed.
52
+ """
53
+ return (
54
+ ray.remote(self._body)
55
+ .options(**self._bound_options)
56
+ .remote(*self._bound_args, **self._bound_kwargs)
57
+ )
58
+
59
+ def __str__(self) -> str:
60
+ return get_dag_node_str(self, str(self._body))
.venv/lib/python3.11/site-packages/ray/dag/input_node.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Optional
2
+
3
+ from ray.dag import DAGNode
4
+ from ray.dag.format_utils import get_dag_node_str
5
+ from ray.experimental.gradio_utils import type_to_string
6
+ from ray.util.annotations import DeveloperAPI
7
+
8
+ IN_CONTEXT_MANAGER = "__in_context_manager__"
9
+
10
+
11
+ @DeveloperAPI
12
+ class InputNode(DAGNode):
13
+ r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.
14
+
15
+ Should only be function or class method. A DAG can have multiple
16
+ entrypoints, but only one instance of InputNode exists per DAG, shared
17
+ among all DAGNodes.
18
+
19
+ Example:
20
+
21
+ .. code-block::
22
+
23
+ m1.forward
24
+ / \
25
+ dag_input ensemble -> dag_output
26
+ \ /
27
+ m2.forward
28
+
29
+ In this pipeline, each user input is broadcasted to both m1.forward and
30
+ m2.forward as first stop of the DAG, and authored like
31
+
32
+ .. code-block:: python
33
+
34
+ import ray
35
+
36
+ @ray.remote
37
+ class Model:
38
+ def __init__(self, val):
39
+ self.val = val
40
+ def forward(self, input):
41
+ return self.val * input
42
+
43
+ @ray.remote
44
+ def combine(a, b):
45
+ return a + b
46
+
47
+ with InputNode() as dag_input:
48
+ m1 = Model.bind(1)
49
+ m2 = Model.bind(2)
50
+ m1_output = m1.forward.bind(dag_input[0])
51
+ m2_output = m2.forward.bind(dag_input.x)
52
+ ray_dag = combine.bind(m1_output, m2_output)
53
+
54
+ # Pass mix of args and kwargs as input.
55
+ ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2
56
+
57
+ # Alternatively user can also pass single data object, list or dict
58
+ # and access them via list index, object attribute or dict key str.
59
+ ray_dag.execute(UserDataObject(m1=1, m2=2))
60
+ # dag_input.m1, dag_input.m2
61
+ ray_dag.execute([1, 2])
62
+ # dag_input[0], dag_input[1]
63
+ ray_dag.execute({"m1": 1, "m2": 2})
64
+ # dag_input["m1"], dag_input["m2"]
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ *args,
70
+ input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
71
+ _other_args_to_resolve=None,
72
+ **kwargs,
73
+ ):
74
+ """InputNode should only take attributes of validating and converting
75
+ input data rather than the input data itself. User input should be
76
+ provided via `ray_dag.execute(user_input)`.
77
+
78
+ Args:
79
+ input_type: Describes the data type of inputs user will be giving.
80
+ - if given through singular InputNode: type of InputNode
81
+ - if given through InputAttributeNodes: map of key -> type
82
+ Used when deciding what Gradio block to represent the input nodes with.
83
+ _other_args_to_resolve: Internal only to keep InputNode's execution
84
+ context throughput pickling, replacement and serialization.
85
+ User should not use or pass this field.
86
+ """
87
+ if len(args) != 0 or len(kwargs) != 0:
88
+ raise ValueError("InputNode should not take any args or kwargs.")
89
+
90
+ self.input_attribute_nodes = {}
91
+
92
+ self.input_type = input_type
93
+ if input_type is not None and isinstance(input_type, type):
94
+ if _other_args_to_resolve is None:
95
+ _other_args_to_resolve = {}
96
+ _other_args_to_resolve["result_type_string"] = type_to_string(input_type)
97
+
98
+ super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)
99
+
100
+ def _copy_impl(
101
+ self,
102
+ new_args: List[Any],
103
+ new_kwargs: Dict[str, Any],
104
+ new_options: Dict[str, Any],
105
+ new_other_args_to_resolve: Dict[str, Any],
106
+ ):
107
+ return InputNode(_other_args_to_resolve=new_other_args_to_resolve)
108
+
109
+ def _execute_impl(self, *args, **kwargs):
110
+ """Executor of InputNode."""
111
+ # Catch and assert singleton context at dag execution time.
112
+ assert self._in_context_manager(), (
113
+ "InputNode is a singleton instance that should be only used in "
114
+ "context manager for dag building and execution. See the docstring "
115
+ "of class InputNode for examples."
116
+ )
117
+ # If user only passed in one value, for simplicity we just return it.
118
+ if len(args) == 1 and len(kwargs) == 0:
119
+ return args[0]
120
+
121
+ return DAGInputData(*args, **kwargs)
122
+
123
+ def _in_context_manager(self) -> bool:
124
+ """Return if InputNode is created in context manager."""
125
+ if (
126
+ not self._bound_other_args_to_resolve
127
+ or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
128
+ ):
129
+ return False
130
+ else:
131
+ return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]
132
+
133
+ def set_context(self, key: str, val: Any):
134
+ """Set field in parent DAGNode attribute that can be resolved in both
135
+ pickle and JSON serialization
136
+ """
137
+ self._bound_other_args_to_resolve[key] = val
138
+
139
+ def __str__(self) -> str:
140
+ return get_dag_node_str(self, "__InputNode__")
141
+
142
+ def __getattr__(self, key: str):
143
+ assert isinstance(
144
+ key, str
145
+ ), "Please only access dag input attributes with str key."
146
+ if key not in self.input_attribute_nodes:
147
+ self.input_attribute_nodes[key] = InputAttributeNode(
148
+ self, key, "__getattr__"
149
+ )
150
+ return self.input_attribute_nodes[key]
151
+
152
+ def __getitem__(self, key: Union[int, str]) -> Any:
153
+ assert isinstance(key, (str, int)), (
154
+ "Please only use int index or str as first-level key to "
155
+ "access fields of dag input."
156
+ )
157
+
158
+ input_type = None
159
+ if self.input_type is not None and key in self.input_type:
160
+ input_type = type_to_string(self.input_type[key])
161
+
162
+ if key not in self.input_attribute_nodes:
163
+ self.input_attribute_nodes[key] = InputAttributeNode(
164
+ self, key, "__getitem__", input_type
165
+ )
166
+ return self.input_attribute_nodes[key]
167
+
168
+ def __enter__(self):
169
+ self.set_context(IN_CONTEXT_MANAGER, True)
170
+ return self
171
+
172
+ def __exit__(self, *args):
173
+ pass
174
+
175
+ def get_result_type(self) -> str:
176
+ """Get type of the output of this DAGNode.
177
+
178
+ Generated by ray.experimental.gradio_utils.type_to_string().
179
+ """
180
+ if "result_type_string" in self._bound_other_args_to_resolve:
181
+ return self._bound_other_args_to_resolve["result_type_string"]
182
+
183
+
184
+ @DeveloperAPI
185
+ class InputAttributeNode(DAGNode):
186
+ """Represents partial access of user input based on an index (int),
187
+ object attribute or dict key (str).
188
+
189
+ Examples:
190
+
191
+ .. code-block:: python
192
+
193
+ with InputNode() as dag_input:
194
+ a = dag_input[0]
195
+ b = dag_input.x
196
+ ray_dag = add.bind(a, b)
197
+
198
+ # This makes a = 1 and b = 2
199
+ ray_dag.execute(1, x=2)
200
+
201
+ with InputNode() as dag_input:
202
+ a = dag_input[0]
203
+ b = dag_input[1]
204
+ ray_dag = add.bind(a, b)
205
+
206
+ # This makes a = 2 and b = 3
207
+ ray_dag.execute(2, 3)
208
+
209
+ # Alternatively, you can input a single object
210
+ # and the inputs are automatically indexed from the object:
211
+ # This makes a = 2 and b = 3
212
+ ray_dag.execute([2, 3])
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ dag_input_node: InputNode,
218
+ key: Union[int, str],
219
+ accessor_method: str,
220
+ input_type: str = None,
221
+ ):
222
+ self._dag_input_node = dag_input_node
223
+ self._key = key
224
+ self._accessor_method = accessor_method
225
+ super().__init__(
226
+ [],
227
+ {},
228
+ {},
229
+ {
230
+ "dag_input_node": dag_input_node,
231
+ "key": key,
232
+ "accessor_method": accessor_method,
233
+ # Type of the input tied to this node. Used by
234
+ # gradio_visualize_graph.GraphVisualizer to determine which Gradio
235
+ # component should be used for this node.
236
+ "result_type_string": input_type,
237
+ },
238
+ )
239
+
240
+ def _copy_impl(
241
+ self,
242
+ new_args: List[Any],
243
+ new_kwargs: Dict[str, Any],
244
+ new_options: Dict[str, Any],
245
+ new_other_args_to_resolve: Dict[str, Any],
246
+ ):
247
+ return InputAttributeNode(
248
+ new_other_args_to_resolve["dag_input_node"],
249
+ new_other_args_to_resolve["key"],
250
+ new_other_args_to_resolve["accessor_method"],
251
+ new_other_args_to_resolve["result_type_string"],
252
+ )
253
+
254
+ def _execute_impl(self, *args, **kwargs):
255
+ """Executor of InputAttributeNode.
256
+
257
+ Args and kwargs are to match base class signature, but not in the
258
+ implementation. All args and kwargs should be resolved and replaced
259
+ with value in bound_args and bound_kwargs via bottom-up recursion when
260
+ current node is executed.
261
+ """
262
+
263
+ if isinstance(self._dag_input_node, DAGInputData):
264
+ return self._dag_input_node[self._key]
265
+ else:
266
+ # dag.execute() is called with only one arg, thus when an
267
+ # InputAttributeNode is executed, its dependent InputNode is
268
+ # resolved with original user input python object.
269
+ user_input_python_object = self._dag_input_node
270
+ if isinstance(self._key, str):
271
+ if self._accessor_method == "__getitem__":
272
+ return user_input_python_object[self._key]
273
+ elif self._accessor_method == "__getattr__":
274
+ return getattr(user_input_python_object, self._key)
275
+ elif isinstance(self._key, int):
276
+ return user_input_python_object[self._key]
277
+ else:
278
+ raise ValueError(
279
+ "Please only use int index or str as first-level key to "
280
+ "access fields of dag input."
281
+ )
282
+
283
+ def __str__(self) -> str:
284
+ return get_dag_node_str(self, f'["{self._key}"]')
285
+
286
+ def get_result_type(self) -> str:
287
+ """Get type of the output of this DAGNode.
288
+
289
+ Generated by ray.experimental.gradio_utils.type_to_string().
290
+ """
291
+ if "result_type_string" in self._bound_other_args_to_resolve:
292
+ return self._bound_other_args_to_resolve["result_type_string"]
293
+
294
+ @property
295
+ def key(self) -> Union[int, str]:
296
+ return self._key
297
+
298
+
299
+ @DeveloperAPI
300
+ class DAGInputData:
301
+ """If user passed multiple args and kwargs directly to dag.execute(), we
302
+ generate this wrapper for all user inputs as one object, accessible via
303
+ list index or object attribute key.
304
+ """
305
+
306
+ def __init__(self, *args, **kwargs):
307
+ self._args = list(args)
308
+ self._kwargs = kwargs
309
+
310
+ def __getitem__(self, key: Union[int, str]) -> Any:
311
+ if isinstance(key, int):
312
+ # Access list args by index.
313
+ return self._args[key]
314
+ elif isinstance(key, str):
315
+ # Access kwarg by key.
316
+ return self._kwargs[key]
317
+ else:
318
+ raise ValueError(
319
+ "Please only use int index or str as first-level key to "
320
+ "access fields of dag input."
321
+ )
.venv/lib/python3.11/site-packages/ray/dag/output_node.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ from typing import Any, Dict, List, Union, Tuple
3
+
4
+ from ray.dag import DAGNode
5
+ from ray.dag.format_utils import get_dag_node_str
6
+ from ray.util.annotations import DeveloperAPI
7
+
8
+
9
+ @DeveloperAPI
10
+ class MultiOutputNode(DAGNode):
11
+ """Ray dag node used in DAG building API to mark the endpoint of DAG"""
12
+
13
+ def __init__(
14
+ self,
15
+ args: Union[List[DAGNode], Tuple[DAGNode]],
16
+ other_args_to_resolve: Dict[str, Any] = None,
17
+ ):
18
+ if isinstance(args, tuple):
19
+ args = list(args)
20
+ if not isinstance(args, list):
21
+ raise ValueError(f"Invalid input type for `args`, {type(args)}.")
22
+ super().__init__(
23
+ args,
24
+ {},
25
+ {},
26
+ other_args_to_resolve=other_args_to_resolve or {},
27
+ )
28
+
29
+ def _execute_impl(
30
+ self, *args, **kwargs
31
+ ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
32
+ return self._bound_args
33
+
34
+ def _copy_impl(
35
+ self,
36
+ new_args: List[Any],
37
+ new_kwargs: Dict[str, Any],
38
+ new_options: Dict[str, Any],
39
+ new_other_args_to_resolve: Dict[str, Any],
40
+ ) -> "DAGNode":
41
+ """Return a copy of this node with the given new args."""
42
+ return MultiOutputNode(new_args, new_other_args_to_resolve)
43
+
44
+ def __str__(self) -> str:
45
+ return get_dag_node_str(self, "__MultiOutputNode__")
.venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union
3
+
4
+ import pickle # noqa: F401
5
+
6
+ import ray
7
+ from ray.dag.base import DAGNodeBase
8
+
9
+
10
+ # Used in deserialization hooks to reference scanner instances.
11
+ _instances: Dict[int, "_PyObjScanner"] = {}
12
+
13
+ # Generic types for the scanner to transform from and to.
14
+ SourceType = TypeVar("SourceType")
15
+ TransformedType = TypeVar("TransformedType")
16
+
17
+
18
+ def _get_node(instance_id: int, node_index: int) -> SourceType:
19
+ """Get the node instance.
20
+
21
+ Note: This function should be static and globally importable,
22
+ otherwise the serialization overhead would be very significant.
23
+ """
24
+ return _instances[instance_id]._replace_index(node_index)
25
+
26
+
27
+ class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
28
+ """Utility to find and replace the `source_type` in Python objects.
29
+
30
+ `source_type` can either be a single type or a tuple of multiple types.
31
+
32
+ The caller must first call `find_nodes()`, then compute a replacement table and
33
+ pass it to `replace_nodes`.
34
+
35
+ This uses cloudpickle under the hood, so all sub-objects that are not `source_type`
36
+ must be serializable.
37
+
38
+ Args:
39
+ source_type: the type(s) of object to find and replace. Default to DAGNodeBase.
40
+ """
41
+
42
+ def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase):
43
+ self.source_type = source_type
44
+ # Buffer to keep intermediate serialized state.
45
+ self._buf = io.BytesIO()
46
+ # List of top-level SourceType found during the serialization pass.
47
+ self._found = None
48
+ # List of other objects found during the serialization pass.
49
+ # This is used to store references to objects so they won't be
50
+ # serialized by cloudpickle.
51
+ self._objects = []
52
+ # Replacement table to consult during deserialization.
53
+ self._replace_table: Dict[SourceType, TransformedType] = None
54
+ _instances[id(self)] = self
55
+ super().__init__(self._buf)
56
+
57
+ def reducer_override(self, obj):
58
+ """Hook for reducing objects.
59
+
60
+ Objects of `self.source_type` are saved to `self._found` and a global map so
61
+ they can later be replaced.
62
+
63
+ All other objects fall back to the default `CloudPickler` serialization.
64
+ """
65
+ if isinstance(obj, self.source_type):
66
+ index = len(self._found)
67
+ self._found.append(obj)
68
+ return _get_node, (id(self), index)
69
+
70
+ return super().reducer_override(obj)
71
+
72
+ def find_nodes(self, obj: Any) -> List[SourceType]:
73
+ """
74
+ Serialize `obj` and store all instances of `source_type` found in `_found`.
75
+
76
+ Args:
77
+ obj: The object to scan for `source_type`.
78
+ Returns:
79
+ A list of all instances of `source_type` found in `obj`.
80
+ """
81
+ assert (
82
+ self._found is None
83
+ ), "find_nodes cannot be called twice on the same PyObjScanner instance."
84
+ self._found = []
85
+ self._objects = []
86
+ self.dump(obj)
87
+ return self._found
88
+
89
+ def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
90
+ """Replace previously found DAGNodes per the given table."""
91
+ assert self._found is not None, "find_nodes must be called first"
92
+ self._replace_table = table
93
+ self._buf.seek(0)
94
+ return pickle.load(self._buf)
95
+
96
+ def _replace_index(self, i: int) -> SourceType:
97
+ return self._replace_table[self._found[i]]
98
+
99
+ def clear(self):
100
+ """Clear the scanner from the _instances"""
101
+ if id(self) in _instances:
102
+ del _instances[id(self)]
103
+
104
+ def __del__(self):
105
+ self.clear()
.venv/lib/python3.11/site-packages/ray/dag/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ray.dag import (
4
+ DAGNode,
5
+ InputNode,
6
+ InputAttributeNode,
7
+ FunctionNode,
8
+ ClassNode,
9
+ ClassMethodNode,
10
+ MultiOutputNode,
11
+ )
12
+
13
+
14
+ class _DAGNodeNameGenerator(object):
15
+ """
16
+ Generate unique suffix for each given Node in the DAG.
17
+ Apply monotonic increasing id suffix for duplicated names.
18
+ """
19
+
20
+ def __init__(self):
21
+ self.name_to_suffix: Dict[str, int] = dict()
22
+
23
+ def get_node_name(self, node: DAGNode):
24
+ # InputNode should be unique.
25
+ if isinstance(node, InputNode):
26
+ return "INPUT_NODE"
27
+ if isinstance(node, MultiOutputNode):
28
+ return "MultiOutputNode"
29
+ # InputAttributeNode suffixes should match the user-defined key.
30
+ elif isinstance(node, InputAttributeNode):
31
+ return f"INPUT_ATTRIBUTE_NODE_{node._key}"
32
+
33
+ # As class, method, and function nodes may have duplicated names,
34
+ # generate unique suffixes for such nodes.
35
+ if isinstance(node, ClassMethodNode):
36
+ node_name = node.get_options().get("name", None) or node._method_name
37
+ elif isinstance(node, (ClassNode, FunctionNode)):
38
+ node_name = node.get_options().get("name", None) or node._body.__name__
39
+ # we use instance class name check here to avoid importing ServeNodes as
40
+ # serve components are not included in Ray Core.
41
+ elif type(node).__name__ in ("DeploymentNode", "DeploymentFunctionNode"):
42
+ node_name = node.get_deployment_name()
43
+ elif type(node).__name__ == "DeploymentFunctionExecutorNode":
44
+ node_name = node._deployment_function_handle.deployment_name
45
+ else:
46
+ raise ValueError(
47
+ "get_node_name() should only be called on DAGNode instances."
48
+ )
49
+
50
+ if node_name not in self.name_to_suffix:
51
+ self.name_to_suffix[node_name] = 0
52
+ return node_name
53
+ else:
54
+ self.name_to_suffix[node_name] += 1
55
+ suffix_num = self.name_to_suffix[node_name]
56
+
57
+ return f"{node_name}_{suffix_num}"
58
+
59
+ def reset(self):
60
+ self.name_to_suffix = dict()
61
+
62
+ def __enter__(self):
63
+ return self
64
+
65
+ def __exit__(self, *args):
66
+ self.reset()
.venv/lib/python3.11/site-packages/ray/dag/vis_utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.dag import DAGNode
2
+
3
+ import os
4
+ import tempfile
5
+
6
+ from ray.dag.utils import _DAGNodeNameGenerator
7
+ from ray.util.annotations import DeveloperAPI
8
+
9
+
10
+ @DeveloperAPI
11
+ def plot(dag: DAGNode, to_file=None):
12
+ if to_file is None:
13
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".png")
14
+ to_file = tmp_file.name
15
+ extension = "png"
16
+ else:
17
+ _, extension = os.path.splitext(to_file)
18
+ if not extension:
19
+ extension = "png"
20
+ else:
21
+ extension = extension[1:]
22
+
23
+ graph = _dag_to_dot(dag)
24
+ graph.write(to_file, format=extension)
25
+
26
+ # Render the image directly if running inside a Jupyter notebook
27
+ try:
28
+ from IPython import display
29
+
30
+ return display.Image(filename=to_file)
31
+ except ImportError:
32
+ pass
33
+
34
+ # close temp file if needed
35
+ try:
36
+ tmp_file.close()
37
+ except NameError:
38
+ pass
39
+
40
+
41
+ def _check_pydot_and_graphviz():
42
+ """Check if pydot and graphviz are installed.
43
+
44
+ pydot and graphviz are required for plotting. We check this
45
+ during runtime rather than adding them to Ray dependencies.
46
+
47
+ """
48
+ try:
49
+ import pydot
50
+ except ImportError:
51
+ raise ImportError(
52
+ "pydot is required to plot DAG, " "install it with `pip install pydot`."
53
+ )
54
+ try:
55
+ pydot.Dot.create(pydot.Dot())
56
+ except (OSError, pydot.InvocationException):
57
+ raise ImportError(
58
+ "graphviz is required to plot DAG, "
59
+ "download it from https://graphviz.gitlab.io/download/"
60
+ )
61
+
62
+
63
+ def _get_nodes_and_edges(dag: DAGNode):
64
+ """Get all unique nodes and edges in the DAG.
65
+
66
+ A basic dfs with memorization to get all unique nodes
67
+ and edges in the DAG.
68
+ Unique nodes will be used to generate unique names,
69
+ while edges will be used to construct the graph.
70
+ """
71
+
72
+ edges = []
73
+ nodes = []
74
+
75
+ def _dfs(node):
76
+ nodes.append(node)
77
+ for child_node in node._get_all_child_nodes():
78
+ edges.append((child_node, node))
79
+ return node
80
+
81
+ dag.apply_recursive(_dfs)
82
+ return nodes, edges
83
+
84
+
85
+ def _dag_to_dot(dag: DAGNode):
86
+ """Create a Dot graph from dag.
87
+
88
+ TODO(lchu):
89
+ 1. add more Dot configs in kwargs,
90
+ e.g. rankdir, alignment, etc.
91
+ 2. add more contents to graph,
92
+ e.g. args, kwargs and options of each node
93
+
94
+ """
95
+ # Step 0: check dependencies and init graph
96
+ _check_pydot_and_graphviz()
97
+ import pydot
98
+
99
+ graph = pydot.Dot(rankdir="LR")
100
+
101
+ # Step 1: generate unique name for each node in dag
102
+ nodes, edges = _get_nodes_and_edges(dag)
103
+ name_generator = _DAGNodeNameGenerator()
104
+ node_names = {}
105
+ for node in nodes:
106
+ node_names[node] = name_generator.get_node_name(node)
107
+
108
+ # Step 2: create graph with all the edges
109
+ for edge in edges:
110
+ graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]]))
111
+ # if there is only one node
112
+ if len(nodes) == 1 and len(edges) == 0:
113
+ graph.add_node(pydot.Node(node_names[nodes[0]]))
114
+
115
+ return graph
.venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.experimental.channel.cached_channel import CachedChannel
2
+ from ray.experimental.channel.common import ( # noqa: F401
3
+ AwaitableBackgroundReader,
4
+ AwaitableBackgroundWriter,
5
+ ChannelContext,
6
+ ChannelInterface,
7
+ ChannelOutputType,
8
+ CompiledDAGArgs,
9
+ ReaderInterface,
10
+ SynchronousReader,
11
+ SynchronousWriter,
12
+ WriterInterface,
13
+ )
14
+ from ray.experimental.channel.communicator import Communicator
15
+ from ray.experimental.channel.intra_process_channel import IntraProcessChannel
16
+ from ray.experimental.channel.shared_memory_channel import (
17
+ BufferedSharedMemoryChannel,
18
+ Channel,
19
+ CompositeChannel,
20
+ )
21
+ from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel
22
+
23
+ __all__ = [
24
+ "AwaitableBackgroundReader",
25
+ "AwaitableBackgroundWriter",
26
+ "CachedChannel",
27
+ "Channel",
28
+ "Communicator",
29
+ "ReaderInterface",
30
+ "SynchronousReader",
31
+ "SynchronousWriter",
32
+ "WriterInterface",
33
+ "ChannelContext",
34
+ "TorchTensorNcclChannel",
35
+ "IntraProcessChannel",
36
+ "CompositeChannel",
37
+ "BufferedSharedMemoryChannel",
38
+ "CompiledDAGArgs",
39
+ ]