diff --git a/.gitattributes b/.gitattributes index 19e8357d01e7258ab53b5ddbd60ae495f1031d94..9be0503eb0fa566d638dc0237121a08aea38718f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -158,3 +158,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccee4cd06b9cfd51c1db48a75189885bd235a8b3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2743c942152a362cad7a1bc54c9dbc5c425e98f8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/_version.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c62d6315241ae452c1c0b4d388885f1788ef24cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/actor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365bcfc30dc6eb67675279c484d61be67f48fa40 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/client_builder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..458458070938d26575b97ea5d949cc26cb0a953e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/cluster_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72f0ec9db69fbf2236bd579cb03169ac477f86f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/cross_language.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf710adb3186b480884a8b9ad522fddec875b6c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..687f32c2786690912cc9cd959498871a414c80c3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/job_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7485213368a75372b7179f44abff65e0a1c05b00 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/remote_function.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..175278d0143700b7cb02c5db6b3e802692ab05b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/runtime_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fb2b8baee30bf6a4728cd726848e42f6e6ef11f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/setup-dev.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88cc19c93540e0a535a2d5642d63a1746b1fe3cc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/__pycache__/types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so b/.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so new file mode 100644 index 0000000000000000000000000000000000000000..2ce59e3d65938408578e721d85346d6105cc5b43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/libjemalloc.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0284919db23f95e692026039838aef89b5964b5cfec4a88acb9b3a9f4a226fd5 +size 885296 diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__init__.py b/.venv/lib/python3.11/site-packages/ray/dag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9970899cf41e4755e4d6c424b768d2357d78f2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/__init__.py @@ -0,0 +1,46 @@ +from ray.dag.dag_node import DAGNode +from ray.dag.function_node import FunctionNode +from ray.dag.class_node import ( + ClassNode, + ClassMethodNode, +) +from ray.dag.collective_node import CollectiveOutputNode +from ray.dag.input_node import ( + InputNode, + InputAttributeNode, + DAGInputData, +) +from ray.dag.output_node import MultiOutputNode +from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture +from ray.dag.constants import ( + PARENT_CLASS_NODE_KEY, + PREV_CLASS_METHOD_CALL_KEY, + BIND_INDEX_KEY, + IS_CLASS_METHOD_OUTPUT_KEY, + COLLECTIVE_OPERATION_KEY, + DAGNODE_TYPE_KEY, +) +from ray.dag.vis_utils import plot +from ray.dag.context import DAGContext + +__all__ = [ + "ClassNode", + "ClassMethodNode", + "CollectiveOutputNode", + "DAGNode", + "DAGOperationFuture", + "FunctionNode", + "GPUFuture", + "InputNode", + "InputAttributeNode", + "DAGInputData", + "PARENT_CLASS_NODE_KEY", + "PREV_CLASS_METHOD_CALL_KEY", + "BIND_INDEX_KEY", + "IS_CLASS_METHOD_OUTPUT_KEY", + "COLLECTIVE_OPERATION_KEY", + "DAGNODE_TYPE_KEY", + "plot", + "MultiOutputNode", + "DAGContext", +] diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8fe8d2211a7a764b53c64c4c25d1c97da0079a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..812788b795c1e557c6334d96232f16275164cdad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0e2fd2fa1fb22604c4bc48eacc1dbf34c87ebb1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/class_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..206abbdd0fb7bafffe37aacf86b3cc455adce89e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/collective_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2995f1b4740eb9c103358609aea86e750308fe80 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/conftest.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05cf1e82ba92d02ebd49f4635c03d949792e17fe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/constants.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..499bf5096df765d1562a73735e4e463bdfd98cb3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db0924132d42944400dc390c8cdc21210e62d967 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c65ccbed99bc127255397aadc55c029e248e62a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_node_operation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c49e8a8c81dde8abc827c3127ee03bb0469e36 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/format_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a246e3c75ec46ab267e7cdcaaab7bcd67f9a09 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/function_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a1ac6b8c2495622e659d637f5fddbbdcf3e675 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/output_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6145377da03d5c0a872470d45cca1a1c287e8ac3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/py_obj_scanner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7427e131f1899a4d3a5c3376956b827e9a85b28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c83c7456a7226aead48fa689b7a67cd29292179c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/vis_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/base.py b/.venv/lib/python3.11/site-packages/ray/dag/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4153866cdeeba83e32becb86ab42ab2544646028 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/base.py @@ -0,0 +1,8 @@ +"""This module defines the base class for object scanning and gets rid of +reference cycles.""" +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DAGNodeBase: + """Common base class for a node in a Ray task graph.""" diff --git a/.venv/lib/python3.11/site-packages/ray/dag/class_node.py b/.venv/lib/python3.11/site-packages/ray/dag/class_node.py new file mode 100644 index 0000000000000000000000000000000000000000..21eb1392f2467ebdb20404aec625172a34ef1ce3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/class_node.py @@ -0,0 +1,321 @@ +from weakref import ReferenceType + +import ray +from ray.dag.dag_node import DAGNode +from ray.dag.input_node import InputNode +from ray.dag.format_utils import get_dag_node_str +from ray.dag.constants import ( + PARENT_CLASS_NODE_KEY, + PREV_CLASS_METHOD_CALL_KEY, + BIND_INDEX_KEY, + IS_CLASS_METHOD_OUTPUT_KEY, +) +from ray.util.annotations import DeveloperAPI + +from typing import Any, Dict, List, Union, Tuple, Optional + + +@DeveloperAPI +class ClassNode(DAGNode): + """Represents an actor creation in a Ray task DAG.""" + + def __init__( + self, + cls, + cls_args, + cls_kwargs, + cls_options, + other_args_to_resolve=None, + ): + self._body = cls + self._last_call: Optional["ClassMethodNode"] = None + super().__init__( + cls_args, + cls_kwargs, + cls_options, + other_args_to_resolve=other_args_to_resolve, + ) + + if self._contains_input_node(): + raise ValueError( + "InputNode handles user dynamic input the the DAG, and " + "cannot be used as args, kwargs, or other_args_to_resolve " + "in ClassNode constructor because it is not available at " + "class construction or binding time." + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return ClassNode( + self._body, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs): + """Executor of ClassNode by ray.remote() + + Args and kwargs are to match base class signature, but not in the + implementation. All args and kwargs should be resolved and replaced + with value in bound_args and bound_kwargs via bottom-up recursion when + current node is executed. + """ + return ( + ray.remote(self._body) + .options(**self._bound_options) + .remote(*self._bound_args, **self._bound_kwargs) + ) + + def _contains_input_node(self) -> bool: + """Check if InputNode is used in children DAGNodes with current node + as the root. + """ + children_dag_nodes = self._get_all_child_nodes() + for child in children_dag_nodes: + if isinstance(child, InputNode): + return True + return False + + def __getattr__(self, method_name: str): + # User trying to call .bind() without a bind class method + if method_name == "bind" and "bind" not in dir(self._body): + raise AttributeError(f".bind() cannot be used again on {type(self)} ") + # Raise an error if the method is invalid. + getattr(self._body, method_name) + call_node = _UnboundClassMethodNode(self, method_name, {}) + return call_node + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._body)) + + +class _UnboundClassMethodNode(object): + def __init__(self, actor: ClassNode, method_name: str, options: dict): + # TODO(sang): Theoretically, We should use weakref cuz it is + # a circular dependency but when I used weakref, it fails + # because we cannot serialize the weakref. + self._actor = actor + self._method_name = method_name + self._options = options + + def bind(self, *args, **kwargs): + other_args_to_resolve = { + PARENT_CLASS_NODE_KEY: self._actor, + PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call, + } + + node = ClassMethodNode( + self._method_name, + args, + kwargs, + self._options, + other_args_to_resolve=other_args_to_resolve, + ) + self._actor._last_call = node + return node + + def __getattr__(self, attr: str): + if attr == "remote": + raise AttributeError( + ".remote() cannot be used on ClassMethodNodes. Use .bind() instead " + "to express an symbolic actor call." + ) + else: + return self.__getattribute__(attr) + + def options(self, **options): + self._options = options + return self + + +class _ClassMethodOutput: + """Represents a class method output in a Ray function DAG.""" + + def __init__(self, class_method_call: "ClassMethodNode", output_idx: int): + # The upstream class method call that returns multiple values. + self._class_method_call = class_method_call + # The output index of the return value from the upstream class method call. + self._output_idx = output_idx + + @property + def class_method_call(self) -> "ClassMethodNode": + return self._class_method_call + + @property + def output_idx(self) -> int: + return self._output_idx + + +@DeveloperAPI +class ClassMethodNode(DAGNode): + """Represents an actor method invocation in a Ray function DAG.""" + + def __init__( + self, + method_name: str, + method_args: Tuple[Any], + method_kwargs: Dict[str, Any], + method_options: Dict[str, Any], + other_args_to_resolve: Dict[str, Any], + ): + self._bound_args = method_args or [] + self._bound_kwargs = method_kwargs or {} + self._bound_options = method_options or {} + self._method_name: str = method_name + # Parse other_args_to_resolve and assign to variables + self._parent_class_node: Union[ + ClassNode, ReferenceType["ray._private.actor.ActorHandle"] + ] = other_args_to_resolve.get(PARENT_CLASS_NODE_KEY) + # Used to track lineage of ClassMethodCall to preserve deterministic + # submission and execution order. + self._prev_class_method_call: Optional[ + ClassMethodNode + ] = other_args_to_resolve.get(PREV_CLASS_METHOD_CALL_KEY, None) + # The index/order when bind() is called on this class method + self._bind_index: Optional[int] = other_args_to_resolve.get( + BIND_INDEX_KEY, None + ) + # Represent if the ClassMethodNode is a class method output. If True, + # the node is a placeholder for a return value from the ClassMethodNode + # that returns multiple values. If False, the node is a class method call. + self._is_class_method_output: bool = other_args_to_resolve.get( + IS_CLASS_METHOD_OUTPUT_KEY, False + ) + # Represents the return value from the upstream ClassMethodNode that + # returns multiple values. If the node is a class method call, this is None. + self._class_method_output: Optional[_ClassMethodOutput] = None + if self._is_class_method_output: + # Set the upstream ClassMethodNode and the output index of the return + # value from `method_args`. + self._class_method_output = _ClassMethodOutput( + method_args[0], method_args[1] + ) + + # The actor creation task dependency is encoded as the first argument, + # and the ordering dependency as the second, which ensures they are + # executed prior to this node. + super().__init__( + method_args, + method_kwargs, + method_options, + other_args_to_resolve=other_args_to_resolve, + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return ClassMethodNode( + self._method_name, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs): + """Executor of ClassMethodNode by ray.remote() + + Args and kwargs are to match base class signature, but not in the + implementation. All args and kwargs should be resolved and replaced + with value in bound_args and bound_kwargs via bottom-up recursion when + current node is executed. + """ + if self.is_class_method_call: + method_body = getattr(self._parent_class_node, self._method_name) + # Execute with bound args. + return method_body.options(**self._bound_options).remote( + *self._bound_args, + **self._bound_kwargs, + ) + else: + assert self._class_method_output is not None + return self._bound_args[0][self._class_method_output.output_idx] + + def __str__(self) -> str: + return get_dag_node_str(self, f"{self._method_name}()") + + def __repr__(self) -> str: + return self.__str__() + + def get_method_name(self) -> str: + return self._method_name + + def _get_bind_index(self) -> int: + return self._bind_index + + def _get_remote_method(self, method_name): + method_body = getattr(self._parent_class_node, method_name) + return method_body + + def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]: + if not isinstance(self._parent_class_node, ray.actor.ActorHandle): + return None + return self._parent_class_node + + @property + def num_returns(self) -> int: + """ + Return the number of return values from the class method call. If the + node is a class method output, return the number of return values from + the upstream class method call. + """ + + if self.is_class_method_call: + num_returns = self._bound_options.get("num_returns", None) + if num_returns is None: + method = self._get_remote_method(self._method_name) + num_returns = method.__getstate__()["num_returns"] + return num_returns + else: + assert self._class_method_output is not None + return self._class_method_output.class_method_call.num_returns + + @property + def is_class_method_call(self) -> bool: + """ + Return True if the node is a class method call, False if the node is a + class method output. + """ + return not self._is_class_method_output + + @property + def is_class_method_output(self) -> bool: + """ + Return True if the node is a class method output, False if the node is a + class method call. + """ + return self._is_class_method_output + + @property + def class_method_call(self) -> Optional["ClassMethodNode"]: + """ + Return the upstream class method call that returns multiple values. If + the node is a class method output, return None. + """ + + if self._class_method_output is None: + return None + return self._class_method_output.class_method_call + + @property + def output_idx(self) -> Optional[int]: + """ + Return the output index of the return value from the upstream class + method call that returns multiple values. If the node is a class method + call, return None. + """ + + if self._class_method_output is None: + return None + return self._class_method_output.output_idx diff --git a/.venv/lib/python3.11/site-packages/ray/dag/collective_node.py b/.venv/lib/python3.11/site-packages/ray/dag/collective_node.py new file mode 100644 index 0000000000000000000000000000000000000000..660c5d575f215f49be81ef618571610945d8dffd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/collective_node.py @@ -0,0 +1,191 @@ +from typing import Any, Dict, List, Union, Tuple, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + +import ray +from ray.dag import ( + DAGNode, + ClassMethodNode, +) +from ray.dag.constants import COLLECTIVE_OPERATION_KEY +from ray.experimental.channel import ChannelContext +from ray.experimental.channel.torch_tensor_nccl_channel import _init_communicator +from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType +from ray.experimental.util.types import _CollectiveOp, ReduceOp +from ray.util.annotations import DeveloperAPI + + +class _CollectiveOperation: + """ + Represent metadata for a NCCL collective operation. + + Args: + input_nodes: A list of input nodes to the collective operation. + op: The collective operation to perform. + transport: The transport to use for the collective operation. + + Requirements: + 1. Input nodes are unique. + 2. Actor handles are unique. + 3. Actor handles match the custom NCCL group if specified. + """ + + def __init__( + self, + input_nodes: List[DAGNode], + op: _CollectiveOp, + transport: Optional[Union[str, Communicator]] = None, + ): + if len(input_nodes) == 0: + raise ValueError("Expected input nodes for a collective operation") + if len(set(input_nodes)) != len(input_nodes): + raise ValueError("Expected unique input nodes for a collective operation") + + self._actor_handles: List["ray.actor.ActorHandle"] = [] + for input_node in input_nodes: + actor_handle = input_node._get_actor_handle() + if actor_handle is None: + raise ValueError("Expected an actor handle from the input node") + self._actor_handles.append(actor_handle) + if len(set(self._actor_handles)) != len(self._actor_handles): + invalid_input_nodes = [ + input_node + for input_node in input_nodes + if self._actor_handles.count(input_node._get_actor_handle()) > 1 + ] + raise ValueError( + "Expected unique actor handles for a collective operation, " + "but found duplicate actor handles from input nodes: " + f"{invalid_input_nodes}" + ) + + self._op = op + if not isinstance(self._op, ReduceOp): + raise NotImplementedError("Only ReduceOp is implemented") + if transport is None: + transport = TorchTensorType.NCCL + self._type_hint = TorchTensorType(transport=transport, _direct_return=True) + if isinstance(transport, Communicator): + if set(transport.get_actor_handles()) != set(self._actor_handles): + raise ValueError( + "Expected actor handles to match the custom NCCL group" + ) + + def __str__(self) -> str: + return ( + f"CollectiveGroup(" + f"_actor_handles={self._actor_handles}, " + f"_op={self._op}, " + f"_type_hint={self._type_hint})" + ) + + @property + def actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + @property + def type_hint(self) -> TorchTensorType: + return self._type_hint + + def init_communicator(self, communicator_id: Optional[str] = None) -> str: + """ + Initialize the communicator if it has not been initialized yet. If + `communicator_id` is provided, it means the communicator has already + been initialized. + """ + type_hint = self._type_hint + if type_hint.communicator_id is not None: + return type_hint.communicator_id + if communicator_id is None: + communicator_id = _init_communicator( + self._actor_handles, type_hint.get_custom_communicator() + ) + type_hint.set_communicator_id(communicator_id) + return communicator_id + + def get_communicator(self) -> Communicator: + if self._type_hint.communicator_id is not None: + ctx = ChannelContext.get_current() + communicator = ctx.communicators[self._type_hint.communicator_id] + elif self._type_hint.get_custom_communicator() is not None: + communicator = self._type_hint.get_custom_communicator() + else: + raise ValueError("Expected a NCCL group") + return communicator + + def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor": + """ + Call the collective operation on the input tensor. An output tensor is + allocated and returned. + """ + import torch + + if not isinstance(send_buf, torch.Tensor): + raise ValueError("Expected a torch tensor") + communicator = self.get_communicator() + recv_buf = torch.empty_like(send_buf) + communicator.allreduce(send_buf, recv_buf, self._op) + return recv_buf + + +@DeveloperAPI +class CollectiveOutputNode(ClassMethodNode): + """Represent an output node from a NCCL collective operation in a Ray DAG.""" + + def __init__( + self, + method_name: str, + method_args: Tuple[ + DAGNode, + ], + method_kwargs: Dict[str, Any], + method_options: Dict[str, Any], + other_args_to_resolve: Dict[str, Any], + ): + # Parse the input node. + if not ( + isinstance(method_args, tuple) + and len(method_args) == 1 + and isinstance(method_args[0], DAGNode) + ): + raise ValueError("Expected a single input node") + self._input_node = method_args[0] + # Parse the collective operation. + self._collective_op: _CollectiveOperation = other_args_to_resolve.get( + COLLECTIVE_OPERATION_KEY, None + ) + if self._collective_op is None: + raise ValueError("Expected a collective operation") + + super().__init__( + method_name, + method_args, + method_kwargs, + method_options, + other_args_to_resolve, + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return CollectiveOutputNode( + self._method_name, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs): + raise NotImplementedError( + "CollectiveOutputNode is only supported with dag.experimental_compile()" + ) + + @property + def collective_op(self) -> _CollectiveOperation: + return self._collective_op diff --git a/.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py b/.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py new file mode 100644 index 0000000000000000000000000000000000000000..122a1921af2c93591be63c67eced3a22246b6fcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py @@ -0,0 +1,2986 @@ +import weakref +import asyncio +from collections import defaultdict +from contextlib import nullcontext +from dataclasses import dataclass, asdict +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + List, + Tuple, + Union, + Optional, + Set, +) +import logging +import threading +import time +import uuid +import traceback + +from ray.experimental.channel.auto_transport_type import ( + AutoTransportType, + TypeHintResolver, +) +import ray.exceptions +from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture +from ray.experimental.channel.cached_channel import CachedChannel +from ray.experimental.channel.communicator import Communicator +from ray.dag.constants import ( + RAY_CGRAPH_ENABLE_NVTX_PROFILING, + RAY_CGRAPH_VISUALIZE_SCHEDULE, +) +import ray +from ray.exceptions import RayTaskError, RayChannelError, RayChannelTimeoutError +from ray.experimental.compiled_dag_ref import ( + CompiledDAGRef, + CompiledDAGFuture, + _process_return_vals, +) +from ray.experimental.channel import ( + ChannelContext, + ChannelInterface, + ChannelOutputType, + ReaderInterface, + SynchronousReader, + WriterInterface, + SynchronousWriter, + AwaitableBackgroundReader, + AwaitableBackgroundWriter, + CompiledDAGArgs, + CompositeChannel, + IntraProcessChannel, +) +from ray.util.annotations import DeveloperAPI + +from ray.experimental.channel.shared_memory_channel import ( + SharedMemoryType, +) +from ray.experimental.channel.torch_tensor_type import TorchTensorType + +from ray.experimental.channel.torch_tensor_nccl_channel import ( + _init_communicator, + _destroy_communicator, +) + +from ray.dag.dag_node_operation import ( + _DAGNodeOperation, + _DAGNodeOperationType, + _DAGOperationGraphNode, + _build_dag_node_operation_graph, + _extract_execution_schedule, + _generate_actor_to_execution_schedule, + _generate_overlapped_execution_schedule, + _visualize_execution_schedule, +) + +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +if TYPE_CHECKING: + import cupy as cp + +logger = logging.getLogger(__name__) + +# Keep tracking of every compiled dag created during the lifetime of +# this process. It tracks them as weakref meaning when the compiled dag +# is GC'ed, it is automatically removed from here. It is used to teardown +# compiled dags at interpreter shutdown time. +_compiled_dags = weakref.WeakValueDictionary() + + +# Relying on __del__ doesn't work well upon shutdown because +# the destructor order is not guaranteed. We call this function +# upon `ray.worker.shutdown` which is registered to atexit handler +# so that teardown is properly called before objects are destructed. +def _shutdown_all_compiled_dags(): + global _compiled_dags + for _, compiled_dag in _compiled_dags.items(): + # Kill DAG actors to avoid hanging during shutdown if the actor tasks + # cannot be cancelled. + compiled_dag.teardown(kill_actors=True) + _compiled_dags = weakref.WeakValueDictionary() + + +def _check_unused_dag_input_attributes( + output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str] +) -> Set[str]: + """ + Helper function to check that all input attributes are used in the DAG. + For example, if the user creates an input attribute by calling + InputNode()["x"], we ensure that there is a path from the + InputAttributeNode corresponding to "x" to the DAG's output. If an + input attribute is not used, throw an error. + + Args: + output_node: The starting node for the traversal. + input_attributes: A set of attributes accessed by the InputNode. + """ + from ray.dag import InputAttributeNode + + used_attributes = set() + visited_nodes = set() + stack: List["ray.dag.DAGNode"] = [output_node] + + while stack: + current_node = stack.pop() + if current_node in visited_nodes: + continue + visited_nodes.add(current_node) + + if isinstance(current_node, InputAttributeNode): + used_attributes.add(current_node.key) + + stack.extend(current_node._upstream_nodes) + + unused_attributes = input_attributes - used_attributes + if unused_attributes: + unused_attributes_str = ", ".join(str(key) for key in unused_attributes) + input_attributes_str = ", ".join(str(key) for key in input_attributes) + unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused" + + raise ValueError( + "Compiled Graph expects input to be accessed " + f"using all of attributes {input_attributes_str}, " + f"but {unused_attributes_str} {unused_phrase}. " + "Ensure all input attributes are used and contribute " + "to the computation of the Compiled Graph output." + ) + + +@DeveloperAPI +def do_allocate_channel( + self, + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + typ: ChannelOutputType, + driver_actor_id: Optional[str] = None, +) -> ChannelInterface: + """Generic actor method to allocate an output channel. + + Args: + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + typ: The output type hint for the channel. + driver_actor_id: If this channel is read by a driver and that driver is an + actual actor, this will be the actor ID of that driver actor. + + Returns: + The allocated channel. + """ + # None means it is called from a driver. + writer: Optional["ray.actor.ActorHandle"] = None + try: + writer = ray.get_runtime_context().current_actor + except RuntimeError: + # This is the driver so there is no current actor handle. + pass + + output_channel = typ.create_channel( + writer, + reader_and_node_list, + driver_actor_id, + ) + return output_channel + + +@DeveloperAPI +def do_exec_tasks( + self, + tasks: List["ExecutableTask"], + schedule: List[_DAGNodeOperation], + overlap_gpu_communication: bool = False, +) -> None: + """A generic actor method to begin executing the operations belonging to an + actor. This runs an infinite loop to execute each _DAGNodeOperation in the + order specified by the schedule. It exits only if the actor dies or an + exception is thrown. + + Args: + tasks: the executable tasks corresponding to the actor methods. + schedule: A list of _DAGNodeOperation that should be executed in order. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + """ + try: + for task in tasks: + task.prepare(overlap_gpu_communication=overlap_gpu_communication) + + if RAY_CGRAPH_ENABLE_NVTX_PROFILING: + try: + import nvtx + except ImportError: + raise ImportError( + "Please install nvtx to enable nsight profiling. " + "You can install it by running `pip install nvtx`." + ) + nvtx_profile = nvtx.Profile() + nvtx_profile.enable() + + done = False + while True: + if done: + break + for operation in schedule: + done = tasks[operation.exec_task_idx].exec_operation( + self, operation.type, overlap_gpu_communication + ) + if done: + break + + if RAY_CGRAPH_ENABLE_NVTX_PROFILING: + nvtx_profile.disable() + except Exception: + logging.exception("Compiled DAG task exited with exception") + raise + + +@DeveloperAPI +def do_profile_tasks( + self, + tasks: List["ExecutableTask"], + schedule: List[_DAGNodeOperation], + overlap_gpu_communication: bool = False, +) -> None: + """A generic actor method similar to `do_exec_tasks`, but with profiling enabled. + + Args: + tasks: the executable tasks corresponding to the actor methods. + schedule: A list of _DAGNodeOperation that should be executed in order. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + """ + try: + for task in tasks: + task.prepare(overlap_gpu_communication=overlap_gpu_communication) + + if not hasattr(self, "__ray_cgraph_events"): + self.__ray_cgraph_events = [] + + done = False + while True: + if done: + break + for operation in schedule: + start_t = time.perf_counter() + task = tasks[operation.exec_task_idx] + done = task.exec_operation( + self, operation.type, overlap_gpu_communication + ) + end_t = time.perf_counter() + + self.__ray_cgraph_events.append( + _ExecutableTaskRecord( + actor_classname=self.__class__.__name__, + actor_name=ray.get_runtime_context().get_actor_name(), + actor_id=ray.get_runtime_context().get_actor_id(), + method_name=task.method_name, + bind_index=task.bind_index, + operation=operation.type.value, + start_t=start_t, + end_t=end_t, + ) + ) + + if done: + break + except Exception: + logging.exception("Compiled DAG task exited with exception") + raise + + +@DeveloperAPI +def do_cancel_executable_tasks(self, tasks: List["ExecutableTask"]) -> None: + for task in tasks: + task.cancel() + + +def _wrap_exception(exc): + backtrace = ray._private.utils.format_error_message( + "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), + task_exception=True, + ) + wrapped = RayTaskError( + function_name="do_exec_tasks", + traceback_str=backtrace, + cause=exc, + ) + return wrapped + + +def _get_nccl_group_id(type_hint: ChannelOutputType) -> Optional[str]: + """ + Get the NCCL group ID from the type hint. If the type hint does not + require NCCL, return None. + + Args: + type_hint: The type hint of the channel. + + Returns: + The NCCL group ID if the type hint requires NCCL, otherwise None. + """ + if type_hint.requires_nccl(): + assert isinstance(type_hint, TorchTensorType) + return type_hint.communicator_id + return None + + +def _device_context_manager(): + """ + Return a context manager for executing communication operations + (i.e., READ and WRITE). For NCCL operations, the context manager + uses the proper cuda device from channel context, otherwise, + nullcontext will be returned. + """ + if not ChannelContext.get_current().torch_available: + return nullcontext() + + import torch + + device = ChannelContext.get_current().torch_device + + if device.type == "cuda" and torch.cuda.is_available(): + # In the case of mocked NCCL, we may get a device with type "cuda" + # but CUDA is not available. We return nullcontext() in that case, + # otherwise torch raises a runtime error if the cuda device context + # manager is used. + # TODO(rui): consider better mocking NCCL to support device context. + return torch.cuda.device(device) + return nullcontext() + + +@DeveloperAPI +class CompiledTask: + """Wraps the normal Ray DAGNode with some metadata.""" + + def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"): + """ + Args: + idx: A unique index into the original DAG. + dag_node: The original DAG node created by the user. + """ + self.idx = idx + self.dag_node = dag_node + + # Dict from task index to actor handle for immediate downstream tasks. + self.downstream_task_idxs: Dict[int, "ray.actor.ActorHandle"] = {} + # Case 1: The task represents a ClassMethodNode. + # + # Multiple return values are written to separate `output_channels`. + # `output_idxs` represents the tuple index of the output value for + # multiple returns in a tuple. If an output index is None, it means + # the complete return value is written to the output channel. + # Otherwise, the return value is a tuple and the index is used + # to extract the value to be written to the output channel. + # + # Case 2: The task represents an InputNode. + # + # `output_idxs` can be an integer or a string to retrieve the + # corresponding value from `args` or `kwargs` in the DAG's input. + self.output_channels: List[ChannelInterface] = [] + self.output_idxs: List[Optional[Union[int, str]]] = [] + # The DAGNodes that are arguments to this task. + # This is used for lazy resolution of the arguments' type hints. + self.arg_nodes: List["ray.dag.DAGNode"] = [] + # idxs of possible ClassMethodOutputNodes if they exist, used for visualization + self.output_node_idxs: List[int] = [] + + @property + def args(self) -> Tuple[Any]: + return self.dag_node.get_args() + + @property + def kwargs(self) -> Dict[str, Any]: + return self.dag_node.get_kwargs() + + @property + def num_readers(self) -> int: + return len(self.downstream_task_idxs) + + @property + def arg_type_hints(self) -> List["ChannelOutputType"]: + return [arg_node.type_hint for arg_node in self.arg_nodes] + + def __str__(self) -> str: + return f""" + Node: {self.dag_node} + Arguments: {self.args} + Output: {self.output_channels} + """ + + +class _ExecutableTaskInput: + """Represents an input to an ExecutableTask. + + Args: + input_variant: either an unresolved input (when type is ChannelInterface) + , or a resolved input value (when type is Any) + channel_idx: if input_variant is an unresolved input, this is the index + into the input channels list. + """ + + def __init__( + self, + input_variant: Union[ChannelInterface, Any], + channel_idx: Optional[int], + ): + self.input_variant = input_variant + self.channel_idx = channel_idx + + def resolve(self, channel_results: Any) -> Any: + """ + Resolve the input value from the channel results. + + Args: + channel_results: The results from reading the input channels. + """ + + if isinstance(self.input_variant, ChannelInterface): + value = channel_results[self.channel_idx] + else: + value = self.input_variant + return value + + +@DeveloperAPI +class ExecutableTask: + """A task that can be executed in a compiled DAG, and it + corresponds to an actor method. + """ + + def __init__( + self, + task: "CompiledTask", + resolved_args: List[Any], + resolved_kwargs: Dict[str, Any], + ): + """ + Args: + task: The CompiledTask that this ExecutableTask corresponds to. + resolved_args: The arguments to the method. Arguments that are + not Channels will get passed through to the actor method. + If the argument is a channel, it will be replaced by the + value read from the channel before the method executes. + resolved_kwargs: The keyword arguments to the method. Currently, we + do not support binding kwargs to other DAG nodes, so the values + of the dictionary cannot be Channels. + """ + from ray.dag import CollectiveOutputNode + + self.method_name = task.dag_node.get_method_name() + self.bind_index = task.dag_node._get_bind_index() + self.output_channels = task.output_channels + self.output_idxs = task.output_idxs + self.input_type_hints: List[ChannelOutputType] = task.arg_type_hints + self.output_type_hint: ChannelOutputType = task.dag_node.type_hint + + # The NCCL collective operation. + self.collective_op: Optional["ray.dag.CollectiveOperation"] = None + if isinstance(task.dag_node, CollectiveOutputNode): + self.collective_op = task.dag_node.collective_op + + self.input_channels: List[ChannelInterface] = [] + self.task_inputs: List[_ExecutableTaskInput] = [] + self.resolved_kwargs: Dict[str, Any] = resolved_kwargs + # A unique index which can be used to index into `idx_to_task` to get + # the corresponding task. + self.task_idx = task.idx + + # Reverse map for input_channels: maps an input channel to + # its index in input_channels. + input_channel_to_idx: dict[ChannelInterface, int] = {} + + for arg in resolved_args: + if isinstance(arg, ChannelInterface): + if isinstance(arg, ChannelInterface): + channel = arg + else: + adapter = arg + channel = adapter.get_dag_input_channel() + + if channel in input_channel_to_idx: + # The same channel was added before, so reuse the index. + channel_idx = input_channel_to_idx[channel] + else: + # Add a new channel to the list of input channels. + self.input_channels.append(channel) + channel_idx = len(self.input_channels) - 1 + input_channel_to_idx[channel] = channel_idx + + task_input = _ExecutableTaskInput(arg, channel_idx) + else: + task_input = _ExecutableTaskInput(arg, None) + self.task_inputs.append(task_input) + + # Currently DAGs do not support binding kwargs to other DAG nodes. + for val in self.resolved_kwargs.values(): + assert not isinstance(val, ChannelInterface) + + # Input reader to read input data from upstream DAG nodes. + self.input_reader: ReaderInterface = SynchronousReader(self.input_channels) + # Output writer to write output data to downstream DAG nodes. + self.output_writer: WriterInterface = SynchronousWriter( + self.output_channels, self.output_idxs + ) + # The intermediate future for a READ or COMPUTE operation, + # and `wait()` must be called to get the actual result of the operation. + # The result of a READ operation will be used by a COMPUTE operation, + # and the result of a COMPUTE operation will be used by a WRITE operation. + self._intermediate_future: Optional[DAGOperationFuture] = None + + def cancel(self): + """ + Close all the input channels and the output channel. The exact behavior + depends on the type of channel. Typically, it will release the resources + used by the channels. + """ + self.input_reader.close() + self.output_writer.close() + + def prepare(self, overlap_gpu_communication: bool = False): + """ + Prepare the task for execution. The `exec_operation` function can only + be called after `prepare` has been called. + + Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance + """ + for typ_hint in self.input_type_hints: + typ_hint.register_custom_serializer() + self.output_type_hint.register_custom_serializer() + self.input_reader.start() + self.output_writer.start() + + self._send_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() + self._recv_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() + if not overlap_gpu_communication: + return + + # Set up send_stream and recv_stream when overlap_gpu_communication + # is configured + if self.output_type_hint.requires_nccl(): + nccl_group_id = _get_nccl_group_id(self.output_type_hint) + nccl_group = ChannelContext.get_current().communicators.get(nccl_group_id) + assert nccl_group is not None + self._send_stream = nccl_group.send_stream + if self.input_type_hints: + for type_hint in self.input_type_hints: + if type_hint.requires_nccl(): + nccl_group_id = _get_nccl_group_id(type_hint) + nccl_group = ChannelContext.get_current().communicators.get( + nccl_group_id + ) + assert nccl_group is not None + if not isinstance(self._recv_stream, nullcontext): + assert self._recv_stream == nccl_group.recv_stream, ( + "Currently all torch tensor input channels of a " + "Compiled Graph task should use the same recv cuda stream." + ) + self._recv_stream = nccl_group.recv_stream + + def wrap_and_set_intermediate_future( + self, val: Any, wrap_in_gpu_future: bool + ) -> None: + """ + Wrap the value in a `DAGOperationFuture` and store to the intermediate future. + The value corresponds to result of a READ or COMPUTE operation. + + If wrap_in_gpu_future is True, the value will be wrapped in a GPUFuture, + Otherwise, the future will be a ResolvedFuture. + + Args: + val: The value to wrap in a future. + wrap_in_gpu_future: Whether to wrap the value in a GPUFuture. + """ + assert self._intermediate_future is None + + if wrap_in_gpu_future: + future = GPUFuture(val) + else: + future = ResolvedFuture(val) + self._intermediate_future = future + + def reset_and_wait_intermediate_future(self) -> Any: + """ + Reset the intermediate future and wait for the result. + + The wait does not block the CPU because: + - If the future is a ResolvedFuture, the result is immediately returned. + - If the future is a GPUFuture, the result is only waited by the current + CUDA stream, and the CPU is not blocked. + + Returns: + The result of a READ or COMPUTE operation from the intermediate future. + """ + future = self._intermediate_future + self._intermediate_future = None + return future.wait() + + def _read(self, overlap_gpu_communication: bool) -> bool: + """ + Read input data from upstream DAG nodes and cache the intermediate result. + + Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + assert self._intermediate_future is None + exit = False + try: + input_data = self.input_reader.read() + # When overlap_gpu_communication is enabled, wrap the result in + # a GPUFuture so that this read operation (communication) can + # be overlapped with computation. + self.wrap_and_set_intermediate_future( + input_data, wrap_in_gpu_future=overlap_gpu_communication + ) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def _compute( + self, + overlap_gpu_communication: bool, + class_handle, + ) -> bool: + """ + Retrieve the intermediate result from the READ operation and perform the + computation. Then, cache the new intermediate result. The caller must ensure + that the last operation executed is READ so that the function retrieves the + correct intermediate result. + + Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + class_handle: An instance of the class to which the actor belongs. For + example, the type of `class_handle` is if the + actor belongs to the `class Worker` class. + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + input_data = self.reset_and_wait_intermediate_future() + try: + _process_return_vals(input_data, return_single_output=False) + except Exception as exc: + # Previous task raised an application-level exception. + # Propagate it and skip the actual task. We don't need to wrap the + # exception in a RayTaskError here because it has already been wrapped + # by the previous task. + self.wrap_and_set_intermediate_future( + exc, wrap_in_gpu_future=overlap_gpu_communication + ) + return False + + resolved_inputs = [] + for task_input in self.task_inputs: + resolved_inputs.append(task_input.resolve(input_data)) + + if self.collective_op is not None: + # Run a NCCL collective operation. + method = self.collective_op.execute + else: + # Run an actor method. + method = getattr(class_handle, self.method_name) + try: + output_val = method(*resolved_inputs, **self.resolved_kwargs) + except Exception as exc: + output_val = _wrap_exception(exc) + + # When overlap_gpu_communication is enabled, wrap the result in a GPUFuture + # so that this compute operation can be overlapped with communication. + self.wrap_and_set_intermediate_future( + output_val, wrap_in_gpu_future=overlap_gpu_communication + ) + return False + + def _write(self) -> bool: + """ + Retrieve the intermediate result from the COMPUTE operation and write to its + downstream DAG nodes. The caller must ensure that the last operation executed + is COMPUTE so that the function retrieves the correct intermediate result. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + output_val = self.reset_and_wait_intermediate_future() + exit = False + try: + self.output_writer.write(output_val) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def exec_operation( + self, + class_handle, + op_type: _DAGNodeOperationType, + overlap_gpu_communication: bool = False, + ) -> bool: + """ + An ExecutableTask corresponds to a DAGNode. It consists of three + operations: READ, COMPUTE, and WRITE, which should be executed in + order to ensure that each operation can read the correct intermediate + result. + Args: + class_handle: The handle of the class to which the actor belongs. + op_type: The type of the operation. Possible types are READ, + COMPUTE, and WRITE. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + Returns: + True if the next operation should not be executed; otherwise, False. + """ + if op_type == _DAGNodeOperationType.READ: + with _device_context_manager(): + with self._recv_stream: + return self._read(overlap_gpu_communication) + elif op_type == _DAGNodeOperationType.COMPUTE: + return self._compute(overlap_gpu_communication, class_handle) + elif op_type == _DAGNodeOperationType.WRITE: + with _device_context_manager(): + with self._send_stream: + return self._write() + + +@dataclass +class _ExecutableTaskRecord: + actor_classname: str + actor_name: str + actor_id: str + method_name: str + bind_index: int + operation: str + start_t: float + end_t: float + + def to_dict(self): + return asdict(self) + + +@DeveloperAPI +class CompiledDAG: + """Experimental class for accelerated execution. + + This class should not be called directly. Instead, create + a ray.dag and call experimental_compile(). + + See REP https://github.com/ray-project/enhancements/pull/48 for more + information. + """ + + @ray.remote(num_cpus=0) + class DAGDriverProxyActor: + """ + To support the driver as a reader, the output writer needs to be able to invoke + remote functions on the driver. This is necessary so that the output writer can + create a reader ref on the driver node, and later potentially create a larger + reader ref on the driver node if the channel backing store needs to be resized. + However, remote functions cannot be invoked on the driver. + + A Compiled Graph creates an actor from this class when the DAG is initialized. + The actor is on the same node as the driver. This class has an empty + implementation, though it serves as a way for the output writer to invoke remote + functions on the driver node. + """ + + pass + + def __init__( + self, + submit_timeout: Optional[float] = None, + buffer_size_bytes: Optional[int] = None, + enable_asyncio: bool = False, + max_inflight_executions: Optional[int] = None, + overlap_gpu_communication: Optional[bool] = None, + ): + """ + Args: + submit_timeout: The maximum time in seconds to wait for execute() calls. + None means using default timeout (DAGContext.submit_timeout), + 0 means immediate timeout (immediate success or timeout without + blocking), -1 means infinite timeout (block indefinitely). + buffer_size_bytes: The initial buffer size in bytes for messages + that can be passed between tasks in the DAG. The buffers will + be automatically resized if larger messages are written to the + channel. + enable_asyncio: Whether to enable asyncio. If enabled, caller must + be running in an event loop and must use `execute_async` to + invoke the DAG. Otherwise, the caller should use `execute` to + invoke the DAG. + max_inflight_executions: The maximum number of in-flight executions that + can be submitted via `execute` or `execute_async` before consuming + the output using `ray.get()`. If the caller submits more executions, + `RayCgraphCapacityExceeded` is raised. + overlap_gpu_communication: (experimental) Whether to overlap GPU + communication with computation during DAG execution. If True, the + communication and computation can be overlapped, which can improve + the performance of the DAG execution. If None, the default value + will be used. + + Returns: + Channel: A wrapper around ray.ObjectRef. + """ + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + + self._enable_asyncio: bool = enable_asyncio + self._fut_queue = asyncio.Queue() + self._max_inflight_executions = max_inflight_executions + if self._max_inflight_executions is None: + self._max_inflight_executions = ctx.max_inflight_executions + self._dag_id = uuid.uuid4().hex + self._submit_timeout: Optional[float] = submit_timeout + if self._submit_timeout is None: + self._submit_timeout = ctx.submit_timeout + self._get_timeout: Optional[float] = ctx.get_timeout + self._buffer_size_bytes: Optional[int] = buffer_size_bytes + if self._buffer_size_bytes is None: + self._buffer_size_bytes = ctx.buffer_size_bytes + self._overlap_gpu_communication: Optional[bool] = overlap_gpu_communication + if self._overlap_gpu_communication is None: + self._overlap_gpu_communication = ctx.overlap_gpu_communication + + self._default_type_hint: ChannelOutputType = SharedMemoryType( + buffer_size_bytes=self._buffer_size_bytes, + # We conservatively set num_shm_buffers to _max_inflight_executions. + # It means that the DAG can be underutilized, but it guarantees there's + # no false positive timeouts. + num_shm_buffers=1, + ) + if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0: + raise ValueError( + "`buffer_size_bytes` must be a positive integer, found " + f"{self._buffer_size_bytes}" + ) + + # Used to ensure that the future returned to the + # caller corresponds to the correct DAG output. I.e. + # order of futures added to fut_queue should match the + # order of inputs written to the DAG. + self._dag_submission_lock = asyncio.Lock() + + # idx -> CompiledTask. + self.idx_to_task: Dict[int, "CompiledTask"] = {} + # DAGNode -> idx. + self.dag_node_to_idx: Dict["ray.dag.DAGNode", int] = {} + # idx counter. + self.counter: int = 0 + + # Attributes that are set during preprocessing. + # Preprocessing identifies the input node and output node. + self.input_task_idx: Optional[int] = None + self.output_task_idx: Optional[int] = None + # List of task indices that are input attribute nodes. + self.input_attr_task_idxs: List[int] = [] + # Denotes whether execute/execute_async returns a list of refs/futures. + self._returns_list: bool = False + # Number of expected positional args and kwargs that may be passed to + # dag.execute. + self._input_num_positional_args: Optional[int] = None + self._input_kwargs: Tuple[str, ...] = None + + # Cached attributes that are set during compilation. + self.dag_input_channels: Optional[List[ChannelInterface]] = None + self.dag_output_channels: Optional[List[ChannelInterface]] = None + self._dag_submitter: Optional[WriterInterface] = None + self._dag_output_fetcher: Optional[ReaderInterface] = None + + # ObjectRef for each worker's task. The task is an infinite loop that + # repeatedly executes the method specified in the DAG. + self.worker_task_refs: Dict["ray.actor.ActorHandle", "ray.ObjectRef"] = {} + # Set of actors present in the DAG. + self.actor_refs = set() + self.actor_to_tasks: Dict[ + "ray.actor.ActorHandle", List["CompiledTask"] + ] = defaultdict(list) + # Mapping from actor handle to its GPU IDs. + # This is used for type hint resolution for with_tensor_transport("auto"). + self.actor_to_gpu_ids: Dict["ray.actor.ActorHandle", List[str]] = {} + self.actor_to_executable_tasks: Dict[ + "ray.actor.ActorHandle", List["ExecutableTask"] + ] = {} + # Mapping from the actor handle to the execution schedule which is a list + # of operations to be executed. + self.actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGNodeOperation] + ] = defaultdict(list) + # Mapping from the actor handle to the node ID that the actor is on. + # A None actor handle means the actor is the driver. + self.actor_to_node_id: Dict[Optional["ray.actor.ActorHandle"], str] = {} + + # This is set to true when type hint of `transport="nccl"` is used. + self._use_default_nccl_group = False + # This is set to the specified custom communicator + # if there exists a type hint of `transport=custom_communicator`. + self._custom_communicator_p2p: Optional[Communicator] = None + # The NCCL group ID for P2P send/recv operations. + self._communicator_id_p2p: Optional[str] = None + # All the NCCL group IDs for P2P send/recv and collective operations. + self._communicator_ids: Set[str] = set() + # The index of the current execution. It is incremented each time + # the DAG is executed. + self._execution_index: int = -1 + # The maximum index of finished executions. + # All results with higher indexes have not been generated yet. + self._max_finished_execution_index: int = -1 + # execution_index -> {channel_index -> result} + self._result_buffer: Dict[int, Dict[int, Any]] = defaultdict(dict) + # channel to possible inner channel + self._channel_dict: Dict[ChannelInterface, ChannelInterface] = {} + + def _create_proxy_actor() -> "ray.actor.ActorHandle": + # Creates the driver actor on the same node as the driver. + # + # To support the driver as a reader, the output writer needs to be able to + # invoke remote functions on the driver (e.g., to create the reader ref, to + # create a reader ref for a larger object when the channel backing store is + # resized, etc.). The driver actor serves as a way for the output writer + # to invoke remote functions on the driver node. + return CompiledDAG.DAGDriverProxyActor.options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ) + ).remote() + + self._proxy_actor = _create_proxy_actor() + # Set to True when `teardown` API is called. + self._is_teardown = False + # execution indices -> set of channel indices of destructed CompiledDAGRefs + # When a CompiledDagRef is destructed and its result has not been cached and + # ray.get has not been called on it, we will add it to this dict, so that + # we can lazily release the native buffers + self._destructed_ref_idxs: Dict[int, Set[Optional[int]]] = defaultdict(set) + + @property + def communicator_id_p2p(self) -> Optional[str]: + return self._communicator_id_p2p + + @property + def is_teardown(self) -> bool: + return self._is_teardown + + @property + def communicator_ids(self) -> Set[str]: + return self._communicator_ids + + def get_id(self) -> str: + """ + Get the unique ID of the compiled DAG. + """ + return self._dag_id + + def __str__(self) -> str: + return f"CompiledDAG({self._dag_id})" + + def _add_node(self, node: "ray.dag.DAGNode") -> None: + idx = self.counter + self.idx_to_task[idx] = CompiledTask(idx, node) + self.dag_node_to_idx[node] = idx + self.counter += 1 + + def _preprocess(self) -> None: + """Before compiling, preprocess the DAG to build an index from task to + upstream and downstream tasks, and to set the input and output node(s) + of the DAG. + + This function is idempotent. + """ + from ray.dag import ( + DAGNode, + ClassMethodNode, + CollectiveOutputNode, + FunctionNode, + InputAttributeNode, + InputNode, + MultiOutputNode, + ) + from ray.dag.collective_node import _CollectiveOperation + + self.input_task_idx, self.output_task_idx = None, None + + nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set() + collective_ops: Set[_CollectiveOperation] = set() + + input_attributes: Set[str] = set() + # Find the input node and input attribute nodes in the DAG. + for idx, task in self.idx_to_task.items(): + if isinstance(task.dag_node, InputNode): + assert self.input_task_idx is None, "More than one InputNode found" + self.input_task_idx = idx + # handle_unused_attributes: + # Save input attributes in a set. + input_node = task.dag_node + input_attributes.update(input_node.input_attribute_nodes.keys()) + elif isinstance(task.dag_node, InputAttributeNode): + self.input_attr_task_idxs.append(idx) + + # Find the (multi-)output node to the DAG. + for idx, task in self.idx_to_task.items(): + if idx == self.input_task_idx or isinstance( + task.dag_node, InputAttributeNode + ): + continue + if ( + len(task.downstream_task_idxs) == 0 + and task.dag_node.is_cgraph_output_node + ): + assert self.output_task_idx is None, "More than one output node found" + self.output_task_idx = idx + + assert self.output_task_idx is not None + output_node = self.idx_to_task[self.output_task_idx].dag_node + # Add an MultiOutputNode to the end of the DAG if it's not already there. + if not isinstance(output_node, MultiOutputNode): + output_node = MultiOutputNode([output_node]) + self._add_node(output_node) + self.output_task_idx = self.dag_node_to_idx[output_node] + else: + self._returns_list = True + + # TODO: Support no-input DAGs (use an empty object to signal). + if self.input_task_idx is None: + raise NotImplementedError( + "Compiled DAGs currently require exactly one InputNode" + ) + + # Whether the DAG binds directly to the InputNode(), versus binding to + # a positional arg or kwarg of the input. For example, a.foo.bind(inp) + # instead of a.foo.bind(inp[0]) or a.foo.bind(inp.key). + direct_input: Optional[bool] = None + # Collect the set of InputNode keys bound to DAG node args. + input_positional_args: Set[int] = set() + input_kwargs: Set[str] = set() + # Set of tasks with annotation of with_tensor_transport("auto"). + # These only correspond to ClassMethodNodes, but not InputNodes + # or InputAttributeNodes. + auto_transport_tasks: Set["CompiledTask"] = set() + + # For each task node, set its upstream and downstream task nodes. + # Also collect the set of tasks that produce torch.tensors. + for task_idx, task in self.idx_to_task.items(): + dag_node = task.dag_node + if not ( + isinstance(dag_node, InputNode) + or isinstance(dag_node, InputAttributeNode) + or isinstance(dag_node, MultiOutputNode) + or isinstance(dag_node, ClassMethodNode) + ): + if isinstance(dag_node, FunctionNode): + # TODO(swang): Support non-actor tasks. + raise NotImplementedError( + "Compiled DAGs currently only support actor method nodes" + ) + else: + raise ValueError(f"Found unsupported node of type {type(dag_node)}") + + if isinstance(dag_node, ClassMethodNode) and dag_node.is_class_method_call: + actor_handle = dag_node._get_actor_handle() + if actor_handle is None: + raise ValueError( + "Compiled DAGs can only bind methods to an actor " + "that is already created with Actor.remote()" + ) + + if actor_handle not in self.actor_to_gpu_ids: + self.actor_to_gpu_ids[actor_handle] = CompiledDAG._get_gpu_ids( + actor_handle + ) + + if isinstance(dag_node.type_hint, AutoTransportType): + auto_transport_tasks.add(task) + + # Collect actors for NCCL P2P methods. + if dag_node.type_hint.requires_nccl(): + nccl_actors_p2p.add(actor_handle) + custom_communicator = dag_node.type_hint.get_custom_communicator() + mixed_nccl_group_error_message = ( + "Compiled Graphs do not support mixed usage of " + "type hints of default NCCL group " + '(i.e., TorchTensor(transport="nccl"))' + "and custom NCCL group " + "(i.e., TorchTensor(transport=nccl_group)). " + "Please check all the TorchTensor type hints and " + "make sure only one type of NCCL transport is specified." + ) + if custom_communicator is None: + if self._custom_communicator_p2p is not None: + raise ValueError(mixed_nccl_group_error_message) + self._use_default_nccl_group = True + else: + if self._use_default_nccl_group: + raise ValueError(mixed_nccl_group_error_message) + if self._custom_communicator_p2p is not None: + if self._custom_communicator_p2p != custom_communicator: + raise ValueError( + "Compiled Graphs currently only support " + "a single custom NCCL group, but multiple " + "have been specified. Check all the " + "TorchTensor(transport=nccl_group) type hints " + "to make sure only one NCCL group is used." + ) + self._custom_communicator_p2p = custom_communicator + + # Collect NCCL collective operations. + if isinstance(dag_node, CollectiveOutputNode): + collective_ops.add(dag_node.collective_op) + assert not self._overlap_gpu_communication, ( + "Currently, the overlap_gpu_communication option is not " + "supported for NCCL collective operations. Please set " + "overlap_gpu_communication=False." + ) + elif isinstance(dag_node, InputNode) or isinstance( + dag_node, InputAttributeNode + ): + if dag_node.type_hint.requires_nccl(): + raise ValueError( + "DAG inputs cannot be transferred via NCCL because " + "the driver cannot participate in the NCCL group" + ) + if isinstance(dag_node.type_hint, AutoTransportType): + # Currently driver on GPU is not supported, so we always + # use shared memory to transfer tensors. + dag_node.type_hint = TorchTensorType() + + if type(dag_node.type_hint) is ChannelOutputType: + # No type hint specified by the user. Replace + # with the default type hint for this DAG. + dag_node.type_hint = self._default_type_hint + + for _, val in task.kwargs.items(): + if isinstance(val, DAGNode): + raise ValueError( + "Compiled DAG currently does not support binding to " + "other DAG nodes as kwargs" + ) + + for _, arg in enumerate(task.args): + if not isinstance(arg, DAGNode): + continue + upstream_node_idx = self.dag_node_to_idx[arg] + upstream_task = self.idx_to_task[upstream_node_idx] + downstream_actor_handle = None + if ( + isinstance(dag_node, ClassMethodNode) + and dag_node.is_class_method_call + ): + downstream_actor_handle = dag_node._get_actor_handle() + + # Add upstream node as the argument nodes of this task, whose + # type hints may be updated when resolved lazily. + task.arg_nodes.append(upstream_task.dag_node) + + if isinstance(upstream_task.dag_node, InputAttributeNode): + # Record all of the keys used to index the InputNode. + # During execution, we will check that the user provides + # the same args and kwargs. + if isinstance(upstream_task.dag_node.key, int): + input_positional_args.add(upstream_task.dag_node.key) + elif isinstance(upstream_task.dag_node.key, str): + input_kwargs.add(upstream_task.dag_node.key) + else: + raise ValueError( + "InputNode() can only be indexed using int " + "for positional args or str for kwargs." + ) + + if direct_input is not None and direct_input: + raise ValueError( + "All tasks must either use InputNode() " + "directly, or they must index to specific args or " + "kwargs." + ) + direct_input = False + + # If the upstream node is an InputAttributeNode, treat the + # DAG's input node as the actual upstream node + upstream_task = self.idx_to_task[self.input_task_idx] + + elif isinstance(upstream_task.dag_node, InputNode): + if direct_input is not None and not direct_input: + raise ValueError( + "All tasks must either use InputNode() directly, " + "or they must index to specific args or kwargs." + ) + direct_input = True + + upstream_task.downstream_task_idxs[task_idx] = downstream_actor_handle + + if upstream_task.dag_node.type_hint.requires_nccl(): + # Add all readers to the NCCL actors of P2P. + nccl_actors_p2p.add(downstream_actor_handle) + + # Check that all specified input attributes, e.g., InputNode()["x"], + # are used in the DAG. + _check_unused_dag_input_attributes(output_node, input_attributes) + + # Collect all leaf nodes. + leaf_nodes: DAGNode = [] + for idx, task in self.idx_to_task.items(): + if not isinstance(task.dag_node, ClassMethodNode): + continue + if ( + len(task.downstream_task_idxs) == 0 + and not task.dag_node.is_cgraph_output_node + ): + leaf_nodes.append(task.dag_node) + # Leaf nodes are not allowed because the exception thrown by the leaf + # node will not be propagated to the driver. + if len(leaf_nodes) != 0: + raise ValueError( + "Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have " + "downstream nodes and are not output nodes. There are " + f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of " + f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the " + f"the MultiOutputNode." + ) + + type_hint_resolver = TypeHintResolver(self.actor_to_gpu_ids) + # Resolve AutoChannelType type hints and track the actors that use NCCL. + # This is needed so that the NCCL group can be initialized for these + # actors that use NCCL. + for task in auto_transport_tasks: + writer = task.dag_node._get_actor_handle() + readers = task.downstream_task_idxs.values() + writer_and_node = (writer, self._get_node_id(writer)) + reader_and_node_list = [ + (reader, self._get_node_id(reader)) for reader in readers + ] + # Update the type hint to the resolved one. This is needed because + # the resolved type hint's `register_custom_serializer` will be called + # in preparation for channel I/O. + task.dag_node.type_hint = type_hint_resolver.resolve( + task.dag_node.type_hint, + writer_and_node, + reader_and_node_list, + ) + if task.dag_node.type_hint.requires_nccl(): + nccl_actors_p2p.add(writer) + nccl_actors_p2p.update(readers) + + nccl_actors_p2p = list(nccl_actors_p2p) + if None in nccl_actors_p2p: + raise ValueError("Driver cannot participate in the NCCL group.") + + # Initialize and cache a NCCL group for each custom NCCL group. All the + # custom NCCL groups are initialized before the default NCCL groups. + custom_communicator_to_id: Dict[Communicator, str] = {} + # Initialize and cache a NCCL group for each set of actors. A set of actors + # can perform P2P send/recv and collective operations. If there are multiple + # custom NCCL groups for a set of actors, only one is cached. + actors_to_communicator_id: Dict[FrozenSet["ray.actor.ActorHandle"], str] = {} + + # If a custom NCCL group is specified for P2P actors, initialize and cache + # the NCCL group ID. + if nccl_actors_p2p and self._custom_communicator_p2p: + if not set(nccl_actors_p2p).issubset( + set(self._custom_communicator_p2p.get_actor_handles()) + ): + raise ValueError( + "Expected P2P actor handles to be a subset of the custom NCCL group" + ) + self._communicator_id_p2p = _init_communicator( + nccl_actors_p2p, + self._custom_communicator_p2p, + self._overlap_gpu_communication, + ) + custom_communicator_to_id[ + self._custom_communicator_p2p + ] = self._communicator_id_p2p + actors = frozenset(nccl_actors_p2p) + actors_to_communicator_id[actors] = self._communicator_id_p2p + + # If a custom communicator is specified for collective actors, initialize and + # cache the communicator ID. + for collective_op in collective_ops: + type_hint = collective_op.type_hint + custom_communicator = type_hint.get_custom_communicator() + if custom_communicator: + communicator_id = collective_op.init_communicator( + custom_communicator_to_id.get(custom_communicator, None) + ) + custom_communicator_to_id[custom_communicator] = communicator_id + actors = frozenset(collective_op.actor_handles) + if actors not in actors_to_communicator_id: + actors_to_communicator_id[actors] = communicator_id + + # If a NCCL group for P2P actors is not initialized, initialize and cache + # the NCCL group ID. + if nccl_actors_p2p and self._communicator_id_p2p is None: + actors = frozenset(nccl_actors_p2p) + if actors in actors_to_communicator_id: + self._communicator_id_p2p = actors_to_communicator_id[actors] + else: + self._communicator_id_p2p = _init_communicator( + nccl_actors_p2p, + self._custom_communicator_p2p, + self._overlap_gpu_communication, + ) + actors_to_communicator_id[actors] = self._communicator_id_p2p + + # If a NCCL group for collective actors is not initialized, initialize and + # cache the NCCL group ID. + for collective_op in collective_ops: + if collective_op.type_hint.communicator_id is None: + actors = frozenset(collective_op.actor_handles) + communicator_id = collective_op.init_communicator( + actors_to_communicator_id.get(actors, None) + ) + if actors not in actors_to_communicator_id: + actors_to_communicator_id[actors] = communicator_id + + # Store all the NCCL group IDs for P2P send/recv and collective operations. + self._communicator_ids = set(actors_to_communicator_id.values()).union( + set(custom_communicator_to_id.values()) + ) + + if direct_input: + self._input_num_positional_args = 1 + elif not input_positional_args: + self._input_num_positional_args = 0 + else: + self._input_num_positional_args = max(input_positional_args) + 1 + self._input_kwargs = tuple(input_kwargs) + + @staticmethod + def _get_gpu_ids(actor_handle: "ray.actor.ActorHandle") -> List[str]: + """ + Get the GPU IDs of an actor handle. + """ + accelerator_ids = ray.get( + actor_handle.__ray_call__.remote( + lambda self: ray.get_runtime_context().get_accelerator_ids() + ) + ) + return accelerator_ids.get("GPU", []) + + def _get_node_id(self, actor_handle: Optional["ray.actor.ActorHandle"]) -> str: + """ + Get the node ID of an actor handle and cache it. + + Args: + actor_handle: The actor handle, or None if the actor handle is the + driver. + Returns: + The node ID of the actor handle or driver. + """ + if actor_handle in self.actor_to_node_id: + return self.actor_to_node_id[actor_handle] + node_id = None + if actor_handle == self._proxy_actor or actor_handle is None: + node_id = ray.get_runtime_context().get_node_id() + else: + node_id = ray.get( + actor_handle.__ray_call__.remote( + lambda self: ray.get_runtime_context().get_node_id() + ) + ) + self.actor_to_node_id[actor_handle] = node_id + return node_id + + def _get_or_compile( + self, + ) -> None: + """Compile an execution path. This allocates channels for adjacent + tasks to send/receive values. An infinite task is submitted to each + actor in the DAG that repeatedly receives from input channel(s) and + sends to output channel(s). + + This function is idempotent and will cache the previously allocated + channels. After calling this function, _dag_submitter and + _dag_output_fetcher will be set and can be used to invoke and fetch + outputs for the DAG. + """ + from ray.dag import ( + DAGNode, + InputNode, + InputAttributeNode, + MultiOutputNode, + ClassMethodNode, + ) + + if self.input_task_idx is None: + self._preprocess() + assert self.input_task_idx is not None + + if self._dag_submitter is not None: + assert self._dag_output_fetcher is not None + return + + frontier = [self.input_task_idx] + visited = set() + # Create output buffers. This loop does a breadth-first search through the DAG. + while frontier: + cur_idx = frontier.pop(0) + if cur_idx in visited: + continue + visited.add(cur_idx) + + task = self.idx_to_task[cur_idx] + type_hint = task.dag_node.type_hint + if type_hint.requires_nccl(): + type_hint.set_communicator_id(self._communicator_id_p2p) + + if ( + isinstance(task.dag_node, ClassMethodNode) + and task.dag_node.is_class_method_call + ): + # Create output buffers for the actor method. + assert len(task.output_channels) == 0 + # `output_to_readers` stores the reader tasks for each output of + # the current node. If the current node returns one output, the + # readers are the downstream nodes of the current node. If the + # current node returns multiple outputs, the readers of each + # output are the downstream nodes of the ClassMethodNode that + # is a class method output. + output_to_readers: Dict[CompiledTask, List[CompiledTask]] = defaultdict( + list + ) + for idx in task.downstream_task_idxs: + downstream_task = self.idx_to_task[idx] + downstream_node = downstream_task.dag_node + if ( + isinstance(downstream_node, ClassMethodNode) + and downstream_node.is_class_method_output + ): + output_to_readers[downstream_task] = [ + self.idx_to_task[idx] + for idx in downstream_task.downstream_task_idxs + ] + else: + if task not in output_to_readers: + output_to_readers[task] = [] + output_to_readers[task].append(downstream_task) + fn = task.dag_node._get_remote_method("__ray_call__") + for output, readers in output_to_readers.items(): + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = [] + # Use reader_handles_set to deduplicate readers on the + # same actor, because with CachedChannel each actor will + # only read from the upstream channel once. + reader_handles_set = set() + read_by_multi_output_node = False + for reader in readers: + if isinstance(reader.dag_node, MultiOutputNode): + read_by_multi_output_node = True + # inserting at 0 to make sure driver is first reader as + # expected by CompositeChannel read + reader_and_node_list.insert( + 0, + ( + self._proxy_actor, + self._get_node_id(self._proxy_actor), + ), + ) + else: + reader_handle = reader.dag_node._get_actor_handle() + if reader_handle not in reader_handles_set: + reader_handle = reader.dag_node._get_actor_handle() + reader_and_node_list.append( + (reader_handle, self._get_node_id(reader_handle)) + ) + reader_handles_set.add(reader_handle) + + # if driver is an actual actor, gets driver actor id + driver_actor_id = ( + ray.get_runtime_context().get_actor_id() + if read_by_multi_output_node + else None + ) + # Create an output channel for each output of the current node. + output_channel = ray.get( + fn.remote( + do_allocate_channel, + reader_and_node_list, + type_hint, + driver_actor_id, + ) + ) + output_idx = None + downstream_node = output.dag_node + if ( + isinstance(downstream_node, ClassMethodNode) + and downstream_node.is_class_method_output + ): + output_idx = downstream_node.output_idx + task.output_channels.append(output_channel) + task.output_idxs.append(output_idx) + task.output_node_idxs.append(self.dag_node_to_idx[downstream_node]) + actor_handle = task.dag_node._get_actor_handle() + assert actor_handle is not None + self.actor_refs.add(actor_handle) + self.actor_to_tasks[actor_handle].append(task) + elif ( + isinstance(task.dag_node, ClassMethodNode) + and task.dag_node.is_class_method_output + ): + task_node = task.dag_node + upstream_node = task_node.class_method_call + assert upstream_node + upstream_task = self.idx_to_task[self.dag_node_to_idx[upstream_node]] + for i in range(len(upstream_task.output_channels)): + if upstream_task.output_idxs[i] == task_node.output_idx: + task.output_channels.append(upstream_task.output_channels[i]) + task.output_idxs.append(upstream_task.output_idxs[i]) + assert len(task.output_channels) == 1 + elif isinstance(task.dag_node, InputNode): + # A dictionary that maps an InputNode or InputAttributeNode to its + # readers and the node on which the reader is running. Use `set` to + # deduplicate readers on the same actor because with CachedChannel + # each actor will only read from the shared memory once. + input_node_to_reader_and_node_set: Dict[ + Union[InputNode, InputAttributeNode], + Set[Tuple["ray.actor.ActorHandle", str]], + ] = defaultdict(set) + + for idx in task.downstream_task_idxs: + reader_task = self.idx_to_task[idx] + assert isinstance(reader_task.dag_node, ClassMethodNode) + reader_handle = reader_task.dag_node._get_actor_handle() + reader_node_id = self._get_node_id(reader_handle) + for arg in reader_task.args: + if isinstance(arg, InputAttributeNode) or isinstance( + arg, InputNode + ): + input_node_to_reader_and_node_set[arg].add( + (reader_handle, reader_node_id) + ) + + # A single channel is responsible for sending the same data to + # corresponding consumers. Therefore, we create a channel for + # each InputAttributeNode, or a single channel for the entire + # input data if there are no InputAttributeNodes. + task.output_channels = [] + for input_dag_node in input_node_to_reader_and_node_set: + reader_and_node_list = list( + input_node_to_reader_and_node_set[input_dag_node] + ) + + output_channel = do_allocate_channel( + self, + reader_and_node_list, + input_dag_node.type_hint, + None, + ) + task.output_channels.append(output_channel) + task.output_idxs.append( + None + if isinstance(input_dag_node, InputNode) + else input_dag_node.key + ) + + # Update the InputAttributeNode's `output_channels`, which is + # used to determine whether to create a CachedChannel. + if isinstance(input_dag_node, InputAttributeNode): + input_attr_idx = self.dag_node_to_idx[input_dag_node] + input_attr_task = self.idx_to_task[input_attr_idx] + input_attr_task.output_channels.append(output_channel) + assert len(input_attr_task.output_channels) == 1 + else: + assert isinstance(task.dag_node, InputAttributeNode) or isinstance( + task.dag_node, MultiOutputNode + ) + + for idx in task.downstream_task_idxs: + frontier.append(idx) + + # Validate input channels for tasks that have not been visited + for node_idx, task in self.idx_to_task.items(): + if ( + node_idx == self.input_task_idx + or node_idx == self.output_task_idx + or isinstance(task.dag_node, InputAttributeNode) + ): + continue + if node_idx not in visited: + has_at_least_one_channel_input = False + for arg in task.args: + if isinstance(arg, DAGNode): + has_at_least_one_channel_input = True + if not has_at_least_one_channel_input: + raise ValueError( + "Compiled DAGs require each task to take a ray.dag.InputNode " + "or at least one other DAGNode as an input. " + "Invalid task node:\n" + f"{task.dag_node}\n" + "Please bind the task to proper DAG nodes." + ) + + from ray.dag.constants import RAY_CGRAPH_ENABLE_DETECT_DEADLOCK + + if RAY_CGRAPH_ENABLE_DETECT_DEADLOCK and self._detect_deadlock(): + raise ValueError( + "This DAG cannot be compiled because it will deadlock on NCCL " + "calls. If you believe this is a false positive, please disable " + "the graph verification by setting the environment variable " + "RAY_CGRAPH_ENABLE_DETECT_DEADLOCK to 0 and file an issue at " + "https://github.com/ray-project/ray/issues/new/." + ) + + input_task = self.idx_to_task[self.input_task_idx] + self.dag_input_channels = input_task.output_channels + assert self.dag_input_channels is not None + + # Create executable tasks for each actor + for actor_handle, tasks in self.actor_to_tasks.items(): + # Dict from arg to the set of tasks that consume it. + arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set) + + # Step 1: populate `arg_to_consumers` and perform some validation. + for task in tasks: + has_at_least_one_channel_input = False + for arg in task.args: + if isinstance(arg, DAGNode): + has_at_least_one_channel_input = True + arg_to_consumers[arg].add(task) + arg_idx = self.dag_node_to_idx[arg] + upstream_task = self.idx_to_task[arg_idx] + assert len(upstream_task.output_channels) == 1 + arg_channel = upstream_task.output_channels[0] + assert arg_channel is not None + # TODO: Support no-input DAGs (use an empty object to signal). + if not has_at_least_one_channel_input: + raise ValueError( + "Compiled DAGs require each task to take a " + "ray.dag.InputNode or at least one other DAGNode as an " + "input" + ) + + # Step 2: create cached channels if needed + + # Dict from original channel to the channel to be used in execution. + # The value of this dict is either the original channel or a newly + # created CachedChannel (if the original channel is read more than once). + for arg, consumers in arg_to_consumers.items(): + arg_idx = self.dag_node_to_idx[arg] + upstream_task = self.idx_to_task[arg_idx] + assert len(upstream_task.output_channels) == 1 + arg_channel = upstream_task.output_channels[0] + assert arg_channel is not None + if len(consumers) > 1: + self._channel_dict[arg_channel] = CachedChannel( + len(consumers), + arg_channel, + ) + else: + self._channel_dict[arg_channel] = arg_channel + + # Step 3: create executable tasks for the actor + executable_tasks = [] + for task in tasks: + resolved_args: List[Any] = [] + for arg in task.args: + if isinstance(arg, DAGNode): + arg_idx = self.dag_node_to_idx[arg] + upstream_task = self.idx_to_task[arg_idx] + assert len(upstream_task.output_channels) == 1 + arg_channel = upstream_task.output_channels[0] + assert arg_channel is not None + arg_channel = self._channel_dict[arg_channel] + resolved_args.append(arg_channel) + else: + # Constant arg + resolved_args.append(arg) + executable_task = ExecutableTask( + task, + resolved_args, + task.kwargs, + ) + executable_tasks.append(executable_task) + # Sort executable tasks based on their bind index, i.e., submission order + # so that they will be executed in that order. + executable_tasks.sort(key=lambda task: task.bind_index) + self.actor_to_executable_tasks[actor_handle] = executable_tasks + + from ray.dag.constants import RAY_CGRAPH_ENABLE_PROFILING + + if RAY_CGRAPH_ENABLE_PROFILING: + exec_task_func = do_profile_tasks + else: + exec_task_func = do_exec_tasks + + # Build an execution schedule for each actor + self.actor_to_execution_schedule = self._build_execution_schedule() + for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): + self.worker_task_refs[actor_handle] = actor_handle.__ray_call__.options( + concurrency_group="_ray_system" + ).remote( + exec_task_func, + executable_tasks, + self.actor_to_execution_schedule[actor_handle], + self._overlap_gpu_communication, + ) + + assert self.output_task_idx is not None + self.dag_output_channels = [] + for output in self.idx_to_task[self.output_task_idx].args: + assert isinstance(output, DAGNode) + output_idx = self.dag_node_to_idx[output] + task = self.idx_to_task[output_idx] + assert len(task.output_channels) == 1 + self.dag_output_channels.append(task.output_channels[0]) + + # Register custom serializers for input, input attribute, and output nodes. + self._register_input_output_custom_serializer() + + assert self.dag_input_channels + assert self.dag_output_channels + assert [ + output_channel is not None for output_channel in self.dag_output_channels + ] + # If no MultiOutputNode was specified during the DAG creation, there is only + # one output. Return a single output channel instead of a list of + # channels. + if not self._returns_list: + assert len(self.dag_output_channels) == 1 + + # Driver should ray.put on input, ray.get/release on output + self._monitor = self._monitor_failures() + input_task = self.idx_to_task[self.input_task_idx] + if self._enable_asyncio: + self._dag_submitter = AwaitableBackgroundWriter( + self.dag_input_channels, + input_task.output_idxs, + is_input=True, + ) + self._dag_output_fetcher = AwaitableBackgroundReader( + self.dag_output_channels, + self._fut_queue, + ) + else: + self._dag_submitter = SynchronousWriter( + self.dag_input_channels, input_task.output_idxs, is_input=True + ) + self._dag_output_fetcher = SynchronousReader(self.dag_output_channels) + + self._dag_submitter.start() + self._dag_output_fetcher.start() + + def _generate_dag_operation_graph_node( + self, + ) -> Dict["ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]]: + """ + Generate READ, COMPUTE, and WRITE operations for each DAG node. + + Returns: + A dictionary that maps an actor handle to a list of lists of + _DAGOperationGraphNode. For the same actor, the index of the + outer list corresponds to the index of the ExecutableTask in + the list of `executable_tasks` in `actor_to_executable_tasks`, + i.e. `exec_task_idx`. In the inner list, the order of operations + is READ, COMPUTE, and WRITE. + + Example: + { + actor1: [ + [READ COMPUTE WRITE] # exec_task_idx 0 + [READ COMPUTE WRITE] # exec_task_idx 1 + ] + } + """ + from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation + + assert self.idx_to_task + assert self.actor_to_executable_tasks + + actor_to_operation_nodes: Dict[ + "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]] + ] = defaultdict(list) + collective_op_to_nodes: Dict[ + _CollectiveOperation, Set[_DAGOperationGraphNode] + ] = defaultdict(set) + collective_op_to_idxs: Dict[ + _CollectiveOperation, Tuple[int, _DAGNodeOperationType] + ] = defaultdict(set) + + for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): + for exec_task_idx, exec_task in enumerate(executable_tasks): + # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE, + # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. + task_idx = exec_task.task_idx + dag_node = self.idx_to_task[task_idx].dag_node + method_name = exec_task.method_name + actor_handle = dag_node._get_actor_handle() + requires_nccl = dag_node.type_hint.requires_nccl() + upstream_requires_nccl = False + for upstream_node in dag_node._upstream_nodes: + if upstream_node.type_hint.requires_nccl(): + upstream_requires_nccl = True + break + + read_node = _DAGOperationGraphNode( + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.READ, method_name + ), + task_idx, + actor_handle, + upstream_requires_nccl, + ) + compute_node = _DAGOperationGraphNode( + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name + ), + task_idx, + actor_handle, + isinstance(dag_node, CollectiveOutputNode), + ) + write_node = _DAGOperationGraphNode( + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.WRITE, method_name + ), + task_idx, + actor_handle, + requires_nccl, + ) + + actor_to_operation_nodes[actor_handle].append( + [read_node, compute_node, write_node] + ) + if isinstance(dag_node, CollectiveOutputNode): + collective_op_to_nodes[dag_node.collective_op].add(compute_node) + collective_op_to_idxs[dag_node.collective_op].add( + (task_idx, _DAGNodeOperationType.COMPUTE) + ) + + # Set collective nodes for all the NCCL collective operation nodes. + for collective_op, nodes in collective_op_to_nodes.items(): + idxs = collective_op_to_idxs[collective_op] + for node in nodes: + node.collective_idxs = idxs + + return actor_to_operation_nodes + + def _build_execution_schedule( + self, + ) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]: + """ + Generate an execution schedule for each actor. The schedule is a list of + _DAGNodeOperation. + + Step 1: Generate a DAG node operation graph. Refer to the functions + `_generate_dag_operation_graph_node` and `_build_dag_node_operation_graph` + for more details. + + Step 2: Topological sort + + It is possible to have multiple _DAGOperationGraphNodes with zero in-degree. + Refer to the function `_select_next_nodes` for the logic of selecting nodes. + + Then, put the selected nodes into the corresponding actors' schedules. + + The schedule should be intuitive to users, meaning that the execution should + perform operations in ascending order of `bind_index` as much as possible. + + [Example]: + + See `test_execution_schedule` for more examples. + + Returns: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operations to be executed. + """ + # Step 1: Build a graph of _DAGOperationGraphNode + actor_to_operation_nodes = self._generate_dag_operation_graph_node() + graph = _build_dag_node_operation_graph( + self.idx_to_task, actor_to_operation_nodes + ) + # Step 2: Generate an execution schedule for each actor using topological sort + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + + # Step 3: Overlap GPU communication for the execution schedule if configured + actor_to_overlapped_schedule = None + if self._overlap_gpu_communication: + actor_to_overlapped_schedule = _generate_overlapped_execution_schedule( + actor_to_execution_schedule + ) + + if RAY_CGRAPH_VISUALIZE_SCHEDULE: + _visualize_execution_schedule( + actor_to_execution_schedule, actor_to_overlapped_schedule, graph + ) + + if actor_to_overlapped_schedule is not None: + return _extract_execution_schedule(actor_to_overlapped_schedule) + else: + return _extract_execution_schedule(actor_to_execution_schedule) + + def _detect_deadlock(self) -> bool: + """ + TODO (kevin85421): Avoid false negatives. + + Currently, a compiled graph may deadlock if there are NCCL channels, and the + readers have control dependencies on the same actor. For example: + + actor1.a ---> actor2.f1 + | + ---> actor2.f2 + + The control dependency between `actor2.f1` and `actor2.f2` is that `f1` should + run before `f2`. If `actor1.a` writes to `actor2.f2` before `actor2.f1`, a + deadlock will occur. + + Currently, the execution schedule is not granular enough to detect this + deadlock. + + Returns: + True if a deadlock is detected; otherwise, False. + """ + logger.warning("Deadlock detection has not been implemented yet.") + return False + + def _monitor_failures(self): + outer = weakref.proxy(self) + + class Monitor(threading.Thread): + def __init__(self): + super().__init__(daemon=True) + self.name = "CompiledGraphMonitorThread" + # Lock to make sure that we only perform teardown for this DAG + # once. + self._in_teardown_lock = threading.Lock() + self._teardown_done = False + + def wait_teardown(self, kill_actors: bool = False): + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + teardown_timeout = ctx.teardown_timeout + for actor, ref in outer.worker_task_refs.items(): + timeout = False + try: + ray.get(ref, timeout=teardown_timeout) + except ray.exceptions.GetTimeoutError: + msg = ( + f"Compiled DAG actor {actor} is still running " + f"{teardown_timeout}s after teardown()." + ) + if kill_actors: + msg += ( + " Force-killing actor. " + "Increase RAY_CGRAPH_teardown_timeout if you want " + "teardown to wait longer." + ) + ray.kill(actor) + else: + msg += ( + " Teardown may hang. " + "Call teardown with kill_actors=True if force kill " + "is desired." + ) + + logger.warning(msg) + timeout = True + except Exception: + # We just want to check that the task has finished so + # we don't care if the actor task ended in an + # exception. + pass + + if not timeout: + continue + + try: + ray.get(ref) + except Exception: + pass + + def teardown(self, kill_actors: bool = False): + with self._in_teardown_lock: + if self._teardown_done: + return + + logger.info("Tearing down compiled DAG") + outer._dag_submitter.close() + outer._dag_output_fetcher.close() + + for actor in outer.actor_refs: + logger.info(f"Cancelling compiled worker on actor: {actor}") + # Cancel all actor loops in parallel. + cancel_refs = [ + actor.__ray_call__.remote(do_cancel_executable_tasks, tasks) + for actor, tasks in outer.actor_to_executable_tasks.items() + ] + for cancel_ref in cancel_refs: + try: + ray.get(cancel_ref, timeout=30) + except RayChannelError: + # Channel error happens when a channel is closed + # or timed out. In this case, do not log. + pass + except Exception: + logger.exception("Error cancelling worker task") + pass + + for communicator_id in outer._communicator_ids: + _destroy_communicator(communicator_id) + + logger.info("Waiting for worker tasks to exit") + self.wait_teardown(kill_actors=kill_actors) + logger.info("Teardown complete") + self._teardown_done = True + + def run(self): + try: + ray.get(list(outer.worker_task_refs.values())) + except KeyboardInterrupt: + logger.info( + "Received KeyboardInterrupt, tearing down with kill_actors=True" + ) + self.teardown(kill_actors=True) + except Exception as e: + logger.debug(f"Handling exception from worker tasks: {e}") + self.teardown() + + monitor = Monitor() + monitor.start() + return monitor + + def _raise_if_too_many_inflight_executions(self): + num_inflight_executions = ( + self._execution_index - self._max_finished_execution_index + ) + len(self._result_buffer) + if num_inflight_executions >= self._max_inflight_executions: + raise ray.exceptions.RayCgraphCapacityExceeded( + "The compiled graph can't have more than " + f"{self._max_inflight_executions} in-flight executions, and you " + f"currently have {num_inflight_executions} in-flight executions. " + "Retrieve an output using ray.get before submitting more requests or " + "increase `_max_inflight_executions`. " + "`dag.experimental_compile(_max_inflight_executions=...)`" + ) + + def _has_execution_results( + self, + execution_index: int, + ) -> bool: + """Check whether there are results corresponding to the given execution + index stored in self._result_buffer. This helps avoid fetching and + caching results again. + + Args: + execution_index: The execution index corresponding to the result. + + Returns: + Whether the result for the given index has been fetched and cached. + """ + return execution_index in self._result_buffer + + def _cache_execution_results( + self, + execution_index: int, + result: Any, + ): + """Cache execution results in self._result_buffer. Results are converted + to dictionary format to allow efficient element removal and calculation of + the buffer size. This can only be called once per execution index. + + Args: + execution_index: The execution index corresponding to the result. + result: The results from all channels to be cached. + """ + if not self._has_execution_results(execution_index): + for chan_idx, res in enumerate(result): + # avoid caching for any CompiledDAGRef that has already been destructed. + if not ( + execution_index in self._destructed_ref_idxs + and chan_idx in self._destructed_ref_idxs[execution_index] + ): + self._result_buffer[execution_index][chan_idx] = res + + def _get_execution_results( + self, execution_index: int, channel_index: Optional[int] + ) -> List[Any]: + """Retrieve execution results from self._result_buffer and return the result. + Results are converted back to original list format ordered by output channel + index. + + Args: + execution_index: The execution index to retrieve results from. + channel_index: The index of the output channel corresponding to the result. + Channel indexing is consistent with the order of + self.dag_output_channels. None means that the result wraps outputs from + all output channels. + + Returns: + The execution result corresponding to the given execution index and channel + index. + """ + # Although CompiledDAGRef and CompiledDAGFuture guarantee that the same + # execution index and channel index combination will not be requested multiple + # times and therefore self._result_buffer will always have execution_index as + # a key, we still do a sanity check to avoid misuses. + assert execution_index in self._result_buffer + + if channel_index is None: + # Convert results stored in self._result_buffer back to original + # list representation + result = [ + kv[1] + for kv in sorted( + self._result_buffer.pop(execution_index).items(), + key=lambda kv: kv[0], + ) + ] + else: + result = [self._result_buffer[execution_index].pop(channel_index)] + if len(self._result_buffer[execution_index]) == 0: + del self._result_buffer[execution_index] + return result + + def _next_execution_can_be_released(self) -> bool: + """ + Check if the next buffers for the next execution which will be completed + can be released. The next execution can be released if the next + execution index is in _destructed_ref_idxs and the number of destructed + channel indices is equal to the number of output channels. + """ + return ( + self._max_finished_execution_index + 1 in self._destructed_ref_idxs + and len(self._destructed_ref_idxs[self._max_finished_execution_index + 1]) + == len(self.dag_output_channels) + ) + + def _try_release_buffers(self): + """ + This will try to repeatedly release channel buffers as long as + max_finished_execution_index + 1 is in the set of destructed indices. + We should be checking to release buffers any time we are incrementing + or checking the max_finished_execution_index or the _destructed_ref_idxs. + """ + timeout = self._get_timeout + while self._next_execution_can_be_released(): + start_time = time.monotonic() + try: + self._dag_output_fetcher.release_channel_buffers(timeout) + except RayChannelTimeoutError as e: + raise RayChannelTimeoutError( + "Releasing native buffers corresponding to a stale CompiledDAGRef " + "is taking a long time. If this is expected, increase " + f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} " + "seconds. Otherwise, this may indicate that the execution " + "is hanging." + ) from e + + self._max_finished_execution_index += 1 + + if timeout != -1: + timeout -= time.monotonic() - start_time + timeout = max(timeout, 0) + + def _execute_until( + self, + execution_index: int, + channel_index: Optional[int] = None, + timeout: Optional[float] = None, + ): + """Repeatedly execute this DAG until the given execution index and + buffer results for all CompiledDagRef's. + If the DAG has already been executed up to the given index, it will do nothing. + + Note: If this comes across execution indices for which the corresponding + CompiledDAGRef's have been destructed, it will release the buffer and not + cache the result. + + Args: + execution_index: The execution index to execute until. + channel_index: The index of the output channel to get the result from. + Channel indexing is consistent with the order of + self.dag_output_channels. None means wrapping results from all output + channels into a single list. + timeout: The maximum time in seconds to wait for the execution. + None means using default timeout (DAGContext.get_timeout), + 0 means immediate timeout (immediate success or timeout without + blocking), -1 means infinite timeout (block indefinitely). + + TODO(rui): catch the case that user holds onto the CompiledDAGRefs + """ + if timeout is None: + timeout = self._get_timeout + while self._max_finished_execution_index < execution_index: + start_time = time.monotonic() + + # Fetch results from each output channel up to execution_index and cache + # them separately to enable individual retrieval + # If a CompiledDagRef for a specific execution index has been destructed, + # release the channel buffers for that execution index instead of caching + try: + if self._next_execution_can_be_released(): + self._dag_output_fetcher.release_channel_buffers(timeout) + else: + result = self._dag_output_fetcher.read(timeout) + self._cache_execution_results( + self._max_finished_execution_index + 1, + result, + ) + except RayChannelTimeoutError as e: + raise RayChannelTimeoutError( + "If the execution is expected to take a long time, increase " + f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} " + "seconds. Otherwise, this may indicate that the execution is " + "hanging." + ) from e + + self._max_finished_execution_index += 1 + + if timeout != -1: + timeout -= time.monotonic() - start_time + timeout = max(timeout, 0) + + def execute( + self, + *args, + **kwargs, + ) -> Union[CompiledDAGRef, List[CompiledDAGRef]]: + """Execute this DAG using the compiled execution path. + + Args: + args: Args to the InputNode. + kwargs: Kwargs to the InputNode + + Returns: + A list of Channels that can be used to read the DAG result. + + Raises: + RayChannelTimeoutError: If the execution does not complete within + self._submit_timeout seconds. + + NOTE: Not thread-safe due to _execution_index etc. + """ + if self._enable_asyncio: + raise ValueError("Use execute_async if enable_asyncio=True") + + self._get_or_compile() + + self._check_inputs(args, kwargs) + if len(args) == 1 and len(kwargs) == 0: + # When serializing a tuple, the Ray serializer invokes pickle5, which adds + # several microseconds of overhead. One common case for Compiled Graphs is + # passing a single argument (oftentimes of of type `bytes`, which requires + # no serialization). To avoid imposing this overhead on this common case, we + # create a fast path for this case that avoids pickle5. + inp = args[0] + else: + inp = CompiledDAGArgs(args=args, kwargs=kwargs) + + # We want to release any buffers we can at this point based on the + # max_finished_execution_index so that the number of inflight executions + # is up to date. + self._try_release_buffers() + self._raise_if_too_many_inflight_executions() + try: + self._dag_submitter.write(inp, self._submit_timeout) + except RayChannelTimeoutError as e: + raise RayChannelTimeoutError( + "If the execution is expected to take a long time, increase " + f"RAY_CGRAPH_submit_timeout which is currently {self._submit_timeout} " + "seconds. Otherwise, this may indicate that execution is hanging." + ) from e + + self._execution_index += 1 + + if self._returns_list: + ref = [ + CompiledDAGRef(self, self._execution_index, channel_index) + for channel_index in range(len(self.dag_output_channels)) + ] + else: + ref = CompiledDAGRef(self, self._execution_index) + + return ref + + def _check_inputs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + """ + Helper method to check that the DAG args provided by the user during + execution are valid according to the defined DAG. + """ + if len(args) != self._input_num_positional_args: + raise ValueError( + "dag.execute() or dag.execute_async() must be " + f"called with {self._input_num_positional_args} positional args, got " + f"{len(args)}" + ) + + for kwarg in self._input_kwargs: + if kwarg not in kwargs: + raise ValueError( + "dag.execute() or dag.execute_async() " + f"must be called with kwarg `{kwarg}`" + ) + + async def execute_async( + self, + *args, + **kwargs, + ) -> Union[CompiledDAGFuture, List[CompiledDAGFuture]]: + """Execute this DAG using the compiled execution path. + + NOTE: Not thread-safe. + + Args: + args: Args to the InputNode. + kwargs: Kwargs to the InputNode. + + Returns: + A list of Channels that can be used to read the DAG result. + """ + if not self._enable_asyncio: + raise ValueError("Use execute if enable_asyncio=False") + + self._get_or_compile() + self._check_inputs(args, kwargs) + async with self._dag_submission_lock: + if len(args) == 1 and len(kwargs) == 0: + # When serializing a tuple, the Ray serializer invokes pickle5, which + # adds several microseconds of overhead. One common case for accelerated + # DAGs is passing a single argument (oftentimes of of type `bytes`, + # which requires no serialization). To avoid imposing this overhead on + # this common case, we create a fast path for this case that avoids + # pickle5. + inp = args[0] + else: + inp = CompiledDAGArgs(args=args, kwargs=kwargs) + + self._raise_if_too_many_inflight_executions() + await self._dag_submitter.write(inp) + # Allocate a future that the caller can use to get the result. + fut = asyncio.Future() + await self._fut_queue.put(fut) + + self._execution_index += 1 + + if self._returns_list: + fut = [ + CompiledDAGFuture(self, self._execution_index, fut, channel_index) + for channel_index in range(len(self.dag_output_channels)) + ] + else: + fut = CompiledDAGFuture(self, self._execution_index, fut) + + return fut + + def _visualize_ascii(self) -> str: + """ + Visualize the compiled graph in + ASCII format with directional markers. + + This function generates an ASCII visualization of a Compiled Graph, + where each task node is labeled, + and edges use `<` and `>` markers to show data flow direction. + + This method is called by: + - `compiled_dag.visualize(format="ascii")` + + + + High-Level Algorithm: + - Topological Sorting: Sort nodes topologically to organize + them into layers based on dependencies. + - Grid Initialization: Set up a 2D grid canvas with dimensions based + on the number of layers and the maximum number of nodes per layer. + - Node Placement: Position each node on the grid according to its + layer and relative position within that layer. + Spacing is added for readability, and directional markers (`<` and `>`) + are added to edges to show input/output flow clearly. + + This method should be called + **after** compiling the graph with `experimental_compile()`. + + Returns: + ASCII representation of the CG with Nodes Information, + Edges Information and Graph Built. + + Limitations: + - Note: This is only used for quick visualization for small graphs. + For complex graph (i.e. more than 20 tasks), please use graphviz. + - Scale: Works best for smaller CGs (typically fewer than 20 tasks). + Larger CGs may result in dense, less readable ASCII + outputs due to limited space for node and edge rendering. + - Shape: Ideal for relatively shallow CGs with clear dependency paths. + For deep, highly branched or densely connected CGs, + readability may suffer. + - Edge Overlap: In cases with high fan-out (i.e., nodes with many children) + or fan-in (nodes with many parents), edge lines may intersect or overlap + in the ASCII visualization, potentially obscuring some connections. + - Multi-output Tasks: Multi-output tasks can be visualized, but positioning + may cause line breaks or overlap when a task has multiple outputs that + feed into nodes at varying depths. + + Example: + Basic Visualization: + ```python + # Print the CG structure in ASCII format + print(compiled_dag.visualize(format="ascii")) + ``` + + Example of Ordered Visualization (task is build in order + to reduce line intersection): + ```python + with InputNode() as i: + o1, o2, o3 = a.return_three.bind(i) + o4 = b.echo.bind(o1) + o5 = b.echo.bind(o2) + o6, o7 = b.return_two.bind(o3) + dag = MultiOutputNode([o4, o5, o6, o7]) + + compiled_dag = dag.experimental_compile() + compiled_dag.visualize(format="ascii",view=True) + + + # Output: + # 0:InputNode + # | + # 1:Actor_54777d:return_three + # |---------------------------->|---------------------------->| # noqa + # 2:Output[0] 3:Output[1] 4:Output[2] # noqa + # | | | # noqa + # 5:Actor_c927c9:echo 6:Actor_c927c9:echo 7:Actor_c927c9:return_two # noqa + # | | |---------------------------->| # noqa + # | | 9:Output[0] 10:Output[1] # noqa + # |<----------------------------|-----------------------------|-----------------------------| # noqa + # 8:MultiOutputNode + ``` + + Example of Anti-pattern Visualization (There are intersections): + # We can swtich the nodes ordering to reduce intersections, i.e. swap o2 and o3 + ```python + with InputNode() as i: + o1, o2, o3 = a.return_three.bind(i) + o4 = b.echo.bind(o1) + o5 = b.echo.bind(o3) + o6, o7 = b.return_two.bind(o2) + dag = MultiOutputNode([o4, o5, o6, o7]) + compiled_dag = dag.experimental_compile() + compiled_dag.visualize(format="ascii",view=True) + + # Output (Nodes 5, 7, 9, 10 should connect to Node 8): + # 0:InputNode + # | + # 1:Actor_84835a:return_three + # |---------------------------->|---------------------------->| # noqa + # 2:Output[0] 3:Output[1] 4:Output[2] # noqa + # | | | # noqa + # 5:Actor_02a6a1:echo 6:Actor_02a6a1:return_two 7:Actor_02a6a1:echo # noqa + # | |---------------------------->| # noqa + # | 9:Output[0] 10:Output[1] # noqa + # |<----------------------------------------------------------| # noqa + # 8:MultiOutputNod + ``` + """ + + from ray.dag import ( + InputAttributeNode, + InputNode, + MultiOutputNode, + ClassMethodNode, + DAGNode, + ) + + # Check that the DAG has been compiled + if not hasattr(self, "idx_to_task") or not self.idx_to_task: + raise ValueError( + "The DAG must be compiled before calling 'visualize()'. " + "Please call 'experimental_compile()' first." + ) + + # Check that each CompiledTask has a valid dag_node + for idx, task in self.idx_to_task.items(): + if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode): + raise ValueError( + f"Task at index {idx} does not have a valid 'dag_node'. " + "Ensure that 'experimental_compile()' completed successfully." + ) + + from collections import defaultdict, deque + + # Create adjacency list representation of the DAG + # Adjacency list for DAG; maps a node index to its downstream nodes. + adj_list: Dict[int, List[int]] = defaultdict(list) + # Indegree count for topological sorting; maps a node index to its indegree. + indegree: Dict[int, int] = defaultdict(int) + + # Tracks whether a node is a multi-output node. + is_multi_output: Dict[int, bool] = defaultdict(bool) + # Maps child node indices to their parent node indices. + child2parent: Dict[int, int] = defaultdict(int) + ascii_visualization = "" + # Node information; maps a node index to its descriptive label. + node_info: Dict[int, str] = {} + # Edge information; tuples of (upstream_index, downstream_index, edge_label). + edge_info: List[Tuple[int, int, str]] = [] + + for idx, task in self.idx_to_task.items(): + dag_node = task.dag_node + label = f"Task {idx} " + + # Determine the type and label of the node + if isinstance(dag_node, InputNode): + label += "InputNode" + elif isinstance(dag_node, InputAttributeNode): + label += f"InputAttributeNode[{dag_node.key}]" + elif isinstance(dag_node, MultiOutputNode): + label += "MultiOutputNode" + elif isinstance(dag_node, ClassMethodNode): + if dag_node.is_class_method_call: + method_name = dag_node.get_method_name() + actor_handle = dag_node._get_actor_handle() + actor_id = ( + actor_handle._actor_id.hex()[:6] if actor_handle else "unknown" + ) + label += f"Actor: {actor_id}... Method: {method_name}" + elif dag_node.is_class_method_output: + label += f"ClassMethodOutputNode[{dag_node.output_idx}]" + else: + label += "ClassMethodNode" + else: + label += type(dag_node).__name__ + + node_info[idx] = label + + for arg_index, arg in enumerate(dag_node.get_args()): + if isinstance(arg, DAGNode): + upstream_task_idx = self.dag_node_to_idx[arg] + + # Get the type hint for this argument + if arg_index < len(task.arg_type_hints): + if task.arg_type_hints[arg_index].requires_nccl(): + type_hint = "Nccl" + else: + type_hint = type(task.arg_type_hints[arg_index]).__name__ + else: + type_hint = "UnknownType" + + adj_list[upstream_task_idx].append(idx) + indegree[idx] += 1 + edge_info.append((upstream_task_idx, idx, type_hint)) + + width_adjust = 0 + for upstream_task_idx, child_idx_list in adj_list.items(): + # Mark as multi-output if the node has more than one output path + if len(child_idx_list) > 1: + for child in child_idx_list: + is_multi_output[child] = True + child2parent[child] = upstream_task_idx + width_adjust = max(width_adjust, len(child_idx_list)) + + # Topological sort to determine layers + layers = defaultdict(list) + zero_indegree = deque([idx for idx in self.idx_to_task if indegree[idx] == 0]) + layer_index = 0 + + while zero_indegree: + next_layer = deque() + while zero_indegree: + task_idx = zero_indegree.popleft() + layers[layer_index].append(task_idx) + for downstream in adj_list[task_idx]: + indegree[downstream] -= 1 + if indegree[downstream] == 0: + next_layer.append(downstream) + zero_indegree = next_layer + layer_index += 1 + + # Print detailed node information + ascii_visualization += "Nodes Information:\n" + for idx, info in node_info.items(): + ascii_visualization += f'{idx} [label="{info}"] \n' + + # Print edges + ascii_visualization += "\nEdges Information:\n" + for upstream_task, downstream_task, type_hint in edge_info: + if type_hint == "Nccl": + edgs_channel = "+++" + else: + edgs_channel = "---" + ascii_visualization += ( + f"{upstream_task} {edgs_channel}>" f" {downstream_task}\n" + ) + + # Add the legend to the output + ascii_visualization += "\nLegend:\n" + ascii_visualization += "+++> : Represents Nccl-type data channels\n" + ascii_visualization += "---> : Represents Shared Memory data channels\n" + + # Find the maximum width (number of nodes in any layer) + max_width = max(len(layer) for layer in layers.values()) + width_adjust + height = len(layers) + + # Build grid for ASCII visualization + grid = [[" " for _ in range(max_width * 20)] for _ in range(height * 2 - 1)] + + # Place nodes in the grid with more details + task_to_pos = {} + for layer_num, layer_tasks in layers.items(): + layer_y = layer_num * 2 # Every second row is for nodes + for col_num, task_idx in enumerate(layer_tasks): + task = self.idx_to_task[task_idx] + task_info = f"{task_idx}:" + + # Determine if it's an actor method or a regular task + if isinstance(task.dag_node, ClassMethodNode): + if task.dag_node.is_class_method_call: + method_name = task.dag_node.get_method_name() + actor_handle = task.dag_node._get_actor_handle() + actor_id = ( + actor_handle._actor_id.hex()[:6] + if actor_handle + else "unknown" + ) + task_info += f"Actor_{actor_id}:{method_name}" + elif task.dag_node.is_class_method_output: + task_info += f"Output[{task.dag_node.output_idx}]" + else: + task_info += "UnknownMethod" + else: + task_info += type(task.dag_node).__name__ + + adjust_col_num = 0 + if task_idx in is_multi_output: + adjust_col_num = layers[layer_num - 1].index(child2parent[task_idx]) + col_x = (col_num + adjust_col_num) * 30 # Every 30th column for spacing + # Place the task information into the grid + for i, char in enumerate(task_info): + if col_x + i < len(grid[0]): # Ensure we don't overflow the grid + grid[layer_y][col_x + i] = char + + task_to_pos[task_idx] = (layer_y, col_x) + + # Connect the nodes with lines + for upstream_task, downstream_tasks in adj_list.items(): + upstream_y, upstream_x = task_to_pos[upstream_task] + for downstream_task in downstream_tasks: + downstream_y, downstream_x = task_to_pos[downstream_task] + + # Draw vertical line + for y in range(upstream_y + 1, downstream_y): + if grid[y][upstream_x] == " ": + grid[y][upstream_x] = "|" + + # Draw horizontal line with directional arrows + if upstream_x != downstream_x: + for x in range( + min(upstream_x, downstream_x) + 1, + max(upstream_x, downstream_x), + ): + grid[downstream_y - 1][x] = ( + "-" + if grid[downstream_y - 1][x] == " " + else grid[downstream_y - 1][x] + ) + + # Add arrows to indicate flow direction + if downstream_x > upstream_x: + grid[downstream_y - 1][downstream_x - 1] = ">" + else: + grid[downstream_y - 1][downstream_x + 1] = "<" + + # Draw connection to the next task + grid[downstream_y - 1][downstream_x] = "|" + + # Ensure proper multi-output task connection + for idx, task in self.idx_to_task.items(): + if isinstance(task.dag_node, MultiOutputNode): + output_tasks = task.dag_node.get_args() + for i, output_task in enumerate(output_tasks): + if isinstance(output_task, DAGNode): + output_task_idx = self.dag_node_to_idx[output_task] + if output_task_idx in task_to_pos: + output_y, output_x = task_to_pos[output_task_idx] + grid[output_y - 1][output_x] = "|" + + # Convert grid to string for printing + ascii_visualization += "\nGraph Built:\n" + ascii_visualization += "\n".join("".join(row) for row in grid) + + return ascii_visualization + + def get_channel_details( + self, channel: ChannelInterface, downstream_actor_id: str + ) -> str: + """ + Get details about outer and inner channel types and channel ids + based on the channel and the downstream actor ID. + Used for graph visualization. + Args: + channel: The channel to get details for. + downstream_actor_id: The downstream actor ID. + Returns: + A string with details about the channel based on its connection + to the actor provided. + """ + channel_details = type(channel).__name__ + # get outer channel + if channel in self._channel_dict and self._channel_dict[channel] != channel: + channel = self._channel_dict[channel] + channel_details += f"\n{type(channel).__name__}" + if type(channel) is CachedChannel: + channel_details += f", {channel._channel_id[:6]}..." + # get inner channel + if ( + type(channel) is CompositeChannel + and downstream_actor_id in channel._channel_dict + ): + inner_channel = channel._channel_dict[downstream_actor_id] + channel_details += f"\n{type(inner_channel).__name__}" + if type(inner_channel) is IntraProcessChannel: + channel_details += f", {inner_channel._channel_id[:6]}..." + return channel_details + + def visualize( + self, + filename="compiled_graph", + format="png", + view=False, + channel_details=False, + ) -> str: + """ + Visualize the compiled graph using Graphviz. + + For non-ASCII formats, the visualization will be saved to a file specified + by the `filename` argument. + + This method generates a graphical representation of the compiled graph, + showing tasks and their dependencies.This method should be called + **after** the graph has been compiled using `experimental_compile()`. + + Args: + filename: The name of the output file (without extension). + format: The format of the output file (e.g., 'png', 'pdf', 'ascii'). + view: For non-ascii: Whether to open the file with the default viewer. + For ascii: Whether to print the visualization and return None + or return the ascii visualization string directly. + channel_details: If True, adds channel details to edges. + + Returns: + str: + - For Graphviz-based formats (e.g., 'png', 'pdf', 'jpeg'), returns + the Graphviz DOT string representation of the compiled graph. + - For ASCII format, returns the ASCII string representation of the + compiled graph. + + Raises: + ValueError: If the graph is empty or not properly compiled. + ImportError: If the `graphviz` package is not installed. + + """ + if format == "ascii": + if channel_details: + raise ValueError( + "Parameters 'channel_details' are" + " not compatible with 'ascii' format." + ) + ascii_visualiztion_str = self._visualize_ascii() + if view: + print(ascii_visualiztion_str) + return ascii_visualiztion_str + try: + import graphviz + except ImportError: + raise ImportError( + "Please install graphviz to visualize the compiled graph. " + "You can install it by running `pip install graphviz`." + ) + from ray.dag import ( + InputAttributeNode, + InputNode, + MultiOutputNode, + ClassMethodNode, + DAGNode, + ) + + # Check that the DAG has been compiled + if not hasattr(self, "idx_to_task") or not self.idx_to_task: + raise ValueError( + "The DAG must be compiled before calling 'visualize()'. " + "Please call 'experimental_compile()' first." + ) + + # Check that each CompiledTask has a valid dag_node + for idx, task in self.idx_to_task.items(): + if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode): + raise ValueError( + f"Task at index {idx} does not have a valid 'dag_node'. " + "Ensure that 'experimental_compile()' completed successfully." + ) + + # Dot file for debugging + dot = graphviz.Digraph(name="compiled_graph", format=format) + # Give every actor a unique color, colors between 24k -> 40k tested as readable + # other colors may be too dark, especially when wrapping back around to 0 + actor_id_to_color = defaultdict( + lambda: f"#{((len(actor_id_to_color) * 2000 + 24000) % 0xFFFFFF):06X}" + ) + # Add nodes with task information + for idx, task in self.idx_to_task.items(): + dag_node = task.dag_node + # Initialize the label and attributes + label = f"Task {idx}\n" + shape = "oval" # Default shape + style = "filled" + fillcolor = "" + + # Handle different types of dag_node + if isinstance(dag_node, InputNode): + label += "InputNode" + shape = "rectangle" + fillcolor = "lightblue" + elif isinstance(dag_node, InputAttributeNode): + label += f"InputAttributeNode[{dag_node.key}]" + shape = "rectangle" + fillcolor = "lightblue" + elif isinstance(dag_node, MultiOutputNode): + label += "MultiOutputNode" + shape = "rectangle" + fillcolor = "yellow" + elif isinstance(dag_node, ClassMethodNode): + if dag_node.is_class_method_call: + # Class Method Call Node + method_name = dag_node.get_method_name() + actor = dag_node._get_actor_handle() + if actor: + class_name = ( + actor._ray_actor_creation_function_descriptor.class_name + ) + actor_id = actor._actor_id.hex() + label += f"Actor: {class_name}\n" + label += f"ID: {actor_id[:6]}...\n" + label += f"Method: {method_name}" + fillcolor = actor_id_to_color[actor_id] + else: + label += f"Method: {method_name}" + fillcolor = "lightgreen" + shape = "oval" + elif dag_node.is_class_method_output: + # Class Method Output Node + label += f"ClassMethodOutputNode[{dag_node.output_idx}]" + shape = "rectangle" + fillcolor = "orange" + else: + # Unexpected ClassMethodNode + label += "ClassMethodNode" + shape = "diamond" + fillcolor = "red" + else: + # Unexpected node type + label += type(dag_node).__name__ + shape = "diamond" + fillcolor = "red" + + # Add the node to the graph with attributes + dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor) + channel_type_str = ( + type(dag_node.type_hint).__name__ + if dag_node.type_hint + else "UnknownType" + ) + "\n" + + # This logic is built on the assumption that there will only be multiple + # output channels if the task has multiple returns + # case: task with one output + if len(task.output_channels) == 1: + for downstream_node in task.dag_node._downstream_nodes: + downstream_idx = self.dag_node_to_idx[downstream_node] + edge_label = channel_type_str + if channel_details: + edge_label += self.get_channel_details( + task.output_channels[0], + ( + downstream_node._get_actor_handle()._actor_id.hex() + if type(downstream_node) is ClassMethodNode + else self._proxy_actor._actor_id.hex() + ), + ) + dot.edge(str(idx), str(downstream_idx), label=edge_label) + # case: multi return, output channels connect to class method output nodes + elif len(task.output_channels) > 1: + assert len(task.output_idxs) == len(task.output_channels) + for output_channel, downstream_idx in zip( + task.output_channels, task.output_node_idxs + ): + edge_label = channel_type_str + if channel_details: + edge_label += self.get_channel_details( + output_channel, + task.dag_node._get_actor_handle()._actor_id.hex(), + ) + dot.edge(str(idx), str(downstream_idx), label=edge_label) + if type(task.dag_node) is InputAttributeNode: + # Add an edge from the InputAttributeNode to the InputNode + dot.edge(str(self.input_task_idx), str(idx)) + dot.render(filename, view=view) + return dot.source + + def _register_input_output_custom_serializer(self): + """ + Register custom serializers for input, input attribute, and output nodes. + """ + assert self.input_task_idx is not None + assert self.output_task_idx is not None + + # Register custom serializers for input node. + input_task = self.idx_to_task[self.input_task_idx] + input_task.dag_node.type_hint.register_custom_serializer() + + # Register custom serializers for input attribute nodes. + for input_attr_task_idx in self.input_attr_task_idxs: + input_attr_task = self.idx_to_task[input_attr_task_idx] + input_attr_task.dag_node.type_hint.register_custom_serializer() + + # Register custom serializers for output nodes. + for output in self.idx_to_task[self.output_task_idx].args: + output.type_hint.register_custom_serializer() + + def teardown(self, kill_actors: bool = False): + """Teardown and cancel all actor tasks for this DAG. After this + function returns, the actors should be available to execute new tasks + or compile a new DAG.""" + if self._is_teardown: + return + + monitor = getattr(self, "_monitor", None) + if monitor is not None: + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + monitor.teardown(kill_actors=kill_actors) + monitor.join(timeout=ctx.teardown_timeout) + # We do not log a warning here if the thread is still alive because + # wait_teardown already logs upon teardown_timeout. + + self._is_teardown = True + + def __del__(self): + self.teardown() + + +@DeveloperAPI +def build_compiled_dag_from_ray_dag( + dag: "ray.dag.DAGNode", + submit_timeout: Optional[float] = None, + buffer_size_bytes: Optional[int] = None, + enable_asyncio: bool = False, + max_inflight_executions: Optional[int] = None, + overlap_gpu_communication: Optional[bool] = None, +) -> "CompiledDAG": + compiled_dag = CompiledDAG( + submit_timeout, + buffer_size_bytes, + enable_asyncio, + max_inflight_executions, + overlap_gpu_communication, + ) + + def _build_compiled_dag(node): + compiled_dag._add_node(node) + return node + + root = dag._find_root() + root.traverse_and_apply(_build_compiled_dag) + compiled_dag._get_or_compile() + global _compiled_dags + _compiled_dags[compiled_dag.get_id()] = compiled_dag + return compiled_dag diff --git a/.venv/lib/python3.11/site-packages/ray/dag/conftest.py b/.venv/lib/python3.11/site-packages/ray/dag/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..a350eb5be2d7b70a2647be9df26f53d02fa93a6f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/conftest.py @@ -0,0 +1,16 @@ +import os +import pytest + +import ray + +TEST_NAMESPACE = "ray_dag_test_namespace" + + +@pytest.fixture(scope="session") +def shared_ray_instance(): + # Remove ray address for test ray cluster in case we have + # lingering RAY_ADDRESS="http://127.0.0.1:8265" from previous local job + # submissions. + if "RAY_ADDRESS" in os.environ: + del os.environ["RAY_ADDRESS"] + yield ray.init(num_cpus=16, namespace=TEST_NAMESPACE, log_to_driver=True) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/constants.py b/.venv/lib/python3.11/site-packages/ray/dag/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..299acf3137c67ffdbbf1fa4bcb5760ba3d216bf8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/constants.py @@ -0,0 +1,33 @@ +import os + +# Reserved keys used to handle ClassMethodNode in Ray DAG building. +PARENT_CLASS_NODE_KEY = "parent_class_node" +PREV_CLASS_METHOD_CALL_KEY = "prev_class_method_call" +BIND_INDEX_KEY = "bind_index" +IS_CLASS_METHOD_OUTPUT_KEY = "is_class_method_output" + +# Reserved keys used to handle CollectiveOutputNode in Ray DAG building. +COLLECTIVE_OPERATION_KEY = "collective_operation" + +# Reserved key to distinguish DAGNode type and avoid collision with user dict. +DAGNODE_TYPE_KEY = "__dag_node_type__" + +# Feature flag to turn off the deadlock detection. +RAY_CGRAPH_ENABLE_DETECT_DEADLOCK = ( + os.environ.get("RAY_CGRAPH_ENABLE_DETECT_DEADLOCK", "1") == "1" +) + +# Feature flag to turn on profiling. +RAY_CGRAPH_ENABLE_PROFILING = os.environ.get("RAY_CGRAPH_ENABLE_PROFILING", "0") == "1" + +# Feature flag to turn on NVTX (NVIDIA Tools Extension Library) profiling. +# With this flag, Compiled Graph uses nvtx to automatically annotate and profile +# function calls during each actor's execution loop. +RAY_CGRAPH_ENABLE_NVTX_PROFILING = ( + os.environ.get("RAY_CGRAPH_ENABLE_NVTX_PROFILING", "0") == "1" +) + +# Feature flag to turn on visualization of the execution schedule. +RAY_CGRAPH_VISUALIZE_SCHEDULE = ( + os.environ.get("RAY_CGRAPH_VISUALIZE_SCHEDULE", "0") == "1" +) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/context.py b/.venv/lib/python3.11/site-packages/ray/dag/context.py new file mode 100644 index 0000000000000000000000000000000000000000..3c92fc34737c716dea641a1e6eb6b2519c9659de --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/context.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +import os +import threading +from typing import Optional +from ray.util.annotations import DeveloperAPI + +# The context singleton on this process. +_default_context: "Optional[DAGContext]" = None +_context_lock = threading.Lock() + +DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10)) +DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10)) +DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30)) +DEFAULT_READ_ITERATION_TIMEOUT_S = float( + os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1) +) +# Default buffer size is 1MB. +DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6)) +# The default number of in-flight executions that can be submitted before consuming the +# output. +DEFAULT_MAX_INFLIGHT_EXECUTIONS = int( + os.environ.get("RAY_CGRAPH_max_inflight_executions", 10) +) + +DEFAULT_OVERLAP_GPU_COMMUNICATION = bool( + os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0) +) + + +@DeveloperAPI +@dataclass +class DAGContext: + """Global settings for Ray DAG. + + You can configure parameters in the DAGContext by setting the environment + variables, `RAY_CGRAPH_` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python. + + Examples: + >>> from ray.dag import DAGContext + >>> DAGContext.get_current().buffer_size_bytes + 1000000 + >>> DAGContext.get_current().buffer_size_bytes = 500 + >>> DAGContext.get_current().buffer_size_bytes + 500 + + Args: + submit_timeout: The maximum time in seconds to wait for execute() + calls. + get_timeout: The maximum time in seconds to wait when retrieving + a result from the DAG during `ray.get`. This should be set to a + value higher than the expected time to execute the entire DAG. + teardown_timeout: The maximum time in seconds to wait for the DAG to + cleanly shut down. + read_iteration_timeout: The timeout in seconds for each read iteration + that reads one of the input channels. If the timeout is reached, the + read operation will be interrupted and will try to read the next + input channel. It must be less than or equal to `get_timeout`. + buffer_size_bytes: The initial buffer size in bytes for messages + that can be passed between tasks in the DAG. The buffers will + be automatically resized if larger messages are written to the + channel. + max_inflight_executions: The maximum number of in-flight executions that + can be submitted via `execute` or `execute_async` before consuming + the output using `ray.get()`. If the caller submits more executions, + `RayCgraphCapacityExceeded` is raised. + overlap_gpu_communication: (experimental) Whether to overlap GPU + communication with computation during DAG execution. If True, the + communication and computation can be overlapped, which can improve + the performance of the DAG execution. + """ + + submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S + get_timeout: int = DEFAULT_GET_TIMEOUT_S + teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S + read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S + buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES + max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS + overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION + + def __post_init__(self): + if self.read_iteration_timeout > self.get_timeout: + raise ValueError( + "RAY_CGRAPH_read_iteration_timeout_s " + f"({self.read_iteration_timeout}) must be less than or equal to " + f"RAY_CGRAPH_get_timeout ({self.get_timeout})" + ) + + @staticmethod + def get_current() -> "DAGContext": + """Get or create a singleton context. + + If the context has not yet been created in this process, it will be + initialized with default settings. + """ + global _default_context + + with _context_lock: + if _default_context is None: + _default_context = DAGContext() + + return _default_context diff --git a/.venv/lib/python3.11/site-packages/ray/dag/dag_node.py b/.venv/lib/python3.11/site-packages/ray/dag/dag_node.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd4fd0a7e97af9a3e2ad0793a167da878c71ae6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/dag_node.py @@ -0,0 +1,622 @@ +import copy +from ray.experimental.channel.auto_transport_type import AutoTransportType +from ray.experimental.channel.torch_tensor_type import TorchTensorType +import ray +from ray.dag.base import DAGNodeBase +from ray.dag.py_obj_scanner import _PyObjScanner +from ray.util.annotations import DeveloperAPI + +from itertools import chain + +from typing import ( + Optional, + Union, + List, + Tuple, + Dict, + Any, + TypeVar, + Callable, +) +import uuid +import asyncio + +from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag +from ray.experimental.channel import ChannelOutputType +from ray.experimental.channel.communicator import Communicator + +T = TypeVar("T") + + +@DeveloperAPI +class DAGNode(DAGNodeBase): + """Abstract class for a node in a Ray task graph. + + A node has a type (e.g., FunctionNode), data (e.g., function options and + body), arguments (Python values, DAGNodes, and DAGNodes nested within Python + argument values) and options (Ray API .options() used for function, class + or class method) + """ + + def __init__( + self, + args: Tuple[Any], + kwargs: Dict[str, Any], + options: Dict[str, Any], + other_args_to_resolve: Dict[str, Any], + ): + """ + args: + args (Tuple[Any]): Bound node arguments. + ex: func_or_class.bind(1) + kwargs (Dict[str, Any]): Bound node keyword arguments. + ex: func_or_class.bind(a=1) + options (Dict[str, Any]): Bound node options arguments. + ex: func_or_class.options(num_cpus=2) + other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve + that's specific to subclass implementation without exposing + as args in base class, example: ClassMethodNode + """ + self._bound_args: Tuple[Any] = args or [] + self._bound_kwargs: Dict[str, Any] = kwargs or {} + self._bound_options: Dict[str, Any] = options or {} + self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = ( + other_args_to_resolve or {} + ) + + # The list of nodes that use this DAG node as an argument. + self._downstream_nodes: List["DAGNode"] = [] + + # UUID that is not changed over copies of this node. + self._stable_uuid = uuid.uuid4().hex + + # Indicates whether this DAG node contains nested DAG nodes. + # Nested DAG nodes are allowed in traditional DAGs but not + # in Ray Compiled Graphs, except for MultiOutputNode. + self._args_contain_nested_dag_node = False + + # The list of nodes that this DAG node uses as an argument. + self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() + + # Cached values from last call to execute() + self.cache_from_last_execute = {} + + self._type_hint: ChannelOutputType = ChannelOutputType() + + # If the original type hint is an AutoTransportType, we make a copy + # here when it is resolved to the actual type, as additional debugging + # information. Otherwise, it is None. + self._original_type_hint: Optional[ChannelOutputType] = None + + # Whether this node calls `experimental_compile`. + self.is_cgraph_output_node = False + + def _collect_upstream_nodes(self) -> List["DAGNode"]: + """ + Retrieve upstream nodes and update their downstream dependencies. + + Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and + `other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs + builds the upstream/downstream relationship based only on args. Be cautious + when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future. + + TODO (kevin85421): Currently, the upstream nodes and downstream nodes have + circular references. Therefore, it relies on the garbage collector to clean + them up instead of reference counting. We should consider using weak references + to avoid circular references. + """ + upstream_nodes: List["DAGNode"] = [] + + # Ray Compiled Graphs do not allow nested DAG nodes in arguments. + # Specifically, a DAGNode should not be placed inside any type of + # container. However, we only know if this is a compiled graph + # when calling `experimental_compile`. Therefore, we need to check + # in advance if the arguments contain nested DAG nodes and raise + # an error after compilation. + assert hasattr(self._bound_args, "__iter__") + for arg in self._bound_args: + if isinstance(arg, DAGNode): + upstream_nodes.append(arg) + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes(arg) + upstream_nodes.extend(dag_nodes) + scanner.clear() + self._args_contain_nested_dag_node = len(dag_nodes) > 0 + + scanner = _PyObjScanner() + other_upstream_nodes: List["DAGNode"] = scanner.find_nodes( + [ + self._bound_kwargs, + self._bound_other_args_to_resolve, + ] + ) + upstream_nodes.extend(other_upstream_nodes) + scanner.clear() + # Update dependencies. + for upstream_node in upstream_nodes: + upstream_node._downstream_nodes.append(self) + return upstream_nodes + + def with_tensor_transport( + self, + transport: Optional[Union[str, Communicator]] = "auto", + _static_shape: bool = False, + _direct_return: bool = False, + ): + if transport == "auto": + self._type_hint = AutoTransportType( + _static_shape=_static_shape, + _direct_return=_direct_return, + ) + elif transport == "nccl": + self._type_hint = TorchTensorType( + transport=transport, + _static_shape=_static_shape, + _direct_return=_direct_return, + ) + else: + if not isinstance(transport, Communicator): + raise ValueError( + "transport must be 'auto', 'nccl' or a Communicator type" + ) + self._type_hint = TorchTensorType( + transport=transport, + _static_shape=_static_shape, + _direct_return=_direct_return, + ) + return self + + @property + def type_hint(self) -> ChannelOutputType: + return self._type_hint + + @type_hint.setter + def type_hint(self, type_hint: ChannelOutputType) -> None: + if isinstance(self._type_hint, AutoTransportType): + self._original_type_hint = self._type_hint + self._type_hint = type_hint + + def get_args(self) -> Tuple[Any]: + """Return the tuple of arguments for this node.""" + + return self._bound_args + + def get_kwargs(self) -> Dict[str, Any]: + """Return the dict of keyword arguments for this node.""" + + return self._bound_kwargs.copy() + + def get_options(self) -> Dict[str, Any]: + """Return the dict of options arguments for this node.""" + + return self._bound_options.copy() + + def get_other_args_to_resolve(self) -> Dict[str, Any]: + """Return the dict of other args to resolve arguments for this node.""" + return self._bound_other_args_to_resolve.copy() + + def get_stable_uuid(self) -> str: + """Return stable uuid for this node. + 1) Generated only once at first instance creation + 2) Stable across pickling, replacement and JSON serialization. + """ + return self._stable_uuid + + async def get_object_refs_from_last_execute(self) -> Dict[str, Any]: + """Gets cached object refs from the last call to execute(). + + After this DAG is executed through execute(), retrieves a map between node + UUID to a reference to the return value of the default executor on that node. + """ + cache = {} + for node_uuid, value in self.cache_from_last_execute.items(): + if isinstance(value, asyncio.Task): + cache[node_uuid] = await value + else: + cache[node_uuid] = value + + return cache + + def clear_cache(self): + self.cache_from_last_execute = {} + + def experimental_compile( + self, + _submit_timeout: Optional[float] = None, + _buffer_size_bytes: Optional[int] = None, + enable_asyncio: bool = False, + _max_inflight_executions: Optional[int] = None, + _overlap_gpu_communication: Optional[bool] = None, + ) -> "ray.dag.CompiledDAG": + """Compile an accelerated execution path for this DAG. + + Args: + _submit_timeout: The maximum time in seconds to wait for execute() calls. + None means using default timeout, 0 means immediate timeout + (immediate success or timeout without blocking), -1 means + infinite timeout (block indefinitely). + _buffer_size_bytes: The initial buffer size in bytes for messages + that can be passed between tasks in the DAG. The buffers will + be automatically resized if larger messages are written to the + channel. + enable_asyncio: Whether to enable asyncio for this DAG. + _max_inflight_executions: The maximum number of in-flight executions that + can be submitted via `execute` or `execute_async` before consuming + the output using `ray.get()`. If the caller submits more executions, + `RayCgraphCapacityExceeded` is raised. + _overlap_gpu_communication: (experimental) Whether to overlap GPU + communication with computation during DAG execution. If True, the + communication and computation can be overlapped, which can improve + the performance of the DAG execution. If None, the default value + will be used. + + Returns: + A compiled DAG. + """ + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + if _buffer_size_bytes is None: + _buffer_size_bytes = ctx.buffer_size_bytes + + # Validate whether this DAG node has already been compiled. + if self.is_cgraph_output_node: + raise ValueError( + "It is not allowed to call `experimental_compile` on the same DAG " + "object multiple times no matter whether `teardown` is called or not. " + "Please reuse the existing compiled DAG or create a new one." + ) + # Whether this node is an output node in the DAG. We cannot determine + # this in the constructor because the output node is determined when + # `experimental_compile` is called. + self.is_cgraph_output_node = True + return build_compiled_dag_from_ray_dag( + self, + _submit_timeout, + _buffer_size_bytes, + enable_asyncio, + _max_inflight_executions, + _overlap_gpu_communication, + ) + + def execute( + self, *args, _ray_cache_refs: bool = False, **kwargs + ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]: + """Execute this DAG using the Ray default executor _execute_impl(). + + Args: + _ray_cache_refs: If true, stores the the default executor's return values + on each node in this DAG in a cache. These should be a mix of: + - ray.ObjectRefs pointing to the outputs of method and function nodes + - Serve handles for class nodes + - resolved values representing user input at runtime + """ + + def executor(node): + return node._execute_impl(*args, **kwargs) + + result = self.apply_recursive(executor) + if _ray_cache_refs: + self.cache_from_last_execute = executor.cache + return result + + def _get_toplevel_child_nodes(self) -> List["DAGNode"]: + """Return the list of nodes specified as top-level args. + + For example, in `f.remote(a, [b])`, only `a` is a top-level arg. + + This list of nodes are those that are typically resolved prior to + task execution in Ray. This does not include nodes nested within args. + For that, use ``_get_all_child_nodes()``. + """ + + # we use List instead of Set here because the hash key of the node + # object changes each time we create it. So if using Set here, the + # order of returned children can be different if we create the same + # nodes and dag one more time. + children = [] + for a in self.get_args(): + if isinstance(a, DAGNode): + if a not in children: + children.append(a) + for a in self.get_kwargs().values(): + if isinstance(a, DAGNode): + if a not in children: + children.append(a) + for a in self.get_other_args_to_resolve().values(): + if isinstance(a, DAGNode): + if a not in children: + children.append(a) + return children + + def _get_all_child_nodes(self) -> List["DAGNode"]: + """Return the list of nodes referenced by the args, kwargs, and + args_to_resolve in current node, even they're deeply nested. + + Examples: + f.remote(a, [b]) -> [a, b] + f.remote(a, [b], key={"nested": [c]}) -> [a, b, c] + """ + + scanner = _PyObjScanner() + # we use List instead of Set here, reason explained + # in `_get_toplevel_child_nodes`. + children = [] + for n in scanner.find_nodes( + [ + self._bound_args, + self._bound_kwargs, + self._bound_other_args_to_resolve, + ] + ): + if n not in children: + children.append(n) + scanner.clear() + return children + + def _apply_and_replace_all_child_nodes( + self, fn: "Callable[[DAGNode], T]" + ) -> "DAGNode": + """Apply and replace all immediate child nodes using a given function. + + This is a shallow replacement only. To recursively transform nodes in + the DAG, use ``apply_recursive()``. + + Args: + fn: Callable that will be applied once to each child of this node. + + Returns: + New DAGNode after replacing all child nodes. + """ + + replace_table = {} + # CloudPickler scanner object for current layer of DAGNode. Same + # scanner should be use for a full find & replace cycle. + scanner = _PyObjScanner() + # Find all first-level nested DAGNode children in args. + # Update replacement table and execute the replace. + for node in scanner.find_nodes( + [ + self._bound_args, + self._bound_kwargs, + self._bound_other_args_to_resolve, + ] + ): + if node not in replace_table: + replace_table[node] = fn(node) + new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes( + replace_table + ) + scanner.clear() + + # Return updated copy of self. + return self._copy( + new_args, new_kwargs, self.get_options(), new_other_args_to_resolve + ) + + def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T: + """Apply callable on each node in this DAG in a bottom-up tree walk. + + Args: + fn: Callable that will be applied once to each node in the + DAG. It will be applied recursively bottom-up, so nodes can + assume the fn has been applied to their args already. + + Returns: + Return type of the fn after application to the tree. + """ + + if not type(fn).__name__ == "_CachingFn": + + class _CachingFn: + def __init__(self, fn): + self.cache = {} + self.fn = fn + self.fn.cache = self.cache + self.input_node_uuid = None + + def __call__(self, node: "DAGNode"): + from ray.dag.input_node import InputNode + + if node._stable_uuid not in self.cache: + self.cache[node._stable_uuid] = self.fn(node) + if isinstance(node, InputNode): + if not self.input_node_uuid: + self.input_node_uuid = node._stable_uuid + elif self.input_node_uuid != node._stable_uuid: + raise AssertionError( + "Each DAG should only have one unique InputNode." + ) + return self.cache[node._stable_uuid] + + fn = _CachingFn(fn) + else: + if self._stable_uuid in fn.cache: + return fn.cache[self._stable_uuid] + + return fn( + self._apply_and_replace_all_child_nodes( + lambda node: node.apply_recursive(fn) + ) + ) + + def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): + """ + Traverse all nodes in the connected component of the DAG that contains + the `self` node, and apply the given function to each node. + """ + visited = set() + queue = [self] + cgraph_output_node: Optional[DAGNode] = None + + while queue: + node = queue.pop(0) + if node._args_contain_nested_dag_node: + self._raise_nested_dag_node_error(node._bound_args) + + if node not in visited: + if node.is_cgraph_output_node: + # Validate whether there are multiple nodes that call + # `experimental_compile`. + if cgraph_output_node is not None: + raise ValueError( + "The DAG was compiled more than once. The following two " + "nodes call `experimental_compile`: " + f"(1) {cgraph_output_node}, (2) {node}" + ) + cgraph_output_node = node + fn(node) + visited.add(node) + """ + Add all unseen downstream and upstream nodes to the queue. + This function should be called by the root of the DAG. However, + in some invalid cases, some nodes may not be descendants of the + root. Therefore, we also add upstream nodes to the queue so that + a meaningful error message can be raised when the DAG is compiled. + + ``` + with InputNode() as inp: + dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)]) + ``` + + In the above example, `a2.inc` is not a descendant of inp. If we only + add downstream nodes to the queue, the `a2.inc` node will not be visited + , and the error message will be hard to understand, such as a key error + in the compiled DAG. + """ + for neighbor in chain.from_iterable( + [node._downstream_nodes, node._upstream_nodes] + ): + if neighbor not in visited: + queue.append(neighbor) + + def _raise_nested_dag_node_error(self, args): + """ + Raise an error for nested DAGNodes in Ray Compiled Graphs. + + Args: + args: The arguments of the DAGNode. + """ + for arg in args: + if isinstance(arg, DAGNode): + continue + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes([arg]) + scanner.clear() + if len(dag_nodes) > 0: + raise ValueError( + f"Found {len(dag_nodes)} DAGNodes from the arg {arg} " + f"in {self}. Please ensure that the argument is a " + "single DAGNode and that a DAGNode is not allowed to " + "be placed inside any type of container." + ) + raise AssertionError( + "A DAGNode's args should contain nested DAGNodes as args, " + "but none were found during the compilation process. This is a " + "Ray internal error. Please report this issue to the Ray team." + ) + + def _find_root(self) -> "DAGNode": + """ + Return the root node of the DAG. The root node must be an InputNode. + """ + from ray.dag.input_node import InputNode + + node = self + while not isinstance(node, InputNode): + if len(node._upstream_nodes) == 0: + raise ValueError( + "No InputNode found in the DAG: when traversing upwards, " + f"no upstream node was found for {node}." + ) + node = node._upstream_nodes[0] + return node + + def apply_functional( + self, + source_input_list: Any, + predictate_fn: Callable, + apply_fn: Callable, + ): + """ + Apply a given function to DAGNodes in source_input_list, and return + the replaced inputs without mutating or coping any DAGNode. + + Args: + source_input_list: Source inputs to extract and apply function on + all children DAGNode instances. + predictate_fn: Applied on each DAGNode instance found and determine + if we should apply function to it. Can be used to filter node + types. + apply_fn: Function to appy on the node on bound attributes. Example: + apply_fn = lambda node: node._get_serve_deployment_handle( + node._deployment, node._bound_other_args_to_resolve + ) + + Returns: + replaced_inputs: Outputs of apply_fn on DAGNodes in + source_input_list that passes predictate_fn. + """ + replace_table = {} + scanner = _PyObjScanner() + for node in scanner.find_nodes(source_input_list): + if predictate_fn(node) and node not in replace_table: + replace_table[node] = apply_fn(node) + + replaced_inputs = scanner.replace_nodes(replace_table) + scanner.clear() + + return replaced_inputs + + def _execute_impl( + self, *args, **kwargs + ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]: + """Execute this node, assuming args have been transformed already.""" + raise NotImplementedError + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DAGNode": + """Return a copy of this node with the given new args.""" + raise NotImplementedError + + def _copy( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DAGNode": + """Return a copy of this node with the given new args.""" + instance = self._copy_impl( + new_args, new_kwargs, new_options, new_other_args_to_resolve + ) + instance._stable_uuid = self._stable_uuid + instance._type_hint = copy.deepcopy(self._type_hint) + instance._original_type_hint = copy.deepcopy(self._original_type_hint) + return instance + + def __getstate__(self): + """Required due to overriding `__getattr__` else pickling fails.""" + return self.__dict__ + + def __setstate__(self, d: Dict[str, Any]): + """Required due to overriding `__getattr__` else pickling fails.""" + self.__dict__.update(d) + + def __getattr__(self, attr: str): + if attr == "bind": + raise AttributeError(f".bind() cannot be used again on {type(self)} ") + elif attr == "remote": + raise AttributeError( + f".remote() cannot be used on {type(self)}. To execute the task " + "graph for this node, use .execute()." + ) + else: + return self.__getattribute__(attr) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py b/.venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..c79cca2c7eb05cf88c9dee4534f1c64057b2697e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/dag_node_operation.py @@ -0,0 +1,789 @@ +from functools import total_ordering +from enum import Enum +from typing import Set, Tuple, List, Dict, Optional +import copy +import logging +import ray +import heapq +from collections import defaultdict + + +logger = logging.getLogger(__name__) + + +class _DAGNodeOperationType(Enum): + """ + There are three types of operations that a DAG node can perform: + 1. READ: Read from an input channel. + 2. COMPUTE: Execute the method corresponding to the node. + 3. WRITE: Write to an output channel. + """ + + READ = "READ" + COMPUTE = "COMPUTE" + WRITE = "WRITE" + + def viz_str(self): + """ + A string representation of the operation type to be used in visualization. + + The result string is a single character because conciseness is preferred. + """ + if self == _DAGNodeOperationType.READ: + return "R" + elif self == _DAGNodeOperationType.COMPUTE: + return "C" + elif self == _DAGNodeOperationType.WRITE: + return "W" + assert False, f"Unknown operation type: {self}" + + +class _DAGNodeOperation: + def __init__( + self, + exec_task_idx: int, + operation_type: _DAGNodeOperationType, + method_name: Optional[str] = None, + ): + """ + Args: + exec_task_idx: The index of the task that this operation belongs to + in the actor's ExecutableTask list. The index is not the same + as bind_index because there may be more tasks bound to an actor + than tasks that appear in the current compiled DAG. + operation_type: The type of operation to perform. + method_name: The name of the method that this operation originates + from. This is only for visualization and debugging purposes. + """ + self.exec_task_idx = exec_task_idx + self.type = operation_type + self.method_name = method_name + + def __repr__(self): + return ( + f"_DAGNodeOperation(" + f"exec_task_idx: {self.exec_task_idx}, " + f" type: {self.type})" + ) + + def viz_str(self): + """ + A string representation of the node to be used in visualization. + """ + return f"[{self.exec_task_idx}] {self.method_name} {self.type.viz_str()}" + + def __hash__(self): + return hash((self.exec_task_idx, self.type)) + + def __eq__(self, other): + # An operation is uniquely identified by its `exec_task_idx` and type. + # `method_name` is only for debugging purposes. + return self.exec_task_idx == other.exec_task_idx and self.type == other.type + + +@total_ordering +class _DAGOperationGraphNode: + def __init__( + self, + operation: _DAGNodeOperation, + task_idx: int, + actor_handle: "ray.actor.ActorHandle", + requires_nccl: bool, + ): + """ + _DAGOperationGraphNode represents a node in the DAG operation graph. + It contains information about the node's in-degree, out-degree, edges, + and the operation it performs. + + Args: + operation: The operation that this node performs. The operation + can be a READ, COMPUTE, or WRITE operation. + task_idx: A unique index which can be used to index into + `CompiledDAG.idx_to_task` to get the corresponding task. + actor_handle: The actor handle to which this operation belongs. + requires_nccl: Whether this operation requires NCCL. + """ + self.operation = operation + self.task_idx = task_idx + self.actor_handle = actor_handle + self.requires_nccl = requires_nccl + # The in_edges and out_edges are dicts of tuples to strings. + # Each tuple (the key) contains an integer `task_idx`, which can be + # used to index into `idx_to_task` to get the corresponding task, + # and a `_DAGNodeOperationType`, which can be READ, COMPUTE, or WRITE. + # The string (the value) is the visualization information of the edge, + # it is a tuple of a label of the edge and a boolean indicating whether + # the edge is a control dependency. + self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {} + self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {} + # The collective nodes are the nodes that belong to the same collective + # operation. Each node is represented by a tuple of its task idx and type. + self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() + # The ready collective nodes are the nodes that are ready to be executed, + # i.e., their in-degrees are zero. When a collective node is ready, it + # will be added to the ready collective nodes of all the nodes in its + # collective operation. + self.ready_collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() + + def __repr__(self): + return ( + f"_DAGOperationGraphNode(" + f"operation: {self.operation}, " + f"task_idx: {self.task_idx}, " + f"actor_handle: {self.actor_handle}, " + f"requires_nccl: {self.requires_nccl})" + ) + + def __lt__(self, other: "_DAGOperationGraphNode"): + """ + This function defines the order of the nodes in the priority queue used in + `_select_next_nodes`. The priority queue is a min-heap, so the node with + higher priority is considered "less than" the other node. + """ + + def compare(lhs: "_DAGOperationGraphNode", rhs: "_DAGOperationGraphNode"): + # If both nodes belong to the same actor, the node with the smaller + # `exec_task_idx` is prioritized. If two nodes belong to different + # actors, it approximates balancing the scheduled tasks across actors, + # by prioritizing the node with the smaller `exec_task_idx`. The tie + # is broken by the `task_idx`. + if lhs.operation.exec_task_idx != rhs.operation.exec_task_idx: + return lhs.operation.exec_task_idx < rhs.operation.exec_task_idx + return lhs.task_idx < rhs.task_idx + + if self.actor_handle == other.actor_handle: + # When both nodes belong to the same actor, use the default comparison. + return compare(self, other) + elif self.is_nccl_op != other.is_nccl_op: + # When one node is a NCCL operation and the other is not, prioritize + # the non-NCCL operation. + return not self.is_nccl_op + else: + # When either both nodes are NCCL operations or both nodes are not + # NCCL operations, use the default comparison. + return compare(self, other) + + def __eq__(self, other: "_DAGOperationGraphNode"): + """ + Two operations are equal only when they have the same `exec_task_idx` and `type` + and belong to the same actor. + """ + return ( + self.actor_handle == other.actor_handle + and self.operation.exec_task_idx == other.operation.exec_task_idx + and self.operation.type == other.operation.type + ) + + def __hash__(self): + """ + An operation is uniquely identified by its `task_idx` and type. + """ + return hash((self.operation, self.task_idx)) + + @property + def in_degree(self) -> int: + return len(self.in_edges) + + @property + def is_ready(self) -> bool: + """ + If a node is not a NCCL collective, it is ready when it has a zero + in-degree. If it is a NCCL collective, it is ready when all the nodes + in its collective operation have zero in-degrees. + """ + return self.in_degree == 0 and ( + len(self.ready_collective_idxs) == len(self.collective_idxs) + ) + + @property + def is_read(self) -> bool: + return self.operation.type == _DAGNodeOperationType.READ + + @property + def is_nccl_collective(self) -> bool: + """ + A node is a NCCL collective if it is a compute node and requires NCCL. + """ + return ( + self.operation.type == _DAGNodeOperationType.COMPUTE and self.requires_nccl + ) + + @property + def is_nccl_write(self) -> bool: + """ + A node is a NCCL write if it is a write node and requires NCCL. + """ + return self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl + + @property + def is_nccl_op(self) -> bool: + return self.is_nccl_collective or self.is_nccl_write + + def viz_str(self): + """ + A string representation of the node to be used in visualization. + """ + return self.operation.viz_str() + + @property + def _actor_id(self): + return self.actor_handle._ray_actor_id.hex() + + +def _add_edge( + from_node: _DAGOperationGraphNode, + to_node: _DAGOperationGraphNode, + label: str = "", + control_dependency: bool = False, +): + """ + Add an edge from `from_node` to `to_node`. + + Args: + from_node: The node from which the edge originates. + to_node: The node to which the edge points. + label: The label of the edge. This will be used to annotate the edge + in the visualization of the execution schedule. + """ + from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = ( + label, + control_dependency, + ) + to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = ( + label, + control_dependency, + ) + + +def _push_candidate_node_if_ready( + actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + node: _DAGOperationGraphNode, +) -> None: + # Collective operations are ready when all the collective nodes have zero + # in-degrees. Only one node per collective will be added as ready. + if node.is_nccl_collective: + for collective_node_metadata in node.collective_idxs: + task_idx, op_type = collective_node_metadata + collective_node = graph[task_idx][op_type] + collective_node.ready_collective_idxs.add( + (node.task_idx, node.operation.type) + ) + if node.is_ready: + heapq.heappush( + actor_to_candidates[node.actor_handle._actor_id], + node, + ) + + +def _select_next_nodes( + actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], +) -> Optional[List[_DAGOperationGraphNode]]: + """ + This function selects the next nodes for the topological sort to generate + execution schedule. If there are multiple candidate _DAGOperationGraphNodes, + select the node with the top priority. The priority is defined in + `_DAGOperationGraphNode.__lt__`. + + For the implementation details, we maintain a priority queue for each actor, + where the head of the priority queue is the node with the smallest `exec_task_idx`. + When a node has a zero in-degree, it is added to the corresponding actor's + priority queue. For a node other than a NCCL collective node, it is ready to be + executed if it has a zero in-degree. For a NCCL collective node, it is ready to + be executed when all the nodes in its collective operation have zero in-degrees. + + If a node is a NCCL collective node, it updates the `ready_collective_nodes` of + all the nodes in its collective operation. Unless all the nodes in its collective + group have zero in-degrees, this node is removed from the candidate list. + Eventually, exactly one NCCL collective node from its collective operation is + selected from the candidate list. + + If the selected node is a NCCL write node, select all the downstream NCCL + read nodes. If the selected node is a NCCL collective node, select all the NCCL + compute nodes in its collective operation. + + Args: + actor_to_candidates: A dictionary mapping an actor id to a list of + candidate nodes. The list is maintained as a priority queue, so + the head of the queue, i.e., `candidates[0]`, is the node with + the smallest `bind_index`. + graph: A dictionary mapping the index of a task to a dictionary of its + _DAGOperationGraphNodes for different operations. + + Returns: + A list of _DAGOperationGraphNodes to be placed into the corresponding + execution schedules. + """ + top_priority_node = None + for _, candidates in actor_to_candidates.items(): + if len(candidates) == 0: + continue + if top_priority_node is None or candidates[0] < top_priority_node: + top_priority_node = candidates[0] + + if top_priority_node is None: + return None + next_nodes = [ + heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id]) + ] + + if not top_priority_node.is_nccl_op: + # A non-NCCL operation node is picked. + assert len(next_nodes) == 1 + elif top_priority_node.is_nccl_write: + # a NCCL write node is picked. NCCL is a blocking operation, so we need + # to pick all the corresponding NCCL read nodes to avoid a deadlock. + for downstream_node_metadata in top_priority_node.out_edges: + task_idx, op_type = downstream_node_metadata + downstream_node = graph[task_idx][op_type] + assert downstream_node.is_read + next_nodes.append(downstream_node) + assert len(next_nodes) == 1 + len(top_priority_node.out_edges) + elif top_priority_node.is_nccl_collective: + # a NCCL collective node is picked. NCCL is a blocking operation, so we need + # to pick all the corresponding NCCL collective nodes in its collective + # operation to avoid a deadlock. + for collective_node_metadata in top_priority_node.collective_idxs: + task_idx, op_type = collective_node_metadata + collective_node = graph[task_idx][op_type] + assert collective_node.is_nccl_collective and collective_node.is_ready + if collective_node != top_priority_node: + next_nodes.append(collective_node) + assert len(next_nodes) == len(top_priority_node.collective_idxs) + + return next_nodes + + +def _build_dag_node_operation_graph( + idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"], + actor_to_operation_nodes: Dict[ + "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]] + ], +) -> Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]: + """ + Generate a DAG node operation graph by adding edges based on the + following rules: + + #1 Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which + belong to the same task. + #2 Add an edge from COMPUTE with bind_index i to COMPUTE with bind_index + i+1 if they belong to the same actor. + #3 Add an edge from WRITE of the writer task to READ of the reader task. + + This is the step one of building an execution schedule for each actor. + + Args: + idx_to_task: A dictionary that maps the `task_idx` to the `CompiledTask`. + `CompiledTask` contains information about a DAGNode and its downstream + nodes. + + actor_to_operation_nodes: A dictionary that maps an actor handle to + a list of lists of _DAGOperationGraphNode. For the same actor, the + index of the outer list corresponds to the index of the ExecutableTask + in the list of `executable_tasks` in `actor_to_executable_tasks`. In + the inner list, the order of operations is READ, COMPUTE, and WRITE. + + Returns: + A graph where each node is a _DAGOperationGraphNode. The key is `task_idx`, + the index to retrieve its task from `idx_to_task`, and the value is a + dictionary that maps the _DAGNodeOperationType (READ, COMPUTE, or WRITE) + to the corresponding _DAGOperationGraphNode + """ + assert idx_to_task + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] = {} + + for _, operation_nodes_list in actor_to_operation_nodes.items(): + prev_compute_node = None + for operation_nodes in operation_nodes_list: + task_idx = operation_nodes[0].task_idx + read_node, compute_node, write_node = ( + operation_nodes[0], + operation_nodes[1], + operation_nodes[2], + ) + # Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which + # belong to the same task. + _add_edge(read_node, compute_node) + _add_edge(compute_node, write_node) + # Add an edge from COMPUTE with `bind_index` i to COMPUTE with + # `bind_index` i+1 if they belong to the same actor. + if prev_compute_node is not None: + _add_edge(prev_compute_node, compute_node, "", True) + prev_compute_node = compute_node + assert task_idx not in graph + graph[task_idx] = { + _DAGNodeOperationType.READ: read_node, + _DAGNodeOperationType.COMPUTE: compute_node, + _DAGNodeOperationType.WRITE: write_node, + } + + # Import `ray.dag` here to avoid circular import. + from ray.dag import ClassMethodNode, CollectiveOutputNode, MultiOutputNode + + # Add an edge from WRITE of the writer task to READ of the reader task. + for task_idx, task in idx_to_task.items(): + if not ( + isinstance(task.dag_node, ClassMethodNode) + or isinstance(task.dag_node, CollectiveOutputNode) + ): + # The graph is used to generate an execution schedule for each actor. + # The edge from the InputNode has no impact on the final execution + # schedule. + continue + if ( + isinstance(task.dag_node, ClassMethodNode) + and task.dag_node.is_class_method_output + ): + # Class method output node dependencies are handled at its upstream: + # i.e., class method node + continue + for downstream_task_idx in task.downstream_task_idxs: + downstream_dag_node = idx_to_task[downstream_task_idx].dag_node + if isinstance(downstream_dag_node, MultiOutputNode): + continue + if ( + isinstance(downstream_dag_node, ClassMethodNode) + and downstream_dag_node.is_class_method_output + ): + consumer_idxs = idx_to_task[downstream_task_idx].downstream_task_idxs + for consumer_idx in consumer_idxs: + if consumer_idx in graph: + _add_edge( + graph[task_idx][_DAGNodeOperationType.WRITE], + graph[consumer_idx][_DAGNodeOperationType.READ], + "nccl" + if graph[task_idx][ + _DAGNodeOperationType.WRITE + ].requires_nccl + else "shm", + ) + continue + _add_edge( + graph[task_idx][_DAGNodeOperationType.WRITE], + graph[downstream_task_idx][_DAGNodeOperationType.READ], + "nccl" + if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl + else "shm", + ) + + return graph + + +def _actor_viz_label(actor: "ray.actor.ActorHandle"): + """ + Returns the label of an actor in the visualization of the execution schedule. + + Args: + actor: The actor to be represented. + """ + class_name = actor._ray_actor_creation_function_descriptor.class_name + actor_id = actor._ray_actor_id.hex() + return f"Actor class name: {class_name}\nActor ID: {actor_id}" + + +def _node_viz_id_and_label( + node: _DAGOperationGraphNode, idx: int, optimized_index: int +): + """ + Returns the visualization id and label of a node. The visualization id is unique + across all nodes. + + Args: + node: The node to be represented. + idx: The index of the node in the execution schedule. + optimized_index: The index of the node in the optimized execution schedule. + """ + node_viz_label = node.viz_str() + f" {idx},{optimized_index}" + node_viz_id = f"{node._actor_id}_{node_viz_label}" + return node_viz_id, node_viz_label + + +def _visualize_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ], + actor_to_overlapped_schedule: Optional[ + Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]] + ], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], +): + """ + Visualize the execution schedule for each actor. + + The visualization will be saved as a PNG file named `compiled_graph_schedule.png`. + Details of the visualization: # noqa + + Node description format: + [] , + + Node description fields: + operation: is R(READ), C(COMPUTE), or W(WRITE) + orig_index: the index in the original execution schedule + overlap_index: the index in the overlap-communication optimized execution schedule + If this is different from orig_index, the node is highlighted in red color + + Node grouping: + The nodes belonging to the same actor are grouped in the same rectangle + The actor class name and the actor id are shown in the rectangle + + Edges: + black color (without label): data dependency + black color (annotated with "shm"): shared memory channel + blue color (annotated with "nccl): NCCL channel + dashed edge: control dependency between compute operations + + Args: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operation nodes. + actor_to_overlapped_schedule: A dictionary that maps an actor handle to the + optimized execution schedule which is a list of operation nodes. + graph: A graph where each node is a _DAGOperationGraphNode. The key is + `task_idx`, the index to retrieve its task from `idx_to_task`, and + the value is a dictionary that maps the _DAGNodeOperationType (READ, + COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is + generated by `_build_dag_node_operation_graph`. + """ + try: + import graphviz + except ImportError: + raise ImportError( + "Please install graphviz to visualize the execution schedule. " + "You can install it by running `pip install graphviz`." + ) + + dot = graphviz.Digraph(comment="DAG") + # A dictionary that maps a node to its visualization id + node_to_viz_id: Dict[_DAGOperationGraphNode, str] = {} + + if actor_to_overlapped_schedule is None: + # TODO(rui): make the visualization more concise by only displaying + # the original schedule + actor_to_overlapped_schedule = actor_to_execution_schedule + for actor, execution_nodes in actor_to_execution_schedule.items(): + overlapped_schedule = actor_to_overlapped_schedule[actor] + node_to_optimized_index = { + node: i for i, node in enumerate(overlapped_schedule) + } + + actor_id = actor._ray_actor_id.hex() + with dot.subgraph(name=f"cluster_{actor_id}") as subgraph: + subgraph.attr(rank=actor_id, label=_actor_viz_label(actor)) + for i, node in enumerate(execution_nodes): + optimized_index = node_to_optimized_index.get(node) + node_viz_id, node_viz_label = _node_viz_id_and_label( + node, i, optimized_index + ) + color = "red" if optimized_index != i else "black" + subgraph.node(node_viz_id, node_viz_label, color=color) + node_to_viz_id[node] = node_viz_id + + for actor, execution_nodes in actor_to_execution_schedule.items(): + for i, node in enumerate(execution_nodes): + node_viz_id = node_to_viz_id[node] + for out_edge, viz_info in node.out_edges.items(): + label, control_dependency = viz_info + out_task_idx, out_op_type = out_edge + out_node = graph[out_task_idx][out_op_type] + out_node_viz_id = node_to_viz_id[out_node] + color = "blue" if label == "nccl" else "black" + style = "dashed" if control_dependency else "solid" + dot.edge( + node_viz_id, out_node_viz_id, label=label, color=color, style=style + ) + + # Add legend + with dot.subgraph(name="cluster_legend") as legend: + legend.attr(label="Legend", labelloc="t", fontsize="20", bgcolor="lightgrey") + + # Single node and its explanation + legend.node("example_node", "[0] bwd C 10,10\n") + explanation = ( + '<' # noqa + '' + '' # noqa + "" + '' + '' # noqa + '' # noqa + '' # noqa + '' # noqa + "" + '' + '' # noqa + '' # noqa + "" + '' + '' # noqa + '' # noqa + '' # noqa + '' # noqa + "
Node description format:
[<task_index>] <method_name> <operation> <orig_index>, <overlap_index>
Node description fields:
operation: is R(READ), C(COMPUTE), or W(WRITE)
orig_index: the index in the original execution schedule
overlap_index: the index in the overlap-communication optimized execution schedule
If this is different from orig_index, the node is highlighted in red color
Node grouping:
The nodes belonging to the same actor are grouped in the same rectangle
The actor class name and the actor id are shown in the rectangle
Edges:
black color (without label): data dependency
black color (annotated with "shm"): shared memory channel
blue color (annotated with "nccl): NCCL channel
dashed edge: control dependency between compute operations
>" + ) + + legend.node("example_explanation", explanation, shape="plaintext") + legend.edge("example_node", "example_explanation", style="invis") + + logger.info( + "Writing compiled graph schedule visualization " + "to compiled_graph_schedule.png" + ) + dot.render("compiled_graph_schedule", format="png", view=False) + + +def _generate_actor_to_execution_schedule( + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] +) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]: + """ + Generate an execution schedule for each actor. The schedule is a list of + operation nodes to be executed. The function uses a topological sort + algorithm to generate the schedule. + + Args: + graph: A graph where each node is a _DAGOperationGraphNode. The key is + `task_idx`, the index to retrieve its task from `idx_to_task`, and + the value is a dictionary that maps the _DAGNodeOperationType (READ, + COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is + generated by `_build_dag_node_operation_graph`. + + Returns: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operation nodes to be + executed. + """ + + # Mapping from the actor handle to the execution schedule which is a list + # of operations to be executed. + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ] = defaultdict(list) + + # A dictionary mapping an actor id to a list of candidate nodes. The list + # is maintained as a priority queue, so the head of the queue, i.e., + # `candidates[0]`, is the node with the smallest `bind_index`. + actor_to_candidates: Dict[ + "ray._raylet.ActorID", List[_DAGOperationGraphNode] + ] = defaultdict(list) + for _, node_dict in graph.items(): + for _, node in node_dict.items(): + # A node with a zero in-degree edge means all of its dependencies + # have been satisfied, including both data and control dependencies. + # Therefore, it is a candidate for execution. + if node.in_degree == 0: + _push_candidate_node_if_ready(actor_to_candidates, graph, node) + + visited_nodes = set() + + # Use topological sort algorithm to generate the execution schedule. + while True: + # Select a list of nodes to be executed. There are three cases: + # 1. If a selected node is not a NCCL operation, only itself is returned. + # 2. If a selected node is a NCCL write operation, the corresponding NCCL + # read operations are also returned. + # 3. If a selected node is a NCCL collective operation, all the nodes in + # its collective operation are returned. + # In cases 1 and 3, all the selected nodes are ready. In case 2, the NCCL + # write node is ready, while the NCCL read nodes are not ready until their + # in-degrees are updated. + nodes = _select_next_nodes(actor_to_candidates, graph) + if nodes is None: + break + # Filter out the visited nodes. + nodes = [node for node in nodes if node not in visited_nodes] + # Add the selected nodes to the execution schedule. + for node in nodes: + actor_to_execution_schedule[node.actor_handle].append(node) + visited_nodes.add(node) + # Update the in-degree of the downstream nodes. + for node in nodes: + for out_node_task_idx, out_node_type in node.out_edges: + out_node = graph[out_node_task_idx][out_node_type] + out_node.in_edges.pop((node.task_idx, node.operation.type)) + if out_node.in_degree == 0 and out_node not in visited_nodes: + # If the downstream node is already visited, it has been added + # to the execution schedule. They are the NCCL read nodes in + # case 2. + _push_candidate_node_if_ready(actor_to_candidates, graph, out_node) + assert len(visited_nodes) == len(graph) * 3, "Expected all nodes to be visited" + for node in visited_nodes: + assert node.is_ready, f"Expected {node} to be ready" + for _, candidates in actor_to_candidates.items(): + assert len(candidates) == 0, "Expected all candidates to be empty" + + return actor_to_execution_schedule + + +def _generate_overlapped_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ], +) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]: + """ + From an existing execution schedule, generate a new schedule by overlapping + computation and communication. + + Currently, the algorithm generates a new schedule for each actor as follows: + For each NCCL read operation (i.e., recv), scan backwards to find the nearest + compute node to swap with so that the NCCL read operation can be overlapped + with computation. + + Collective operations are not yet supported. + + Args: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the existing execution schedule for the actor. The schedule is a list + is a list of operations to be executed. + + Returns: + A dictionary that maps an actor handle to the overlapped execution schedule + for the actor. + """ + + actor_to_overlapped_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ] = copy.deepcopy(actor_to_execution_schedule) + for overlapped_schedule in actor_to_overlapped_schedule.values(): + for i in range(len(overlapped_schedule)): + if ( + overlapped_schedule[i].operation.type == _DAGNodeOperationType.READ + and overlapped_schedule[i].requires_nccl + ): + # For each NCCL read operation (i.e., recv), scan backwards + # to find the nearest compute node to swap with so that + # the NCCL read operation can be overlapped with computation. + for j in range(i - 1, -1, -1): + if ( + overlapped_schedule[j].operation.type + == _DAGNodeOperationType.COMPUTE + ): + # Found a desired compute operation, make the swap + nccl_read_op = overlapped_schedule[i] + prev_ops = overlapped_schedule[j:i] + overlapped_schedule[j + 1 : i + 1] = prev_ops + overlapped_schedule[j] = nccl_read_op + break + if ( + overlapped_schedule[j].operation.type + == _DAGNodeOperationType.READ + or overlapped_schedule[j].operation.type + == _DAGNodeOperationType.WRITE + ) and overlapped_schedule[j].requires_nccl: + # Found a NCCL read/write operation, skip the overlap + # optimization to keep relative order of NCCL operations + break + return actor_to_overlapped_schedule + + +def _extract_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ] +) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]: + """ + Extract _DAGNodeOperation from _DAGOperationGraphNode in the schedule + and discard unnecessary information. + """ + return { + actor: [node.operation for node in nodes] + for actor, nodes in actor_to_execution_schedule.items() + } diff --git a/.venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py b/.venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py new file mode 100644 index 0000000000000000000000000000000000000000..33d790515d3cb12fabb7db3e8d352a186632c264 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/dag_operation_future.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from ray.util.annotations import DeveloperAPI + + +if TYPE_CHECKING: + import cupy as cp + +T = TypeVar("T") + + +@DeveloperAPI +class DAGOperationFuture(ABC, Generic[T]): + """ + A future representing the result of a DAG operation. + + This is an abstraction that is internal to each actor, + and is not exposed to the DAG caller. + """ + + @abstractmethod + def wait(self): + """ + Wait for the future and return the result of the operation. + """ + raise NotImplementedError + + +@DeveloperAPI +class ResolvedFuture(DAGOperationFuture): + """ + A future that is already resolved. Calling `wait()` on this will + immediately return the result without blocking. + """ + + def __init__(self, result): + """ + Initialize a resolved future. + + Args: + result: The result of the future. + """ + self._result = result + + def wait(self): + """ + Wait and immediately return the result. This operation will not block. + """ + return self._result + + +@DeveloperAPI +class GPUFuture(DAGOperationFuture[Any]): + """ + A future for a GPU event on a CUDA stream. + + This future wraps a buffer, and records an event on the given stream + when it is created. When the future is waited on, it makes the current + CUDA stream wait on the event, then returns the buffer. + + The buffer must be a GPU tensor produced by an earlier operation launched + on the given stream, or it could be CPU data. Then the future guarantees + that when the wait() returns, the buffer is ready on the current stream. + + The `wait()` does not block CPU. + """ + + def __init__(self, buf: Any, stream: Optional["cp.cuda.Stream"] = None): + """ + Initialize a GPU future on the given stream. + + Args: + buf: The buffer to return when the future is resolved. + stream: The CUDA stream to record the event on, this event is waited + on when the future is resolved. If None, the current stream is used. + """ + import cupy as cp + + if stream is None: + stream = cp.cuda.get_current_stream() + + self._buf = buf + self._event = cp.cuda.Event() + self._event.record(stream) + + def wait(self) -> Any: + """ + Wait for the future on the current CUDA stream and return the result from + the GPU operation. This operation does not block CPU. + """ + import cupy as cp + + current_stream = cp.cuda.get_current_stream() + current_stream.wait_event(self._event) + return self._buf diff --git a/.venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py b/.venv/lib/python3.11/site-packages/ray/dag/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5bcd32c0e6d677644946655ece33f7f8ac6482c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/experimental/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/format_utils.py b/.venv/lib/python3.11/site-packages/ray/dag/format_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1428317da1c37a93800c66d3162ddffa58ad82e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/format_utils.py @@ -0,0 +1,155 @@ +from ray.dag import DAGNode +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +def get_dag_node_str( + dag_node: DAGNode, + body_line, +): + indent = _get_indentation() + other_args_to_resolve_lines = _get_other_args_to_resolve_lines( + dag_node._bound_other_args_to_resolve + ) + return ( + f"({dag_node.__class__.__name__}, {dag_node._stable_uuid})(\n" + f"{indent}body={body_line}\n" + f"{indent}args={_get_args_lines(dag_node._bound_args)}\n" + f"{indent}kwargs={_get_kwargs_lines(dag_node._bound_kwargs)}\n" + f"{indent}options={_get_options_lines(dag_node._bound_options)}\n" + f"{indent}other_args_to_resolve={other_args_to_resolve_lines}\n" + f")" + ) + + +def _get_indentation(num_spaces=4): + return " " * num_spaces + + +def _get_args_lines(bound_args): + """Pretty prints bounded args of a DAGNode, and recursively handle + DAGNode in list / dict containers. + """ + indent = _get_indentation() + lines = [] + for arg in bound_args: + if isinstance(arg, DAGNode): + node_repr_lines = str(arg).split("\n") + for node_repr_line in node_repr_lines: + lines.append(f"{indent}" + node_repr_line) + elif isinstance(arg, list): + for ele in arg: + node_repr_lines = str(ele).split("\n") + for node_repr_line in node_repr_lines: + lines.append(f"{indent}" + node_repr_line) + elif isinstance(arg, dict): + for _, val in arg.items(): + node_repr_lines = str(val).split("\n") + for node_repr_line in node_repr_lines: + lines.append(f"{indent}" + node_repr_line) + # TODO: (jiaodong) Handle nested containers and other obj types + else: + lines.append(f"{indent}" + str(arg) + ", ") + + if len(lines) == 0: + args_line = "[]" + else: + args_line = "[" + for args in lines: + args_line += f"\n{indent}{args}" + args_line += f"\n{indent}]" + + return args_line + + +def _get_kwargs_lines(bound_kwargs): + """Pretty prints bounded kwargs of a DAGNode, and recursively handle + DAGNode in list / dict containers. + """ + # TODO: (jiaodong) Nits, we're missing keys and indentation was a bit off. + if not bound_kwargs: + return "{}" + indent = _get_indentation() + kwargs_lines = [] + for key, val in bound_kwargs.items(): + if isinstance(val, DAGNode): + node_repr_lines = str(val).split("\n") + for index, node_repr_line in enumerate(node_repr_lines): + if index == 0: + kwargs_lines.append( + f"{indent}{key}:" + f"{indent}" + node_repr_line + ) + else: + kwargs_lines.append(f"{indent}{indent}" + node_repr_line) + + elif isinstance(val, list): + for ele in val: + node_repr_lines = str(ele).split("\n") + for node_repr_line in node_repr_lines: + kwargs_lines.append(f"{indent}" + node_repr_line) + elif isinstance(val, dict): + for _, inner_val in val.items(): + node_repr_lines = str(inner_val).split("\n") + for node_repr_line in node_repr_lines: + kwargs_lines.append(f"{indent}" + node_repr_line) + # TODO: (jiaodong) Handle nested containers and other obj types + else: + kwargs_lines.append(val) + + if len(kwargs_lines) > 0: + kwargs_line = "{" + for line in kwargs_lines: + kwargs_line += f"\n{indent}{line}" + kwargs_line += f"\n{indent}}}" + else: + kwargs_line = "{}" + + return kwargs_line + + +def _get_options_lines(bound_options): + """Pretty prints .options() in DAGNode. Only prints non-empty values.""" + if not bound_options: + return "{}" + indent = _get_indentation() + options_lines = [] + for key, val in bound_options.items(): + if val: + options_lines.append(f"{indent}{key}: " + str(val)) + + options_line = "{" + for line in options_lines: + options_line += f"\n{indent}{line}" + options_line += f"\n{indent}}}" + return options_line + + +def _get_other_args_to_resolve_lines(other_args_to_resolve): + if not other_args_to_resolve: + return "{}" + indent = _get_indentation() + other_args_to_resolve_lines = [] + for key, val in other_args_to_resolve.items(): + if isinstance(val, DAGNode): + node_repr_lines = str(val).split("\n") + for index, node_repr_line in enumerate(node_repr_lines): + if index == 0: + other_args_to_resolve_lines.append( + f"{indent}{key}:" + + f"{indent}" + + "\n" + + f"{indent}{indent}{indent}" + + node_repr_line + ) + else: + other_args_to_resolve_lines.append( + f"{indent}{indent}" + node_repr_line + ) + else: + other_args_to_resolve_lines.append(f"{indent}{key}: " + str(val)) + + other_args_to_resolve_line = "{" + for line in other_args_to_resolve_lines: + other_args_to_resolve_line += f"\n{indent}{line}" + other_args_to_resolve_line += f"\n{indent}}}" + return other_args_to_resolve_line diff --git a/.venv/lib/python3.11/site-packages/ray/dag/function_node.py b/.venv/lib/python3.11/site-packages/ray/dag/function_node.py new file mode 100644 index 0000000000000000000000000000000000000000..4565fcffe8ff3ac8e0f2c4821d27ae6d7d16a4f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/function_node.py @@ -0,0 +1,60 @@ +from typing import Any, Dict, List + + +import ray +from ray.dag.dag_node import DAGNode +from ray.dag.format_utils import get_dag_node_str +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class FunctionNode(DAGNode): + """Represents a bound task node in a Ray task DAG.""" + + def __init__( + self, + func_body, + func_args, + func_kwargs, + func_options, + other_args_to_resolve=None, + ): + self._body = func_body + super().__init__( + func_args, + func_kwargs, + func_options, + other_args_to_resolve=other_args_to_resolve, + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return FunctionNode( + self._body, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs): + """Executor of FunctionNode by ray.remote(). + + Args and kwargs are to match base class signature, but not in the + implementation. All args and kwargs should be resolved and replaced + with value in bound_args and bound_kwargs via bottom-up recursion when + current node is executed. + """ + return ( + ray.remote(self._body) + .options(**self._bound_options) + .remote(*self._bound_args, **self._bound_kwargs) + ) + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._body)) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/input_node.py b/.venv/lib/python3.11/site-packages/ray/dag/input_node.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea1f2730422d3496f072801ae3d4d803990c659 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/input_node.py @@ -0,0 +1,321 @@ +from typing import Any, Dict, List, Union, Optional + +from ray.dag import DAGNode +from ray.dag.format_utils import get_dag_node_str +from ray.experimental.gradio_utils import type_to_string +from ray.util.annotations import DeveloperAPI + +IN_CONTEXT_MANAGER = "__in_context_manager__" + + +@DeveloperAPI +class InputNode(DAGNode): + r"""Ray dag node used in DAG building API to mark entrypoints of a DAG. + + Should only be function or class method. A DAG can have multiple + entrypoints, but only one instance of InputNode exists per DAG, shared + among all DAGNodes. + + Example: + + .. code-block:: + + m1.forward + / \ + dag_input ensemble -> dag_output + \ / + m2.forward + + In this pipeline, each user input is broadcasted to both m1.forward and + m2.forward as first stop of the DAG, and authored like + + .. code-block:: python + + import ray + + @ray.remote + class Model: + def __init__(self, val): + self.val = val + def forward(self, input): + return self.val * input + + @ray.remote + def combine(a, b): + return a + b + + with InputNode() as dag_input: + m1 = Model.bind(1) + m2 = Model.bind(2) + m1_output = m1.forward.bind(dag_input[0]) + m2_output = m2.forward.bind(dag_input.x) + ray_dag = combine.bind(m1_output, m2_output) + + # Pass mix of args and kwargs as input. + ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2 + + # Alternatively user can also pass single data object, list or dict + # and access them via list index, object attribute or dict key str. + ray_dag.execute(UserDataObject(m1=1, m2=2)) + # dag_input.m1, dag_input.m2 + ray_dag.execute([1, 2]) + # dag_input[0], dag_input[1] + ray_dag.execute({"m1": 1, "m2": 2}) + # dag_input["m1"], dag_input["m2"] + """ + + def __init__( + self, + *args, + input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None, + _other_args_to_resolve=None, + **kwargs, + ): + """InputNode should only take attributes of validating and converting + input data rather than the input data itself. User input should be + provided via `ray_dag.execute(user_input)`. + + Args: + input_type: Describes the data type of inputs user will be giving. + - if given through singular InputNode: type of InputNode + - if given through InputAttributeNodes: map of key -> type + Used when deciding what Gradio block to represent the input nodes with. + _other_args_to_resolve: Internal only to keep InputNode's execution + context throughput pickling, replacement and serialization. + User should not use or pass this field. + """ + if len(args) != 0 or len(kwargs) != 0: + raise ValueError("InputNode should not take any args or kwargs.") + + self.input_attribute_nodes = {} + + self.input_type = input_type + if input_type is not None and isinstance(input_type, type): + if _other_args_to_resolve is None: + _other_args_to_resolve = {} + _other_args_to_resolve["result_type_string"] = type_to_string(input_type) + + super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return InputNode(_other_args_to_resolve=new_other_args_to_resolve) + + def _execute_impl(self, *args, **kwargs): + """Executor of InputNode.""" + # Catch and assert singleton context at dag execution time. + assert self._in_context_manager(), ( + "InputNode is a singleton instance that should be only used in " + "context manager for dag building and execution. See the docstring " + "of class InputNode for examples." + ) + # If user only passed in one value, for simplicity we just return it. + if len(args) == 1 and len(kwargs) == 0: + return args[0] + + return DAGInputData(*args, **kwargs) + + def _in_context_manager(self) -> bool: + """Return if InputNode is created in context manager.""" + if ( + not self._bound_other_args_to_resolve + or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve + ): + return False + else: + return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER] + + def set_context(self, key: str, val: Any): + """Set field in parent DAGNode attribute that can be resolved in both + pickle and JSON serialization + """ + self._bound_other_args_to_resolve[key] = val + + def __str__(self) -> str: + return get_dag_node_str(self, "__InputNode__") + + def __getattr__(self, key: str): + assert isinstance( + key, str + ), "Please only access dag input attributes with str key." + if key not in self.input_attribute_nodes: + self.input_attribute_nodes[key] = InputAttributeNode( + self, key, "__getattr__" + ) + return self.input_attribute_nodes[key] + + def __getitem__(self, key: Union[int, str]) -> Any: + assert isinstance(key, (str, int)), ( + "Please only use int index or str as first-level key to " + "access fields of dag input." + ) + + input_type = None + if self.input_type is not None and key in self.input_type: + input_type = type_to_string(self.input_type[key]) + + if key not in self.input_attribute_nodes: + self.input_attribute_nodes[key] = InputAttributeNode( + self, key, "__getitem__", input_type + ) + return self.input_attribute_nodes[key] + + def __enter__(self): + self.set_context(IN_CONTEXT_MANAGER, True) + return self + + def __exit__(self, *args): + pass + + def get_result_type(self) -> str: + """Get type of the output of this DAGNode. + + Generated by ray.experimental.gradio_utils.type_to_string(). + """ + if "result_type_string" in self._bound_other_args_to_resolve: + return self._bound_other_args_to_resolve["result_type_string"] + + +@DeveloperAPI +class InputAttributeNode(DAGNode): + """Represents partial access of user input based on an index (int), + object attribute or dict key (str). + + Examples: + + .. code-block:: python + + with InputNode() as dag_input: + a = dag_input[0] + b = dag_input.x + ray_dag = add.bind(a, b) + + # This makes a = 1 and b = 2 + ray_dag.execute(1, x=2) + + with InputNode() as dag_input: + a = dag_input[0] + b = dag_input[1] + ray_dag = add.bind(a, b) + + # This makes a = 2 and b = 3 + ray_dag.execute(2, 3) + + # Alternatively, you can input a single object + # and the inputs are automatically indexed from the object: + # This makes a = 2 and b = 3 + ray_dag.execute([2, 3]) + """ + + def __init__( + self, + dag_input_node: InputNode, + key: Union[int, str], + accessor_method: str, + input_type: str = None, + ): + self._dag_input_node = dag_input_node + self._key = key + self._accessor_method = accessor_method + super().__init__( + [], + {}, + {}, + { + "dag_input_node": dag_input_node, + "key": key, + "accessor_method": accessor_method, + # Type of the input tied to this node. Used by + # gradio_visualize_graph.GraphVisualizer to determine which Gradio + # component should be used for this node. + "result_type_string": input_type, + }, + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return InputAttributeNode( + new_other_args_to_resolve["dag_input_node"], + new_other_args_to_resolve["key"], + new_other_args_to_resolve["accessor_method"], + new_other_args_to_resolve["result_type_string"], + ) + + def _execute_impl(self, *args, **kwargs): + """Executor of InputAttributeNode. + + Args and kwargs are to match base class signature, but not in the + implementation. All args and kwargs should be resolved and replaced + with value in bound_args and bound_kwargs via bottom-up recursion when + current node is executed. + """ + + if isinstance(self._dag_input_node, DAGInputData): + return self._dag_input_node[self._key] + else: + # dag.execute() is called with only one arg, thus when an + # InputAttributeNode is executed, its dependent InputNode is + # resolved with original user input python object. + user_input_python_object = self._dag_input_node + if isinstance(self._key, str): + if self._accessor_method == "__getitem__": + return user_input_python_object[self._key] + elif self._accessor_method == "__getattr__": + return getattr(user_input_python_object, self._key) + elif isinstance(self._key, int): + return user_input_python_object[self._key] + else: + raise ValueError( + "Please only use int index or str as first-level key to " + "access fields of dag input." + ) + + def __str__(self) -> str: + return get_dag_node_str(self, f'["{self._key}"]') + + def get_result_type(self) -> str: + """Get type of the output of this DAGNode. + + Generated by ray.experimental.gradio_utils.type_to_string(). + """ + if "result_type_string" in self._bound_other_args_to_resolve: + return self._bound_other_args_to_resolve["result_type_string"] + + @property + def key(self) -> Union[int, str]: + return self._key + + +@DeveloperAPI +class DAGInputData: + """If user passed multiple args and kwargs directly to dag.execute(), we + generate this wrapper for all user inputs as one object, accessible via + list index or object attribute key. + """ + + def __init__(self, *args, **kwargs): + self._args = list(args) + self._kwargs = kwargs + + def __getitem__(self, key: Union[int, str]) -> Any: + if isinstance(key, int): + # Access list args by index. + return self._args[key] + elif isinstance(key, str): + # Access kwarg by key. + return self._kwargs[key] + else: + raise ValueError( + "Please only use int index or str as first-level key to " + "access fields of dag input." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/output_node.py b/.venv/lib/python3.11/site-packages/ray/dag/output_node.py new file mode 100644 index 0000000000000000000000000000000000000000..f9abdf1643e092bcfc030b4855d6ebd86b3355ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/output_node.py @@ -0,0 +1,45 @@ +import ray +from typing import Any, Dict, List, Union, Tuple + +from ray.dag import DAGNode +from ray.dag.format_utils import get_dag_node_str +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class MultiOutputNode(DAGNode): + """Ray dag node used in DAG building API to mark the endpoint of DAG""" + + def __init__( + self, + args: Union[List[DAGNode], Tuple[DAGNode]], + other_args_to_resolve: Dict[str, Any] = None, + ): + if isinstance(args, tuple): + args = list(args) + if not isinstance(args, list): + raise ValueError(f"Invalid input type for `args`, {type(args)}.") + super().__init__( + args, + {}, + {}, + other_args_to_resolve=other_args_to_resolve or {}, + ) + + def _execute_impl( + self, *args, **kwargs + ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]: + return self._bound_args + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DAGNode": + """Return a copy of this node with the given new args.""" + return MultiOutputNode(new_args, new_other_args_to_resolve) + + def __str__(self) -> str: + return get_dag_node_str(self, "__MultiOutputNode__") diff --git a/.venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py b/.venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd6b94ab535bd485d06989ec686da36955788f4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/py_obj_scanner.py @@ -0,0 +1,105 @@ +import io +from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union + +import pickle # noqa: F401 + +import ray +from ray.dag.base import DAGNodeBase + + +# Used in deserialization hooks to reference scanner instances. +_instances: Dict[int, "_PyObjScanner"] = {} + +# Generic types for the scanner to transform from and to. +SourceType = TypeVar("SourceType") +TransformedType = TypeVar("TransformedType") + + +def _get_node(instance_id: int, node_index: int) -> SourceType: + """Get the node instance. + + Note: This function should be static and globally importable, + otherwise the serialization overhead would be very significant. + """ + return _instances[instance_id]._replace_index(node_index) + + +class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]): + """Utility to find and replace the `source_type` in Python objects. + + `source_type` can either be a single type or a tuple of multiple types. + + The caller must first call `find_nodes()`, then compute a replacement table and + pass it to `replace_nodes`. + + This uses cloudpickle under the hood, so all sub-objects that are not `source_type` + must be serializable. + + Args: + source_type: the type(s) of object to find and replace. Default to DAGNodeBase. + """ + + def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase): + self.source_type = source_type + # Buffer to keep intermediate serialized state. + self._buf = io.BytesIO() + # List of top-level SourceType found during the serialization pass. + self._found = None + # List of other objects found during the serialization pass. + # This is used to store references to objects so they won't be + # serialized by cloudpickle. + self._objects = [] + # Replacement table to consult during deserialization. + self._replace_table: Dict[SourceType, TransformedType] = None + _instances[id(self)] = self + super().__init__(self._buf) + + def reducer_override(self, obj): + """Hook for reducing objects. + + Objects of `self.source_type` are saved to `self._found` and a global map so + they can later be replaced. + + All other objects fall back to the default `CloudPickler` serialization. + """ + if isinstance(obj, self.source_type): + index = len(self._found) + self._found.append(obj) + return _get_node, (id(self), index) + + return super().reducer_override(obj) + + def find_nodes(self, obj: Any) -> List[SourceType]: + """ + Serialize `obj` and store all instances of `source_type` found in `_found`. + + Args: + obj: The object to scan for `source_type`. + Returns: + A list of all instances of `source_type` found in `obj`. + """ + assert ( + self._found is None + ), "find_nodes cannot be called twice on the same PyObjScanner instance." + self._found = [] + self._objects = [] + self.dump(obj) + return self._found + + def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any: + """Replace previously found DAGNodes per the given table.""" + assert self._found is not None, "find_nodes must be called first" + self._replace_table = table + self._buf.seek(0) + return pickle.load(self._buf) + + def _replace_index(self, i: int) -> SourceType: + return self._replace_table[self._found[i]] + + def clear(self): + """Clear the scanner from the _instances""" + if id(self) in _instances: + del _instances[id(self)] + + def __del__(self): + self.clear() diff --git a/.venv/lib/python3.11/site-packages/ray/dag/utils.py b/.venv/lib/python3.11/site-packages/ray/dag/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce96b3c27a8aabc060901e75f4a675c943fa18c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/utils.py @@ -0,0 +1,66 @@ +from typing import Dict + +from ray.dag import ( + DAGNode, + InputNode, + InputAttributeNode, + FunctionNode, + ClassNode, + ClassMethodNode, + MultiOutputNode, +) + + +class _DAGNodeNameGenerator(object): + """ + Generate unique suffix for each given Node in the DAG. + Apply monotonic increasing id suffix for duplicated names. + """ + + def __init__(self): + self.name_to_suffix: Dict[str, int] = dict() + + def get_node_name(self, node: DAGNode): + # InputNode should be unique. + if isinstance(node, InputNode): + return "INPUT_NODE" + if isinstance(node, MultiOutputNode): + return "MultiOutputNode" + # InputAttributeNode suffixes should match the user-defined key. + elif isinstance(node, InputAttributeNode): + return f"INPUT_ATTRIBUTE_NODE_{node._key}" + + # As class, method, and function nodes may have duplicated names, + # generate unique suffixes for such nodes. + if isinstance(node, ClassMethodNode): + node_name = node.get_options().get("name", None) or node._method_name + elif isinstance(node, (ClassNode, FunctionNode)): + node_name = node.get_options().get("name", None) or node._body.__name__ + # we use instance class name check here to avoid importing ServeNodes as + # serve components are not included in Ray Core. + elif type(node).__name__ in ("DeploymentNode", "DeploymentFunctionNode"): + node_name = node.get_deployment_name() + elif type(node).__name__ == "DeploymentFunctionExecutorNode": + node_name = node._deployment_function_handle.deployment_name + else: + raise ValueError( + "get_node_name() should only be called on DAGNode instances." + ) + + if node_name not in self.name_to_suffix: + self.name_to_suffix[node_name] = 0 + return node_name + else: + self.name_to_suffix[node_name] += 1 + suffix_num = self.name_to_suffix[node_name] + + return f"{node_name}_{suffix_num}" + + def reset(self): + self.name_to_suffix = dict() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.reset() diff --git a/.venv/lib/python3.11/site-packages/ray/dag/vis_utils.py b/.venv/lib/python3.11/site-packages/ray/dag/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f83257295a10b8c3507dd3ec77a861ada7cf814 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dag/vis_utils.py @@ -0,0 +1,115 @@ +from ray.dag import DAGNode + +import os +import tempfile + +from ray.dag.utils import _DAGNodeNameGenerator +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +def plot(dag: DAGNode, to_file=None): + if to_file is None: + tmp_file = tempfile.NamedTemporaryFile(suffix=".png") + to_file = tmp_file.name + extension = "png" + else: + _, extension = os.path.splitext(to_file) + if not extension: + extension = "png" + else: + extension = extension[1:] + + graph = _dag_to_dot(dag) + graph.write(to_file, format=extension) + + # Render the image directly if running inside a Jupyter notebook + try: + from IPython import display + + return display.Image(filename=to_file) + except ImportError: + pass + + # close temp file if needed + try: + tmp_file.close() + except NameError: + pass + + +def _check_pydot_and_graphviz(): + """Check if pydot and graphviz are installed. + + pydot and graphviz are required for plotting. We check this + during runtime rather than adding them to Ray dependencies. + + """ + try: + import pydot + except ImportError: + raise ImportError( + "pydot is required to plot DAG, " "install it with `pip install pydot`." + ) + try: + pydot.Dot.create(pydot.Dot()) + except (OSError, pydot.InvocationException): + raise ImportError( + "graphviz is required to plot DAG, " + "download it from https://graphviz.gitlab.io/download/" + ) + + +def _get_nodes_and_edges(dag: DAGNode): + """Get all unique nodes and edges in the DAG. + + A basic dfs with memorization to get all unique nodes + and edges in the DAG. + Unique nodes will be used to generate unique names, + while edges will be used to construct the graph. + """ + + edges = [] + nodes = [] + + def _dfs(node): + nodes.append(node) + for child_node in node._get_all_child_nodes(): + edges.append((child_node, node)) + return node + + dag.apply_recursive(_dfs) + return nodes, edges + + +def _dag_to_dot(dag: DAGNode): + """Create a Dot graph from dag. + + TODO(lchu): + 1. add more Dot configs in kwargs, + e.g. rankdir, alignment, etc. + 2. add more contents to graph, + e.g. args, kwargs and options of each node + + """ + # Step 0: check dependencies and init graph + _check_pydot_and_graphviz() + import pydot + + graph = pydot.Dot(rankdir="LR") + + # Step 1: generate unique name for each node in dag + nodes, edges = _get_nodes_and_edges(dag) + name_generator = _DAGNodeNameGenerator() + node_names = {} + for node in nodes: + node_names[node] = name_generator.get_node_name(node) + + # Step 2: create graph with all the edges + for edge in edges: + graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]])) + # if there is only one node + if len(nodes) == 1 and len(edges) == 0: + graph.add_node(pydot.Node(node_names[nodes[0]])) + + return graph diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2683c40984342b629eecc4135995d32826204723 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__init__.py @@ -0,0 +1,39 @@ +from ray.experimental.channel.cached_channel import CachedChannel +from ray.experimental.channel.common import ( # noqa: F401 + AwaitableBackgroundReader, + AwaitableBackgroundWriter, + ChannelContext, + ChannelInterface, + ChannelOutputType, + CompiledDAGArgs, + ReaderInterface, + SynchronousReader, + SynchronousWriter, + WriterInterface, +) +from ray.experimental.channel.communicator import Communicator +from ray.experimental.channel.intra_process_channel import IntraProcessChannel +from ray.experimental.channel.shared_memory_channel import ( + BufferedSharedMemoryChannel, + Channel, + CompositeChannel, +) +from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel + +__all__ = [ + "AwaitableBackgroundReader", + "AwaitableBackgroundWriter", + "CachedChannel", + "Channel", + "Communicator", + "ReaderInterface", + "SynchronousReader", + "SynchronousWriter", + "WriterInterface", + "ChannelContext", + "TorchTensorNcclChannel", + "IntraProcessChannel", + "CompositeChannel", + "BufferedSharedMemoryChannel", + "CompiledDAGArgs", +] diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c888f8dc4e9a3042b1ab0a5a736ab6dd701b6b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/auto_transport_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/auto_transport_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2667f2f5fe2d346860d2f31993059cd961c3c2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/auto_transport_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cached_channel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cached_channel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761d519cfdf06fc3248dbd8973593fb63a287fe5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cached_channel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6bf01788bf9c6f994182368f2f7894713a5198e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/communicator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/communicator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1487885d850ed485b45c94712000f7eed54b5bd6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/communicator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/conftest.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/conftest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e06c2d7fe8c756aa69be1152d2d5dcdcaa9374 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/conftest.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cpu_communicator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cpu_communicator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c037cff4a3aeaa5008c14515a13b2f60ff4932 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/cpu_communicator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/intra_process_channel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/intra_process_channel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ff531599c59ddb4d82294da54d06743d7afc706 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/intra_process_channel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/nccl_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/nccl_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff02eada9612b5e8a85608e41cffd5792f56172c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/nccl_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/serialization_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/serialization_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de392f4f9c4f3f184359d1a0b3ce4a1ea4ef318 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/serialization_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/torch_tensor_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/torch_tensor_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97946124e225e61e52f4abecc773e687e15bc46a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/torch_tensor_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7d42bd49963591edfd60c5dcc446cfc040b839e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/channel/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c395422d5daef00f506ac203d9bbb7328a3ce9e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py @@ -0,0 +1,683 @@ +import asyncio +import concurrent +import sys +import threading +import time +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import ray +import ray.exceptions +from ray.experimental.channel.communicator import Communicator +from ray.experimental.channel.serialization_context import _SerializationContext +from ray.util.annotations import DeveloperAPI, PublicAPI + +# The context singleton on this process. +_default_context: "Optional[ChannelContext]" = None +_context_lock = threading.Lock() + +if TYPE_CHECKING: + import torch + + +def retry_and_check_interpreter_exit(f: Callable[[], None]) -> bool: + """This function is only useful when f contains channel read/write. + + Keep retrying channel read/write inside `f` and check if interpreter exits. + It is important in case the read/write happens in a separate thread pool. + See https://github.com/ray-project/ray/pull/47702 + + f should a function that doesn't receive any input and return nothing. + """ + exiting = False + while True: + try: + f() + break + except ray.exceptions.RayChannelTimeoutError: + if sys.is_finalizing(): + # Interpreter exits. We should ignore the error and + # stop reading so that the thread can join. + exiting = True + break + + return exiting + + +# Holds the input arguments for Compiled Graph +@PublicAPI(stability="alpha") +class CompiledDAGArgs(NamedTuple): + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + + +@PublicAPI(stability="alpha") +class ChannelOutputType: + def register_custom_serializer(self) -> None: + """ + Register any custom serializers needed to pass data of this type. This + method should be run on the reader(s) and writer of a channel, which + are the driver and/or Ray actors. + + NOTE: When custom serializers are registered with Ray, the registered + deserializer is shipped with the serialized value and used on the + receiving end. Therefore, the deserializer function should *not* + capture state that is meant to be worker-local, such as the worker's + default device. Instead, these should be extracted from the + worker-local _SerializationContext. + """ + pass + + def create_channel( + self, + writer: Optional["ray.actor.ActorHandle"], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + driver_actor_id: Optional[str] = None, + ) -> "ChannelInterface": + """ + Instantiate a ChannelInterface class that can be used + to pass data of this type. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + driver_actor_id: If this is a CompositeChannel that is read by a driver and + that driver is an actual actor, this will be the actor ID of that + driver actor. + Returns: + A ChannelInterface that can be used to pass data + of this type. + """ + raise NotImplementedError + + def requires_nccl(self) -> bool: + # By default, channels do not require NCCL. + return False + + def get_custom_communicator(self) -> Optional[Communicator]: + """ + Return the custom NCCL group if one is specified. + """ + if self._contains_type is not None: + return self._contains_type.get_custom_nccl_group() + return None + + def set_communicator_id(self, group_id: str) -> None: + raise NotImplementedError + + +@DeveloperAPI +@dataclass +class ChannelContext: + serialization_context = _SerializationContext() + _torch_available: Optional[bool] = None + _torch_device: Optional["torch.device"] = None + _current_stream: Optional["torch.cuda.Stream"] = None + + def __init__(self): + # Used for the torch.Tensor NCCL transport. + self.communicators: Dict[str, "Communicator"] = {} + + @staticmethod + def get_current() -> "ChannelContext": + """Get or create a singleton context. + + If the context has not yet been created in this process, it will be + initialized with default settings. + """ + + global _default_context + + with _context_lock: + if _default_context is None: + _default_context = ChannelContext() + + return _default_context + + @property + def torch_available(self) -> bool: + """ + Check if torch package is available. + """ + if self._torch_available is not None: + return self._torch_available + + try: + import torch # noqa: F401 + except ImportError: + self._torch_available = False + return False + self._torch_available = True + return True + + @property + def torch_device(self) -> "torch.device": + if self._torch_device is None: + + if not ray.get_gpu_ids(): + import torch + + # torch_utils defaults to returning GPU 0 if no GPU IDs were assigned + # by Ray. We instead want the default to be CPU. + self._torch_device = torch.device("cpu") + + from ray.air._internal import torch_utils + + self._torch_device = torch_utils.get_devices()[0] + + return self._torch_device + + def set_torch_device(self, device: "torch.device"): + self._torch_device = device + + +@PublicAPI(stability="alpha") +class ChannelInterface: + """ + Abstraction for a transport between a writer actor and some number of + reader actors. + """ + + def __init__( + self, + writer: Optional[ray.actor.ActorHandle], + readers: List[Optional[ray.actor.ActorHandle]], + typ: Optional["ChannelOutputType"], + ): + """ + Create a channel that can be read and written by a Ray driver or actor. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + readers: The actors that may read from the channel. None signifies + the driver. + typ: Type information about the values passed through the channel. + """ + pass + + def ensure_registered_as_writer(self): + """ + Check whether the process is a valid writer. This method must be idempotent. + """ + raise NotImplementedError + + def ensure_registered_as_reader(self): + """ + Check whether the process is a valid reader. This method must be idempotent. + """ + raise NotImplementedError + + def write(self, value: Any, timeout: Optional[float] = None) -> None: + """ + Write a value to the channel. + + Blocks if there are still pending readers for the previous value. The + writer may not write again until the specified number of readers have + read the value. + + Args: + value: The value to write. + timeout: The maximum time in seconds to wait to write the value. + None means using default timeout, 0 means immediate timeout + (immediate success or timeout without blocking), -1 means + infinite timeout (block indefinitely). + """ + raise NotImplementedError + + def read(self, timeout: Optional[float] = None) -> Any: + """ + Read the latest value from the channel. This call will block until a + value is available to read. + + Subsequent calls to read() may *block* if the deserialized object is + zero-copy (e.g., bytes or a numpy array) *and* the object is still in scope. + + Args: + timeout: The maximum time in seconds to wait to read the value. + None means using default timeout, 0 means immediate timeout + (immediate success or timeout without blocking), -1 means + infinite timeout (block indefinitely). + + Returns: + Any: The deserialized value. If the deserialized value is an + Exception, it will be returned directly instead of being raised. + """ + raise NotImplementedError + + def close(self) -> None: + """ + Close this channel. This method must not block and it must be made + idempotent. Any existing values in the channel may be lost after the + channel is closed. + """ + raise NotImplementedError + + +# Interfaces for channel I/O. +@DeveloperAPI +class ReaderInterface: + def __init__( + self, + input_channels: List[ChannelInterface], + ): + assert isinstance(input_channels, list) + for chan in input_channels: + assert isinstance(chan, ChannelInterface) + + self._input_channels = input_channels + self._closed = False + self._num_reads = 0 + + # A list of channels that were not read in the last `read` call + # because the reader returned immediately when a RayTaskError was found. + # These channels must be consumed before the next read to avoid reading + # stale data remaining from the last read. + self._leftover_channels: List[ChannelInterface] = [] + + def get_num_reads(self) -> int: + return self._num_reads + + def start(self): + raise NotImplementedError + + def _read_list(self, timeout: Optional[float] = None) -> List[Any]: + """ + Read a list of values from this reader. + + Args: + timeout: The maximum time in seconds to wait for reading. + None means using default timeout which is infinite, 0 means immediate + timeout (immediate success or timeout without blocking), -1 means + infinite timeout (block indefinitely). + + """ + raise NotImplementedError + + def read(self, timeout: Optional[float] = None) -> List[Any]: + """ + Read from this reader. + + Args: + timeout: The maximum time in seconds to wait for reading. + None means using default timeout, 0 means immediate timeout + (immediate success or timeout without blocking), -1 means + infinite timeout (block indefinitely). + """ + assert ( + timeout is None or timeout >= 0 or timeout == -1 + ), "Timeout must be non-negative or -1." + outputs = self._read_list(timeout) + self._num_reads += 1 + return outputs + + def close(self) -> None: + self._closed = True + for channel in self._input_channels: + channel.close() + + def _consume_leftover_channels_if_needed( + self, timeout: Optional[float] = None + ) -> None: + # Consume the channels that were not read in the last `read` call because a + # RayTaskError was returned from another channel. If we don't do this, the + # read operation will read stale versions of the object refs. + # + # If a RayTaskError is returned from a leftover channel, it will be ignored. + # If a read operation times out, a RayChannelTimeoutError exception will be + # raised. + # + # TODO(kevin85421): Currently, a DAG with NCCL channels and fast fail enabled + # may not be reusable. Revisit this in the future. + for c in self._leftover_channels: + start_time = time.monotonic() + c.read(timeout) + if timeout is not None: + timeout -= time.monotonic() - start_time + timeout = max(timeout, 0) + self._leftover_channels = [] + + +@DeveloperAPI +class SynchronousReader(ReaderInterface): + def __init__( + self, + input_channels: List[ChannelInterface], + ): + super().__init__(input_channels) + + def start(self): + pass + + def _read_list(self, timeout: Optional[float] = None) -> List[Any]: + self._consume_leftover_channels_if_needed(timeout) + # We don't update `remaining_timeout` here because in the worst case, + # consuming leftover channels requires reading all `_input_channels`, + # which users expect to complete within the original `timeout`. Updating + # `remaining_timeout` could cause unexpected timeouts in subsequent read + # operations. + + # It is a special case that `timeout` is set to 0, which means + # read once for each channel. + is_zero_timeout = timeout == 0 + + results = [None for _ in range(len(self._input_channels))] + if timeout is None or timeout == -1: + timeout = float("inf") + timeout_point = time.monotonic() + timeout + remaining_timeout = timeout + + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + iteration_timeout = ctx.read_iteration_timeout + + # Iterate over the input channels with a shorter timeout for each iteration + # to detect RayTaskError early and fail fast. + done_channels = set() + while len(done_channels) < len(self._input_channels): + for i, c in enumerate(self._input_channels): + if c in done_channels: + continue + try: + result = c.read(min(remaining_timeout, iteration_timeout)) + results[i] = result + done_channels.add(c) + if isinstance(result, ray.exceptions.RayTaskError): + # If we raise an exception immediately, it will be considered + # as a system error which will cause the execution loop to + # exit. Hence, return immediately and let `_process_return_vals` + # handle the exception. + # + # Return a list of RayTaskError so that the caller will not + # get an undefined partial result. + self._leftover_channels = [ + c for c in self._input_channels if c not in done_channels + ] + return [result for _ in range(len(self._input_channels))] + except ray.exceptions.RayChannelTimeoutError as e: + remaining_timeout = max(timeout_point - time.monotonic(), 0) + if remaining_timeout == 0: + raise e + continue + + remaining_timeout = max(timeout_point - time.monotonic(), 0) + if remaining_timeout == 0 and not is_zero_timeout: + raise ray.exceptions.RayChannelTimeoutError( + f"Cannot read all channels within {timeout} seconds" + ) + return results + + def release_channel_buffers(self, timeout: Optional[float] = None) -> None: + for c in self._input_channels: + start_time = time.monotonic() + c.release_buffer(timeout) + if timeout is not None: + timeout -= time.monotonic() - start_time + timeout = max(timeout, 0) + + +@DeveloperAPI +class AwaitableBackgroundReader(ReaderInterface): + """ + Asyncio-compatible channel reader. + + The reader is constructed with an async queue of futures whose values it + will fulfill. It uses a threadpool to execute the blocking calls to read + from the input channel(s). + """ + + def __init__( + self, + input_channels: List[ChannelInterface], + fut_queue: asyncio.Queue, + ): + super().__init__(input_channels) + self._fut_queue = fut_queue + self._background_task = None + self._background_task_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="channel.AwaitableBackgroundReader" + ) + + def start(self): + self._background_task = asyncio.ensure_future(self.run()) + + def _run(self): + # Give it a default timeout 60 seconds to release the buffers + # of the channels that were not read in the last `read` call. + self._consume_leftover_channels_if_needed(60) + + results = [None for _ in range(len(self._input_channels))] + + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + iteration_timeout = ctx.read_iteration_timeout + + done_channels = set() + while len(done_channels) < len(self._input_channels): + for i, c in enumerate(self._input_channels): + if c in done_channels: + continue + try: + result = c.read(iteration_timeout) + results[i] = result + done_channels.add(c) + if isinstance(result, ray.exceptions.RayTaskError): + self._leftover_channels = [ + c for c in self._input_channels if c not in done_channels + ] + return [result for _ in range(len(self._input_channels))] + except ray.exceptions.RayChannelTimeoutError: + pass + if sys.is_finalizing(): + return results + return results + + async def run(self): + loop = asyncio.get_running_loop() + while not self._closed: + res, fut = await asyncio.gather( + loop.run_in_executor(self._background_task_executor, self._run), + self._fut_queue.get(), + return_exceptions=True, + ) + + # Set the result on the main thread. + fut.set_result(res) + # NOTE(swang): If the object is zero-copy deserialized, then it + # will stay in scope as long as ret and the future are in scope. + # Therefore, we must delete both here after fulfilling the future. + del res + del fut + + def close(self): + super().close() + self._background_task_executor.shutdown(cancel_futures=True) + self._background_task.cancel() + + +@DeveloperAPI +class WriterInterface: + def __init__( + self, + output_channels: List[ChannelInterface], + output_idxs: List[Optional[Union[int, str]]], + is_input=False, + ): + """ + Initialize the writer. + + Args: + output_channels: The output channels to write to. + output_idxs: The indices of the values to write to each channel. + This has the same length as `output_channels`. If `is_input` is True, + the index can be an integer or a string to retrieve the corresponding + value from `args` or `kwargs` in the DAG's input. If `is_input` + is False, the entire value is written if the index is None. Otherwise, + the value at the specified index in the tuple is written. + is_input: Whether the writer is DAG input writer or not. + """ + + assert len(output_channels) == len(output_idxs) + self._output_channels = output_channels + self._output_idxs = output_idxs + self._closed = False + self._num_writes = 0 + self._is_input = is_input + + def get_num_writes(self) -> int: + return self._num_writes + + def start(self): + raise NotImplementedError() + + def write(self, val: Any, timeout: Optional[float] = None) -> None: + """ + Write the value. + + Args: + timeout: The maximum time in seconds to wait for writing. 0 means + immediate timeout (immediate success or timeout without blocking). + -1 and None mean infinite timeout (blocks indefinitely). + """ + raise NotImplementedError() + + def close(self) -> None: + self._closed = True + for channel in self._output_channels: + channel.close() + + +def _adapt(raw_args: Any, key: Optional[Union[int, str]], is_input: bool): + """ + Adapt the raw arguments to the key. If `is_input` is True, this method will + retrieve the value from the input data for an InputAttributeNode. Otherwise, it + will retrieve either a partial value or the entire value from the output of + a ClassMethodNode. + + Args: + raw_args: The raw arguments to adapt. + key: The key to adapt. + is_input: Whether the writer is DAG input writer or not. + """ + if is_input: + if not isinstance(raw_args, CompiledDAGArgs): + # Fast path for a single input. + return raw_args + else: + args = raw_args.args + kwargs = raw_args.kwargs + + if isinstance(key, int): + return args[key] + else: + return kwargs[key] + else: + if key is not None: + return raw_args[key] + else: + return raw_args + + +@DeveloperAPI +class SynchronousWriter(WriterInterface): + def start(self): + for channel in self._output_channels: + channel.ensure_registered_as_writer() + + def write(self, val: Any, timeout: Optional[float] = None) -> None: + # If it is an exception, there's only 1 return value. + # We have to send the same data to all channels. + if isinstance(val, Exception): + if len(self._output_channels) > 1: + val = tuple(val for _ in range(len(self._output_channels))) + + if not self._is_input: + if len(self._output_channels) > 1: + if not isinstance(val, tuple): + raise ValueError( + f"Expected a tuple of {len(self._output_channels)} outputs, " + f"but got {type(val)}" + ) + if len(val) != len(self._output_channels): + raise ValueError( + f"Expected {len(self._output_channels)} outputs, but got " + f"{len(val)} outputs" + ) + + for i, channel in enumerate(self._output_channels): + idx = self._output_idxs[i] + val_i = _adapt(val, idx, self._is_input) + channel.write(val_i, timeout) + self._num_writes += 1 + + +@DeveloperAPI +class AwaitableBackgroundWriter(WriterInterface): + def __init__( + self, + output_channels: List[ChannelInterface], + output_idxs: List[Optional[Union[int, str]]], + is_input=False, + ): + super().__init__(output_channels, output_idxs, is_input=is_input) + self._queue = asyncio.Queue() + self._background_task = None + self._background_task_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="channel.AwaitableBackgroundWriter" + ) + + def start(self): + for channel in self._output_channels: + channel.ensure_registered_as_writer() + self._background_task = asyncio.ensure_future(self.run()) + + def _run(self, res): + if not self._is_input: + if len(self._output_channels) > 1: + if not isinstance(res, tuple): + raise ValueError( + f"Expected a tuple of {len(self._output_channels)} outputs, " + f"but got {type(res)}" + ) + if len(res) != len(self._output_channels): + raise ValueError( + f"Expected {len(self._output_channels)} outputs, but got " + f"{len(res)} outputs" + ) + + for i, channel in enumerate(self._output_channels): + idx = self._output_idxs[i] + res_i = _adapt(res, idx, self._is_input) + exiting = retry_and_check_interpreter_exit( + lambda: channel.write(res_i, timeout=1) + ) + if exiting: + break + + async def run(self): + loop = asyncio.get_event_loop() + while True: + res = await self._queue.get() + await loop.run_in_executor(self._background_task_executor, self._run, res) + + async def write(self, val: Any) -> None: + if self._closed: + raise RuntimeError("DAG execution cancelled") + await self._queue.put(val) + self._num_writes += 1 + + def close(self): + self._background_task.cancel() + super().close() diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/communicator.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/communicator.py new file mode 100644 index 0000000000000000000000000000000000000000..587f06256dcbec6936ecf99333e23c31a6652e84 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/communicator.py @@ -0,0 +1,158 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple + +import ray +from ray.experimental.util.types import ReduceOp +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import cupy as cp + import torch + + +# Signature for a torch.Tensor allocator is: +# (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor. +TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"] + + +@DeveloperAPI +class Communicator(ABC): + """ + Communicator for a group of Compiled Graph actors on Nvidia GPU. + + The Compiled Graph execution leverages this internally to support communication + between actors in the group. + """ + + @abstractmethod + def initialize(self, rank: int) -> None: + """ + Initialize the communicator from the actor. + + This is called once by Compiled Graph on each actor to initialize the + communicator,before any other methods. + + Args: + rank: The rank of this actor in the group. + """ + raise NotImplementedError + + @abstractmethod + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + """ + Get handles of all actors for this communicator group. + """ + raise NotImplementedError + + @abstractmethod + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + """ + Return the given actor's rank in the group. + + Args: + actor: The actor handle to look up. + """ + raise NotImplementedError + + @abstractmethod + def get_self_rank(self) -> Optional[int]: + """ + Return this actor's rank. + """ + raise NotImplementedError + + def get_world_size(self) -> int: + """ + Return the number of ranks in the group. + """ + raise NotImplementedError + + @abstractmethod + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + """ + Send a torch.Tensor to a peer. + + This returns when the send kernel has been queued, but the kernel may + not have completed. Therefore, the caller should ensure that there are + no concurrent writes to the sent `value` until the send has finished. + + Args: + value: The torch.Tensor to send. It should already be on this + actor's default device. + peer_rank: The rank of the actor to send to. + """ + raise NotImplementedError + + @abstractmethod + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + """ + Receive a torch.Tensor from a peer and synchronize. + + After this call returns, the receive buffer is safe to read from from + any stream. An RayChannelError will be raised if an error occurred (e.g., + remote actor died), and the buffer is not safe to read. + + Args: + shape: The shape of the tensor to receive. + dtype: The dtype of the tensor to receive. + peer_rank: The rank of the actor to receive from. + allocator: A function to allocate the tensor to receive into. + """ + raise NotImplementedError + + @property + @abstractmethod + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + """ + Return the cuda stream used for receiving tensors. + """ + raise NotImplementedError + + @property + @abstractmethod + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + """ + Return the cuda stream used for sending tensors. + """ + raise NotImplementedError + + @abstractmethod + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + """ + Collectively allreduce the tensor across the group. + + Args: + send_buf: The input torch.tensor to allreduce. It should already be + on this actor's default device. + recv_buf: The output torch.tensor to store the allreduce result. + op: The reduce operation. + """ + raise NotImplementedError + + @abstractmethod + def destroy() -> None: + """ + Destroy the GPU communicator. + + Any destruction and cleanup for the GPU communicator should be + done here. Implement as a noop is nothing is needed. + """ + raise NotImplementedError + + @abstractmethod + def get_transport_name() -> str: + """ + Return the type of the communicator (gpu or cpu). + """ + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/conftest.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..2349d334f71e45197996680ab7fe506ee302c995 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/conftest.py @@ -0,0 +1,164 @@ +import asyncio +from collections import defaultdict +from typing import Optional, Tuple +from unittest import mock + +import torch + +import ray +import ray.dag +import ray.experimental.channel as ray_channel +from ray.experimental.channel.communicator import TorchTensorAllocator + + +@ray.remote(num_cpus=0) +class Barrier: + """ + Barrier that blocks the given number of actors until all actors have + reached the barrier. This is used to mock out blocking NCCL ops. + """ + + def __init__(self, num_actors=2): + self.num_actors = num_actors + self.condition = asyncio.Condition() + # Buffer for the data that is "sent" between the actors, each entry is + # one p2p op. + self.data = {} + # Buffer for the number of actors seen, each entry is one p2p op. + self.num_actors_seen = defaultdict(int) + + async def wait(self, idx: int, data=None): + """ + Wait at barrier until all actors have sent `idx`. One actor should + provide `data`, and this value will be returned by this method for all + other actors. + """ + async with self.condition: + if data is not None: + assert idx not in self.data, (self.data, self.num_actors_seen) + self.data[idx] = data + self.num_actors_seen[idx] += 1 + + if self.num_actors_seen[idx] == self.num_actors: + # Wake up all tasks waiting on this condition. + self.condition.notify_all() + else: + await self.condition.wait_for( + lambda: self.num_actors_seen[idx] == self.num_actors + ) + + if data is None: + data = self.data[idx] + + return data + + +class MockCudaStream: + def __init__(self): + self.cuda_stream = 0 + + +class MockNcclGroup(ray_channel.nccl_group._NcclGroup): + """ + Mock the internal _NcclGroup to use a barrier actor instead of a NCCL group + for communication. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # We use the op index to synchronize the sender and receiver at the + # barrier. + self.num_ops = defaultdict(int) + self.barriers = set() + + def send(self, tensor: torch.Tensor, peer_rank: int): + # "Send" the tensor to the barrier actor. + barrier_key = sorted([self.get_self_rank(), peer_rank]) + barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}" + barrier = ray.get_actor(name=barrier_key) + self.barriers.add(barrier) + ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor)) + self.num_ops[barrier_key] += 1 + + def recv( + self, + shape: Tuple[int], + dtype: torch.dtype, + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ): + # "Receive" the tensor from the barrier actor. + barrier_key = sorted([self.get_self_rank(), peer_rank]) + barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}" + barrier = ray.get_actor(name=barrier_key) + self.barriers.add(barrier) + received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key])) + assert ( + allocator is not None + ), "torch tensor allocator is required for MockNcclGroup" + buf = allocator(shape, dtype) + buf[:] = received_tensor[:] + self.num_ops[barrier_key] += 1 + return buf + + def destroy(self) -> None: + for barrier in self.barriers: + ray.kill(barrier) + + +def start_nccl_mock(): + """ + Patch methods that require CUDA. + """ + # Mock cupy dependencies. + nccl_mock = mock.MagicMock() + nccl_mock.nccl.get_unique_id.return_value = 0 + cp_patcher = mock.patch.dict( + "sys.modules", + { + "cupy.cuda": nccl_mock, + "cupy": mock.MagicMock(), + "ray.util.collective.collective_group": mock.MagicMock(), + }, + ) + cp_patcher.start() + + # Mock send/recv ops to use an actor instead of NCCL. + ray.experimental.channel.torch_tensor_nccl_channel._NcclGroup = MockNcclGroup + + # PyTorch mocks. + stream_patcher = mock.patch( + "torch.cuda.current_stream", new_callable=lambda: MockCudaStream + ) + stream_patcher.start() + new_stream_patcher = mock.patch( + "torch.cuda.Stream", new_callable=lambda: MockCudaStream + ) + new_stream_patcher.start() + tensor_patcher = mock.patch("torch.Tensor.device", torch.device("cuda")) + tensor_patcher.start() + tensor_patcher = mock.patch("torch.Tensor.is_cuda", True) + tensor_patcher.start() + tensor_allocator_patcher = mock.patch( + "ray.experimental.channel.torch_tensor_nccl_channel._torch_zeros_allocator", + lambda shape, dtype: torch.zeros(shape, dtype=dtype), + ) + tensor_allocator_patcher.start() + + ctx = ray_channel.ChannelContext.get_current() + ctx.set_torch_device(torch.device("cuda")) + + +class TracedChannel(ray_channel.shared_memory_channel.Channel): + """ + Patched Channel that records all write ops for testing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.ops = [] + + def write(self, *args, **kwargs): + self.ops.append((args, kwargs)) + return super().write(*args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/cpu_communicator.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/cpu_communicator.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff14b3bfd1a1bb41a663b65f3e3f69e8f82d2a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/cpu_communicator.py @@ -0,0 +1,186 @@ +import asyncio +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import ray +from ray.experimental.channel.communicator import ( + Communicator, + ReduceOp, + TorchTensorAllocator, +) + +if TYPE_CHECKING: + import torch + + +@ray.remote(num_cpus=0) +class CPUCommBarrier: + """ + Barrier actor that blocks the given number of actors until all actors have + reached the Barrier. + + p2p operations are not done here (completed via shared memory channel). + """ + + def __init__(self, num_actors: int): + self.num_actors = num_actors + self.condition = asyncio.Condition() + # Stores the data for each collective operation + self.collective_data: Dict[int, List["torch.Tensor"]] = defaultdict(list) + # Stores the shape of data for each collective operation + self.collective_data_shape: Dict[int, "torch.Tensor.type"] = {} + # Buffer for the number of actors seen + self.num_actors_seen = defaultdict(int) + # Number of actors who have read the result, and are about to exit the function. + # State is kept so we only garbage collect after the last actor has read the + # relevant data. + self.num_actors_read = defaultdict(int) + + async def wait_collective(self, op_id: int, data: "torch.Tensor", op: ReduceOp): + """ + Wait at the communicator until all actors have sent `op_id` and `data`. + Once data from all actors is received, execute the collective `op` + on the communicator actor and return the result. + """ + async with self.condition: + self.collective_data[op_id].append(data) + self.num_actors_seen[op_id] += 1 + + if self.num_actors_seen[op_id] == self.num_actors: + # Apply the collective operation across all gathered tensors + data = self._apply_op(op, self.collective_data[op_id]) + self.collective_data[op_id] = data + self.condition.notify_all() + else: + await self.condition.wait_for( + lambda: self.num_actors_seen[op_id] == self.num_actors + ) + + data = self.collective_data[op_id] + self.num_actors_read[op_id] += 1 + + if self.num_actors_read[op_id] == self.num_actors: + del self.collective_data[op_id] + del self.num_actors_seen[op_id] + del self.num_actors_read[op_id] + + return data + + def _apply_op(self, op: ReduceOp, tensors: List["torch.Tensor"]) -> "torch.Tensor": + """Apply the specified reduction operation across a list of tensors.""" + + result = tensors[0].clone() + if op == ReduceOp.SUM: + for tensor in tensors[1:]: + result += tensor + elif op == ReduceOp.PRODUCT: + for tensor in tensors[1:]: + result *= tensor + elif op == ReduceOp.MAX: + for tensor in tensors[1:]: + result = torch.max(result, tensor) + elif op == ReduceOp.MIN: + for tensor in tensors[1:]: + result = torch.min(result, tensor) + elif op == ReduceOp.AVG: + result = sum(tensors) / len(tensors) + else: + raise ValueError(f"Operation {op} not supported") + return result + + +class CPUCommunicator(Communicator): + """ + Uses a CPU-based communicator actor instead of a NCCL group. + """ + + def __init__(self, world_size: int, actor_handles: List["ray.actor.ActorHandle"]): + """We use the op index to synchronize the sender and receiver at the + communicator actor.""" + self._world_size = world_size + self._actor_handles = actor_handles + self.num_ops = defaultdict(int) + + # For collective communication, one barrier will be created for + # each unique group of participants. + self.barriers = set() + self._rank = None + + def send(self, tensor: "torch.Tensor", peer_rank: int): + # p2p operations are done via a shared memory channel, initialized in + # `create_channel` of `TorchTensorType` + pass + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ): + # See the comment on `send` + pass + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ): + all_ranks = [ + self.get_rank(actor_handle) for actor_handle in self.get_actor_handles() + ] + barrier_key = "barrier-collective-" + "-".join(map(str, sorted(all_ranks))) + barrier = CPUCommBarrier.options(name=barrier_key, get_if_exists=True).remote( + self._world_size + ) + self.barriers.add(barrier) + + result = ray.get( + barrier.wait_collective.remote(self.num_ops[barrier_key], send_buf, op) + ) + assert recv_buf is not None, "Receiving buffer required for CPUCommunicator" + recv_buf[:] = result[:] + self.num_ops[barrier_key] += 1 + + def destroy(self) -> None: + for barrier in self.barriers: + ray.kill(barrier) + + def initialize(self, rank: int) -> None: + self._rank = rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + """ + Return the given actor's rank in the CPU communicator. + + Args: + actor: The actor handle to look up. + """ + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the CPUCommunicator group.") + return rank + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_world_size(self) -> int: + """ + Return the number of ranks in the CPU communicator. + """ + return self._world_size + + def get_transport_name(self) -> str: + return "cpu" + + def recv_stream(self): + raise NotImplementedError + + def send_stream(self): + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/nccl_group.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/nccl_group.py new file mode 100644 index 0000000000000000000000000000000000000000..d44c680d1089fda6a2b9e499d9df060ea719e704 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/nccl_group.py @@ -0,0 +1,318 @@ +import logging +from types import ModuleType +from typing import TYPE_CHECKING, List, Optional, Tuple + +import ray +from ray.exceptions import RayChannelError +from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator +from ray.experimental.util.types import ReduceOp + +if TYPE_CHECKING: + import cupy as cp + import torch + + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + + +class _NcclGroup(Communicator): + """ + Represents an actor's NCCL communicator. This is the default NCCL communicator + to be used in Compiled Graph if a custom communicator is not provided. + + This class is not thread-safe. + """ + + def __init__( + self, + world_size: int, + comm_id: int, + rank: Optional[int], + actor_handles: List["ray.actor.ActorHandle"], + cuda_stream: Optional[int], + use_communication_streams: bool = False, + ): + """ + Initialize a NCCL communicator that can be used to communicate p2p with + other GPU actors. + + This method blocks until the same call has been made on all other + actors in the group, with the same arguments for world_size and + comm_id. + + NOTE: A concurrent NCCL group can coexist with this one but using the + two groups concurrently on different CUDA streams may cause deadlock. + See + https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html + #using-multiple-nccl-communicators-concurrently. + + If the user can guarantee that all involved actors execute the same ops + in the same order, then the other NCCL group should use the given + `cuda_stream`, and there will not be a concurrency issue. Otherwise, + the other stream needs to synchronize with the given `cuda_stream` + before and after it launches NCCL ops, e.g., at the beginning and end + of a DAG task. + + Args: + world_size: The number of participating actors/devices. + comm_id: A unique communicator ID returned by + cupy.cuda.nccl.get_unique_id(). + rank: The rank of this actor. If None, then the caller is not a + participant of the NCCL group. + actor_handles: A list of actor handles, in rank order. + cuda_stream: A raw CUDA stream to dispatch NCCL ops to. If rank is + specified, then this must be specified too. + use_communication_streams: Whether to use dedicated send and recv + streams for communication. If True, communication and computation + can be overlapped to improve performance. + """ + self._world_size = world_size + self._rank: Optional[int] = rank + self.nccl_util: Optional[ModuleType] = None + self._actor_handles = actor_handles + self._use_communication_streams = use_communication_streams + + if rank is not None: + assert ray.get_gpu_ids(), "NCCL actor has no GPUs assigned" + assert cuda_stream is not None, "NCCL actor must specify cuda_stream" + + expected_rank = self.get_rank(ray.get_runtime_context().current_actor) + assert ( + rank == expected_rank + ), f"NCCL actor's rank {rank} does not match expected rank {expected_rank}" + + from ray.util.collective.collective_group import nccl_util + + self.nccl_util = nccl_util + self._comm = self.nccl_util.NcclCommunicator(world_size, comm_id, rank) + else: + # Driver does not have a rank. + self._comm = None + + self._cuda_stream: Optional["cp.cuda.ExternalStream"] = None + self._send_stream: Optional["cp.cuda.ExternalStream"] = None + self._recv_stream: Optional["cp.cuda.ExternalStream"] = None + if cuda_stream is not None: + assert rank is not None, "NCCL actor has no rank assigned" + + import cupy as cp + + from ray.air._internal import torch_utils + + # TODO(swang): Allow default device to be overridden. + device = torch_utils.get_devices()[0] + self._cuda_stream = cp.cuda.ExternalStream( + cuda_stream, device_id=device.index + ) + + if use_communication_streams: + import torch + + self._send_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + self._recv_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + else: + self._send_stream = self._cuda_stream + self._recv_stream = self._cuda_stream + + self._closed = False + + def initialize(self, rank: int) -> None: + # No additional initialization is needed. + pass + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + """ + Return the given actor's rank in the NCCL communicator. + + Args: + actor: The actor handle to look up. + """ + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_self_rank(self) -> Optional[int]: + """ + Return this actor's rank. + """ + return self._rank + + def get_world_size(self) -> int: + """ + Return the number of ranks in the NCCL communicator. + """ + return self._world_size + + def send(self, buf: "torch.Tensor", peer_rank: int) -> None: + """ + Send a torch.Tensor to a peer. + + This returns when the send kernel has been queued, but the kernel may + not have completed. Therefore, the caller should ensure that there are + no concurrent writes to the sent `buf` until the send has finished. + That is, either all writes should be submitted on the current stream + (self._cuda_stream) or, if on a different stream, that stream should + synchronize with the current stream. + + Args: + buf: The torch.Tensor to send. It should already be on this + actor's default device. + peer_rank: The rank of the actor to send to. + """ + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + + if self._use_communication_streams: + # We observed that if all recv/compute/send operations run on GPU, + # since there is no synchronization, the CPU execution loop may be + # far ahead of the GPU operations and lead to runtime failures. + # To avoid that, we synchronize on the send stream. + # TODO(rui): find a better approach + self._send_stream.synchronize() + + # TODO(swang): Handle send/recv async NCCL errors such as network + # failures. + self._comm.send( + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), + peer_rank, + self._send_stream.ptr, + ) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator=Optional[TorchTensorAllocator], + ) -> "torch.Tensor": + """ + Receive a torch.Tensor from a peer and synchronize the current stream. + + After this call returns, the receive buffer is safe to read from from + any stream. An RayChannelError will be raised if an error occurred (e.g., + remote actor died), and the buffer is not safe to read. + + Args: + buf: The torch.Tensor to receive into. This buffer is safe to read + peer_rank: The rank of the actor to receive from. + """ + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + assert allocator is not None, "NCCL group requires a tensor allocator" + buf = allocator(shape, dtype) + + if self._use_communication_streams: + # We observed that if all recv/compute/send operations run on GPU, + # since there is no synchronization, the CPU execution loop may be + # far ahead of the GPU operations and lead to runtime failures. + # To avoid that, we synchronize on the recv stream. + # TODO(rui): find a better approach + self._recv_stream.synchronize() + + self._comm.recv( + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), + peer_rank, + self._recv_stream.ptr, + ) + else: + self._comm.recv( + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), + peer_rank, + self._recv_stream.ptr, + ) + + # Buffer values are undefined if NCCL ops are aborted. Therefore, we + # need to synchronize here and check that the channel is still open to + # ensure that the receive buffer is valid. + # TODO(swang): Avoid CUDA synchronization. + self._cuda_stream.synchronize() + + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + return buf + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ): + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + + assert send_buf.dtype == recv_buf.dtype, ( + "Ray Compiled Graph derived the dtype of recv_buf from send_buf, " + "so send_buf and recv_buf must have the same dtype. " + "If you see this error, please file an issue at Ray repository." + ) + self._comm.allReduce( + self.nccl_util.get_tensor_ptr(send_buf), + self.nccl_util.get_tensor_ptr(recv_buf), + send_buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(send_buf), + op.value, + self._cuda_stream.ptr, + ) + + # Buffer values are undefined if NCCL ops are aborted. Therefore, we + # need to synchronize here and check that the channel is still open to + # ensure that the receive buffer is valid. + # TODO(swang): Avoid CUDA synchronization. + # TODO(wxdeng): Use check_async_error. + self._cuda_stream.synchronize() + if self._closed: + raise RayChannelError( + "NCCL group has been destroyed during allreduce operation. " + "There may be a dtype mismatch between input tensors from " + "different ranks." + ) + + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._recv_stream + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._send_stream + + def destroy(self) -> None: + """ + Destroy the NCCL group. + """ + if self._closed: + return + + self._closed = True + + if self._comm is not None: + logger.info( + "Destructing NCCL group on actor: " + f"{ray.get_runtime_context().current_actor}" + ) + # Abort *after* setting the _closed flag. This ensures that NCCL + # ops that were blocked on a remote peer will see that the _closed + # flag is True when they exit from the abort. + self._comm.abort() + self._comm.destroy() + + def get_transport_name(self) -> str: + return "nccl" diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py new file mode 100644 index 0000000000000000000000000000000000000000..661437fae3cde1e33fcc88ed6302e3d54e57b843 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py @@ -0,0 +1,775 @@ +import io +import logging +import time +from collections import defaultdict, namedtuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import ray +import ray.exceptions +from ray._raylet import SerializedObject +from ray.experimental.channel import utils +from ray.experimental.channel.common import ChannelInterface, ChannelOutputType +from ray.experimental.channel.intra_process_channel import IntraProcessChannel +from ray.experimental.channel.utils import get_self_actor +from ray.util.annotations import DeveloperAPI, PublicAPI + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + +DEFAULT_MAX_BUFFER_SIZE = int(1e6) # 100 mB +# The min buffer size must be large enough to at least fit an instance of the +# _ResizeChannel class along with any metadata. +MIN_BUFFER_SIZE = int(1000) # 1000 bytes +# For shared memory channels, the default number of buffers per channel to +# allocate. +DEFAULT_NUM_SHM_BUFFERS = 1 + + +def _create_channel_ref( + self, + buffer_size_bytes: int, +) -> "ray.ObjectRef": + """ + Create a channel that can be read and written through Ray's shared-memory + object store. + + The channel has no buffer, so the writer will block until reader(s) have + read the previous value. + + A writer and colocated readers can communicate via a shared memory buffer. + If the readers are remote, then RPC is used to synchronize the writer and + readers' buffers. + + Args: + buffer_size_bytes: The initial buffer size in bytes for messages + that can be passed between tasks in the DAG. The buffers will + be automatically resized if larger messages are written to the + channel. + Returns: + Channel: A wrapper around ray.ObjectRef. + """ + worker = ray._private.worker.global_worker + worker.check_connected() + + value = b"0" * buffer_size_bytes + + try: + object_ref = worker.put_object( + value, owner_address=None, _is_experimental_channel=True + ) + except ray.exceptions.ObjectStoreFullError: + logger.info( + "Put failed since the value was either too large or the " + "store was full of pinned objects." + ) + raise + return object_ref + + +# Compiled Graph maintains 1 reader object reference (also called buffer) per node. +# reader_ref: The object reference. +# ref_owner_actor_id: The actor who created the object reference. +# num_readers: The number of reader actors who reads this object reference. +ReaderRefInfo = namedtuple( + "ReaderRefInfo", ["reader_ref", "ref_owner_actor_id", "num_reader_actors"] +) + + +class _ResizeChannel: + """ + When a channel must be resized, the channel backing store must be resized on both + the writer and the reader nodes. The writer first resizes its own backing store. The + writer then uses an instance of this class as a sentinel value to tell the reader to + resize its own backing store. The class instance is sent through the channel. + """ + + def __init__( + self, + _node_id_to_reader_ref_info: Dict[str, ReaderRefInfo], + ): + """ + Args: + _node_id_to_reader_ref_info: A node id to ReaderRefInfo. + """ + self._node_id_to_reader_ref_info = _node_id_to_reader_ref_info + + +class SharedMemoryType(ChannelOutputType): + def __init__( + self, + *, + buffer_size_bytes: Optional[int] = None, + num_shm_buffers: Optional[int] = None, + ): + """ + Args: + buffer_size_bytes: The initial buffer size in bytes for messages + that can be passed between tasks in the DAG. The buffers will + be automatically resized if larger messages are written to the + channel. + num_shm_buffers: The number of shared memory buffer per channel. + """ + super().__init__() + if buffer_size_bytes is None: + buffer_size_bytes = DEFAULT_MAX_BUFFER_SIZE + self.buffer_size_bytes = buffer_size_bytes + if num_shm_buffers is None: + num_shm_buffers = DEFAULT_NUM_SHM_BUFFERS + self._num_shm_buffers = num_shm_buffers + + def create_channel( + self, + writer: Optional["ray.actor.ActorHandle"], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + driver_actor_id: Optional[str] = None, + ) -> "Channel": + """ + Instantiate a ChannelInterface class that can be used + to pass data of this type. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + driver_actor_id: If this channel is read by a driver and that driver is an + actual actor, this will be the actor ID of that driver actor. + + Returns: + A ChannelInterface that can be used to pass data + of this type. + """ + return CompositeChannel( + writer, + reader_and_node_list, + self._num_shm_buffers, + driver_actor_id, + ) + + +@PublicAPI(stability="alpha") +class Channel(ChannelInterface): + """ + A wrapper type for ray.ObjectRef. Currently supports ray.get but not + ray.wait. + """ + + def __init__( + self, + writer: Optional[ray.actor.ActorHandle], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + typ: Optional[Union[int, SharedMemoryType]] = None, + _writer_node_id: Optional["ray.NodeID"] = None, + _writer_ref: Optional["ray.ObjectRef"] = None, + _node_id_to_reader_ref_info: Optional[Dict[str, ReaderRefInfo]] = None, + _writer_registered: bool = False, + _reader_registered: bool = False, + ): + """ + Create a channel that can be read and written by co-located Ray processes. + + Anyone may write to or read from the channel. The channel has no + buffer, so the writer will block until reader(s) have read the previous + value. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + typ: Type information about the values passed through the channel. + Either an integer representing the max buffer size in bytes + allowed, or a SharedMemoryType. + Returns: + Channel: A wrapper around ray.ObjectRef. + """ + assert len(reader_and_node_list) > 0 + for reader, _ in reader_and_node_list: + assert isinstance(reader, ray.actor.ActorHandle) + + if typ is None: + typ = SharedMemoryType() + elif isinstance(typ, int): + typ = SharedMemoryType(buffer_size_bytes=typ) + + if typ.buffer_size_bytes < MIN_BUFFER_SIZE: + raise ValueError( + "typ.buffer_size_bytes must be at least MIN_BUFFER_SIZE " + f"({MIN_BUFFER_SIZE} bytes)" + ) + + self._writer = writer + self._reader_and_node_list = reader_and_node_list + self._typ = typ + + self._worker = ray._private.worker.global_worker + self._worker.check_connected() + + self._writer_registered = _writer_registered + self._reader_registered = _reader_registered + # NodeID -> ReaderRefInfo on that node. Note that there's only 1 + # reader ref per node. + self._node_id_to_reader_ref_info: Dict[str, ReaderRefInfo] = ( + _node_id_to_reader_ref_info or {} + ) + + # Node ID -> a list of reader actors. + self._node_id_to_readers: Dict[str, "ray.actor.ActorHandle"] = defaultdict(list) + for reader, node_id in self._reader_and_node_list: + self._node_id_to_readers[node_id].append(reader) + + # Number of readers in a local node. + self._num_local_readers = 0 + + if _writer_ref is None: + # We are the writer. Check that the passed handle matches the + # current actor (or it is the driver). + # TODO(swang): Channels must be initially constructed by the writer + # actor, so we shouldn't need to include `writer` in the + # constructor args. Either support Channels being constructed by + # someone other than the writer or remove it from the args. + self_actor = get_self_actor() + assert writer == self_actor + + self._writer_node_id = ( + ray.runtime_context.get_runtime_context().get_node_id() + ) + self._writer_ref = _create_channel_ref(self, typ.buffer_size_bytes) + + self._create_reader_refs(typ.buffer_size_bytes) + else: + assert ( + _writer_node_id is not None + ), "_writer_node_id must also be passed to the constructor when " + "_writer_ref is." + assert _node_id_to_reader_ref_info is not None, ( + "_node_id_to_reader_ref_info must also be passed to the constructor " + "when _writer_ref is." + ) + + self._writer_ref = _writer_ref + self._writer_node_id = _writer_node_id + self._node_id_to_reader_ref_info = _node_id_to_reader_ref_info + + assert self._num_local_readers == 0 + remote_node_exists = False + for node_id, readers in self._node_id_to_readers.items(): + if self.is_local_node(node_id): + self._num_local_readers += len(readers) + else: + remote_node_exists = True + # If remote node exists, we have 1 additional reader that listens + # to object changes and push them to remote nodes. + if remote_node_exists: + self._num_local_readers += 1 + # There must be at least 1 local reader + assert self._num_local_readers > 0 + + self._local_reader_ref: Optional["ray.ObjectRef"] = self._get_local_reader_ref( + self._node_id_to_reader_ref_info + ) + + def _get_local_reader_ref( + self, _node_id_to_reader_ref_info: Dict[str, ReaderRefInfo] + ) -> Optional["ray.ObjectRef"]: + for node_id, reader_ref_info in _node_id_to_reader_ref_info.items(): + if self.is_local_node(node_id): + return reader_ref_info.reader_ref + return None + + def _create_reader_refs( + self, + buffer_size_bytes: int, + ): + # TODO(jhumphri): Free the current reader ref once the reference to it is + # destroyed below. + + for node_id, readers in self._node_id_to_readers.items(): + if not self.is_local_node(node_id): + # Find 1 reader in a remote node to create a reference that's + # shared by all readers. When a new value is written to a reference, + # it is sent to this reference. + reader = readers[0] + fn = reader.__ray_call__ + self._node_id_to_reader_ref_info[node_id] = ReaderRefInfo( + reader_ref=ray.get( + fn.remote(_create_channel_ref, buffer_size_bytes) + ), + ref_owner_actor_id=reader._actor_id, + num_reader_actors=len(readers), + ) + else: + writer_id = ray.ActorID.nil() + if self._writer is not None: + writer_id = self._writer._actor_id + self._node_id_to_reader_ref_info[node_id] = ReaderRefInfo( + reader_ref=self._writer_ref, + ref_owner_actor_id=writer_id, + num_reader_actors=len(readers), + ) + # There must be only 1 node reader reference per node. + assert len(self._node_id_to_reader_ref_info) == len(self._node_id_to_readers) + + # We need to register the new writer_ref. + self._writer_registered = False + self.ensure_registered_as_writer() + + @staticmethod + def is_local_node(node_id): + return ray.runtime_context.get_runtime_context().get_node_id() == node_id + + def ensure_registered_as_writer(self) -> None: + if self._writer_registered: + return + + if not self.is_local_node(self._writer_node_id): + raise ValueError( + "`ensure_registered_as_writer()` must only be called on the node that " + "the writer is on." + ) + + remote_reader_ref_info: Dict[str, ReaderRefInfo] = {} + for node_id, reader_ref_info in self._node_id_to_reader_ref_info.items(): + if self.is_local_node(node_id): + continue + remote_reader_ref_info[node_id] = reader_ref_info + + self._worker.core_worker.experimental_channel_register_writer( + self._writer_ref, + remote_reader_ref_info, + ) + self._writer_registered = True + + def ensure_registered_as_reader(self) -> None: + if self._reader_registered: + return + + for node_id, reader_ref_info in self._node_id_to_reader_ref_info.items(): + if self.is_local_node(node_id): + self._worker.core_worker.experimental_channel_register_reader( + reader_ref_info.reader_ref, + ) + self._reader_registered = True + + @staticmethod + def _deserialize_reader_channel( + writer: ray.actor.ActorHandle, + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + typ: int, + writer_node_id, + writer_ref: "ray.ObjectRef", + node_id_to_reader_ref_info: Dict[str, ReaderRefInfo], + writer_registered: bool, + reader_registered: bool, + ) -> "Channel": + chan = Channel( + writer, + reader_and_node_list, + typ, + _writer_node_id=writer_node_id, + _writer_ref=writer_ref, + _node_id_to_reader_ref_info=node_id_to_reader_ref_info, + _writer_registered=writer_registered, + _reader_registered=reader_registered, + ) + return chan + + def __reduce__(self): + assert self._node_id_to_reader_ref_info is not None + return self._deserialize_reader_channel, ( + self._writer, + self._reader_and_node_list, + self._typ, + self._writer_node_id, + self._writer_ref, + self._node_id_to_reader_ref_info, + self._writer_registered, + self._reader_registered, + ) + + def __str__(self) -> str: + return ( + f"Channel(_node_id_to_reader_ref_info={self._node_id_to_reader_ref_info}, " + f"_writer_ref={self._writer_ref})" + ) + + def _resize_channel_if_needed(self, serialized_value: str, timeout_ms: int): + # serialized_value.total_bytes *only* includes the size of the data. It does not + # include the size of the metadata, so we must account for the size of the + # metadata explicitly. + size = serialized_value.total_bytes + len(serialized_value.metadata) + if size > self._typ.buffer_size_bytes: + # Now make the channel backing store larger. + self._typ.buffer_size_bytes = size + # TODO(jhumphri): Free the current writer ref once the reference to it is + # destroyed below. + # TODO(sang): Support different policies such as 2X buffer size. + prev_writer_ref = self._writer_ref + self._writer_ref = _create_channel_ref(self, self._typ.buffer_size_bytes) + self._create_reader_refs(self._typ.buffer_size_bytes) + self._local_reader_ref = self._get_local_reader_ref( + self._node_id_to_reader_ref_info + ) + + # Write a special message to the channel so that the readers know to + # stop using the current reader_ref. + special_message = _ResizeChannel(self._node_id_to_reader_ref_info) + special_message_serialized = ( + self._worker.get_serialization_context().serialize(special_message) + ) + self._worker.core_worker.experimental_channel_put_serialized( + special_message_serialized, + prev_writer_ref, + self._num_local_readers, + timeout_ms, + ) + # TODO(sang): Clean the previous ref that won't be used. + # Right now, if we just close it here, it will not work because + # of race conditions. + # self._worker.core_worker.experimental_channel_set_error( + # prev_writer_ref + # ) + + def write(self, value: Any, timeout: Optional[float] = None) -> None: + self.ensure_registered_as_writer() + assert ( + timeout is None or timeout >= 0 or timeout == -1 + ), "Timeout must be non-negative or -1." + # -1 means no timeout (block indefinitely) + timeout_ms = int(timeout * 1000) if timeout is not None else -1 + + if not isinstance(value, SerializedObject): + try: + serialized_value = self._worker.get_serialization_context().serialize( + value + ) + except TypeError as e: + sio = io.StringIO() + ray.util.inspect_serializability(value, print_file=sio) + msg = ( + "Could not serialize the put value " + f"{repr(value)}:\n" + f"{sio.getvalue()}" + ) + raise TypeError(msg) from e + else: + serialized_value = value + + start_time = time.monotonic() + self._resize_channel_if_needed(serialized_value, timeout_ms) + if timeout is not None: + timeout_ms -= int((time.monotonic() - start_time) * 1000) + timeout_ms = max(timeout_ms, 0) + + self._worker.core_worker.experimental_channel_put_serialized( + serialized_value, + self._writer_ref, + self._num_local_readers, + timeout_ms, + ) + + def read(self, timeout: Optional[float] = None) -> Any: + assert ( + timeout is None or timeout >= 0 or timeout == -1 + ), "Timeout must be non-negative or -1." + self.ensure_registered_as_reader() + + start_time = time.monotonic() + ret = self._worker.get_objects( + [self._local_reader_ref], timeout=timeout, return_exceptions=True + )[0][0] + + if isinstance(ret, _ResizeChannel): + self._node_id_to_reader_ref_info = ret._node_id_to_reader_ref_info + self._local_reader_ref = self._get_local_reader_ref( + self._node_id_to_reader_ref_info + ) + # We need to register the new reader_ref. + self._reader_registered = False + self.ensure_registered_as_reader() + if timeout is not None: + timeout -= time.monotonic() - start_time + timeout = max(timeout, 0) + ret = self._worker.get_objects( + [self._local_reader_ref], timeout=timeout, return_exceptions=True + )[0][0] + + return ret + + def release_buffer(self, timeout: Optional[float] = None) -> None: + assert ( + timeout is None or timeout >= 0 or timeout == -1 + ), "Timeout must be non-negative or -1." + self.ensure_registered_as_reader() + self._worker.get_objects( + [self._local_reader_ref], + timeout=timeout, + return_exceptions=True, + skip_deserialization=True, + ) + + def close(self) -> None: + """ + Close this channel by setting the error bit on both the writer_ref and the + reader_ref. + """ + self._worker.core_worker.experimental_channel_set_error(self._writer_ref) + is_local_node_reader = False + + for node_id in self._node_id_to_readers.keys(): + if self.is_local_node(node_id): + is_local_node_reader = True + if is_local_node_reader: + self.ensure_registered_as_reader() + + for reader_ref_info in self._node_id_to_reader_ref_info.values(): + self._worker.core_worker.experimental_channel_set_error( + reader_ref_info.reader_ref + ) + + +@DeveloperAPI +class BufferedSharedMemoryChannel(ChannelInterface): + """A channel that can be read and written by Ray processes. + + It creates `num_shm_buffers` number of buffers and allows buffered read and + write APIs. I.e., read and write APIs are non-blocking as long as it can write to + next buffer or read from a next buffer. See `read` and `write` APIs for + more details. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + num_shm_buffers: Number of shared memory buffers to read/write. + typ: Type information about the values passed through the channel. + Either an integer representing the max buffer size in bytes + allowed, or a SharedMemoryType. + """ + + def __init__( + self, + writer: Optional[ray.actor.ActorHandle], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + num_shm_buffers: int, + typ: Optional[Union[int, SharedMemoryType]] = None, + ): + self._num_shm_buffers = num_shm_buffers + self._buffers = [ + # We use Channel directly as a buffer implementation as + # channel only allows to have 1 shared memory buffer. + Channel(writer, reader_and_node_list, typ) + for _ in range(num_shm_buffers) + ] + # The next index to write from self._buffers. + self._next_write_index = 0 + # The next index to read from self._buffers. + self._next_read_index = 0 + + def ensure_registered_as_writer(self): + """ + Check whether the process is a valid writer. This method must be idempotent. + """ + for buffer in self._buffers: + buffer.ensure_registered_as_writer() + + def ensure_registered_as_reader(self): + """ + Check whether the process is a valid reader. This method must be idempotent. + """ + for buffer in self._buffers: + buffer.ensure_registered_as_reader() + + def write(self, value: Any, timeout: Optional[float] = None) -> None: + """Write a value to a channel. + + If the next buffer is available, it returns immediately. If the next + buffer is not read by downstream consumers, it blocks until a buffer is + available to write. If a buffer is not available within timeout, it raises + RayChannelTimeoutError. + """ + self.ensure_registered_as_writer() + # A single channel is not supposed to read and write at the same time. + assert self._next_read_index == 0 + self._buffers[self._next_write_index].write(value, timeout) + self._next_write_index += 1 + self._next_write_index %= self._num_shm_buffers + + def read(self, timeout: Optional[float] = None) -> Any: + """Read a value from a channel. + + If the next buffer is available, it returns immediately. If the next + buffer is not written by an upstream producer, it blocks until a buffer is + available to read. If a buffer is not available within timeout, it raises + RayChannelTimeoutError. + """ + self.ensure_registered_as_reader() + # A single channel is not supposed to read and write at the same time. + assert self._next_write_index == 0 + output = self._buffers[self._next_read_index].read(timeout) + self._next_read_index += 1 + self._next_read_index %= self._num_shm_buffers + return output + + def release_buffer(self, timeout: Optional[float] = None): + """Release the native buffer of the channel to allow the buffer to be reused for + future data. + + If the next buffer is available, it returns immediately. If the next + buffer is not written by an upstream producer, it blocks until a buffer is + available to be released. If a buffer is not available within timeout, it raises + RayChannelTimeoutError. + """ + # A single channel is not supposed to read and write at the same time. + assert self._next_write_index == 0 + self._buffers[self._next_read_index].release_buffer(timeout) + self._next_read_index += 1 + self._next_read_index %= self._num_shm_buffers + + def close(self) -> None: + for buffer in self._buffers: + buffer.close() + + @property + def next_write_index(self): + # Testing only + return self._next_write_index + + @property + def next_read_index(self): + # Testing only + return self._next_read_index + + +@PublicAPI(stability="alpha") +class CompositeChannel(ChannelInterface): + """ + Can be used to send data to different readers via different channels. + For example, if the reader is in the same worker process as the writer, + the data can be sent via IntraProcessChannel. If the reader is in a different + worker process, the data can be sent via shared memory channel. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + driver_actor_id: If this channel is read by a driver and that driver is an + actual actor, this will be the actor ID of that driver actor. + """ + + def __init__( + self, + writer: Optional[ray.actor.ActorHandle], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + num_shm_buffers: int, + driver_actor_id: Optional[str] = None, + _channel_dict: Optional[Dict[ray.ActorID, ChannelInterface]] = None, + _channels: Optional[Set[ChannelInterface]] = None, + _writer_registered: bool = False, + _reader_registered: bool = False, + ): + self._writer = writer + self._reader_and_node_list = reader_and_node_list + self._num_shm_buffers = num_shm_buffers + self._driver_actor_id = driver_actor_id + self._writer_registered = _writer_registered + self._reader_registered = _reader_registered + # A dictionary that maps the actor ID to the channel object. + self._channel_dict = _channel_dict or {} + # The set of channels is a deduplicated version of the _channel_dict values. + self._channels = _channels or set() + if self._channels: + # This CompositeChannel object is created by deserialization. + # We don't need to create channels again. + return + + ( + remote_reader_and_node_list, + local_reader_and_node_list, + ) = utils.split_readers_by_locality(self._writer, self._reader_and_node_list) + # There are some local readers which are the same worker process as the writer. + # Create a local channel for the writer and the local readers. + num_local_readers = len(local_reader_and_node_list) + if num_local_readers > 0: + # Use num_readers = 1 when creating the local channel, + # because we have channel cache to support reading + # from the same channel multiple times. + local_channel = IntraProcessChannel(num_readers=1) + self._channels.add(local_channel) + actor_id = self._get_actor_id(self._writer) + self._channel_dict[actor_id] = local_channel + # There are some remote readers which are not the same Ray actor as the writer. + # Create a shared memory channel for the writer and the remote readers. + if len(remote_reader_and_node_list) != 0: + remote_channel = BufferedSharedMemoryChannel( + self._writer, remote_reader_and_node_list, num_shm_buffers + ) + self._channels.add(remote_channel) + + for reader, _ in remote_reader_and_node_list: + actor_id = self._get_actor_id(reader) + self._channel_dict[actor_id] = remote_channel + + def _get_actor_id(self, reader: ray.actor.ActorHandle) -> str: + return reader._actor_id.hex() + + def ensure_registered_as_writer(self) -> None: + if self._writer_registered: + return + for channel in self._channels: + channel.ensure_registered_as_writer() + self._writer_registered = True + + def ensure_registered_as_reader(self) -> None: + if self._reader_registered: + return + for channel in self._channels: + channel.ensure_registered_as_reader() + self._reader_registered = True + + def __reduce__(self): + return CompositeChannel, ( + self._writer, + self._reader_and_node_list, + self._num_shm_buffers, + self._driver_actor_id, + self._channel_dict, + self._channels, + self._writer_registered, + self._reader_registered, + ) + + def __str__(self) -> str: + return ( + "CompositeChannel(_channels=" + f"{[str(channel) for channel in self._channels]})" + ) + + def write(self, value: Any, timeout: Optional[float] = None) -> None: + self.ensure_registered_as_writer() + for channel in self._channels: + channel.write(value, timeout) + + def read(self, timeout: Optional[float] = None) -> Any: + self.ensure_registered_as_reader() + return self._channel_dict[self._resolve_actor_id()].read(timeout) + + def release_buffer(self, timeout: Optional[float] = None): + self.ensure_registered_as_reader() + self._channel_dict[self._resolve_actor_id()].release_buffer(timeout) + + def _resolve_actor_id(self) -> str: + actor_id = ray.get_runtime_context().get_actor_id() + # If actor_id is None, read was called by the driver + # If the driver is an actor, driver_actor_id will be set to that actor id + if actor_id is None or actor_id == self._driver_actor_id: + # Use the actor ID of the DAGDriverProxyActor. + # The proxy actor is always the first actor in the reader_and_node_list. + assert len(self._reader_and_node_list) >= 1 + driver_proxy_actor = self._reader_and_node_list[0][0] + actor_id = self._get_actor_id(driver_proxy_actor) + return actor_id + + def close(self) -> None: + for channel in self._channels: + channel.close() diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_nccl_channel.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_nccl_channel.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf42f7ec47eded05a830bed2ca14fb1de80e151 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -0,0 +1,837 @@ +import io +import logging +import uuid +from dataclasses import dataclass +from types import ModuleType +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import ray +import ray.util.serialization +from ray.experimental.channel import ChannelContext, utils +from ray.experimental.channel.common import ChannelInterface +from ray.experimental.channel.communicator import Communicator +from ray.experimental.channel.cpu_communicator import CPUCommunicator +from ray.experimental.channel.intra_process_channel import IntraProcessChannel +from ray.experimental.channel.nccl_group import _NcclGroup +from ray.experimental.channel.shared_memory_channel import SharedMemoryType +from ray.experimental.channel.torch_tensor_type import TorchTensorType +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import torch + + from ray.experimental.channel.shared_memory_channel import Channel + + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + + +@dataclass +class _TorchTensorMetadata: + """ + Metadata for torch.Tensors that can be sent between processes to determine + how large of a buffer to allocate on the receiver(s). + """ + + shape: Union[int, Tuple[int]] + dtype: "torch.dtype" + + +@DeveloperAPI +class TorchTensorNcclChannel(ChannelInterface): + def __init__( + self, + writer: ray.actor.ActorHandle, + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + typ: "TorchTensorType", + driver_actor_id: str, + tensor_metadata_channel: Optional["Channel"] = None, + _cpu_data_channel: Optional["Channel"] = None, + _gpu_data_channel: Optional["_TorchTensorNcclChannel"] = None, + _local_channel: Optional["IntraProcessChannel"] = None, + ): + """ + Can be used to send GPU tensors nested inside other data. The data is + sent via shared memory while the GPU tensors are sent through a P2P + transport (NCCL). + + NOTE: This class is currently not thread-safe because it reads and + writes the worker-local + ray.experimental.channel.serialization_context._SerializationContext + when serializing data. + + Args: + writer: The actor that may write to the channel. None signifies the + driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + typ: Type information about the values passed through the channel. + driver_actor_id: The actor ID of the DAGDriverProxyActor. + tensor_metadata_channel: A shared-memory channel for sending tensor + metadata. + _cpu_data_channel: A shared-memory channel for sending + non-tensor data. Its writer and readers should match the given + writer and readers. If None is provided, then we assume that + there is no CPU-specific data, i.e. the task directly returned + a CUDA torch.Tensor. + _gpu_data_channel: A channel for sending torch.Tensors via NCCL. + _local_channel: A channel for sending data between the writer and + local readers. + + NOTE: `tensor_metadata_channel` will be set only for testing purposes. + `_cpu_data_channel` is set for testing purposes and for deserialization. + `_gpu_data_channel` and `_local_channel` are set only during deserialization. + """ + self._writer = writer + self._reader_and_node_list = reader_and_node_list + self._typ = typ + + ( + remote_reader_and_node_list, + local_reader_and_node_list, + ) = utils.split_readers_by_locality(self._writer, self._reader_and_node_list) + + num_local_readers = len(local_reader_and_node_list) + self._local_channel = _local_channel + if self._local_channel is None and num_local_readers > 0: + # There are some local readers which are the same worker process as + # the writer. Create a local channel for the writer and the local readers. + # + # Use num_readers = 1 when creating the local channel, + # because we have channel cache to support reading + # from the same channel multiple times. + self._local_channel = IntraProcessChannel(num_readers=1) + + assert len(remote_reader_and_node_list) > 0, ( + "All readers are from the same actor. " + "The TorchTensorType type hint is not needed. " + "No NCCL channel will be created." + ) + self._gpu_data_channel = _gpu_data_channel + if self._gpu_data_channel is None: + self._gpu_data_channel: _TorchTensorNcclChannel = _TorchTensorNcclChannel( + writer, + remote_reader_and_node_list, + typ, + _meta_channel=tensor_metadata_channel, + ) + + self._cpu_data_channel: Optional["Channel"] = _cpu_data_channel + if self._cpu_data_channel is not None: + assert ( + not self._typ.direct_return + ), "CPU channel should be None if direct return is enabled" + + if self._cpu_data_channel is None and not self._typ.direct_return: + # Create a CPU channel to send non-tensor data. + self._cpu_data_channel = SharedMemoryType().create_channel( + writer, remote_reader_and_node_list, driver_actor_id + ) + + # Used for serialization. + self._worker = ray._private.worker.global_worker + self._worker.check_connected() + + ctx = ChannelContext.get_current() + self.serialization_ctx = ctx.serialization_context + assert self.serialization_ctx is not None + + def __reduce__(self): + return ( + TorchTensorNcclChannel, + ( + self._writer, + self._reader_and_node_list, + self._typ, + # driver_actor_id and tensor_metadata_channel are used to initialize + # the _cpu_data_channel and _gpu_data_channel, so we don't need to + # pass them in here. + None, + None, + self._cpu_data_channel, + self._gpu_data_channel, + self._local_channel, + ), + ) + + def ensure_registered_as_writer(self): + if self._local_channel is not None: + self._local_channel.ensure_registered_as_writer() + self._gpu_data_channel.ensure_registered_as_writer() + if self._cpu_data_channel is not None: + self._cpu_data_channel.ensure_registered_as_writer() + + def ensure_registered_as_reader(self): + reader = utils.get_self_actor() + if reader == self._writer: + self._local_channel.ensure_registered_as_reader() + return + self._gpu_data_channel.ensure_registered_as_reader() + if self._cpu_data_channel is not None: + self._cpu_data_channel.ensure_registered_as_reader() + + def _send_cpu_and_gpu_data(self, value: Any, timeout: Optional[float]): + self.serialization_ctx.reset_out_of_band_tensors([]) + # All tensors found in `value` will be transferred via NCCL. + self.serialization_ctx.set_use_external_transport(True) + + try: + # Serialize the data. All tensors that match our current device + # will be extracted into the serialization context and replaced + # with a placeholder. + cpu_data = self._worker.get_serialization_context().serialize(value) + except TypeError as e: + sio = io.StringIO() + ray.util.inspect_serializability(value, print_file=sio) + msg = ( + "Could not serialize the put value " + f"{repr(value)}:\n" + f"{sio.getvalue()}" + ) + raise TypeError(msg) from e + finally: + # Pop the tensors that were found during serialization of `value`. + gpu_tensors, _ = self.serialization_ctx.reset_out_of_band_tensors([]) + # Reset the serialization method to now serialize torch.Tensors + # normally. + self.serialization_ctx.set_use_external_transport(False) + + # First send the extracted tensors through a GPU-specific channel. + self._gpu_data_channel.write(gpu_tensors) + # Next send the non-tensor data through a CPU-specific channel. The + # data contains placeholders for the extracted tensors. + self._cpu_data_channel.write(cpu_data) + + def write(self, value: Any, timeout: Optional[float] = None) -> None: + """ + Send a value that may contain torch.Tensors that should be sent via + external transport. + + Case 1: Use `_local_channel` to send the data to local readers. + + Case 2: Otherwise, use the following method to send the data to remote readers. + + 1) Serializes `value`. During serialization, all torch.Tensors that are + on the default device are extracted and replaced with a unique + placeholder. Thus, the serialized value will contain all non-tensor + data, and any tensors that were not on the default device (e.g., CPU + tensor returned by a GPU actor). + 2) Sends extracted torch.Tensors via the tensor data channel (e.g., + NCCL). + 3) Sends the non-tensor data via the non-tensor data channel. + + If static_non_tensor_data=True was specified, then we only perform step + (3) on the first `write` call. The reader is expected to reuse the sent + data for subsequent messages. + """ + self.ensure_registered_as_writer() + + if self._local_channel is not None: + self._local_channel.write(value) + + if isinstance(value, ray.exceptions.RayTaskError): + if self._typ.static_shape or self._typ.direct_return: + # Raise a fatal error to teardown the DAG. + # This error will also be caught from `CompiledDAGRef.get()` + # and raised to the user + # TODO(swang): Write exceptions to the tensor metadata or + # non-tensor data channel if it is available to make these + # exceptions recoverable. + raise value + + if self._cpu_data_channel is None: + # Handle the case where _direct_return=True. In this case, we check + # that the task returned a CUDA torch.Tensor and just send it + # directly without trying to serialize it first. + import torch + + # These ValueErrors will also be caught from `CompiledDAGRef.get()` + # and raised to the user + if not isinstance(value, torch.Tensor): + # TODO(swang): These errors are currently fatal for the DAG. + # This could be improved by sending the exception through the + # gpu_data_channel's CPU-based metadata channel, if one exists. + raise ValueError( + "Task annotated with _direct_return=True must " + "return a CUDA torch.Tensor, instead found value " + f"`{value}`. DAG will shut down." + ) + elif not value.is_cuda: + raise ValueError( + "Task annotated with _direct_return=True must " + "return a CUDA torch.Tensor, instead found CPU tensor. " + "DAG will shut down." + ) + self._gpu_data_channel.write([value], timeout=timeout) + else: + self._send_cpu_and_gpu_data(value, timeout) + + def _recv_cpu_and_gpu_data( + self, tensors: List["torch.Tensor"], timeout: Optional[float] = None + ) -> Any: + """ + Helper method to receive data that contains a mix of CPU and GPU data. + + Args: + tensors: The GPU data. This is a list of the torch.Tensors that + were found in the sent data. + timeout: Timeout for channel receive. + """ + self.serialization_ctx.reset_out_of_band_tensors(tensors) + + # Next, read and deserialize the non-tensor data. The registered custom + # deserializer will replace the found tensor placeholders with + # `tensors`. + data = self._cpu_data_channel.read( + timeout=timeout, + ) + # Check that all placeholders had a corresponding tensor. + ( + _, + deserialized_tensor_placeholders, + ) = self.serialization_ctx.reset_out_of_band_tensors([]) + assert deserialized_tensor_placeholders == set(range(len(tensors))) + + return data + + def read(self, timeout: Optional[float] = None) -> Any: + """ + Read a value that may contain torch.Tensors sent via external + transport. + + Case 1: If the reader is a local reader and is the same actor as the writer, + then use the `_local_channel` to read the data. + + Case 2: Otherwise, use the following method to read data from remote readers. + + 1) Receives torch.Tensors via the tensor data channel (e.g., NCCL). + 2) Reads the serialized non-tensor data. + 3) Deserializes the non-tensor data. During deserialization, replaces + all found placeholders with the received torch.Tensors. + + If _direct_return=True was specified, then we skip step (2) and (3) and + directly return the data received in (1). + """ + self.ensure_registered_as_reader() + + # If the reader is the same actor as the writer, then we can use the + # local channel to read the data. + reader = utils.get_self_actor() + if reader == self._writer: + assert self._local_channel is not None + return self._local_channel.read() + + # First, read the tensor data. + tensors = self._gpu_data_channel.read(timeout) + + if self._cpu_data_channel is None: + # Handle _direct_return=True. In this case, we expect to receive + # only one tensor, and we return it directly. + assert len(tensors) == 1 + data = tensors[0] + else: + data = self._recv_cpu_and_gpu_data(tensors, timeout) + + return data + + def close(self) -> None: + self._gpu_data_channel.close() + if self._cpu_data_channel is not None: + self._cpu_data_channel.close() + if self._local_channel is not None: + self._local_channel.close() + + +def _torch_zeros_allocator( + shape: Union[int, Tuple[int]], + dtype: "torch.dtype", +): + """ + Allocate a zeros tensor buffer matching the given metadata. + """ + import torch + + ctx = ChannelContext.get_current() + return torch.zeros(shape, dtype=dtype, device=ctx.torch_device) + + +class _TorchTensorNcclChannel(ChannelInterface): + def __init__( + self, + writer: ray.actor.ActorHandle, + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + typ: "TorchTensorType", + _meta_channel: Optional["Channel"] = None, + ): + """ + A helper channel for TorchTensorNcclChannel that is used to transfer + lists of torch.Tensors via NCCL. This class can only transfer + torch.Tensors and cannot transfer other CPU data, such as Exception + objects or tensors nested inside of a dictionary. + + Args: + writer: The actor that may write to the channel. None signifies the driver. + reader_and_node_list: A list of tuples, where each tuple contains a reader + actor handle and the node ID where the actor is located. + typ: Type information about the values passed through the channel. + _meta_channel: A channel used to send metadata for the tensors, + i.e. shape and dtype. If not provided, and if the typ does not + specify a static shape and dtype, then a metadata channel based + on shared memory will be created. + """ + import torch + + self.torch: ModuleType = torch + + self._writer = writer + self._writer_rank: Optional[int] = None + self._reader_and_node_list = reader_and_node_list + self._reader_ranks: Optional[List[int]] = None + self._writer_registered: bool = False + self._reader_registered: bool = False + + ctx = ChannelContext.get_current() + assert isinstance( + typ.communicator_id, str + ), "NCCL group ID ({nccl_group_id}) must be a str." + self._typ = typ + + assert self._typ.communicator_id is not None, "No NCCL group specified." + self._nccl_group_id: str = self._typ.communicator_id + self._nccl_group: "Communicator" = ctx.communicators[self._typ.communicator_id] + assert ( + self._nccl_group is not None + ), "ChannelContext.nccl_group is not initialized." + + self._static_shape = typ.static_shape + + self._writer_rank = self._nccl_group.get_rank(self._writer) + self._reader_ranks = [ + self._nccl_group.get_rank(reader) + for reader, _ in self._reader_and_node_list + ] + + if ( + self._writer_rank is not None + and self._writer_rank == self._nccl_group.get_self_rank() + ): + self._writer_registered = True + + if ( + self._reader_ranks + and self._nccl_group.get_self_rank() in self._reader_ranks + ): + self._reader_registered = True + + # If the channel type specifies that the tensor shape is static, then the + # receiver can allocate buffers without needing to coordinate with the + # sender. We set the metadata on the first send-recv op. Thereafter, + # the sender must ensure that sent tensors match this metadata, and the + # receiver will allocate tensors with this shape. + self._static_tensor_metadata: Optional[List[_TorchTensorMetadata]] = None + self._meta_channel: Optional[Channel] = _meta_channel + if self._meta_channel is None and self._writer_registered: + # We are the writer. Therefore, we also need to allocate a metadata + # channel that will be used to send the shape and dtype of the + # tensor to the receiver(s). + metadata_type = SharedMemoryType() + self._meta_channel = metadata_type.create_channel( + self._writer, + self._reader_and_node_list, + None, + ) + + def ensure_registered_as_writer(self): + assert self._nccl_group is not None, "Actor is not part of a NCCL group" + assert self._writer_registered + ctx = ChannelContext.get_current() + assert ctx.torch_device.type == "cuda" + + def ensure_registered_as_reader(self) -> bool: + assert self._nccl_group is not None, "Actor is not part of a NCCL group" + assert self._reader_registered + ctx = ChannelContext.get_current() + assert ctx.torch_device.type == "cuda" + + def __reduce__(self): + return ( + self.__class__, + ( + self._writer, + self._reader_and_node_list, + self._typ, + self._meta_channel, + ), + ) + + def _get_send_tensors_metadata( + self, tensors: List["torch.Tensor"] + ) -> Optional[List[_TorchTensorMetadata]]: + """ + Helper method to get the metadata that should be sent to the reader so + that they can allocate the proper-sized buffer(s). Throws error if + static_shape=True was set and the given tensors do not match the + inferred shapes. + + Returns: The metadata to send to the reader. None means that we should + not send any metadata message to the reader. + """ + ctx = ChannelContext.get_current() + + # TODO(swang): Currently any exceptions thrown during this method are + # fatal for the DAG because there is no way for the receiver to receive + # the exception. This can be improved by sending the exception through + # the CPU-based non-tensor-data channel, if one exists. The tensor + # channel can send empty data alongside the exception to avoid hanging. + + # Get the shape and dtype of each tensor to send. + metadata_list = [] + for tensor in tensors: + # Basic type checking. + if not isinstance(tensor, self.torch.Tensor): + raise ValueError("Task must return torch.Tensors") + + if tensor.device != ctx.torch_device: + raise ValueError( + f"torch.Tensor must be on the default device: {ctx.torch_device}" + ) + + metadata = _TorchTensorMetadata(tensor.shape, tensor.dtype) + metadata_list.append(metadata) + + if self._static_tensor_metadata is not None: + if metadata_list != self._static_tensor_metadata: + metadata_str = [ + f"(shape={m.shape}, dtype={m.dtype})" for m in metadata_list + ] + expected_str = [ + f"(shape={m.shape}, dtype={m.dtype})" + for m in self._static_tensor_metadata + ] + raise ValueError( + "Expected torch.Tensors with shapes and dtypes: " + "[" + ", ".join(expected_str) + "], " + "found: [" + ", ".join(metadata_str) + "]. " + "DAG will shut down." + ) + # The receiver has already determined the shape and dtype of the + # tensors from a previous send, so no need to send the metadata + # again. + return None + + if self._static_shape: + # The shape and dtype is static. This is the first send op and + # afterwards, a ValueError will be thrown if the sent tensors do + # not match this metadata. + self._static_tensor_metadata = metadata_list + return metadata_list + + def write( + self, + tensors: List["torch.Tensor"], + timeout: Optional[float] = None, + ): + """ + Write a list of tensors via NCCL: + + 1) Send the tensor metadata, i.e. the shape and dtypes of all tensors + via the shared-memory metadata channel. + 2) Send the tensor data via NCCL. + + If static_shape=True was set, then we only perform step (1) on the + first message. The reader is expected to reuse the sent metadata for + subsequent messages. + """ + self.ensure_registered_as_writer() + + import torch + + for tensor in tensors: + assert isinstance( + tensor, torch.Tensor + ), f"{tensor} must be instance of torch.Tensor" + + # Send the tensors metadata so that the receiver knows what buffers to + # allocate. + metadata = self._get_send_tensors_metadata(tensors) + if metadata is not None: + self._meta_channel.write(metadata) + + # NOTE(swang): We must send the metadata *before* launching the NCCL + # send. We are using blocking NCCL ops, so the following calls will + # block until the kernel has been enqueued. Also, peers must launch the + # kernel together before either can proceed. Therefore, we send the + # metadata first so that the receiver can read the metadata and then + # launch the same NCCL op. + for tensor in tensors: + # TODO: If there are multiple readers, can replace with a + # broadcast. + for rank in self._reader_ranks: + self._nccl_group.send(tensor, rank) + + def _get_recv_tensors_metadata( + self, timeout: Optional[float] = None + ) -> List[_TorchTensorMetadata]: + """ + Get the shape(s) and dtype(s) of the tensors to receive from the + metadata channel. If static_shape=True was set, then we reuse the first + metadata received. + """ + if self._static_tensor_metadata is not None: + return self._static_tensor_metadata + + meta = self._meta_channel.read(timeout) + + if self._static_shape: + self._static_tensor_metadata = meta + + return meta + + def read( + self, + timeout: Optional[float] = None, + ) -> Union["torch.Tensor", List["torch.Tensor"]]: + """ + Receive a list of tensors. + + (1) Receive the tensor metadata via the shared-memory metadata channel. + (2) Allocate buffers on our default device according to the received + tensor metadata. + (3) Receive the tensor data via NCCL. + + If static_data=True was set, then we only perform step (1) on the first + message. Subsequent messages reuse the same metadata. + + NOTE: Currently `timeout` only applies to receiving the CPU-based + tensor metadata. The GPU recv may exceed the timeout without throwing + an error. + """ + self.ensure_registered_as_reader() + + meta_list: List[_TorchTensorMetadata] = self._get_recv_tensors_metadata(timeout) + + bufs: List["torch.Tensor"] = [] + for meta in meta_list: + buf = self._nccl_group.recv( + meta.shape, meta.dtype, self._writer_rank, _torch_zeros_allocator + ) + bufs.append(buf) + # TODO: Sync CUDA stream after receiving all tensors, instead of after + # each tensor. + return bufs + + def close(self) -> None: + self._meta_channel.close() + + self._nccl_group.destroy() + ctx = ChannelContext.get_current() + if self._nccl_group_id in ctx.communicators: + del ctx.communicators[self._nccl_group_id] + + +def _do_init_communicator( + self, + group_id, + world_size, + comm_id, + rank, + actor_handles, + use_communication_streams, + custom_communicator: Optional[Communicator] = None, +): + import torch + + if not custom_communicator: + assert ( + ray.get_gpu_ids() + ), "Actors participating in NCCL group must have at least one GPU assigned" + + ctx = ChannelContext.get_current() + if custom_communicator is not None: + custom_communicator.initialize(rank) + ctx.communicators[group_id] = custom_communicator + else: + # default to NcclGroup + ctx.communicators[group_id] = _NcclGroup( + world_size, + comm_id, + rank, + actor_handles, + torch.cuda.current_stream().cuda_stream, + use_communication_streams, + ) + + +def _do_destroy_communicator(self, group_id): + ctx = ChannelContext.get_current() + if group_id not in ctx.communicators: + return + ctx.communicators[group_id].destroy() + + # Keep the NCCL group in the map after destruction in case there is still a + # task loop running. + + +def _do_check_has_gpu(self) -> bool: + return bool(ray.get_gpu_ids()) + + +def _do_get_unique_nccl_id(self) -> bool: + from cupy.cuda import nccl + + return nccl.get_unique_id() + + +def _get_ranks( + actors: List[ray.actor.ActorHandle], custom_nccl_group: Optional[Communicator] +) -> List[int]: + """ + Get ranks for the NCCL group to use. If custom_nccl_group is specified, + return the ranks of the actors in the custom NCCL group, in the same + order of the actors; otherwise, return list(range(len(actors))). + + Args: + actors: A list of actors that participate in the NCCL group. + custom_nccl_group: The custom NCCL group to use. + """ + if custom_nccl_group is None: + return list(range(len(actors))) + + assert len(actors) == custom_nccl_group.get_world_size(), ( + "The world size of the custom NCCL group does not match the number " + "of actors." + ) + ranks = [] + for actor in actors: + rank = custom_nccl_group.get_rank(actor) + assert rank not in ranks, "Duplicate rank in custom NCCL group" + ranks.append(rank) + assert custom_nccl_group.get_world_size() == len(actors), ( + "The world size of the custom NCCL group " + f"({custom_nccl_group.get_world_size()}) " + "does not match the number of actors " + f"({len(actors)})." + ) + return ranks + + +def _init_communicator( + actors: List[ray.actor.ActorHandle], + custom_communicator: Optional[Communicator] = None, + use_communication_streams: bool = False, +) -> str: + """ + Initialize a NCCL group with the given actors. If a custom NCCL group is + provided, then it will be used, otherwise a new NCCL group will be created. + + Args: + actors: A list of actors that participate in the NCCL group. + custom_communicator: A custom NCCL group to initialize. + use_communication_streams: Whether to use dedicated send and recv + streams for communication. If True, communication and computation + can be overlapped to improve perfomrance. + """ + ctx = ChannelContext.get_current() + + is_cpu_communicator = custom_communicator and isinstance( + custom_communicator, CPUCommunicator + ) + + has_gpus = ray.get( + [actor.__ray_call__.remote(_do_check_has_gpu) for actor in actors] + ) + for has_gpu, actor in zip(has_gpus, actors): + if not has_gpu and not is_cpu_communicator: + raise ValueError( + f"Actor {actor} returns a tensor with type hint " + 'TorchTensor(transport="nccl") or ' + "TorchTensor(transport=nccl_group_handle)" + "but actor does not have a GPU assigned by Ray." + ) + + actor_ids = {actor._ray_actor_id for actor in actors} + assert len(actor_ids) == len(actors), "Actors must be unique" + + # Allocate a communicator ID on one of the actors that will participate in + # the group. This is in case the driver is not on the same node as one of + # the NCCL actors. + nccl_comm_id = ( + ray.get(actors[0].__ray_call__.remote(_do_get_unique_nccl_id)) + if not is_cpu_communicator + else str(uuid.uuid4()) + ) + # Used to uniquely identify this NCCL group. + group_id = str(uuid.uuid4()) + + if custom_communicator is not None: + logger.info(f"Initializing custom NCCL group {group_id} on actors: {actors}") + else: + logger.info(f"Creating NCCL group {group_id} on actors: {actors}") + + world_size = len(actors) + ranks = _get_ranks(actors, custom_communicator) + init_tasks = [ + actor.__ray_call__.remote( + _do_init_communicator, + group_id, + world_size, + nccl_comm_id, + rank, + actors, + use_communication_streams, + custom_communicator, + ) + for rank, actor in zip(ranks, actors) + ] + try: + ray.get(init_tasks, timeout=30) + except ray.exceptions.GetTimeoutError: + logger.warning( + "NCCL group creation not done after 30s. NCCL group creation may be hung." + ) + ray.get(init_tasks) + + logger.info("NCCL group initialized.") + + if custom_communicator is not None: + ctx.communicators[group_id] = custom_communicator + else: + ctx.communicators[group_id] = _NcclGroup( + world_size, + nccl_comm_id, + rank=None, + actor_handles=actors, + cuda_stream=None, + ) + return group_id + + +def _destroy_communicator(group_id: str) -> None: + """ + Destroy the NCCL group with the given ID. + """ + ctx = ChannelContext.get_current() + if group_id not in ctx.communicators: + return + + group = ctx.communicators[group_id] + actors = group.get_actor_handles() + destroy_tasks = [ + actor.__ray_call__.remote( + _do_destroy_communicator, + group_id, + ) + for actor in actors + ] + + _, unready = ray.wait(destroy_tasks, timeout=30, num_returns=len(destroy_tasks)) + if unready: + logger.warning( + "NCCL group destruction not done after 30s. NCCL group destruction " + "may be hung." + ) + + del ctx.communicators[group_id] diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_type.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5db286865cb0e49f085dddccd4c5ba7a916df0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_type.py @@ -0,0 +1,180 @@ +import logging +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import ray +from ray.experimental.channel import ChannelContext, ChannelOutputType +from ray.experimental.channel.communicator import Communicator +from ray.experimental.channel.shared_memory_channel import SharedMemoryType +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.experimental.channel.shared_memory_channel import Channel + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class TorchTensorType(ChannelOutputType): + AUTO = "auto" + NCCL = "nccl" + CPU = "cpu" + + def __init__( + self, + transport: Optional[Union[str, Communicator]] = AUTO, + _static_shape: bool = False, + _direct_return: Optional[bool] = False, + ): + """ + A type hint that can be used to annotate DAG nodes that return a + torch.Tensor. + + NOTE: Use of this type in the DAG will register a custom serializer for + torch.Tensor that moves the tensor to the correct device on the + receiver. If you are using ray.cloudpickle to serialize objects and you + do not want this behavior, deregister the custom serializer using + ray.util.serialization.deregister_serializer(torch.Tensor). + + Args: + transport: "auto" (default) means that tensors will be passed via + host memory, using numpy as the serialization format. Pass + TorchTensorType.NCCL or "nccl" to use NCCL instead, avoiding + the host memory copy. + _static_shape: A hint indicating whether the shape(s) and dtype(s) + of tensor(s) contained in this value always remain the same + across different executions of the DAG. + _direct_return: Whether the tensor is sent directly or inside of + other data. If a non-default `transport` is used, this allows + the sender and receiver to eliminate performance overhead from + an additional data transfer. + + NOTE: Setting static_shape=True and _direct_return=True can improve + performance if a non-default transport is used. However, if either flag + is set, then the user must ensure that the condition is met. + + If using this type as a Compiled Graph annotation, an exception will + be thrown in the following cases, and the DAG will be torn down. To + continue execution, a new DAG must be created: + 1. If _static_shape=True, and the found tensors don't match the + previous shape or dtype(s). + 2. If _direct_return=True, and the returned value is not a + torch.Tensor. + """ + super().__init__() + + self._static_shape = _static_shape + self._direct_return = _direct_return + + self._communicator: Optional[Communicator] = None + if isinstance(transport, Communicator): + self._communicator = transport + transport = transport.get_transport_name() + + if transport not in [self.AUTO, self.NCCL, self.CPU]: + raise ValueError( + "`transport` must be TorchTensorType.AUTO, TorchTensorType.NCCL, " + "or TorchTensorType.CPU" + ) + self.transport = transport + + self._communicator_id: Optional[str] = None + + if self._static_shape and self.transport == self.AUTO: + logger.info( + "TorchTensorType(_static_shape=True) has no effect when " + "`transport` is TorchTensorType.AUTO (default)." + ) + if self._direct_return and self.transport == self.AUTO: + logger.info( + "TorchTensorType(_direct_return=True) has no effect when " + "`transport` is TorchTensorType.AUTO (default)." + ) + + @property + def static_shape(self): + return self._static_shape + + @property + def direct_return(self): + return self._direct_return + + def register_custom_serializer(self) -> None: + super().register_custom_serializer() + + import torch + + def serialize(t): + ctx = ChannelContext.get_current() + return ctx.serialization_context.serialize_tensor(t) + + def deserialize(b): + ctx = ChannelContext.get_current() + return ctx.serialization_context.deserialize_tensor(b) + + ray.util.serialization.register_serializer( + torch.Tensor, + serializer=serialize, + deserializer=deserialize, + ) + + def set_contains_type(self, typ: "ChannelOutputType") -> None: + raise ValueError("TorchTensorType cannot contain other types") + + def create_channel( + self, + writer: Optional["ray.actor.ActorHandle"], + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], + driver_actor_id: Optional[str] = None, + _cpu_data_channel: Optional["Channel"] = None, + _tensor_metadata_channel: Optional["Channel"] = None, + ) -> type: + if self.requires_nccl(): + from ray.experimental.channel.torch_tensor_nccl_channel import ( + TorchTensorNcclChannel, + ) + + return TorchTensorNcclChannel( + writer, + reader_and_node_list, + self, + driver_actor_id, + _tensor_metadata_channel, + _cpu_data_channel, + ) + + # Data does not require NCCL. Transfer via host memory using a + # shared-memory channel. + # TODO(swang): Allow the initial max buffer size to be overridden. + typ = SharedMemoryType() + return typ.create_channel(writer, reader_and_node_list, driver_actor_id) + + def requires_nccl(self) -> bool: + return self.transport == self.NCCL + + def get_custom_communicator(self) -> Optional[Communicator]: + """ + Return the NCCL group if one is specified. + """ + return self._communicator + + def set_communicator_id(self, group_id: str) -> None: + self._communicator_id = group_id + + @property + def communicator_id(self) -> Optional[str]: + return self._communicator_id + + def __deepcopy__(self, memo): + """ + Deep copy all the fields except for the NCCL group. The NCCL group + should not be deep copied because it can be shared across + `TorchTensorType` instances. + """ + copy = TorchTensorType( + transport=self.transport, + _static_shape=self._static_shape, + _direct_return=self._direct_return, + ) + copy._communicator = self._communicator + copy._communicator_id = self._communicator_id + return copy diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/channel/utils.py b/.venv/lib/python3.11/site-packages/ray/experimental/channel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88560b3bc1c879e6dfe3089dad38d948aeb16443 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/channel/utils.py @@ -0,0 +1,43 @@ +from typing import List, Optional, Tuple + +import ray + + +def get_self_actor() -> Optional["ray.actor.ActorHandle"]: + """ + Get the current actor handle in this worker. + If this is called in a driver process, it will return None. + """ + try: + return ray.get_runtime_context().current_actor + except RuntimeError: + return None + + +def split_readers_by_locality( + writer: "ray.actor.ActorHandle", + reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], +) -> Tuple[ + List[Tuple["ray.actor.ActorHandle", str]], List[Tuple["ray.actor.ActorHandle", str]] +]: + """Split readers into remote and local readers based on writer. + + Args: + writer: The actor handle of the writer + reader_and_node_list: List of (reader, node) tuples + + Returns: + Tuple containing: + - List of (reader, node) tuples for remote readers + - List of (reader, node) tuples for local readers + """ + remote_readers = [] + local_readers = [] + + for reader, node in reader_and_node_list: + if reader != writer: + remote_readers.append((reader, node)) + else: + local_readers.append((reader, node)) + + return remote_readers, local_readers diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e824152b7473a6debd663be5f37e2d317671b7b5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__init__.py @@ -0,0 +1,3 @@ +from ray.experimental.collective.allreduce import allreduce + +__all__ = ["allreduce"] diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac71627544be8226ae105ded2dc97db0d561fd8a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/allreduce.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/allreduce.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bef73d9662813e7deab1c2e86023e0bbfffc1df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/allreduce.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/conftest.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/conftest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26e5578da50b3bc60d5ae124bde05abdab8a5ce3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/collective/__pycache__/conftest.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/allreduce.py b/.venv/lib/python3.11/site-packages/ray/experimental/collective/allreduce.py new file mode 100644 index 0000000000000000000000000000000000000000..83e85d780db0641c3ec5bd524164da70eafbfcb4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/collective/allreduce.py @@ -0,0 +1,92 @@ +import logging +from typing import List, Optional, Union + +import ray +from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation +from ray.dag.constants import ( + BIND_INDEX_KEY, + COLLECTIVE_OPERATION_KEY, + PARENT_CLASS_NODE_KEY, +) +from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType +from ray.experimental.util.types import ReduceOp +from ray.util.collective.types import ReduceOp as RayReduceOp + +# TODO(wxdeng): Unify `ReduceOp` and `RayReduceOp`. Directly importing `RayReduceOp` +# has dependency issues for some tests. + +logger = logging.getLogger(__name__) + + +class AllReduceWrapper: + """Wrapper for NCCL all-reduce.""" + + def bind( + self, + input_nodes: List["ray.dag.DAGNode"], + op: ReduceOp = ReduceOp.SUM, + transport: Optional[Union[str, Communicator]] = None, + ) -> List[CollectiveOutputNode]: + """ + Bind input nodes with a collective operation. The collective operation is + directly applied to the torch tensors from the input nodes. The output nodes + are the results of the collective operation in the same torch tensors. + + Requirements: + 1. Each input node returns a torch tensor. + 2. Each input node is from a different actor. + 3. If a custom transport is specified, its actor set matches the actor set + of the input nodes. + 4. All tensors have the same shape. + + Requirements 1-3 are checked in the `CollectiveGroup` constructor. + Requirement 4 is not checked yet. + + Args: + input_nodes: A list of DAG nodes. + op: The collective operation. + transport: GPU communicator for the collective operation. If not + specified, the default NCCL is used. + + Returns: + A list of collective output nodes. + """ + if transport is None: + transport = TorchTensorType.NCCL + collective_op = _CollectiveOperation(input_nodes, op, transport) + collective_output_nodes: List[CollectiveOutputNode] = [] + + for input_node in input_nodes: + actor_handle: Optional[ + "ray.actor.ActorHandle" + ] = input_node._get_actor_handle() + if actor_handle is None: + raise ValueError("Expected an actor handle from the input node") + collective_output_node = CollectiveOutputNode( + method_name=f"allreduce.{op}", + method_args=(input_node,), + method_kwargs=dict(), + method_options=dict(), + other_args_to_resolve={ + PARENT_CLASS_NODE_KEY: actor_handle, + BIND_INDEX_KEY: actor_handle._ray_dag_bind_index, + COLLECTIVE_OPERATION_KEY: collective_op, + }, + ) + actor_handle._ray_dag_bind_index += 1 + collective_output_nodes.append(collective_output_node) + + return collective_output_nodes + + def __call__( + self, + tensor, + group_name: str = "default", + op: RayReduceOp = RayReduceOp.SUM, + ): + from ray.util.collective.collective import allreduce + + return allreduce(tensor, group_name, op) + + +allreduce = AllReduceWrapper() diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/collective/conftest.py b/.venv/lib/python3.11/site-packages/ray/experimental/collective/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..70013d600018b4e26174df4bc342676e295cd0ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/collective/conftest.py @@ -0,0 +1,253 @@ +import copy +import uuid +from typing import Dict, FrozenSet, List, Optional, Set, Tuple + +import torch + +import ray +from ray.experimental.channel.common import ChannelContext +from ray.experimental.channel.communicator import ( + Communicator, + ReduceOp, + TorchTensorAllocator, +) + + +class AbstractNcclGroup(Communicator): + """ + A dummy NCCL group for testing. + """ + + import cupy as cp + + def __init__(self, actor_handles: List[ray.actor.ActorHandle]): + self._actor_handles = actor_handles + self._rank = None + + def initialize(self, rank: int) -> None: + self._rank = rank + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + return self._actor_handles.index(actor) + + def get_world_size(self) -> int: + return len(self._actor_handles) + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + raise NotImplementedError + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + raise NotImplementedError + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + raise NotImplementedError + + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + + def destroy(self) -> None: + pass + + def get_transport_name(self) -> str: + return "nccl" + + +class MockNcclGroupSet: + def __init__(self): + # Represents a mapping from a NCCL group ID to a set of actors and a custom + # NCCL group. + self.ids_to_actors_and_custom_comms: Dict[ + str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]] + ] = {} + + def __call__( + self, + actors: List["ray.actor.ActorHandle"], + custom_nccl_group: Optional[Communicator] = None, + use_communication_streams: bool = False, + ) -> str: + group_id = str(uuid.uuid4()) + self.ids_to_actors_and_custom_comms[group_id] = ( + frozenset(actors), + custom_nccl_group, + ) + + if custom_nccl_group is None: + ranks = list(range(len(actors))) + else: + ranks = [custom_nccl_group.get_rank(actor) for actor in actors] + init_tasks = [ + actor.__ray_call__.remote( + mock_do_init_nccl_group, + group_id, + rank, + actors, + custom_nccl_group, + ) + for rank, actor in zip(ranks, actors) + ] + ray.get(init_tasks, timeout=30) + + ctx = ChannelContext.get_current() + if custom_nccl_group is not None: + ctx.communicators[group_id] = custom_nccl_group + else: + ctx.communicators[group_id] = AbstractNcclGroup(actors) + + return group_id + + def mock_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.communicators: + return + + actors, _ = self.ids_to_actors_and_custom_comms[group_id] + destroy_tasks = [ + actor.__ray_call__.remote( + mock_do_destroy_nccl_group, + group_id, + ) + for actor in actors + ] + ray.wait(destroy_tasks, timeout=30) + + if group_id in self.ids_to_actors_and_custom_comms: + del self.ids_to_actors_and_custom_comms[group_id] + ctx.communicators[group_id].destroy() + del ctx.communicators[group_id] + + def check_init( + self, + compiled_dag: "ray.dag.CompiledDAG", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]] + ], + ) -> None: + assert len(self.ids_to_actors_and_custom_comms) == len(actors_and_custom_comms) + assert ( + set(self.ids_to_actors_and_custom_comms.values()) == actors_and_custom_comms + ) + + nccl_group_id_p2p = compiled_dag.communicator_id_p2p + if p2p_actors_and_custom_comm is None: + assert nccl_group_id_p2p is None + else: + assert nccl_group_id_p2p + assert ( + self.ids_to_actors_and_custom_comms[nccl_group_id_p2p] + == p2p_actors_and_custom_comm + ) + + def check_teardown(self, nccl_group_ids: List[str]) -> None: + ctx = ChannelContext.get_current() + for nccl_group_id in nccl_group_ids: + assert nccl_group_id not in self.ids_to_actors_and_custom_comms + assert nccl_group_id not in ctx.communicators + + +@ray.remote +class CPUTorchTensorWorker: + def __init__(self): + self.device = "cpu" + + def return_tensor(self, size: int) -> torch.Tensor: + return torch.ones(size, device=self.device) + + def recv(self, tensor: torch.Tensor) -> Tuple[int, int]: + assert tensor.device == self.device + return tensor.shape, tensor[0] + + +def mock_do_init_nccl_group( + self, + group_id: str, + rank: int, + actors: List[ray.actor.ActorHandle], + custom_nccl_group: Optional[Communicator], +) -> None: + ctx = ChannelContext.get_current() + if custom_nccl_group is None: + nccl_group = AbstractNcclGroup(actors) + nccl_group.initialize(rank) + ctx.communicators[group_id] = nccl_group + else: + custom_nccl_group.initialize(rank) + ctx.communicators[group_id] = custom_nccl_group + + +def mock_do_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.communicators: + return + ctx.communicators[group_id].destroy() + del ctx.communicators[group_id] + + +def check_nccl_group_init( + monkeypatch, + dag: "ray.dag.DAGNode", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]] + ] = None, +) -> "ray.dag.CompiledDAG": + mock_nccl_group_set = MockNcclGroupSet() + monkeypatch.setattr( + "ray.dag.compiled_dag_node._init_communicator", + mock_nccl_group_set, + ) + monkeypatch.setattr( + "ray.dag.collective_node._init_communicator", + mock_nccl_group_set, + ) + + compiled_dag = dag.experimental_compile() + mock_nccl_group_set.check_init( + compiled_dag, + actors_and_custom_comms, + p2p_actors_and_custom_comm, + ) + + return compiled_dag, mock_nccl_group_set + + +def check_nccl_group_teardown( + monkeypatch, + compiled_dag: "ray.dag.CompiledDAG", + mock_nccl_group_set: MockNcclGroupSet, +): + monkeypatch.setattr( + "ray.dag.compiled_dag_node._destroy_communicator", + mock_nccl_group_set.mock_destroy_nccl_group, + ) + + nccl_group_ids = copy.deepcopy(compiled_dag.communicator_ids) + compiled_dag.teardown() + mock_nccl_group_set.check_teardown(nccl_group_ids)