diff --git a/.venv/lib/python3.11/site-packages/ray/_private/__init__.py b/.venv/lib/python3.11/site-packages/ray/_private/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/_private/arrow_serialization.py b/.venv/lib/python3.11/site-packages/ray/_private/arrow_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcd0972502c14ea5082793af97107cac8f0f7b6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/arrow_serialization.py @@ -0,0 +1,816 @@ +# arrow_serialization.py must resides outside of ray.data, otherwise +# it causes circular dependency issues for AsyncActors due to +# ray.data's lazy import. +# see https://github.com/ray-project/ray/issues/30498 for more context. +from dataclasses import dataclass +import logging +import os +import sys +from typing import List, Tuple, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow + from ray.data.extensions import ArrowTensorArray + +RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = ( + "RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION" +) +RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION = ( + "RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION" +) + +logger = logging.getLogger(__name__) + +# Whether we have already warned the user about bloated fallback serialization. +_serialization_fallback_set = set() + +# Whether we're currently running in a test, either local or CI. +_in_test = None + + +def _is_in_test(): + global _in_test + + if _in_test is None: + _in_test = any( + env_var in os.environ + # These environment variables are always set by pytest and Buildkite, + # respectively. + for env_var in ("PYTEST_CURRENT_TEST", "BUILDKITE") + ) + return _in_test + + +def _register_custom_datasets_serializers(serialization_context): + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError: + # No pyarrow installed so not using Arrow, so no need for custom serializers. + return + + # Register all custom serializers required by Datasets. + _register_arrow_data_serializer(serialization_context) + _register_arrow_json_readoptions_serializer(serialization_context) + _register_arrow_json_parseoptions_serializer(serialization_context) + + +# Register custom Arrow JSON ReadOptions serializer to workaround it not being picklable +# in Arrow < 8.0.0. +def _register_arrow_json_readoptions_serializer(serialization_context): + if ( + os.environ.get( + RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION, + "0", + ) + == "1" + ): + return + + import pyarrow.json as pajson + + serialization_context._register_cloudpickle_serializer( + pajson.ReadOptions, + custom_serializer=lambda opts: (opts.use_threads, opts.block_size), + custom_deserializer=lambda args: pajson.ReadOptions(*args), + ) + + +def _register_arrow_json_parseoptions_serializer(serialization_context): + if ( + os.environ.get( + RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION, + "0", + ) + == "1" + ): + return + + import pyarrow.json as pajson + + serialization_context._register_cloudpickle_serializer( + pajson.ParseOptions, + custom_serializer=lambda opts: ( + opts.explicit_schema, + opts.newlines_in_values, + opts.unexpected_field_behavior, + ), + custom_deserializer=lambda args: pajson.ParseOptions(*args), + ) + + +# Register custom Arrow data serializer to work around zero-copy slice pickling bug. +# See https://issues.apache.org/jira/browse/ARROW-10739. +def _register_arrow_data_serializer(serialization_context): + """Custom reducer for Arrow data that works around a zero-copy slicing pickling + bug by using the Arrow IPC format for the underlying serialization. + + Background: + Arrow has both array-level slicing and buffer-level slicing; both are zero-copy, + but the former has a serialization bug where the entire buffer is serialized + instead of just the slice, while the latter's serialization works as expected + and only serializes the slice of the buffer. I.e., array-level slicing doesn't + propagate the slice down to the buffer when serializing the array. + + We work around this by registering a custom cloudpickle reducers for Arrow + Tables that delegates serialization to the Arrow IPC format; thankfully, Arrow's + IPC serialization has fixed this buffer truncation bug. + + See https://issues.apache.org/jira/browse/ARROW-10739. + """ + if os.environ.get(RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION, "0") == "1": + return + + import pyarrow as pa + + serialization_context._register_cloudpickle_reducer(pa.Table, _arrow_table_reduce) + + +def _arrow_table_reduce(t: "pyarrow.Table"): + """Custom reducer for Arrow Tables that works around a zero-copy slice pickling bug. + Background: + Arrow has both array-level slicing and buffer-level slicing; both are zero-copy, + but the former has a serialization bug where the entire buffer is serialized + instead of just the slice, while the latter's serialization works as expected + and only serializes the slice of the buffer. I.e., array-level slicing doesn't + propagate the slice down to the buffer when serializing the array. + All that these copy methods do is, at serialization time, take the array-level + slicing and translate them to buffer-level slicing, so only the buffer slice is + sent over the wire instead of the entire buffer. + See https://issues.apache.org/jira/browse/ARROW-10739. + """ + global _serialization_fallback_set + + # Reduce the ChunkedArray columns. + reduced_columns = [] + for column_name in t.column_names: + column = t[column_name] + try: + # Delegate to ChunkedArray reducer. + reduced_column = _arrow_chunked_array_reduce(column) + except Exception as e: + if not _is_dense_union(column.type) and _is_in_test(): + # If running in a test and the column is not a dense union array + # (which we expect to need a fallback), we want to raise the error, + # not fall back. + raise e from None + if type(column.type) not in _serialization_fallback_set: + logger.warning( + "Failed to complete optimized serialization of Arrow Table, " + f"serialization of column '{column_name}' of type {column.type} " + "failed, so we're falling back to Arrow IPC serialization for the " + "table. Note that this may result in slower serialization and more " + "worker memory utilization. Serialization error:", + exc_info=True, + ) + _serialization_fallback_set.add(type(column.type)) + # Fall back to Arrow IPC-based workaround for the entire table. + return _arrow_table_ipc_reduce(t) + else: + # Column reducer succeeded, add reduced column to list. + reduced_columns.append(reduced_column) + return _reconstruct_table, (reduced_columns, t.schema) + + +def _reconstruct_table( + reduced_columns: List[Tuple[List["pyarrow.Array"], "pyarrow.DataType"]], + schema: "pyarrow.Schema", +) -> "pyarrow.Table": + """Restore a serialized Arrow Table, reconstructing each reduced column.""" + import pyarrow as pa + + # Reconstruct each reduced column. + columns = [] + for chunks_payload, type_ in reduced_columns: + columns.append(_reconstruct_chunked_array(chunks_payload, type_)) + + return pa.Table.from_arrays(columns, schema=schema) + + +def _arrow_chunked_array_reduce( + ca: "pyarrow.ChunkedArray", +) -> Tuple[List["PicklableArrayPayload"], "pyarrow.DataType"]: + """Custom reducer for Arrow ChunkedArrays that works around a zero-copy slice + pickling bug. This reducer does not return a reconstruction function, since it's + expected to be reconstructed by the Arrow Table reconstructor. + """ + # Convert chunks to serialization payloads. + chunk_payloads = [] + for chunk in ca.chunks: + chunk_payload = PicklableArrayPayload.from_array(chunk) + chunk_payloads.append(chunk_payload) + return chunk_payloads, ca.type + + +def _reconstruct_chunked_array( + chunks: List["PicklableArrayPayload"], type_: "pyarrow.DataType" +) -> "pyarrow.ChunkedArray": + """Restore a serialized Arrow ChunkedArray from chunks and type.""" + import pyarrow as pa + + # Reconstruct chunks from serialization payloads. + chunks = [chunk.to_array() for chunk in chunks] + + return pa.chunked_array(chunks, type_) + + +@dataclass +class PicklableArrayPayload: + """Picklable array payload, holding data buffers and array metadata. + + This is a helper container for pickling and reconstructing nested Arrow Arrays while + ensuring that the buffers that underly zero-copy slice views are properly truncated. + """ + + # Array type. + type: "pyarrow.DataType" + # Length of array. + length: int + # Underlying data buffers. + buffers: List["pyarrow.Buffer"] + # Cached null count. + null_count: int + # Slice offset into base array. + offset: int + # Serialized array payloads for nested (child) arrays. + children: List["PicklableArrayPayload"] + + @classmethod + def from_array(self, a: "pyarrow.Array") -> "PicklableArrayPayload": + """Create a picklable array payload from an Arrow Array. + + This will recursively accumulate data buffer and metadata payloads that are + ready for pickling; namely, the data buffers underlying zero-copy slice views + will be properly truncated. + """ + return _array_to_array_payload(a) + + def to_array(self) -> "pyarrow.Array": + """Reconstruct an Arrow Array from this picklable payload.""" + return _array_payload_to_array(self) + + +def _array_payload_to_array(payload: "PicklableArrayPayload") -> "pyarrow.Array": + """Reconstruct an Arrow Array from a possibly nested PicklableArrayPayload.""" + import pyarrow as pa + from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types + + children = [child_payload.to_array() for child_payload in payload.children] + + tensor_extension_types = get_arrow_extension_tensor_types() + + if pa.types.is_dictionary(payload.type): + # Dedicated path for reconstructing a DictionaryArray, since + # Array.from_buffers() doesn't work for DictionaryArrays. + assert len(children) == 2, len(children) + indices, dictionary = children + return pa.DictionaryArray.from_arrays(indices, dictionary) + elif pa.types.is_map(payload.type) and len(children) > 1: + # In pyarrow<7.0.0, the underlying map child array is not exposed, so we work + # with the key and item arrays. + assert len(children) == 3, len(children) + offsets, keys, items = children + return pa.MapArray.from_arrays(offsets, keys, items) + elif isinstance( + payload.type, + tensor_extension_types, + ): + # Dedicated path for reconstructing an ArrowTensorArray or + # ArrowVariableShapedTensorArray, both of which can't be reconstructed by the + # Array.from_buffers() API. + assert len(children) == 1, len(children) + storage = children[0] + return pa.ExtensionArray.from_storage(payload.type, storage) + else: + # Common case: use Array.from_buffers() to construct an array of a certain type. + return pa.Array.from_buffers( + type=payload.type, + length=payload.length, + buffers=payload.buffers, + null_count=payload.null_count, + offset=payload.offset, + children=children, + ) + + +def _array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload": + """Serialize an Arrow Array to an PicklableArrayPayload for later pickling. + + This function's primary purpose is to dispatch to the handler for the input array + type. + """ + import pyarrow as pa + + from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types + + tensor_extension_types = get_arrow_extension_tensor_types() + + if _is_dense_union(a.type): + # Dense unions are not supported. + # TODO(Clark): Support dense unions. + raise NotImplementedError( + "Custom slice view serialization of dense union arrays is not yet " + "supported." + ) + + # Dispatch to handler for array type. + if pa.types.is_null(a.type): + return _null_array_to_array_payload(a) + elif _is_primitive(a.type): + return _primitive_array_to_array_payload(a) + elif _is_binary(a.type): + return _binary_array_to_array_payload(a) + elif pa.types.is_list(a.type) or pa.types.is_large_list(a.type): + return _list_array_to_array_payload(a) + elif pa.types.is_fixed_size_list(a.type): + return _fixed_size_list_array_to_array_payload(a) + elif pa.types.is_struct(a.type): + return _struct_array_to_array_payload(a) + elif pa.types.is_union(a.type): + return _union_array_to_array_payload(a) + elif pa.types.is_dictionary(a.type): + return _dictionary_array_to_array_payload(a) + elif pa.types.is_map(a.type): + return _map_array_to_array_payload(a) + elif isinstance(a.type, tensor_extension_types): + return _tensor_array_to_array_payload(a) + elif isinstance(a.type, pa.ExtensionType): + return _extension_array_to_array_payload(a) + else: + raise ValueError("Unhandled Arrow array type:", a.type) + + +def _is_primitive(type_: "pyarrow.DataType") -> bool: + """Whether the provided Array type is primitive (boolean, numeric, temporal or + fixed-size binary).""" + import pyarrow as pa + + return ( + pa.types.is_integer(type_) + or pa.types.is_floating(type_) + or pa.types.is_decimal(type_) + or pa.types.is_boolean(type_) + or pa.types.is_temporal(type_) + or pa.types.is_fixed_size_binary(type_) + ) + + +def _is_binary(type_: "pyarrow.DataType") -> bool: + """Whether the provided Array type is a variable-sized binary type.""" + import pyarrow as pa + + return ( + pa.types.is_string(type_) + or pa.types.is_large_string(type_) + or pa.types.is_binary(type_) + or pa.types.is_large_binary(type_) + ) + + +def _null_array_to_array_payload(a: "pyarrow.NullArray") -> "PicklableArrayPayload": + """Serialize null array to PicklableArrayPayload.""" + # Buffer scheme: [None] + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[None], # Single null buffer is expected. + null_count=a.null_count, + offset=0, + children=[], + ) + + +def _primitive_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload": + """Serialize primitive (numeric, temporal, boolean) arrays to + PicklableArrayPayload. + """ + assert _is_primitive(a.type), a.type + # Buffer scheme: [bitmap, data] + buffers = a.buffers() + assert len(buffers) == 2, len(buffers) + + # Copy bitmap buffer, if needed. + bitmap_buf = buffers[0] + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(bitmap_buf, a.offset, len(a)) + else: + bitmap_buf = None + + # Copy data buffer, if needed. + data_buf = buffers[1] + if data_buf is not None: + data_buf = _copy_buffer_if_needed(buffers[1], a.type, a.offset, len(a)) + + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf, data_buf], + null_count=a.null_count, + offset=0, + children=[], + ) + + +def _binary_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload": + """Serialize binary (variable-sized binary, string) arrays to + PicklableArrayPayload. + """ + assert _is_binary(a.type), a.type + # Buffer scheme: [bitmap, value_offsets, data] + buffers = a.buffers() + assert len(buffers) == 3, len(buffers) + + # Copy bitmap buffer, if needed. + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a)) + else: + bitmap_buf = None + + # Copy offset buffer, if needed. + offset_buf = buffers[1] + offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed( + offset_buf, a.type, a.offset, len(a) + ) + data_buf = buffers[2] + data_buf = _copy_buffer_if_needed(data_buf, None, data_offset, data_length) + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf, offset_buf, data_buf], + null_count=a.null_count, + offset=0, + children=[], + ) + + +def _list_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload": + """Serialize list (regular and large) arrays to PicklableArrayPayload.""" + # Dedicated path for ListArrays. These arrays have a nested set of bitmap and + # offset buffers, eventually bottoming out on a data buffer. + # Buffer scheme: + # [bitmap, offsets, bitmap, offsets, ..., bitmap, data] + buffers = a.buffers() + assert len(buffers) > 1, len(buffers) + + # Copy bitmap buffer, if needed. + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a)) + else: + bitmap_buf = None + + # Copy offset buffer, if needed. + offset_buf = buffers[1] + offset_buf, child_offset, child_length = _copy_offsets_buffer_if_needed( + offset_buf, a.type, a.offset, len(a) + ) + + # Propagate slice to child. + child = a.values.slice(child_offset, child_length) + + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf, offset_buf], + null_count=a.null_count, + offset=0, + children=[_array_to_array_payload(child)], + ) + + +def _fixed_size_list_array_to_array_payload( + a: "pyarrow.FixedSizeListArray", +) -> "PicklableArrayPayload": + """Serialize fixed size list arrays to PicklableArrayPayload.""" + # Dedicated path for fixed-size lists. + # Buffer scheme: + # [bitmap, values_bitmap, values_data, values_subbuffers...] + buffers = a.buffers() + assert len(buffers) >= 1, len(buffers) + + # Copy bitmap buffer, if needed. + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a)) + else: + bitmap_buf = None + + # Propagate slice to child. + child_offset = a.type.list_size * a.offset + child_length = a.type.list_size * len(a) + child = a.values.slice(child_offset, child_length) + + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf], + null_count=a.null_count, + offset=0, + children=[_array_to_array_payload(child)], + ) + + +def _struct_array_to_array_payload(a: "pyarrow.StructArray") -> "PicklableArrayPayload": + """Serialize struct arrays to PicklableArrayPayload.""" + # Dedicated path for StructArrays. + # StructArrays have a top-level bitmap buffer and one or more children arrays. + # Buffer scheme: [bitmap, None, child_bitmap, child_data, ...] + buffers = a.buffers() + assert len(buffers) >= 1, len(buffers) + + # Copy bitmap buffer, if needed. + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a)) + else: + bitmap_buf = None + + # Get field children payload. + # Offsets and truncations are already propagated to the field arrays, so we can + # serialize them as-is. + children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)] + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf], + null_count=a.null_count, + offset=0, + children=children, + ) + + +def _union_array_to_array_payload(a: "pyarrow.UnionArray") -> "PicklableArrayPayload": + """Serialize union arrays to PicklableArrayPayload.""" + import pyarrow as pa + + # Dedicated path for UnionArrays. + # UnionArrays have a top-level bitmap buffer and type code buffer, and one or + # more children arrays. + # Buffer scheme: [None, typecodes, child_bitmap, child_data, ...] + assert not _is_dense_union(a.type) + buffers = a.buffers() + assert len(buffers) > 1, len(buffers) + + bitmap_buf = buffers[0] + assert bitmap_buf is None, bitmap_buf + + # Copy type code buffer, if needed. + type_code_buf = buffers[1] + type_code_buf = _copy_buffer_if_needed(type_code_buf, pa.int8(), a.offset, len(a)) + + # Get field children payload. + # Offsets and truncations are already propagated to the field arrays, so we can + # serialize them as-is. + children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)] + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[bitmap_buf, type_code_buf], + null_count=a.null_count, + offset=0, + children=children, + ) + + +def _dictionary_array_to_array_payload( + a: "pyarrow.DictionaryArray", +) -> "PicklableArrayPayload": + """Serialize dictionary arrays to PicklableArrayPayload.""" + # Dedicated path for DictionaryArrays. + # Buffer scheme: [indices_bitmap, indices_data] (dictionary stored separately) + indices_payload = _array_to_array_payload(a.indices) + dictionary_payload = _array_to_array_payload(a.dictionary) + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[], + null_count=a.null_count, + offset=0, + children=[indices_payload, dictionary_payload], + ) + + +def _map_array_to_array_payload(a: "pyarrow.MapArray") -> "PicklableArrayPayload": + """Serialize map arrays to PicklableArrayPayload.""" + import pyarrow as pa + + # Dedicated path for MapArrays. + # Buffer scheme: [bitmap, offsets, child_struct_array_buffers, ...] + buffers = a.buffers() + assert len(buffers) > 0, len(buffers) + + # Copy bitmap buffer, if needed. + if a.null_count > 0: + bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a)) + else: + bitmap_buf = None + + new_buffers = [bitmap_buf] + + # Copy offsets buffer, if needed. + offset_buf = buffers[1] + offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed( + offset_buf, a.type, a.offset, len(a) + ) + + if isinstance(a, pa.lib.ListArray): + # Map arrays directly expose the one child struct array in pyarrow>=7.0.0, which + # is easier to work with than the raw buffers. + new_buffers.append(offset_buf) + children = [_array_to_array_payload(a.values.slice(data_offset, data_length))] + else: + # In pyarrow<7.0.0, the child struct array is not exposed, so we work with the + # key and item arrays. + buffers = a.buffers() + assert len(buffers) > 2, len(buffers) + # Reconstruct offsets array. + offsets = pa.Array.from_buffers( + pa.int32(), len(a) + 1, [bitmap_buf, offset_buf] + ) + # Propagate slice to keys. + keys = a.keys.slice(data_offset, data_length) + # Propagate slice to items. + items = a.items.slice(data_offset, data_length) + children = [ + _array_to_array_payload(offsets), + _array_to_array_payload(keys), + _array_to_array_payload(items), + ] + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=new_buffers, + null_count=a.null_count, + offset=0, + children=children, + ) + + +def _tensor_array_to_array_payload(a: "ArrowTensorArray") -> "PicklableArrayPayload": + """Serialize tensor arrays to PicklableArrayPayload.""" + # Offset is propagated to storage array, and the storage array items align with the + # tensor elements, so we only need to do the straightforward creation of the storage + # array payload. + storage_payload = _array_to_array_payload(a.storage) + return PicklableArrayPayload( + type=a.type, + length=len(a), + buffers=[], + null_count=a.null_count, + offset=0, + children=[storage_payload], + ) + + +def _extension_array_to_array_payload( + a: "pyarrow.ExtensionArray", +) -> "PicklableArrayPayload": + payload = _array_to_array_payload(a.storage) + payload.type = a.type + payload.length = len(a) + payload.null_count = a.null_count + return payload + + +def _copy_buffer_if_needed( + buf: "pyarrow.Buffer", + type_: Optional["pyarrow.DataType"], + offset: int, + length: int, +) -> "pyarrow.Buffer": + """Copy buffer, if needed.""" + import pyarrow as pa + + if type_ is not None and pa.types.is_boolean(type_): + # Arrow boolean array buffers are bit-packed, with 8 entries per byte, + # and are accessed via bit offsets. + buf = _copy_bitpacked_buffer_if_needed(buf, offset, length) + else: + type_bytewidth = type_.bit_width // 8 if type_ is not None else 1 + buf = _copy_normal_buffer_if_needed(buf, type_bytewidth, offset, length) + return buf + + +def _copy_normal_buffer_if_needed( + buf: "pyarrow.Buffer", + byte_width: int, + offset: int, + length: int, +) -> "pyarrow.Buffer": + """Copy buffer, if needed.""" + byte_offset = offset * byte_width + byte_length = length * byte_width + if offset > 0 or byte_length < buf.size: + # Array is a zero-copy slice, so we need to copy to a new buffer before + # serializing; this slice of the underlying buffer (not the array) will ensure + # that the buffer is properly copied at pickle-time. + buf = buf.slice(byte_offset, byte_length) + return buf + + +def _copy_bitpacked_buffer_if_needed( + buf: "pyarrow.Buffer", + offset: int, + length: int, +) -> "pyarrow.Buffer": + """Copy bit-packed binary buffer, if needed.""" + bit_offset = offset % 8 + byte_offset = offset // 8 + byte_length = _bytes_for_bits(bit_offset + length) // 8 + if offset > 0 or byte_length < buf.size: + buf = buf.slice(byte_offset, byte_length) + if bit_offset != 0: + # Need to manually shift the buffer to eliminate the bit offset. + buf = _align_bit_offset(buf, bit_offset, byte_length) + return buf + + +def _copy_offsets_buffer_if_needed( + buf: "pyarrow.Buffer", + arr_type: "pyarrow.DataType", + offset: int, + length: int, +) -> Tuple["pyarrow.Buffer", int, int]: + """Copy the provided offsets buffer, returning the copied buffer and the + offset + length of the underlying data. + """ + import pyarrow as pa + import pyarrow.compute as pac + + if ( + pa.types.is_large_list(arr_type) + or pa.types.is_large_string(arr_type) + or pa.types.is_large_binary(arr_type) + or pa.types.is_large_unicode(arr_type) + ): + offset_type = pa.int64() + else: + offset_type = pa.int32() + # Copy offset buffer, if needed. + buf = _copy_buffer_if_needed(buf, offset_type, offset, length + 1) + # Reconstruct the offset array so we can determine the offset and length + # of the child array. + offsets = pa.Array.from_buffers(offset_type, length + 1, [None, buf]) + child_offset = offsets[0].as_py() + child_length = offsets[-1].as_py() - child_offset + # Create new offsets aligned to 0 for the copied data buffer slice. + offsets = pac.subtract(offsets, child_offset) + if pa.types.is_int32(offset_type): + # We need to cast the resulting Int64Array back down to an Int32Array. + offsets = offsets.cast(offset_type, safe=False) + buf = offsets.buffers()[1] + return buf, child_offset, child_length + + +def _bytes_for_bits(n: int) -> int: + """Round up n to the nearest multiple of 8. + This is used to get the byte-padded number of bits for n bits. + """ + return (n + 7) & (-8) + + +def _align_bit_offset( + buf: "pyarrow.Buffer", + bit_offset: int, + byte_length: int, +) -> "pyarrow.Buffer": + """Align the bit offset into the buffer with the front of the buffer by shifting + the buffer and eliminating the offset. + """ + import pyarrow as pa + + bytes_ = buf.to_pybytes() + bytes_as_int = int.from_bytes(bytes_, sys.byteorder) + bytes_as_int >>= bit_offset + bytes_ = bytes_as_int.to_bytes(byte_length, sys.byteorder) + return pa.py_buffer(bytes_) + + +def _arrow_table_ipc_reduce(table: "pyarrow.Table"): + """Custom reducer for Arrow Table that works around a zero-copy slicing pickling + bug by using the Arrow IPC format for the underlying serialization. + + This is currently used as a fallback for unsupported types (or unknown bugs) for + the manual buffer truncation workaround, e.g. for dense unions. + """ + from pyarrow.ipc import RecordBatchStreamWriter + from pyarrow.lib import BufferOutputStream + + output_stream = BufferOutputStream() + with RecordBatchStreamWriter(output_stream, schema=table.schema) as wr: + wr.write_table(table) + # NOTE: output_stream.getvalue() materializes the serialized table to a single + # contiguous bytestring, resulting in a few copy. This adds 1-2 extra copies on the + # serialization side, and 1 extra copy on the deserialization side. + return _restore_table_from_ipc, (output_stream.getvalue(),) + + +def _restore_table_from_ipc(buf: bytes) -> "pyarrow.Table": + """Restore an Arrow Table serialized to Arrow IPC format.""" + from pyarrow.ipc import RecordBatchStreamReader + + with RecordBatchStreamReader(buf) as reader: + return reader.read_all() + + +def _is_dense_union(type_: "pyarrow.DataType") -> bool: + """Whether the provided Arrow type is a dense union.""" + import pyarrow as pa + + return pa.types.is_union(type_) and type_.mode == "dense" diff --git a/.venv/lib/python3.11/site-packages/ray/_private/async_compat.py b/.venv/lib/python3.11/site-packages/ray/_private/async_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..a9081c2719b34795015394239405af3acbf45ac8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/async_compat.py @@ -0,0 +1,52 @@ +""" +This file should only be imported from Python 3. +It will raise SyntaxError when importing from Python 2. +""" +import asyncio +import inspect +from functools import lru_cache + +try: + import uvloop +except ImportError: + uvloop = None + + +def get_new_event_loop(): + """Construct a new event loop. Ray will use uvloop if it exists""" + if uvloop: + return uvloop.new_event_loop() + else: + return asyncio.new_event_loop() + + +def try_install_uvloop(): + """Installs uvloop as event-loop implementation for asyncio (if available)""" + if uvloop: + uvloop.install() + else: + pass + + +def is_async_func(func) -> bool: + """Return True if the function is an async or async generator method.""" + return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) + + +@lru_cache(maxsize=2**10) +def has_async_methods(cls: object) -> bool: + """Return True if the class has any async methods.""" + return len(inspect.getmembers(cls, predicate=is_async_func)) > 0 + + +@lru_cache(maxsize=2**10) +def sync_to_async(func): + """Wrap a blocking function in an async function""" + + if is_async_func(func): + return func + + async def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/ray/_private/async_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..014abbdfaed78c80db7386893ca2d1f11410f0fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/async_utils.py @@ -0,0 +1,52 @@ +# Adapted from [aiodebug](https://gitlab.com/quantlane/libs/aiodebug) + +# Copyright 2016-2022 Quantlane s.r.o. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modifications: +# - Removed the dependency to `logwood`. +# - Renamed `monitor_loop_lag.enable()` to just `enable_monitor_loop_lag()`. +# - Miscellaneous changes to make it work with Ray. + +from typing import Callable, Optional +import asyncio +import asyncio.events + + +def enable_monitor_loop_lag( + callback: Callable[[float], None], + interval_s: float = 0.25, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> None: + """ + Start logging event loop lags to the callback. In ideal circumstances they should be + very close to zero. Lags may increase if event loop callbacks block for too long. + + Note: this works for all event loops, including uvloop. + + :param callback: Callback to call with the lag in seconds. + """ + if loop is None: + loop = asyncio.get_running_loop() + if loop is None: + raise ValueError("No provided loop, nor running loop found.") + + async def monitor(): + while loop.is_running(): + t0 = loop.time() + await asyncio.sleep(interval_s) + lag = loop.time() - t0 - interval_s # Should be close to zero. + callback(lag) + + loop.create_task(monitor(), name="async_utils.monitor_loop_lag") diff --git a/.venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py b/.venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..68c6612bc077c7f6060112ff8f750364ac769de2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py @@ -0,0 +1,31 @@ +import ray +import os +from functools import wraps +import threading + +auto_init_lock = threading.Lock() +enable_auto_connect = os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0" + + +def auto_init_ray(): + if enable_auto_connect and not ray.is_initialized(): + with auto_init_lock: + if not ray.is_initialized(): + ray.init() + + +def wrap_auto_init(fn): + @wraps(fn) + def auto_init_wrapper(*args, **kwargs): + auto_init_ray() + return fn(*args, **kwargs) + + return auto_init_wrapper + + +def wrap_auto_init_for_all_apis(api_names): + """Wrap public APIs with automatic ray.init.""" + for api_name in api_names: + api = getattr(ray, api_name, None) + assert api is not None, api_name + setattr(ray, api_name, wrap_auto_init(api)) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py b/.venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..052aa01b0b75e0b735a08a6e3516ab002ce9f030 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py @@ -0,0 +1,184 @@ +import os +import threading +from contextlib import contextmanager +from functools import wraps +from ray._private.auto_init_hook import auto_init_ray + +# Attr set on func defs to mark they have been converted to client mode. +RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__" + +# Global setting of whether client mode is enabled. This default to OFF, +# but is enabled upon ray.client(...).connect() or in tests. +is_client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" + +# When RAY_CLIENT_MODE == 1, we treat it as default enabled client mode +# This is useful for testing +is_client_mode_enabled_by_default = is_client_mode_enabled +os.environ.update({"RAY_CLIENT_MODE": "0"}) + +is_init_called = False + +# Local setting of whether to ignore client hook conversion. This defaults +# to TRUE and is disabled when the underlying 'real' Ray function is needed. +_client_hook_status_on_thread = threading.local() +_client_hook_status_on_thread.status = True + + +def _get_client_hook_status_on_thread(): + """Get's the value of `_client_hook_status_on_thread`. + Since `_client_hook_status_on_thread` is a thread-local variable, we may + need to add and set the 'status' attribute. + """ + global _client_hook_status_on_thread + if not hasattr(_client_hook_status_on_thread, "status"): + _client_hook_status_on_thread.status = True + return _client_hook_status_on_thread.status + + +def _set_client_hook_status(val: bool): + global _client_hook_status_on_thread + _client_hook_status_on_thread.status = val + + +def _disable_client_hook(): + global _client_hook_status_on_thread + out = _get_client_hook_status_on_thread() + _client_hook_status_on_thread.status = False + return out + + +def _explicitly_enable_client_mode(): + """Force client mode to be enabled. + NOTE: This should not be used in tests, use `enable_client_mode`. + """ + global is_client_mode_enabled + is_client_mode_enabled = True + + +def _explicitly_disable_client_mode(): + global is_client_mode_enabled + is_client_mode_enabled = False + + +@contextmanager +def disable_client_hook(): + val = _disable_client_hook() + try: + yield None + finally: + _set_client_hook_status(val) + + +@contextmanager +def enable_client_mode(): + _explicitly_enable_client_mode() + try: + yield None + finally: + _explicitly_disable_client_mode() + + +def client_mode_hook(func: callable): + """Decorator for whether to use the 'regular' ray version of a function, + or the Ray Client version of that function. + + Args: + func: This function. This is set when this function is used + as a decorator. + """ + + from ray.util.client import ray + + @wraps(func) + def wrapper(*args, **kwargs): + # NOTE(hchen): DO NOT use "import" inside this function. + # Because when it's called within a `__del__` method, this error + # will be raised (see #35114): + # ImportError: sys.meta_path is None, Python is likely shutting down. + if client_mode_should_convert(): + # Legacy code + # we only convert init function if RAY_CLIENT_MODE=1 + if func.__name__ != "init" or is_client_mode_enabled_by_default: + return getattr(ray, func.__name__)(*args, **kwargs) + return func(*args, **kwargs) + + return wrapper + + +def client_mode_should_convert(): + """Determines if functions should be converted to client mode.""" + + # `is_client_mode_enabled_by_default` is used for testing with + # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode. + return ( + is_client_mode_enabled or is_client_mode_enabled_by_default + ) and _get_client_hook_status_on_thread() + + +def client_mode_wrap(func): + """Wraps a function called during client mode for execution as a remote + task. + + Can be used to implement public features of ray client which do not + belong in the main ray API (`ray.*`), yet require server-side execution. + An example is the creation of placement groups: + `ray.util.placement_group.placement_group()`. When called on the client + side, this function is wrapped in a task to facilitate interaction with + the GCS. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + from ray.util.client import ray + + auto_init_ray() + # Directly pass this through since `client_mode_wrap` is for + # Placement Group APIs + if client_mode_should_convert(): + f = ray.remote(num_cpus=0)(func) + ref = f.remote(*args, **kwargs) + return ray.get(ref) + return func(*args, **kwargs) + + return wrapper + + +def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs): + """Runs a preregistered ray RemoteFunction through the ray client. + + The common case for this is to transparently convert that RemoteFunction + to a ClientRemoteFunction. This happens in circumstances where the + RemoteFunction is declared early, in a library and only then is Ray used in + client mode -- necessitating a conversion. + """ + from ray.util.client import ray + + key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None) + + # Second part of "or" is needed in case func_cls is reused between Ray + # client sessions in one Python interpreter session. + if (key is None) or (not ray._converted_key_exists(key)): + key = ray._convert_function(func_cls) + setattr(func_cls, RAY_CLIENT_MODE_ATTR, key) + client_func = ray._get_converted(key) + return client_func._remote(in_args, in_kwargs, **kwargs) + + +def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs): + """Runs a preregistered actor class on the ray client + + The common case for this decorator is for instantiating an ActorClass + transparently as a ClientActorClass. This happens in circumstances where + the ActorClass is declared early, in a library and only then is Ray used in + client mode -- necessitating a conversion. + """ + from ray.util.client import ray + + key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None) + # Second part of "or" is needed in case actor_cls is reused between Ray + # client sessions in one Python interpreter session. + if (key is None) or (not ray._converted_key_exists(key)): + key = ray._convert_actor(actor_cls) + setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key) + client_actor = ray._get_converted(key) + return client_actor._remote(in_args, in_kwargs, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/collections_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/collections_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7ff6f3c3108b52bc55d52df308da00a03fb4d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/collections_utils.py @@ -0,0 +1,10 @@ +from typing import List, Any + + +def split(items: List[Any], chunk_size: int): + """Splits provided list into chunks of given size""" + + assert chunk_size > 0, "Chunk size has to be > 0" + + for i in range(0, len(items), chunk_size): + yield items[i : i + chunk_size] diff --git a/.venv/lib/python3.11/site-packages/ray/_private/compat.py b/.venv/lib/python3.11/site-packages/ray/_private/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a3896cce09fc662f2337ad5293dab76421d588 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/compat.py @@ -0,0 +1,40 @@ +import io +import platform + + +def patch_psutil(): + """WSL's /proc/meminfo has an inconsistency where it + nondeterministically omits a space after colons (after "SwapFree:" + in my case). + psutil then splits on spaces and then parses the wrong field, + crashing on the 'int(fields[1])' expression in + psutil._pslinux.virtual_memory(). + Workaround: We ensure there is a space following each colon. + """ + assert ( + platform.system() == "Linux" + and "Microsoft".lower() in platform.release().lower() + ) + + try: + import psutil._pslinux + except ImportError: + psutil = None + psutil_open_binary = None + if psutil: + try: + psutil_open_binary = psutil._pslinux.open_binary + except AttributeError: + pass + # Only patch it if it doesn't seem to have been patched already + if psutil_open_binary and psutil_open_binary.__name__ == "open_binary": + + def psutil_open_binary_patched(fname, *args, **kwargs): + f = psutil_open_binary(fname, *args, **kwargs) + if fname == "/proc/meminfo": + with f: + # Make sure there's a space after colons + return io.BytesIO(f.read().replace(b":", b": ")) + return f + + psutil._pslinux.open_binary = psutil_open_binary_patched diff --git a/.venv/lib/python3.11/site-packages/ray/_private/conftest_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/conftest_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9036dd4954a23f4712ad0a0bf7f11290086d1e2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/conftest_utils.py @@ -0,0 +1,14 @@ +import pytest +import ray._private.ray_constants as ray_constants + + +@pytest.fixture +def set_override_dashboard_url(monkeypatch, request): + override_url = getattr(request, "param", "https://external_dashboard_url") + with monkeypatch.context() as m: + if override_url: + m.setenv( + ray_constants.RAY_OVERRIDE_DASHBOARD_URL, + override_url, + ) + yield diff --git a/.venv/lib/python3.11/site-packages/ray/_private/dict.py b/.venv/lib/python3.11/site-packages/ray/_private/dict.py new file mode 100644 index 0000000000000000000000000000000000000000..3d102b32961f034f4e7af469898cb2fb66c5ec93 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/dict.py @@ -0,0 +1,247 @@ +import copy +from collections import deque +from collections.abc import Mapping, Sequence +from typing import Dict, List, Optional, TypeVar, Union + +from ray.util.annotations import Deprecated + +T = TypeVar("T") + + +@Deprecated +def merge_dicts(d1: dict, d2: dict) -> dict: + """ + Args: + d1 (dict): Dict 1. + d2 (dict): Dict 2. + + Returns: + dict: A new dict that is d1 and d2 deep merged. + """ + merged = copy.deepcopy(d1) + deep_update(merged, d2, True, []) + return merged + + +@Deprecated +def deep_update( + original: dict, + new_dict: dict, + new_keys_allowed: bool = False, + allow_new_subkey_list: Optional[List[str]] = None, + override_all_if_type_changes: Optional[List[str]] = None, + override_all_key_list: Optional[List[str]] = None, +) -> dict: + """Updates original dict with values from new_dict recursively. + + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the allow_new_subkey_list, then new subkeys can be introduced. + + Args: + original: Dictionary with default values. + new_dict: Dictionary with values to be updated + new_keys_allowed: Whether new keys are allowed. + allow_new_subkey_list: List of keys that + correspond to dict values where new subkeys can be introduced. + This is only at the top level. + override_all_if_type_changes: List of top level + keys with value=dict, for which we always simply override the + entire value (dict), iff the "type" key in that value dict changes. + override_all_key_list: List of top level keys + for which we override the entire value if the key is in the new_dict. + """ + allow_new_subkey_list = allow_new_subkey_list or [] + override_all_if_type_changes = override_all_if_type_changes or [] + override_all_key_list = override_all_key_list or [] + + for k, value in new_dict.items(): + if k not in original and not new_keys_allowed: + raise Exception("Unknown config parameter `{}` ".format(k)) + + # Both orginal value and new one are dicts. + if ( + isinstance(original.get(k), dict) + and isinstance(value, dict) + and k not in override_all_key_list + ): + # Check old type vs old one. If different, override entire value. + if ( + k in override_all_if_type_changes + and "type" in value + and "type" in original[k] + and value["type"] != original[k]["type"] + ): + original[k] = value + # Allowed key -> ok to add new subkeys. + elif k in allow_new_subkey_list: + deep_update( + original[k], + value, + True, + override_all_key_list=override_all_key_list, + ) + # Non-allowed key. + else: + deep_update( + original[k], + value, + new_keys_allowed, + override_all_key_list=override_all_key_list, + ) + # Original value not a dict OR new value not a dict: + # Override entire value. + else: + original[k] = value + return original + + +@Deprecated +def flatten_dict( + dt: Dict, + delimiter: str = "/", + prevent_delimiter: bool = False, + flatten_list: bool = False, +): + """Flatten dict. + + Output and input are of the same dict type. + Input dict remains the same after the operation. + """ + + def _raise_delimiter_exception(): + raise ValueError( + f"Found delimiter `{delimiter}` in key when trying to flatten " + f"array. Please avoid using the delimiter in your specification." + ) + + dt = copy.copy(dt) + if prevent_delimiter and any(delimiter in key for key in dt): + # Raise if delimiter is any of the keys + _raise_delimiter_exception() + + while_check = (dict, list) if flatten_list else dict + + while any(isinstance(v, while_check) for v in dt.values()): + remove = [] + add = {} + for key, value in dt.items(): + if isinstance(value, dict): + for subkey, v in value.items(): + if prevent_delimiter and delimiter in subkey: + # Raise if delimiter is in any of the subkeys + _raise_delimiter_exception() + + add[delimiter.join([key, str(subkey)])] = v + remove.append(key) + elif flatten_list and isinstance(value, list): + for i, v in enumerate(value): + if prevent_delimiter and delimiter in subkey: + # Raise if delimiter is in any of the subkeys + _raise_delimiter_exception() + + add[delimiter.join([key, str(i)])] = v + remove.append(key) + + dt.update(add) + for k in remove: + del dt[k] + return dt + + +@Deprecated +def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]: + """Unflatten dict. Does not support unflattening lists.""" + dict_type = type(dt) + out = dict_type() + for key, val in dt.items(): + path = key.split(delimiter) + item = out + for k in path[:-1]: + item = item.setdefault(k, dict_type()) + if not isinstance(item, dict_type): + raise TypeError( + f"Cannot unflatten dict due the key '{key}' " + f"having a parent key '{k}', which value is not " + f"of type {dict_type} (got {type(item)}). " + "Change the key names to resolve the conflict." + ) + item[path[-1]] = val + return out + + +@Deprecated +def unflatten_list_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]: + """Unflatten nested dict and list. + + This function now has some limitations: + (1) The keys of dt must be str. + (2) If unflattened dt (the result) contains list, the index order must be + ascending when accessing dt. Otherwise, this function will throw + AssertionError. + (3) The unflattened dt (the result) shouldn't contain dict with number + keys. + + Be careful to use this function. If you want to improve this function, + please also improve the unit test. See #14487 for more details. + + Args: + dt: Flattened dictionary that is originally nested by multiple + list and dict. + delimiter: Delimiter of keys. + + Example: + >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92} + >>> unflatten_list_dict(dt) + {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]} + """ + out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() else type(dt) + out = out_type() + for key, val in dt.items(): + path = key.split(delimiter) + + item = out + for i, k in enumerate(path[:-1]): + next_type = list if path[i + 1].isdigit() else dict + if isinstance(item, dict): + item = item.setdefault(k, next_type()) + elif isinstance(item, list): + if int(k) >= len(item): + item.append(next_type()) + assert int(k) == len(item) - 1 + item = item[int(k)] + + if isinstance(item, dict): + item[path[-1]] = val + elif isinstance(item, list): + item.append(val) + assert int(path[-1]) == len(item) - 1 + return out + + +@Deprecated +def unflattened_lookup( + flat_key: str, lookup: Union[Mapping, Sequence], delimiter: str = "/", **kwargs +) -> Union[Mapping, Sequence]: + """ + Unflatten `flat_key` and iteratively look up in `lookup`. E.g. + `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`. + """ + if flat_key in lookup: + return lookup[flat_key] + keys = deque(flat_key.split(delimiter)) + base = lookup + while keys: + key = keys.popleft() + try: + if isinstance(base, Mapping): + base = base[key] + elif isinstance(base, Sequence): + base = base[int(key)] + else: + raise KeyError() + except KeyError as e: + if "default" in kwargs: + return kwargs["default"] + raise e + return base diff --git a/.venv/lib/python3.11/site-packages/ray/_private/external_storage.py b/.venv/lib/python3.11/site-packages/ray/_private/external_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..d86430a6fd5c8ca1af0cf8307c649bded4da92cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/external_storage.py @@ -0,0 +1,707 @@ +import abc +import logging +import os +import random +import shutil +import time +import urllib +import uuid +from collections import namedtuple +from typing import IO, List, Optional, Tuple, Union + +import ray +from ray._private.ray_constants import DEFAULT_OBJECT_PREFIX +from ray._raylet import ObjectRef + +ParsedURL = namedtuple("ParsedURL", "base_url, offset, size") +logger = logging.getLogger(__name__) + + +def create_url_with_offset(*, url: str, offset: int, size: int) -> str: + """Methods to create a URL with offset. + + When ray spills objects, it fuses multiple objects + into one file to optimize the performance. That says, each object + needs to keep tracking of its own special url to store metadata. + + This method creates an url_with_offset, which is used internally + by Ray. + + Created url_with_offset can be passed to the self._get_base_url method + to parse the filename used to store files. + + Example) file://path/to/file?offset=""&size="" + + Args: + url: url to the object stored in the external storage. + offset: Offset from the beginning of the file to + the first bytes of this object. + size: Size of the object that is stored in the url. + It is used to calculate the last offset. + + Returns: + url_with_offset stored internally to find + objects from external storage. + """ + return f"{url}?offset={offset}&size={size}" + + +def parse_url_with_offset(url_with_offset: str) -> Tuple[str, int, int]: + """Parse url_with_offset to retrieve information. + + base_url is the url where the object ref + is stored in the external storage. + + Args: + url_with_offset: url created by create_url_with_offset. + + Returns: + named tuple of base_url, offset, and size. + """ + parsed_result = urllib.parse.urlparse(url_with_offset) + query_dict = urllib.parse.parse_qs(parsed_result.query) + # Split by ? to remove the query from the url. + base_url = parsed_result.geturl().split("?")[0] + if "offset" not in query_dict or "size" not in query_dict: + raise ValueError(f"Failed to parse URL: {url_with_offset}") + offset = int(query_dict["offset"][0]) + size = int(query_dict["size"][0]) + return ParsedURL(base_url=base_url, offset=offset, size=size) + + +class ExternalStorage(metaclass=abc.ABCMeta): + """The base class for external storage. + + This class provides some useful functions for zero-copy object + put/get from plasma store. Also it specifies the interface for + object spilling. + + When inheriting this class, please make sure to implement validation + logic inside __init__ method. When ray instance starts, it will + instantiating external storage to validate the config. + + Raises: + ValueError: when given configuration for + the external storage is invalid. + """ + + HEADER_LENGTH = 24 + + def _get_objects_from_store(self, object_refs): + worker = ray._private.worker.global_worker + # Since the object should always exist in the plasma store before + # spilling, it can directly get the object from the local plasma + # store. + # issue: https://github.com/ray-project/ray/pull/13831 + ray_object_pairs = worker.core_worker.get_if_local(object_refs) + return ray_object_pairs + + def _put_object_to_store( + self, metadata, data_size, file_like, object_ref, owner_address + ): + worker = ray._private.worker.global_worker + worker.core_worker.put_file_like_object( + metadata, data_size, file_like, object_ref, owner_address + ) + + def _write_multiple_objects( + self, f: IO, object_refs: List[ObjectRef], owner_addresses: List[str], url: str + ) -> List[str]: + """Fuse all given objects into a given file handle. + + Args: + f: File handle to fusion all given object refs. + object_refs: Object references to fusion to a single file. + owner_addresses: Owner addresses for the provided objects. + url: url where the object ref is stored + in the external storage. + + Return: + List of urls_with_offset of fused objects. + The order of returned keys are equivalent to the one + with given object_refs. + """ + keys = [] + offset = 0 + ray_object_pairs = self._get_objects_from_store(object_refs) + for ref, (buf, metadata), owner_address in zip( + object_refs, ray_object_pairs, owner_addresses + ): + address_len = len(owner_address) + metadata_len = len(metadata) + if buf is None and len(metadata) == 0: + error = f"Object {ref.hex()} does not exist." + raise ValueError(error) + buf_len = 0 if buf is None else len(buf) + payload = ( + address_len.to_bytes(8, byteorder="little") + + metadata_len.to_bytes(8, byteorder="little") + + buf_len.to_bytes(8, byteorder="little") + + owner_address + + metadata + + (memoryview(buf) if buf_len else b"") + ) + # 24 bytes to store owner address, metadata, and buffer lengths. + payload_len = len(payload) + assert ( + self.HEADER_LENGTH + address_len + metadata_len + buf_len == payload_len + ) + written_bytes = f.write(payload) + assert written_bytes == payload_len + url_with_offset = create_url_with_offset( + url=url, offset=offset, size=written_bytes + ) + keys.append(url_with_offset.encode()) + offset += written_bytes + # Necessary because pyarrow.io.NativeFile does not flush() on close(). + f.flush() + return keys + + def _size_check(self, address_len, metadata_len, buffer_len, obtained_data_size): + """Check whether or not the obtained_data_size is as expected. + + Args: + metadata_len: Actual metadata length of the object. + buffer_len: Actual buffer length of the object. + obtained_data_size: Data size specified in the + url_with_offset. + + Raises: + ValueError if obtained_data_size is different from + address_len + metadata_len + buffer_len + + 24 (first 8 bytes to store length). + """ + data_size_in_bytes = ( + address_len + metadata_len + buffer_len + self.HEADER_LENGTH + ) + if data_size_in_bytes != obtained_data_size: + raise ValueError( + f"Obtained data has a size of {data_size_in_bytes}, " + "although it is supposed to have the " + f"size of {obtained_data_size}." + ) + + @abc.abstractmethod + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + """Spill objects to the external storage. Objects are specified + by their object refs. + + Args: + object_refs: The list of the refs of the objects to be spilled. + owner_addresses: Owner addresses for the provided objects. + Returns: + A list of internal URLs with object offset. + """ + + @abc.abstractmethod + def restore_spilled_objects( + self, object_refs: List[ObjectRef], url_with_offset_list: List[str] + ) -> int: + """Restore objects from the external storage. + + Args: + object_refs: List of object IDs (note that it is not ref). + url_with_offset_list: List of url_with_offset. + + Returns: + The total number of bytes restored. + """ + + @abc.abstractmethod + def delete_spilled_objects(self, urls: List[str]): + """Delete objects that are spilled to the external storage. + + Args: + urls: URLs that store spilled object files. + + NOTE: This function should not fail if some of the urls + do not exist. + """ + + @abc.abstractmethod + def destroy_external_storage(self): + """Destroy external storage when a head node is down. + + NOTE: This is currently working when the cluster is + started by ray.init + """ + + +class NullStorage(ExternalStorage): + """The class that represents an uninitialized external storage.""" + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + raise NotImplementedError("External storage is not initialized") + + def restore_spilled_objects(self, object_refs, url_with_offset_list): + raise NotImplementedError("External storage is not initialized") + + def delete_spilled_objects(self, urls: List[str]): + raise NotImplementedError("External storage is not initialized") + + def destroy_external_storage(self): + raise NotImplementedError("External storage is not initialized") + + +class FileSystemStorage(ExternalStorage): + """The class for filesystem-like external storage. + + Raises: + ValueError: Raises directory path to + spill objects doesn't exist. + """ + + def __init__( + self, + node_id: str, + directory_path: Union[str, List[str]], + buffer_size: Optional[int] = None, + ): + # -- A list of directory paths to spill objects -- + self._directory_paths = [] + # -- Current directory to spill objects -- + self._current_directory_index = 0 + # -- File buffer size to spill objects -- + self._buffer_size = -1 + + # Validation. + assert ( + directory_path is not None + ), "directory_path should be provided to use object spilling." + if isinstance(directory_path, str): + directory_path = [directory_path] + assert isinstance( + directory_path, list + ), "Directory_path must be either a single string or a list of strings" + if buffer_size is not None: + assert isinstance(buffer_size, int), "buffer_size must be an integer." + self._buffer_size = buffer_size + + # Create directories. + for path in directory_path: + full_dir_path = os.path.join(path, f"{DEFAULT_OBJECT_PREFIX}_{node_id}") + os.makedirs(full_dir_path, exist_ok=True) + if not os.path.exists(full_dir_path): + raise ValueError( + "The given directory path to store objects, " + f"{full_dir_path}, could not be created." + ) + self._directory_paths.append(full_dir_path) + assert len(self._directory_paths) == len(directory_path) + # Choose the current directory. + # It chooses a random index to maximize multiple directories that are + # mounted at different point. + self._current_directory_index = random.randrange(0, len(self._directory_paths)) + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + if len(object_refs) == 0: + return [] + # Choose the current directory path by round robin order. + self._current_directory_index = (self._current_directory_index + 1) % len( + self._directory_paths + ) + directory_path = self._directory_paths[self._current_directory_index] + + filename = _get_unique_spill_filename(object_refs) + url = f"{os.path.join(directory_path, filename)}" + with open(url, "wb", buffering=self._buffer_size) as f: + return self._write_multiple_objects(f, object_refs, owner_addresses, url) + + def restore_spilled_objects( + self, object_refs: List[ObjectRef], url_with_offset_list: List[str] + ): + total = 0 + for i in range(len(object_refs)): + object_ref = object_refs[i] + url_with_offset = url_with_offset_list[i].decode() + # Retrieve the information needed. + parsed_result = parse_url_with_offset(url_with_offset) + base_url = parsed_result.base_url + offset = parsed_result.offset + # Read a part of the file and recover the object. + with open(base_url, "rb") as f: + f.seek(offset) + address_len = int.from_bytes(f.read(8), byteorder="little") + metadata_len = int.from_bytes(f.read(8), byteorder="little") + buf_len = int.from_bytes(f.read(8), byteorder="little") + self._size_check(address_len, metadata_len, buf_len, parsed_result.size) + total += buf_len + owner_address = f.read(address_len) + metadata = f.read(metadata_len) + # read remaining data to our buffer + self._put_object_to_store( + metadata, buf_len, f, object_ref, owner_address + ) + return total + + def delete_spilled_objects(self, urls: List[str]): + for url in urls: + path = parse_url_with_offset(url.decode()).base_url + try: + os.remove(path) + except FileNotFoundError: + # Occurs when the urls are retried during worker crash/failure. + pass + + def destroy_external_storage(self): + for directory_path in self._directory_paths: + self._destroy_external_storage(directory_path) + + def _destroy_external_storage(self, directory_path): + # There's a race condition where IO workers are still + # deleting each objects while we try deleting the + # whole directory. So we should keep trying it until + # The directory is actually deleted. + while os.path.isdir(directory_path): + try: + shutil.rmtree(directory_path) + except (FileNotFoundError): + # If exception occurs when other IO workers are + # deleting the file at the same time. + pass + except Exception: + logger.exception( + "Error cleaning up spill files. " + "You might still have remaining spilled " + "objects inside `ray_spilled_objects` directory." + ) + break + + +class ExternalStorageRayStorageImpl(ExternalStorage): + """Implements the external storage interface using the ray storage API.""" + + def __init__( + self, + node_id: str, + session_name: str, + # For remote spilling, at least 1MB is recommended. + buffer_size=1024 * 1024, + # Override the storage config for unit tests. + _force_storage_for_testing: Optional[str] = None, + ): + from ray._private import storage + + if _force_storage_for_testing: + storage._reset() + storage._init_storage(_force_storage_for_testing, True) + + self._fs, storage_prefix = storage._get_filesystem_internal() + self._buffer_size = buffer_size + self._prefix = os.path.join( + storage_prefix, f"{DEFAULT_OBJECT_PREFIX}_{node_id}", session_name + ) + self._fs.create_dir(self._prefix) + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + if len(object_refs) == 0: + return [] + filename = _get_unique_spill_filename(object_refs) + url = f"{os.path.join(self._prefix, filename)}" + with self._fs.open_output_stream(url, buffer_size=self._buffer_size) as f: + return self._write_multiple_objects(f, object_refs, owner_addresses, url) + + def restore_spilled_objects( + self, object_refs: List[ObjectRef], url_with_offset_list: List[str] + ): + total = 0 + for i in range(len(object_refs)): + object_ref = object_refs[i] + url_with_offset = url_with_offset_list[i].decode() + # Retrieve the information needed. + parsed_result = parse_url_with_offset(url_with_offset) + base_url = parsed_result.base_url + offset = parsed_result.offset + # Read a part of the file and recover the object. + with self._fs.open_input_file(base_url) as f: + f.seek(offset) + address_len = int.from_bytes(f.read(8), byteorder="little") + metadata_len = int.from_bytes(f.read(8), byteorder="little") + buf_len = int.from_bytes(f.read(8), byteorder="little") + self._size_check(address_len, metadata_len, buf_len, parsed_result.size) + total += buf_len + owner_address = f.read(address_len) + metadata = f.read(metadata_len) + # read remaining data to our buffer + self._put_object_to_store( + metadata, buf_len, f, object_ref, owner_address + ) + return total + + def delete_spilled_objects(self, urls: List[str]): + for url in urls: + path = parse_url_with_offset(url.decode()).base_url + try: + self._fs.delete_file(path) + except FileNotFoundError: + # Occurs when the urls are retried during worker crash/failure. + pass + + def destroy_external_storage(self): + try: + self._fs.delete_dir(self._prefix) + except Exception: + logger.exception( + "Error cleaning up spill files. " + "You might still have remaining spilled " + "objects inside `{}`.".format(self._prefix) + ) + + +class ExternalStorageSmartOpenImpl(ExternalStorage): + """The external storage class implemented by smart_open. + (https://github.com/RaRe-Technologies/smart_open) + + Smart open supports multiple backend with the same APIs. + + To use this implementation, you should pre-create the given uri. + For example, if your uri is a local file path, you should pre-create + the directory. + + Args: + uri: Storage URI used for smart open. + prefix: Prefix of objects that are stored. + override_transport_params: Overriding the default value of + transport_params for smart-open library. + + Raises: + ModuleNotFoundError: If it fails to setup. + For example, if smart open library + is not downloaded, this will fail. + """ + + def __init__( + self, + node_id: str, + uri: str or list, + override_transport_params: dict = None, + buffer_size=1024 * 1024, # For remote spilling, at least 1MB is recommended. + ): + try: + from smart_open import open # noqa + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Smart open is chosen to be a object spilling " + "external storage, but smart_open and boto3 " + f"is not downloaded. Original error: {e}" + ) + + # Validation + assert uri is not None, "uri should be provided to use object spilling." + if isinstance(uri, str): + uri = [uri] + assert isinstance(uri, list), "uri must be a single string or list of strings." + assert isinstance(buffer_size, int), "buffer_size must be an integer." + + uri_is_s3 = [u.startswith("s3://") for u in uri] + self.is_for_s3 = all(uri_is_s3) + if not self.is_for_s3: + assert not any(uri_is_s3), "all uri's must be s3 or none can be s3." + self._uris = uri + else: + self._uris = [u.strip("/") for u in uri] + assert len(self._uris) == len(uri) + + self._current_uri_index = random.randrange(0, len(self._uris)) + self.prefix = f"{DEFAULT_OBJECT_PREFIX}_{node_id}" + self.override_transport_params = override_transport_params or {} + + if self.is_for_s3: + import boto3 # noqa + + # Setup boto3. It is essential because if we don't create boto + # session, smart_open will create a new session for every + # open call. + self.s3 = boto3.resource(service_name="s3") + + # smart_open always seek to 0 if we don't set this argument. + # This will lead us to call a Object.get when it is not necessary, + # so defer seek and call seek before reading objects instead. + self.transport_params = { + "defer_seek": True, + "resource": self.s3, + "buffer_size": buffer_size, + } + else: + self.transport_params = {} + + self.transport_params.update(self.override_transport_params) + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + if len(object_refs) == 0: + return [] + from smart_open import open + + # Choose the current uri by round robin order. + self._current_uri_index = (self._current_uri_index + 1) % len(self._uris) + uri = self._uris[self._current_uri_index] + + key = f"{self.prefix}-{_get_unique_spill_filename(object_refs)}" + url = f"{uri}/{key}" + + with open( + url, + mode="wb", + transport_params=self.transport_params, + ) as file_like: + return self._write_multiple_objects( + file_like, object_refs, owner_addresses, url + ) + + def restore_spilled_objects( + self, object_refs: List[ObjectRef], url_with_offset_list: List[str] + ): + from smart_open import open + + total = 0 + for i in range(len(object_refs)): + object_ref = object_refs[i] + url_with_offset = url_with_offset_list[i].decode() + + # Retrieve the information needed. + parsed_result = parse_url_with_offset(url_with_offset) + base_url = parsed_result.base_url + offset = parsed_result.offset + + with open(base_url, "rb", transport_params=self.transport_params) as f: + # smart open seek reads the file from offset-end_of_the_file + # when the seek is called. + f.seek(offset) + address_len = int.from_bytes(f.read(8), byteorder="little") + metadata_len = int.from_bytes(f.read(8), byteorder="little") + buf_len = int.from_bytes(f.read(8), byteorder="little") + self._size_check(address_len, metadata_len, buf_len, parsed_result.size) + owner_address = f.read(address_len) + total += buf_len + metadata = f.read(metadata_len) + # read remaining data to our buffer + self._put_object_to_store( + metadata, buf_len, f, object_ref, owner_address + ) + return total + + def delete_spilled_objects(self, urls: List[str]): + pass + + def destroy_external_storage(self): + pass + + +_external_storage = NullStorage() + + +class UnstableFileStorage(FileSystemStorage): + """This class is for testing with writing failure.""" + + def __init__(self, node_id: str, **kwargs): + super().__init__(node_id, **kwargs) + self._failure_rate = 0.1 + self._partial_failure_ratio = 0.2 + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + r = random.random() < self._failure_rate + failed = r < self._failure_rate + partial_failed = r < self._partial_failure_ratio + if failed: + raise IOError("Spilling object failed") + elif partial_failed: + i = random.choice(range(len(object_refs))) + return super().spill_objects(object_refs[:i], owner_addresses) + else: + return super().spill_objects(object_refs, owner_addresses) + + +class SlowFileStorage(FileSystemStorage): + """This class is for testing slow object spilling.""" + + def __init__(self, node_id: str, **kwargs): + super().__init__(node_id, **kwargs) + self._min_delay = 1 + self._max_delay = 2 + + def spill_objects(self, object_refs, owner_addresses) -> List[str]: + delay = random.random() * (self._max_delay - self._min_delay) + self._min_delay + time.sleep(delay) + return super().spill_objects(object_refs, owner_addresses) + + +def setup_external_storage(config, node_id, session_name): + """Setup the external storage according to the config.""" + assert node_id is not None, "node_id should be provided." + global _external_storage + if config: + storage_type = config["type"] + if storage_type == "filesystem": + _external_storage = FileSystemStorage(node_id, **config["params"]) + elif storage_type == "ray_storage": + _external_storage = ExternalStorageRayStorageImpl( + node_id, session_name, **config["params"] + ) + elif storage_type == "smart_open": + _external_storage = ExternalStorageSmartOpenImpl( + node_id, **config["params"] + ) + elif storage_type == "mock_distributed_fs": + # This storage is used to unit test distributed external storages. + # TODO(sang): Delete it after introducing the mock S3 test. + _external_storage = FileSystemStorage(node_id, **config["params"]) + elif storage_type == "unstable_fs": + # This storage is used to unit test unstable file system for fault + # tolerance. + _external_storage = UnstableFileStorage(node_id, **config["params"]) + elif storage_type == "slow_fs": + # This storage is used to unit test slow filesystems. + _external_storage = SlowFileStorage(node_id, **config["params"]) + else: + raise ValueError(f"Unknown external storage type: {storage_type}") + else: + _external_storage = NullStorage() + return _external_storage + + +def reset_external_storage(): + global _external_storage + _external_storage = NullStorage() + + +def spill_objects(object_refs, owner_addresses): + """Spill objects to the external storage. Objects are specified + by their object refs. + + Args: + object_refs: The list of the refs of the objects to be spilled. + owner_addresses: The owner addresses of the provided object refs. + Returns: + A list of keys corresponding to the input object refs. + """ + return _external_storage.spill_objects(object_refs, owner_addresses) + + +def restore_spilled_objects( + object_refs: List[ObjectRef], url_with_offset_list: List[str] +): + """Restore objects from the external storage. + + Args: + object_refs: List of object IDs (note that it is not ref). + url_with_offset_list: List of url_with_offset. + """ + return _external_storage.restore_spilled_objects(object_refs, url_with_offset_list) + + +def delete_spilled_objects(urls: List[str]): + """Delete objects that are spilled to the external storage. + + Args: + urls: URLs that store spilled object files. + """ + _external_storage.delete_spilled_objects(urls) + + +def _get_unique_spill_filename(object_refs: List[ObjectRef]): + """Generate a unqiue spill file name. + + Args: + object_refs: objects to be spilled in this file. + """ + return f"{uuid.uuid4().hex}-multi-{len(object_refs)}" diff --git a/.venv/lib/python3.11/site-packages/ray/_private/function_manager.py b/.venv/lib/python3.11/site-packages/ray/_private/function_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..63a0c497e7d84f1c28350e5ef04b06b5686add61 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/function_manager.py @@ -0,0 +1,706 @@ +import dis +import sys +import hashlib +import importlib +import inspect +import json +import logging +import os +import threading +import time +import traceback +from collections import defaultdict, namedtuple +from typing import Optional, Callable + +import ray +from ray.remote_function import RemoteFunction +import ray._private.profiling as profiling +from ray import cloudpickle as pickle +from ray._private import ray_constants +from ray._private.inspect_util import ( + is_class_method, + is_function_or_method, + is_static_method, +) +from ray._private.ray_constants import KV_NAMESPACE_FUNCTION_TABLE +from ray._private.utils import ( + check_oversized_function, + ensure_str, + format_error_message, +) +from ray._private.serialization import pickle_dumps +from ray._raylet import ( + JobID, + PythonFunctionDescriptor, + WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS, +) + +FunctionExecutionInfo = namedtuple( + "FunctionExecutionInfo", ["function", "function_name", "max_calls"] +) +ImportedFunctionInfo = namedtuple( + "ImportedFunctionInfo", + ["job_id", "function_id", "function_name", "function", "module", "max_calls"], +) + +"""FunctionExecutionInfo: A named tuple storing remote function information.""" + +logger = logging.getLogger(__name__) + + +def make_function_table_key(key_type: bytes, job_id: JobID, key: Optional[bytes]): + if key is None: + return b":".join([key_type, job_id.hex().encode()]) + else: + return b":".join([key_type, job_id.hex().encode(), key]) + + +class FunctionActorManager: + """A class used to export/load remote functions and actors. + Attributes: + _worker: The associated worker that this manager related. + _functions_to_export: The remote functions to export when + the worker gets connected. + _actors_to_export: The actors to export when the worker gets + connected. + _function_execution_info: The function_id + and execution_info. + _num_task_executions: The function + execution times. + imported_actor_classes: The set of actor classes keys (format: + ActorClass:function_id) that are already in GCS. + """ + + def __init__(self, worker): + self._worker = worker + self._functions_to_export = [] + self._actors_to_export = [] + # This field is a dictionary that maps function IDs + # to a FunctionExecutionInfo object. This should only be used on + # workers that execute remote functions. + self._function_execution_info = defaultdict(lambda: {}) + self._num_task_executions = defaultdict(lambda: {}) + # A set of all of the actor class keys that have been imported by the + # import thread. It is safe to convert this worker into an actor of + # these types. + self.imported_actor_classes = set() + self._loaded_actor_classes = {} + # Deserialize an ActorHandle will call load_actor_class(). If a + # function closure captured an ActorHandle, the deserialization of the + # function will be: + # -> fetch_and_register_remote_function (acquire lock) + # -> _load_actor_class_from_gcs (acquire lock, too) + # So, the lock should be a reentrant lock. + self.lock = threading.RLock() + + self.execution_infos = {} + # This is the counter to keep track of how many keys have already + # been exported so that we can find next key quicker. + self._num_exported = 0 + # This is to protect self._num_exported when doing exporting + self._export_lock = threading.Lock() + + def increase_task_counter(self, function_descriptor): + function_id = function_descriptor.function_id + self._num_task_executions[function_id] += 1 + + def get_task_counter(self, function_descriptor): + function_id = function_descriptor.function_id + return self._num_task_executions[function_id] + + def compute_collision_identifier(self, function_or_class): + """The identifier is used to detect excessive duplicate exports. + The identifier is used to determine when the same function or class is + exported many times. This can yield false positives. + Args: + function_or_class: The function or class to compute an identifier + for. + Returns: + The identifier. Note that different functions or classes can give + rise to same identifier. However, the same function should + hopefully always give rise to the same identifier. TODO(rkn): + verify if this is actually the case. Note that if the + identifier is incorrect in any way, then we may give warnings + unnecessarily or fail to give warnings, but the application's + behavior won't change. + """ + import io + + string_file = io.StringIO() + dis.dis(function_or_class, file=string_file, depth=2) + collision_identifier = function_or_class.__name__ + ":" + string_file.getvalue() + + # Return a hash of the identifier in case it is too large. + return hashlib.sha1(collision_identifier.encode("utf-8")).digest() + + def load_function_or_class_from_local(self, module_name, function_or_class_name): + """Try to load a function or class in the module from local.""" + module = importlib.import_module(module_name) + parts = [part for part in function_or_class_name.split(".") if part] + object = module + try: + for part in parts: + object = getattr(object, part) + return object + except Exception: + return None + + def export_setup_func( + self, setup_func: Callable, timeout: Optional[int] = None + ) -> bytes: + """Export the setup hook function and return the key.""" + pickled_function = pickle_dumps( + setup_func, + "Cannot serialize the worker_process_setup_hook " f"{setup_func.__name__}", + ) + + function_to_run_id = hashlib.shake_128(pickled_function).digest( + ray_constants.ID_SIZE + ) + key = make_function_table_key( + # This value should match with gcs_function_manager.h. + # Otherwise, it won't be GC'ed. + WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS.encode(), + # b"FunctionsToRun", + self._worker.current_job_id.binary(), + function_to_run_id, + ) + + check_oversized_function( + pickled_function, setup_func.__name__, "function", self._worker + ) + + try: + self._worker.gcs_client.internal_kv_put( + key, + pickle.dumps( + { + "job_id": self._worker.current_job_id.binary(), + "function_id": function_to_run_id, + "function": pickled_function, + } + ), + # overwrite + True, + ray_constants.KV_NAMESPACE_FUNCTION_TABLE, + timeout=timeout, + ) + except Exception as e: + logger.exception( + "Failed to export the setup hook " f"{setup_func.__name__}." + ) + raise e + + return key + + def export(self, remote_function): + """Pickle a remote function and export it to redis. + Args: + remote_function: the RemoteFunction object. + """ + if self._worker.load_code_from_local: + function_descriptor = remote_function._function_descriptor + module_name, function_name = ( + function_descriptor.module_name, + function_descriptor.function_name, + ) + # If the function is dynamic, we still export it to GCS + # even if load_code_from_local is set True. + if ( + self.load_function_or_class_from_local(module_name, function_name) + is not None + ): + return + function = remote_function._function + pickled_function = remote_function._pickled_function + + check_oversized_function( + pickled_function, + remote_function._function_name, + "remote function", + self._worker, + ) + key = make_function_table_key( + b"RemoteFunction", + self._worker.current_job_id, + remote_function._function_descriptor.function_id.binary(), + ) + if self._worker.gcs_client.internal_kv_exists(key, KV_NAMESPACE_FUNCTION_TABLE): + return + val = pickle.dumps( + { + "job_id": self._worker.current_job_id.binary(), + "function_id": remote_function._function_descriptor.function_id.binary(), # noqa: E501 + "function_name": remote_function._function_name, + "module": function.__module__, + "function": pickled_function, + "collision_identifier": self.compute_collision_identifier(function), + "max_calls": remote_function._max_calls, + } + ) + self._worker.gcs_client.internal_kv_put( + key, val, True, KV_NAMESPACE_FUNCTION_TABLE + ) + + def fetch_registered_method( + self, key: str, timeout: Optional[int] = None + ) -> Optional[ImportedFunctionInfo]: + vals = self._worker.gcs_client.internal_kv_get( + key, KV_NAMESPACE_FUNCTION_TABLE, timeout=timeout + ) + if vals is None: + return None + else: + vals = pickle.loads(vals) + fields = [ + "job_id", + "function_id", + "function_name", + "function", + "module", + "max_calls", + ] + return ImportedFunctionInfo._make(vals.get(field) for field in fields) + + def fetch_and_register_remote_function(self, key): + """Import a remote function.""" + remote_function_info = self.fetch_registered_method(key) + if not remote_function_info: + return False + ( + job_id_str, + function_id_str, + function_name, + serialized_function, + module, + max_calls, + ) = remote_function_info + + function_id = ray.FunctionID(function_id_str) + job_id = ray.JobID(job_id_str) + max_calls = int(max_calls) + + # This function is called by ImportThread. This operation needs to be + # atomic. Otherwise, there is race condition. Another thread may use + # the temporary function above before the real function is ready. + with self.lock: + self._num_task_executions[function_id] = 0 + + try: + function = pickle.loads(serialized_function) + except Exception: + # If an exception was thrown when the remote function was + # imported, we record the traceback and notify the scheduler + # of the failure. + traceback_str = format_error_message(traceback.format_exc()) + + def f(*args, **kwargs): + raise RuntimeError( + "The remote function failed to import on the " + "worker. This may be because needed library " + "dependencies are not installed in the worker " + "environment or cannot be found from sys.path " + f"{sys.path}:\n\n{traceback_str}" + ) + + # Use a placeholder method when function pickled failed + self._function_execution_info[function_id] = FunctionExecutionInfo( + function=f, function_name=function_name, max_calls=max_calls + ) + + # Log the error message. Log at DEBUG level to avoid overly + # spamming the log on import failure. The user gets the error + # via the RuntimeError message above. + logger.debug( + "Failed to unpickle the remote function " + f"'{function_name}' with " + f"function ID {function_id.hex()}. " + f"Job ID:{job_id}." + f"Traceback:\n{traceback_str}. " + ) + else: + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python + # script was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + function.__module__ = module + self._function_execution_info[function_id] = FunctionExecutionInfo( + function=function, function_name=function_name, max_calls=max_calls + ) + return True + + def get_execution_info(self, job_id, function_descriptor): + """Get the FunctionExecutionInfo of a remote function. + Args: + job_id: ID of the job that the function belongs to. + function_descriptor: The FunctionDescriptor of the function to get. + Returns: + A FunctionExecutionInfo object. + """ + function_id = function_descriptor.function_id + # If the function has already been loaded, + # There's no need to load again + if function_id in self._function_execution_info: + return self._function_execution_info[function_id] + if self._worker.load_code_from_local: + # Load function from local code. + if not function_descriptor.is_actor_method(): + # If the function is not able to be loaded, + # try to load it from GCS, + # even if load_code_from_local is set True + if self._load_function_from_local(function_descriptor) is True: + return self._function_execution_info[function_id] + # Load function from GCS. + # Wait until the function to be executed has actually been + # registered on this worker. We will push warnings to the user if + # we spend too long in this loop. + # The driver function may not be found in sys.path. Try to load + # the function from GCS. + with profiling.profile("wait_for_function"): + self._wait_for_function(function_descriptor, job_id) + try: + function_id = function_descriptor.function_id + info = self._function_execution_info[function_id] + except KeyError as e: + message = ( + "Error occurs in get_execution_info: " + "job_id: %s, function_descriptor: %s. Message: %s" + % (job_id, function_descriptor, e) + ) + raise KeyError(message) + return info + + def _load_function_from_local(self, function_descriptor): + assert not function_descriptor.is_actor_method() + function_id = function_descriptor.function_id + + module_name, function_name = ( + function_descriptor.module_name, + function_descriptor.function_name, + ) + + object = self.load_function_or_class_from_local(module_name, function_name) + if object is not None: + # Directly importing from local may break function with dynamic ray.remote, + # such as the _start_controller function utilized for the Ray service. + if isinstance(object, RemoteFunction): + function = object._function + else: + function = object + self._function_execution_info[function_id] = FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=0, + ) + self._num_task_executions[function_id] = 0 + return True + else: + return False + + def _wait_for_function(self, function_descriptor, job_id: str, timeout=10): + """Wait until the function to be executed is present on this worker. + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate + a problem somewhere and we will push an error message to the user. + If this worker is an actor, then this will wait until the actor has + been defined. + Args: + function_descriptor : The FunctionDescriptor of the function that + we want to execute. + job_id: The ID of the job to push the error message to + if this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with self.lock: + if self._worker.actor_id.is_nil(): + if function_descriptor.function_id in self._function_execution_info: + break + else: + key = make_function_table_key( + b"RemoteFunction", + job_id, + function_descriptor.function_id.binary(), + ) + if self.fetch_and_register_remote_function(key) is True: + break + else: + assert not self._worker.actor_id.is_nil() + # Actor loading will happen when execute_task is called. + assert self._worker.actor_id in self._worker.actors + break + + if time.time() - start_time > timeout: + warning_message = ( + "This worker was asked to execute a function " + f"that has not been registered ({function_descriptor}, " + f"node={self._worker.node_ip_address}, " + f"worker_id={self._worker.worker_id.hex()}, " + f"pid={os.getpid()}). You may have to restart Ray." + ) + if not warning_sent: + logger.error(warning_message) + ray._private.utils.push_error_to_driver( + self._worker, + ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, + warning_message, + job_id=job_id, + ) + warning_sent = True + time.sleep(0.001) + + def export_actor_class( + self, Class, actor_creation_function_descriptor, actor_method_names + ): + if self._worker.load_code_from_local: + module_name, class_name = ( + actor_creation_function_descriptor.module_name, + actor_creation_function_descriptor.class_name, + ) + # If the class is dynamic, we still export it to GCS + # even if load_code_from_local is set True. + if ( + self.load_function_or_class_from_local(module_name, class_name) + is not None + ): + return + + # `current_job_id` shouldn't be NIL, unless: + # 1) This worker isn't an actor; + # 2) And a previous task started a background thread, which didn't + # finish before the task finished, and still uses Ray API + # after that. + assert not self._worker.current_job_id.is_nil(), ( + "You might have started a background thread in a non-actor " + "task, please make sure the thread finishes before the " + "task finishes." + ) + job_id = self._worker.current_job_id + key = make_function_table_key( + b"ActorClass", + job_id, + actor_creation_function_descriptor.function_id.binary(), + ) + serialized_actor_class = pickle_dumps( + Class, + f"Could not serialize the actor class " + f"{actor_creation_function_descriptor.repr}", + ) + actor_class_info = { + "class_name": actor_creation_function_descriptor.class_name.split(".")[-1], + "module": actor_creation_function_descriptor.module_name, + "class": serialized_actor_class, + "job_id": job_id.binary(), + "collision_identifier": self.compute_collision_identifier(Class), + "actor_method_names": json.dumps(list(actor_method_names)), + } + + check_oversized_function( + actor_class_info["class"], + actor_class_info["class_name"], + "actor", + self._worker, + ) + + self._worker.gcs_client.internal_kv_put( + key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE + ) + # TODO(rkn): Currently we allow actor classes to be defined + # within tasks. I tried to disable this, but it may be necessary + # because of https://github.com/ray-project/ray/issues/1146. + + def load_actor_class(self, job_id, actor_creation_function_descriptor): + """Load the actor class. + Args: + job_id: job ID of the actor. + actor_creation_function_descriptor: Function descriptor of + the actor constructor. + Returns: + The actor class. + """ + function_id = actor_creation_function_descriptor.function_id + # Check if the actor class already exists in the cache. + actor_class = self._loaded_actor_classes.get(function_id, None) + if actor_class is None: + # Load actor class. + if self._worker.load_code_from_local: + # Load actor class from local code first. + actor_class = self._load_actor_class_from_local( + actor_creation_function_descriptor + ) + # If the actor is unable to be loaded + # from local, try to load it + # from GCS even if load_code_from_local is set True + if actor_class is None: + actor_class = self._load_actor_class_from_gcs( + job_id, actor_creation_function_descriptor + ) + + else: + # Load actor class from GCS. + actor_class = self._load_actor_class_from_gcs( + job_id, actor_creation_function_descriptor + ) + # Save the loaded actor class in cache. + self._loaded_actor_classes[function_id] = actor_class + + # Generate execution info for the methods of this actor class. + module_name = actor_creation_function_descriptor.module_name + actor_class_name = actor_creation_function_descriptor.class_name + actor_methods = inspect.getmembers( + actor_class, predicate=is_function_or_method + ) + for actor_method_name, actor_method in actor_methods: + # Actor creation function descriptor use a unique function + # hash to solve actor name conflict. When constructing an + # actor, the actor creation function descriptor will be the + # key to find __init__ method execution info. So, here we + # use actor creation function descriptor as method descriptor + # for generating __init__ method execution info. + if actor_method_name == "__init__": + method_descriptor = actor_creation_function_descriptor + else: + method_descriptor = PythonFunctionDescriptor( + module_name, actor_method_name, actor_class_name + ) + method_id = method_descriptor.function_id + executor = self._make_actor_method_executor( + actor_method_name, + actor_method, + actor_imported=True, + ) + self._function_execution_info[method_id] = FunctionExecutionInfo( + function=executor, + function_name=actor_method_name, + max_calls=0, + ) + self._num_task_executions[method_id] = 0 + self._num_task_executions[function_id] = 0 + return actor_class + + def _load_actor_class_from_local(self, actor_creation_function_descriptor): + """Load actor class from local code.""" + module_name, class_name = ( + actor_creation_function_descriptor.module_name, + actor_creation_function_descriptor.class_name, + ) + + object = self.load_function_or_class_from_local(module_name, class_name) + + if object is not None: + if isinstance(object, ray.actor.ActorClass): + return object.__ray_metadata__.modified_class + else: + return object + else: + return None + + def _create_fake_actor_class( + self, actor_class_name, actor_method_names, traceback_str + ): + class TemporaryActor: + pass + + def temporary_actor_method(*args, **kwargs): + raise RuntimeError( + f"The actor with name {actor_class_name} " + "failed to import on the worker. This may be because " + "needed library dependencies are not installed in the " + f"worker environment:\n\n{traceback_str}" + ) + + for method in actor_method_names: + setattr(TemporaryActor, method, temporary_actor_method) + + return TemporaryActor + + def _load_actor_class_from_gcs(self, job_id, actor_creation_function_descriptor): + """Load actor class from GCS.""" + key = make_function_table_key( + b"ActorClass", + job_id, + actor_creation_function_descriptor.function_id.binary(), + ) + + # Fetch raw data from GCS. + vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE) + fields = ["job_id", "class_name", "module", "class", "actor_method_names"] + if vals is None: + vals = {} + else: + vals = pickle.loads(vals) + (job_id_str, class_name, module, pickled_class, actor_method_names) = ( + vals.get(field) for field in fields + ) + + class_name = ensure_str(class_name) + module_name = ensure_str(module) + job_id = ray.JobID(job_id_str) + actor_method_names = json.loads(ensure_str(actor_method_names)) + + actor_class = None + try: + with self.lock: + actor_class = pickle.loads(pickled_class) + except Exception: + logger.debug("Failed to load actor class %s.", class_name) + # If an exception was thrown when the actor was imported, we record + # the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # The actor class failed to be unpickled, create a fake actor + # class instead (just to produce error messages and to prevent + # the driver from hanging). + actor_class = self._create_fake_actor_class( + class_name, actor_method_names, traceback_str + ) + + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python script + # was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + actor_class.__module__ = module_name + return actor_class + + def _make_actor_method_executor( + self, method_name: str, method, actor_imported: bool + ): + """Make an executor that wraps a user-defined actor method. + The wrapped method updates the worker's internal state and performs any + necessary checkpointing operations. + Args: + method_name: The name of the actor method. + method: The actor method to wrap. This should be a + method defined on the actor class and should therefore take an + instance of the actor as the first argument. + actor_imported: Whether the actor has been imported. + Checkpointing operations will not be run if this is set to + False. + Returns: + A function that executes the given actor method on the worker's + stored instance of the actor. The function also updates the + worker's internal state to record the executed method. + """ + + def actor_method_executor(__ray_actor, *args, **kwargs): + # Execute the assigned method. + is_bound = is_class_method(method) or is_static_method( + type(__ray_actor), method_name + ) + if is_bound: + return method(*args, **kwargs) + else: + return method(__ray_actor, *args, **kwargs) + + # Set method_name and method as attributes to the executor closure + # so we can make decision based on these attributes in task executor. + # Precisely, asyncio support requires to know whether: + # - the method is a ray internal method: starts with __ray + # - the method is a coroutine function: defined by async def + actor_method_executor.name = method_name + actor_method_executor.method = method + + return actor_method_executor diff --git a/.venv/lib/python3.11/site-packages/ray/_private/gcs_aio_client.py b/.venv/lib/python3.11/site-packages/ray/_private/gcs_aio_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5ef7b64b1017cef238ad946a113bd223a52248bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/gcs_aio_client.py @@ -0,0 +1,47 @@ +import logging +from typing import Optional +import ray +from ray._raylet import InnerGcsClient + +logger = logging.getLogger(__name__) + + +class GcsAioClient: + """ + Async GCS client. + + Historical note: there was a `ray::gcs::PythonGcsClient` C++ binding which has only + sync API and in Python we wrap it with ThreadPoolExecutor. It's been removed in + favor of `ray::gcs::GcsClient` which contains async API. + """ + + def __init__( + self, + address: str = None, + loop=None, + executor=None, + nums_reconnect_retry: int = 5, + cluster_id: Optional[str] = None, + ): + # This must be consistent with GcsClient.__cinit__ in _raylet.pyx + timeout_ms = ray._config.py_gcs_connect_timeout_s() * 1000 + self.inner = InnerGcsClient.standalone( + str(address), cluster_id=cluster_id, timeout_ms=timeout_ms + ) + # Forwarded Methods. Not using __getattr__ because we want one fewer layer of + # indirection. + self.internal_kv_get = self.inner.async_internal_kv_get + self.internal_kv_multi_get = self.inner.async_internal_kv_multi_get + self.internal_kv_put = self.inner.async_internal_kv_put + self.internal_kv_del = self.inner.async_internal_kv_del + self.internal_kv_exists = self.inner.async_internal_kv_exists + self.internal_kv_keys = self.inner.async_internal_kv_keys + self.check_alive = self.inner.async_check_alive + self.get_all_job_info = self.inner.async_get_all_job_info + # Forwarded Properties. + self.address = self.inner.address + self.cluster_id = self.inner.cluster_id + # Note: these only exists in the new client. + self.get_all_actor_info = self.inner.async_get_all_actor_info + self.get_all_node_info = self.inner.async_get_all_node_info + self.kill_actor = self.inner.async_kill_actor diff --git a/.venv/lib/python3.11/site-packages/ray/_private/gcs_pubsub.py b/.venv/lib/python3.11/site-packages/ray/_private/gcs_pubsub.py new file mode 100644 index 0000000000000000000000000000000000000000..27d53c9763b12503b2dc848b84a9ba31d63800d2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/gcs_pubsub.py @@ -0,0 +1,311 @@ +import asyncio +from collections import deque +import logging +import random +from typing import Tuple, List + +import grpc +from ray._private.utils import get_or_create_event_loop + +try: + from grpc import aio as aiogrpc +except ImportError: + from grpc.experimental import aio as aiogrpc + +import ray._private.gcs_utils as gcs_utils +from ray.core.generated import gcs_service_pb2_grpc +from ray.core.generated import gcs_service_pb2 +from ray.core.generated import gcs_pb2 +from ray.core.generated import common_pb2 +from ray.core.generated import pubsub_pb2 + +logger = logging.getLogger(__name__) + +# Max retries for GCS publisher connection error +MAX_GCS_PUBLISH_RETRIES = 60 + + +class _PublisherBase: + @staticmethod + def _create_node_resource_usage_request(key: str, json: str): + return gcs_service_pb2.GcsPublishRequest( + pub_messages=[ + pubsub_pb2.PubMessage( + channel_type=pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, + key_id=key.encode(), + node_resource_usage_message=common_pb2.NodeResourceUsage(json=json), + ) + ] + ) + + +class _SubscriberBase: + def __init__(self, worker_id: bytes = None): + self._worker_id = worker_id + # self._subscriber_id needs to match the binary format of a random + # SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes. + self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28))) + self._last_batch_size = 0 + self._max_processed_sequence_id = 0 + self._publisher_id = b"" + + # Batch size of the result from last poll. Used to indicate whether the + # subscriber can keep up. + @property + def last_batch_size(self): + return self._last_batch_size + + def _subscribe_request(self, channel): + cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={}) + req = gcs_service_pb2.GcsSubscriberCommandBatchRequest( + subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd] + ) + return req + + def _poll_request(self): + return gcs_service_pb2.GcsSubscriberPollRequest( + subscriber_id=self._subscriber_id, + max_processed_sequence_id=self._max_processed_sequence_id, + publisher_id=self._publisher_id, + ) + + def _unsubscribe_request(self, channels): + req = gcs_service_pb2.GcsSubscriberCommandBatchRequest( + subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[] + ) + for channel in channels: + req.commands.append( + pubsub_pb2.Command(channel_type=channel, unsubscribe_message={}) + ) + return req + + @staticmethod + def _should_terminate_polling(e: grpc.RpcError) -> None: + # Caller only expects polling to be terminated after deadline exceeded. + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + return True + # Could be a temporary connection issue. Suppress error. + # TODO: reconnect GRPC channel? + if e.code() == grpc.StatusCode.UNAVAILABLE: + return True + return False + + +class GcsAioPublisher(_PublisherBase): + """Publisher to GCS. Uses async io.""" + + def __init__(self, address: str = None, channel: aiogrpc.Channel = None): + if address: + assert channel is None, "address and channel cannot both be specified" + channel = gcs_utils.create_gcs_channel(address, aio=True) + else: + assert channel is not None, "One of address and channel must be specified" + self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel) + + async def publish_resource_usage(self, key: str, json: str) -> None: + """Publishes logs to GCS.""" + req = self._create_node_resource_usage_request(key, json) + await self._stub.GcsPublish(req) + + +class _AioSubscriber(_SubscriberBase): + """Async io subscriber to GCS. + + Usage example common to Aio subscribers: + subscriber = GcsAioXxxSubscriber(address="...") + await subscriber.subscribe() + while running: + ...... = await subscriber.poll() + ...... + await subscriber.close() + """ + + def __init__( + self, + pubsub_channel_type, + worker_id: bytes = None, + address: str = None, + channel: aiogrpc.Channel = None, + ): + super().__init__(worker_id) + + if address: + assert channel is None, "address and channel cannot both be specified" + channel = gcs_utils.create_gcs_channel(address, aio=True) + else: + assert channel is not None, "One of address and channel must be specified" + # GRPC stub to GCS pubsub. + self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel) + + # Type of the channel. + self._channel = pubsub_channel_type + # A queue of received PubMessage. + self._queue = deque() + # Indicates whether the subscriber has closed. + self._close = asyncio.Event() + + async def subscribe(self) -> None: + """Registers a subscription for the subscriber's channel type. + + Before the registration, published messages in the channel will not be + saved for the subscriber. + """ + if self._close.is_set(): + return + req = self._subscribe_request(self._channel) + await self._stub.GcsSubscriberCommandBatch(req, timeout=30) + + async def _poll_call(self, req, timeout=None): + # Wrap GRPC _AioCall as a coroutine. + return await self._stub.GcsSubscriberPoll(req, timeout=timeout) + + async def _poll(self, timeout=None) -> None: + while len(self._queue) == 0: + req = self._poll_request() + poll = get_or_create_event_loop().create_task( + self._poll_call(req, timeout=timeout) + ) + close = get_or_create_event_loop().create_task(self._close.wait()) + done, others = await asyncio.wait( + [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + # Cancel the other task if needed to prevent memory leak. + other_task = others.pop() + if not other_task.done(): + other_task.cancel() + if poll not in done or close in done: + # Request timed out or subscriber closed. + break + try: + self._last_batch_size = len(poll.result().pub_messages) + if poll.result().publisher_id != self._publisher_id: + if self._publisher_id != "": + logger.debug( + f"replied publisher_id {poll.result().publisher_id}" + f"different from {self._publisher_id}, this should " + "only happens during gcs failover." + ) + self._publisher_id = poll.result().publisher_id + self._max_processed_sequence_id = 0 + for msg in poll.result().pub_messages: + if msg.sequence_id <= self._max_processed_sequence_id: + logger.warning(f"Ignoring out of order message {msg}") + continue + self._max_processed_sequence_id = msg.sequence_id + self._queue.append(msg) + except grpc.RpcError as e: + if self._should_terminate_polling(e): + return + raise + + async def close(self) -> None: + """Closes the subscriber and its active subscription.""" + + # Mark close to terminate inflight polling and prevent future requests. + if self._close.is_set(): + return + self._close.set() + req = self._unsubscribe_request(channels=[self._channel]) + try: + await self._stub.GcsSubscriberCommandBatch(req, timeout=5) + except Exception: + pass + self._stub = None + + +class GcsAioResourceUsageSubscriber(_AioSubscriber): + def __init__( + self, + worker_id: bytes = None, + address: str = None, + channel: grpc.Channel = None, + ): + super().__init__( + pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel + ) + + async def poll(self, timeout=None) -> Tuple[bytes, str]: + """Polls for new resource usage message. + + Returns: + A tuple of string reporter ID and resource usage json string. + """ + await self._poll(timeout=timeout) + return self._pop_resource_usage(self._queue) + + @staticmethod + def _pop_resource_usage(queue): + if len(queue) == 0: + return None, None + msg = queue.popleft() + return msg.key_id.decode(), msg.node_resource_usage_message.json + + +class GcsAioActorSubscriber(_AioSubscriber): + def __init__( + self, + worker_id: bytes = None, + address: str = None, + channel: grpc.Channel = None, + ): + super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel) + + @property + def queue_size(self): + return len(self._queue) + + async def poll( + self, batch_size, timeout=None + ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]: + """Polls for new actor message. + + Returns: + A list of tuples of binary actor ID and actor table data. + """ + await self._poll(timeout=timeout) + return self._pop_actors(self._queue, batch_size=batch_size) + + @staticmethod + def _pop_actors(queue, batch_size): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.actor_message)) + popped += 1 + return msgs + + +class GcsAioNodeInfoSubscriber(_AioSubscriber): + def __init__( + self, + worker_id: bytes = None, + address: str = None, + channel: grpc.Channel = None, + ): + super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel) + + async def poll( + self, batch_size, timeout=None + ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]: + """Polls for new node info message. + + Returns: + A list of tuples of (node_id, GcsNodeInfo). + """ + await self._poll(timeout=timeout) + return self._pop_node_infos(self._queue, batch_size=batch_size) + + @staticmethod + def _pop_node_infos(queue, batch_size): + if len(queue) == 0: + return [] + popped = 0 + msgs = [] + while len(queue) > 0 and popped < batch_size: + msg = queue.popleft() + msgs.append((msg.key_id, msg.node_info_message)) + popped += 1 + return msgs diff --git a/.venv/lib/python3.11/site-packages/ray/_private/gcs_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/gcs_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..307e0bdde11b98eeefa8c807a4bdb5f2c9e3f9c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/gcs_utils.py @@ -0,0 +1,163 @@ +import logging +from typing import Optional + +from ray._private import ray_constants + +import ray._private.gcs_aio_client + +from ray.core.generated.common_pb2 import ErrorType, JobConfig +from ray.core.generated.gcs_pb2 import ( + ActorTableData, + AvailableResources, + TotalResources, + ErrorTableData, + GcsEntry, + GcsNodeInfo, + JobTableData, + PlacementGroupTableData, + PubSubMessage, + ResourceDemand, + ResourceLoad, + ResourcesData, + ResourceUsageBatchData, + TablePrefix, + TablePubsub, + TaskEvents, + WorkerTableData, +) + +logger = logging.getLogger(__name__) + +__all__ = [ + "ActorTableData", + "GcsNodeInfo", + "AvailableResources", + "TotalResources", + "JobTableData", + "JobConfig", + "ErrorTableData", + "ErrorType", + "GcsEntry", + "ResourceUsageBatchData", + "ResourcesData", + "TablePrefix", + "TablePubsub", + "TaskEvents", + "ResourceDemand", + "ResourceLoad", + "PubSubMessage", + "WorkerTableData", + "PlacementGroupTableData", +] + + +WORKER = 0 +DRIVER = 1 + +# Cap messages at 512MB +_MAX_MESSAGE_LENGTH = 512 * 1024 * 1024 +# Send keepalive every 60s +_GRPC_KEEPALIVE_TIME_MS = 60 * 1000 +# Keepalive should be replied < 60s +_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000 + +# Also relying on these defaults: +# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls. +# grpc.use_local_subchannel_pool=0: Subchannels are shared. +_GRPC_OPTIONS = [ + *ray_constants.GLOBAL_GRPC_OPTIONS, + ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH), + ("grpc.keepalive_time_ms", _GRPC_KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", _GRPC_KEEPALIVE_TIMEOUT_MS), +] + + +def create_gcs_channel(address: str, aio=False): + """Returns a GRPC channel to GCS. + + Args: + address: GCS address string, e.g. ip:port + aio: Whether using grpc.aio + Returns: + grpc.Channel or grpc.aio.Channel to GCS + """ + from ray._private.utils import init_grpc_channel + + return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio) + + +class GcsChannel: + def __init__(self, gcs_address: Optional[str] = None, aio: bool = False): + self._gcs_address = gcs_address + self._aio = aio + + @property + def address(self): + return self._gcs_address + + def connect(self): + # GCS server uses a cached port, so it should use the same port after + # restarting. This means GCS address should stay the same for the + # lifetime of the Ray cluster. + self._channel = create_gcs_channel(self._gcs_address, self._aio) + + def channel(self): + return self._channel + + +# re-export +GcsAioClient = ray._private.gcs_aio_client.GcsAioClient + + +def cleanup_redis_storage( + host: str, + port: int, + password: str, + use_ssl: bool, + storage_namespace: str, + username: Optional[str] = None, +): + """This function is used to cleanup the storage. Before we having + a good design for storage backend, it can be used to delete the old + data. It support redis cluster and non cluster mode. + + Args: + host: The host address of the Redis. + port: The port of the Redis. + username: The username of the Redis. + password: The password of the Redis. + use_ssl: Whether to encrypt the connection. + storage_namespace: The namespace of the storage to be deleted. + """ + + from ray._raylet import del_key_prefix_from_storage # type: ignore + + if not isinstance(host, str): + raise ValueError("Host must be a string") + + if username is None: + username = "" + + if not isinstance(username, str): + raise ValueError("Username must be a string") + + if not isinstance(password, str): + raise ValueError("Password must be a string") + + if port < 0: + raise ValueError(f"Invalid port: {port}") + + if not isinstance(use_ssl, bool): + raise TypeError("use_ssl must be a boolean") + + if not isinstance(storage_namespace, str): + raise ValueError("storage namespace must be a string") + + # Right now, GCS stores all data into multiple hashes with keys prefixed by + # storage_namespace. So we only need to delete the specific key prefix to cleanup + # the cluster. + # Note this deletes all keys with prefix `RAY{key_prefix}@`, not `{key_prefix}`. + return del_key_prefix_from_storage( + host, port, username, password, use_ssl, storage_namespace + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/inspect_util.py b/.venv/lib/python3.11/site-packages/ray/_private/inspect_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae603f0160073a547131182556a3afff371390d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/inspect_util.py @@ -0,0 +1,49 @@ +import inspect + + +def is_cython(obj): + """Check if an object is a Cython function or method""" + + # TODO(suo): We could split these into two functions, one for Cython + # functions and another for Cython methods. + # TODO(suo): There doesn't appear to be a Cython function 'type' we can + # check against via isinstance. Please correct me if I'm wrong. + def check_cython(x): + return type(x).__name__ == "cython_function_or_method" + + # Check if function or method, respectively + return check_cython(obj) or ( + hasattr(obj, "__func__") and check_cython(obj.__func__) + ) + + +def is_function_or_method(obj): + """Check if an object is a function or method. + + Args: + obj: The Python object in question. + + Returns: + True if the object is an function or method. + """ + return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj) + + +def is_class_method(f): + """Returns whether the given method is a class_method.""" + return hasattr(f, "__self__") and f.__self__ is not None + + +def is_static_method(cls, f_name): + """Returns whether the class has a static method with the given name. + + Args: + cls: The Python class (i.e. object of type `type`) to + search for the method in. + f_name: The name of the method to look up in this class + and check whether or not it is static. + """ + for base_cls in inspect.getmro(cls): + if f_name in base_cls.__dict__: + return isinstance(base_cls.__dict__[f_name], staticmethod) + return False diff --git a/.venv/lib/python3.11/site-packages/ray/_private/internal_api.py b/.venv/lib/python3.11/site-packages/ray/_private/internal_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f4efbde4db21d6f3239f9893e6d27981667f0e60 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/internal_api.py @@ -0,0 +1,255 @@ +from typing import List, Tuple + +import ray +import ray._private.profiling as profiling +import ray._private.services as services +import ray._private.utils as utils +import ray._private.worker +from ray._private.state import GlobalState +from ray._raylet import GcsClientOptions +from ray.core.generated import common_pb2 + +__all__ = ["free", "global_gc"] +MAX_MESSAGE_LENGTH = ray._config.max_grpc_message_size() + + +def global_gc(): + """Trigger gc.collect() on all workers in the cluster.""" + + worker = ray._private.worker.global_worker + worker.core_worker.global_gc() + + +def get_state_from_address(address=None): + address = services.canonicalize_bootstrap_address_or_die(address) + + state = GlobalState() + options = GcsClientOptions.create( + address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False + ) + state._initialize_global_state(options) + return state + + +def memory_summary( + address=None, + group_by="NODE_ADDRESS", + sort_by="OBJECT_SIZE", + units="B", + line_wrap=True, + stats_only=False, + num_entries=None, +): + from ray.dashboard.memory_utils import memory_summary + + state = get_state_from_address(address) + reply = get_memory_info_reply(state) + + if stats_only: + return store_stats_summary(reply) + return memory_summary( + state, group_by, sort_by, line_wrap, units, num_entries + ) + store_stats_summary(reply) + + +def get_memory_info_reply(state, node_manager_address=None, node_manager_port=None): + """Returns global memory info.""" + + from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc + + # We can ask any Raylet for the global memory info, that Raylet internally + # asks all nodes in the cluster for memory stats. + if node_manager_address is None or node_manager_port is None: + # We should ask for a raylet that is alive. + raylet = None + for node in state.node_table(): + if node["Alive"]: + raylet = node + break + assert raylet is not None, "Every raylet is dead" + raylet_address = "{}:{}".format( + raylet["NodeManagerAddress"], raylet["NodeManagerPort"] + ) + else: + raylet_address = "{}:{}".format(node_manager_address, node_manager_port) + + channel = utils.init_grpc_channel( + raylet_address, + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + reply = stub.FormatGlobalMemoryInfo( + node_manager_pb2.FormatGlobalMemoryInfoRequest(include_memory_info=False), + timeout=60.0, + ) + return reply + + +def node_stats( + node_manager_address=None, node_manager_port=None, include_memory_info=True +): + """Returns NodeStats object describing memory usage in the cluster.""" + + from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc + + # We can ask any Raylet for the global memory info. + assert node_manager_address is not None and node_manager_port is not None + raylet_address = "{}:{}".format(node_manager_address, node_manager_port) + channel = utils.init_grpc_channel( + raylet_address, + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + node_stats = stub.GetNodeStats( + node_manager_pb2.GetNodeStatsRequest(include_memory_info=include_memory_info), + timeout=30.0, + ) + return node_stats + + +def store_stats_summary(reply): + """Returns formatted string describing object store stats in all nodes.""" + store_summary = "--- Aggregate object store stats across all nodes ---\n" + # TODO(ekl) it would be nice if we could provide a full memory usage + # breakdown by type (e.g., pinned by worker, primary, etc.) + store_summary += ( + "Plasma memory usage {} MiB, {} objects, {}% full, {}% " + "needed\n".format( + int(reply.store_stats.object_store_bytes_used / (1024 * 1024)), + reply.store_stats.num_local_objects, + round( + 100 + * reply.store_stats.object_store_bytes_used + / reply.store_stats.object_store_bytes_avail, + 2, + ), + round( + 100 + * reply.store_stats.object_store_bytes_primary_copy + / reply.store_stats.object_store_bytes_avail, + 2, + ), + ) + ) + if reply.store_stats.object_store_bytes_fallback > 0: + store_summary += "Plasma filesystem mmap usage: {} MiB\n".format( + int(reply.store_stats.object_store_bytes_fallback / (1024 * 1024)) + ) + if reply.store_stats.spill_time_total_s > 0: + store_summary += ( + "Spilled {} MiB, {} objects, avg write throughput {} MiB/s\n".format( + int(reply.store_stats.spilled_bytes_total / (1024 * 1024)), + reply.store_stats.spilled_objects_total, + int( + reply.store_stats.spilled_bytes_total + / (1024 * 1024) + / reply.store_stats.spill_time_total_s + ), + ) + ) + if reply.store_stats.restore_time_total_s > 0: + store_summary += ( + "Restored {} MiB, {} objects, avg read throughput {} MiB/s\n".format( + int(reply.store_stats.restored_bytes_total / (1024 * 1024)), + reply.store_stats.restored_objects_total, + int( + reply.store_stats.restored_bytes_total + / (1024 * 1024) + / reply.store_stats.restore_time_total_s + ), + ) + ) + if reply.store_stats.consumed_bytes > 0: + store_summary += "Objects consumed by Ray tasks: {} MiB.\n".format( + int(reply.store_stats.consumed_bytes / (1024 * 1024)) + ) + if reply.store_stats.object_pulls_queued: + store_summary += "Object fetches queued, waiting for available memory." + + return store_summary + + +def free(object_refs: list, local_only: bool = False): + """Free a list of IDs from the in-process and plasma object stores. + + This function is a low-level API which should be used in restricted + scenarios. + + If local_only is false, the request will be send to all object stores. + + This method will not return any value to indicate whether the deletion is + successful or not. This function is an instruction to the object store. If + some of the objects are in use, the object stores will delete them later + when the ref count is down to 0. + + Examples: + + .. testcode:: + + import ray + + @ray.remote + def f(): + return 0 + + obj_ref = f.remote() + ray.get(obj_ref) # wait for object to be created first + free([obj_ref]) # unpin & delete object globally + + Args: + object_refs (List[ObjectRef]): List of object refs to delete. + local_only: Whether only deleting the list of objects in local + object store or all object stores. + """ + worker = ray._private.worker.global_worker + + if isinstance(object_refs, ray.ObjectRef): + object_refs = [object_refs] + + if not isinstance(object_refs, list): + raise TypeError( + "free() expects a list of ObjectRef, got {}".format(type(object_refs)) + ) + + # Make sure that the values are object refs. + for object_ref in object_refs: + if not isinstance(object_ref, ray.ObjectRef): + raise TypeError( + "Attempting to call `free` on the value {}, " + "which is not an ray.ObjectRef.".format(object_ref) + ) + + worker.check_connected() + with profiling.profile("ray.free"): + if len(object_refs) == 0: + return + + worker.core_worker.free_objects(object_refs, local_only) + + +def get_local_ongoing_lineage_reconstruction_tasks() -> List[ + Tuple[common_pb2.LineageReconstructionTask, int] +]: + """Return the locally submitted ongoing retry tasks + triggered by lineage reconstruction. + + NOTE: for the lineage reconstruction task status, + this method only returns the status known to the submitter + (i.e. it returns SUBMITTED_TO_WORKER instead of RUNNING). + + The return type is a list of pairs where pair.first is the + lineage reconstruction task info and pair.second is the number + of ongoing lineage reconstruction tasks of this type. + """ + + worker = ray._private.worker.global_worker + worker.check_connected() + return worker.core_worker.get_local_ongoing_lineage_reconstruction_tasks() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/log.py b/.venv/lib/python3.11/site-packages/ray/_private/log.py new file mode 100644 index 0000000000000000000000000000000000000000..e475b42a8a5a422f60d749eb36a4e5fc42173c09 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/log.py @@ -0,0 +1,117 @@ +import logging +import threading +from typing import Union +import time + +INTERNAL_TIMESTAMP_LOG_KEY = "_ray_timestamp_ns" + + +def _print_loggers(): + """Print a formatted list of loggers and their handlers for debugging.""" + loggers = {logging.root.name: logging.root} + loggers.update(dict(sorted(logging.root.manager.loggerDict.items()))) + for name, logger in loggers.items(): + if isinstance(logger, logging.Logger): + print(f" {name}: disabled={logger.disabled}, propagate={logger.propagate}") + for handler in logger.handlers: + print(f" {handler}") + + +def clear_logger(logger: Union[str, logging.Logger]): + """Reset a logger, clearing its handlers and enabling propagation. + + Args: + logger: Logger to be cleared + """ + if isinstance(logger, str): + logger = logging.getLogger(logger) + logger.propagate = True + logger.handlers.clear() + + +class PlainRayHandler(logging.StreamHandler): + """A plain log handler. + + This handler writes to whatever sys.stderr points to at emit-time, + not at instantiation time. See docs for logging._StderrHandler. + """ + + def __init__(self): + super().__init__() + self.plain_handler = logging._StderrHandler() + self.plain_handler.level = self.level + self.plain_handler.formatter = logging.Formatter(fmt="%(message)s") + + def emit(self, record: logging.LogRecord): + """Emit the log message. + + If this is a worker, bypass fancy logging and just emit the log record. + If this is the driver, emit the message using the appropriate console handler. + + Args: + record: Log record to be emitted + """ + import ray + + if ( + hasattr(ray, "_private") + and hasattr(ray._private, "worker") + and ray._private.worker.global_worker.mode + == ray._private.worker.WORKER_MODE + ): + self.plain_handler.emit(record) + else: + logging._StderrHandler.emit(self, record) + + +logger_initialized = False +logging_config_lock = threading.Lock() + + +def _setup_log_record_factory(): + """Setup log record factory to add _ray_timestamp_ns to LogRecord.""" + old_factory = logging.getLogRecordFactory() + + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + # Python logging module starts to use `time.time_ns()` to generate `created` + # from Python 3.13 to avoid the precision loss caused by the float type. + # Here, we generate the `created` for the LogRecord to support older Python + # versions. + ct = time.time_ns() + record.created = ct / 1e9 + + record.__dict__[INTERNAL_TIMESTAMP_LOG_KEY] = ct + + return record + + logging.setLogRecordFactory(record_factory) + + +def generate_logging_config(): + """Generate the default Ray logging configuration.""" + with logging_config_lock: + global logger_initialized + if logger_initialized: + return + logger_initialized = True + + plain_formatter = logging.Formatter( + "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" + ) + + default_handler = PlainRayHandler() + default_handler.setFormatter(plain_formatter) + + ray_logger = logging.getLogger("ray") + ray_logger.setLevel(logging.INFO) + ray_logger.addHandler(default_handler) + ray_logger.propagate = False + + # Special handling for ray.rllib: only warning-level messages passed through + # See https://github.com/ray-project/ray/pull/31858 for related PR + rllib_logger = logging.getLogger("ray.rllib") + rllib_logger.setLevel(logging.WARN) + + # Set up the LogRecord factory. + _setup_log_record_factory() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/log_monitor.py b/.venv/lib/python3.11/site-packages/ray/_private/log_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..80fcf09a5a7a719fde63724462c735d340ddc18e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/log_monitor.py @@ -0,0 +1,581 @@ +import argparse +import errno +import glob +import logging +import logging.handlers +import os +import platform +import re +import shutil +import time +import traceback +from typing import Callable, List, Optional, Set + +from ray._raylet import GcsClient +import ray._private.ray_constants as ray_constants +import ray._private.services as services +import ray._private.utils +from ray._private.ray_logging import setup_component_logger + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + +# The groups are job id, and pid. +WORKER_LOG_PATTERN = re.compile(".*worker.*-([0-9a-f]+)-(\d+)") +# The groups are job id. +RUNTIME_ENV_SETUP_PATTERN = re.compile(".*runtime_env_setup-(\d+).log") +# Log name update interval under pressure. +# We need it because log name update is CPU intensive and uses 100% +# of cpu when there are many log files. +LOG_NAME_UPDATE_INTERVAL_S = float(os.getenv("LOG_NAME_UPDATE_INTERVAL_S", 0.5)) +# Once there are more files than this threshold, +# log monitor start giving backpressure to lower cpu usages. +RAY_LOG_MONITOR_MANY_FILES_THRESHOLD = int( + os.getenv("RAY_LOG_MONITOR_MANY_FILES_THRESHOLD", 1000) +) +RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED = int( + os.getenv("RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED", 0) +) + + +class LogFileInfo: + def __init__( + self, + filename=None, + size_when_last_opened=None, + file_position=None, + file_handle=None, + is_err_file=False, + job_id=None, + worker_pid=None, + ): + assert ( + filename is not None + and size_when_last_opened is not None + and file_position is not None + ) + self.filename = filename + self.size_when_last_opened = size_when_last_opened + self.file_position = file_position + self.file_handle = file_handle + self.is_err_file = is_err_file + self.job_id = job_id + self.worker_pid = worker_pid + self.actor_name = None + self.task_name = None + + def reopen_if_necessary(self): + """Check if the file's inode has changed and reopen it if necessary. + There are a variety of reasons what we would logically consider a file + would have different inodes, such as log rotation or file syncing + semantics. + """ + try: + open_inode = None + if self.file_handle and not self.file_handle.closed: + open_inode = os.fstat(self.file_handle.fileno()).st_ino + + new_inode = os.stat(self.filename).st_ino + if open_inode != new_inode: + self.file_handle = open(self.filename, "rb") + self.file_handle.seek(self.file_position) + except Exception: + logger.debug(f"file no longer exists, skip re-opening of {self.filename}") + + def __repr__(self): + return ( + "FileInfo(\n" + f"\tfilename: {self.filename}\n" + f"\tsize_when_last_opened: {self.size_when_last_opened}\n" + f"\tfile_position: {self.file_position}\n" + f"\tfile_handle: {self.file_handle}\n" + f"\tis_err_file: {self.is_err_file}\n" + f"\tjob_id: {self.job_id}\n" + f"\tworker_pid: {self.worker_pid}\n" + f"\tactor_name: {self.actor_name}\n" + f"\ttask_name: {self.task_name}\n" + ")" + ) + + +class LogMonitor: + """A monitor process for monitoring Ray log files. + + This class maintains a list of open files and a list of closed log files. We + can't simply leave all files open because we'll run out of file + descriptors. + + The "run" method of this class will cycle between doing several things: + 1. First, it will check if any new files have appeared in the log + directory. If so, they will be added to the list of closed files. + 2. Then, if we are unable to open any new files, we will close all of the + files. + 3. Then, we will open as many closed files as we can that may have new + lines (judged by an increase in file size since the last time the file + was opened). + 4. Then we will loop through the open files and see if there are any new + lines in the file. If so, we will publish them to Ray pubsub. + + Attributes: + ip: The hostname of this machine, for grouping log messages. + logs_dir: The directory that the log files are in. + log_filenames: This is the set of filenames of all files in + open_file_infos and closed_file_infos. + open_file_infos (list[LogFileInfo]): Info for all of the open files. + closed_file_infos (list[LogFileInfo]): Info for all of the closed + files. + can_open_more_files: True if we can still open more files and + false otherwise. + max_files_open: The maximum number of files that can be open. + """ + + def __init__( + self, + node_ip_address: str, + logs_dir: str, + gcs_publisher: ray._raylet.GcsPublisher, + is_proc_alive_fn: Callable[[int], bool], + max_files_open: int = ray_constants.LOG_MONITOR_MAX_OPEN_FILES, + gcs_address: Optional[str] = None, + ): + """Initialize the log monitor object.""" + self.ip: str = node_ip_address + self.logs_dir: str = logs_dir + self.publisher = gcs_publisher + self.log_filenames: Set[str] = set() + self.open_file_infos: List[LogFileInfo] = [] + self.closed_file_infos: List[LogFileInfo] = [] + self.can_open_more_files: bool = True + self.max_files_open: int = max_files_open + self.is_proc_alive_fn: Callable[[int], bool] = is_proc_alive_fn + self.is_autoscaler_v2: bool = self.get_is_autoscaler_v2(gcs_address) + + logger.info( + f"Starting log monitor with [max open files={max_files_open}]," + f" [is_autoscaler_v2={self.is_autoscaler_v2}]" + ) + + def get_is_autoscaler_v2(self, gcs_address: Optional[str]) -> bool: + """Check if autoscaler v2 is enabled.""" + if gcs_address is None: + return False + + if not ray.experimental.internal_kv._internal_kv_initialized(): + gcs_client = GcsClient(address=gcs_address) + ray.experimental.internal_kv._initialize_internal_kv(gcs_client) + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + return is_autoscaler_v2() + + def _close_all_files(self): + """Close all open files (so that we can open more).""" + while len(self.open_file_infos) > 0: + file_info = self.open_file_infos.pop(0) + file_info.file_handle.close() + file_info.file_handle = None + + proc_alive = True + # Test if the worker process that generated the log file + # is still alive. Only applies to worker processes. + # For all other system components, we always assume they are alive. + if ( + file_info.worker_pid != "raylet" + and file_info.worker_pid != "gcs_server" + and file_info.worker_pid != "autoscaler" + and file_info.worker_pid != "runtime_env" + and file_info.worker_pid is not None + ): + assert not isinstance(file_info.worker_pid, str), ( + "PID should be an int type. " f"Given PID: {file_info.worker_pid}." + ) + proc_alive = self.is_proc_alive_fn(file_info.worker_pid) + if not proc_alive: + # The process is not alive any more, so move the log file + # out of the log directory so glob.glob will not be slowed + # by it. + target = os.path.join( + self.logs_dir, "old", os.path.basename(file_info.filename) + ) + try: + shutil.move(file_info.filename, target) + except (IOError, OSError) as e: + if e.errno == errno.ENOENT: + logger.warning( + f"Warning: The file {file_info.filename} was not found." + ) + else: + raise e + + if proc_alive: + self.closed_file_infos.append(file_info) + + self.can_open_more_files = True + + def update_log_filenames(self): + """Update the list of log files to monitor.""" + monitor_log_paths = [] + # output of user code is written here + monitor_log_paths += glob.glob( + f"{self.logs_dir}/worker*[.out|.err]" + ) + glob.glob(f"{self.logs_dir}/java-worker*.log") + # segfaults and other serious errors are logged here + monitor_log_paths += glob.glob(f"{self.logs_dir}/raylet*.err") + # monitor logs are needed to report autoscaler events + # TODO(rickyx): remove this after migration. + if not self.is_autoscaler_v2: + # We publish monitor logs in autoscaler v1 + monitor_log_paths += glob.glob(f"{self.logs_dir}/monitor.log") + else: + # We publish autoscaler events directly in autoscaler v2 + monitor_log_paths += glob.glob( + f"{self.logs_dir}/events/event_AUTOSCALER.log" + ) + + # If gcs server restarts, there can be multiple log files. + monitor_log_paths += glob.glob(f"{self.logs_dir}/gcs_server*.err") + + # runtime_env setup process is logged here + if RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED: + monitor_log_paths += glob.glob(f"{self.logs_dir}/runtime_env*.log") + for file_path in monitor_log_paths: + if os.path.isfile(file_path) and file_path not in self.log_filenames: + worker_match = WORKER_LOG_PATTERN.match(file_path) + if worker_match: + worker_pid = int(worker_match.group(2)) + else: + worker_pid = None + job_id = None + + # Perform existence check first because most file will not be + # including runtime_env. This saves some cpu cycle. + if "runtime_env" in file_path: + runtime_env_job_match = RUNTIME_ENV_SETUP_PATTERN.match(file_path) + if runtime_env_job_match: + job_id = runtime_env_job_match.group(1) + + is_err_file = file_path.endswith("err") + + self.log_filenames.add(file_path) + self.closed_file_infos.append( + LogFileInfo( + filename=file_path, + size_when_last_opened=0, + file_position=0, + file_handle=None, + is_err_file=is_err_file, + job_id=job_id, + worker_pid=worker_pid, + ) + ) + log_filename = os.path.basename(file_path) + logger.info(f"Beginning to track file {log_filename}") + + def open_closed_files(self): + """Open some closed files if they may have new lines. + + Opening more files may require us to close some of the already open + files. + """ + if not self.can_open_more_files: + # If we can't open any more files. Close all of the files. + self._close_all_files() + + files_with_no_updates = [] + while len(self.closed_file_infos) > 0: + if len(self.open_file_infos) >= self.max_files_open: + self.can_open_more_files = False + break + + file_info = self.closed_file_infos.pop(0) + assert file_info.file_handle is None + # Get the file size to see if it has gotten bigger since we last + # opened it. + try: + file_size = os.path.getsize(file_info.filename) + except (IOError, OSError) as e: + # Catch "file not found" errors. + if e.errno == errno.ENOENT: + logger.warning( + f"Warning: The file {file_info.filename} was not found." + ) + self.log_filenames.remove(file_info.filename) + continue + raise e + + # If some new lines have been added to this file, try to reopen the + # file. + if file_size > file_info.size_when_last_opened: + try: + f = open(file_info.filename, "rb") + except (IOError, OSError) as e: + if e.errno == errno.ENOENT: + logger.warning( + f"Warning: The file {file_info.filename} was not found." + ) + self.log_filenames.remove(file_info.filename) + continue + else: + raise e + + f.seek(file_info.file_position) + file_info.size_when_last_opened = file_size + file_info.file_handle = f + self.open_file_infos.append(file_info) + else: + files_with_no_updates.append(file_info) + + if len(self.open_file_infos) >= self.max_files_open: + self.can_open_more_files = False + # Add the files with no changes back to the list of closed files. + self.closed_file_infos += files_with_no_updates + + def check_log_files_and_publish_updates(self): + """Gets updates to the log files and publishes them. + + Returns: + True if anything was published and false otherwise. + """ + anything_published = False + lines_to_publish = [] + + def flush(): + nonlocal lines_to_publish + nonlocal anything_published + if len(lines_to_publish) > 0: + data = { + "ip": self.ip, + "pid": file_info.worker_pid, + "job": file_info.job_id, + "is_err": file_info.is_err_file, + "lines": lines_to_publish, + "actor_name": file_info.actor_name, + "task_name": file_info.task_name, + } + try: + self.publisher.publish_logs(data) + except Exception: + logger.exception(f"Failed to publish log messages {data}") + anything_published = True + lines_to_publish = [] + + for file_info in self.open_file_infos: + assert not file_info.file_handle.closed + file_info.reopen_if_necessary() + + max_num_lines_to_read = ray_constants.LOG_MONITOR_NUM_LINES_TO_READ + for _ in range(max_num_lines_to_read): + try: + next_line = file_info.file_handle.readline() + # Replace any characters not in UTF-8 with + # a replacement character, see + # https://stackoverflow.com/a/38565489/10891801 + next_line = next_line.decode("utf-8", "replace") + if next_line == "": + break + next_line = next_line.rstrip("\r\n") + + if next_line.startswith(ray_constants.LOG_PREFIX_ACTOR_NAME): + flush() # Possible change of task/actor name. + file_info.actor_name = next_line.split( + ray_constants.LOG_PREFIX_ACTOR_NAME, 1 + )[1] + file_info.task_name = None + elif next_line.startswith(ray_constants.LOG_PREFIX_TASK_NAME): + flush() # Possible change of task/actor name. + file_info.task_name = next_line.split( + ray_constants.LOG_PREFIX_TASK_NAME, 1 + )[1] + elif next_line.startswith(ray_constants.LOG_PREFIX_JOB_ID): + file_info.job_id = next_line.split( + ray_constants.LOG_PREFIX_JOB_ID, 1 + )[1] + elif next_line.startswith( + "Windows fatal exception: access violation" + ): + # We are suppressing the + # 'Windows fatal exception: access violation' + # message on workers on Windows here. + # As far as we know it is harmless, + # but is frequently popping up if Python + # functions are run inside the core + # worker C extension. See the investigation in + # github.com/ray-project/ray/issues/18944 + # Also skip the following line, which is an + # empty line. + file_info.file_handle.readline() + else: + lines_to_publish.append(next_line) + except Exception: + logger.error( + f"Error: Reading file: {file_info.filename}, " + f"position: {file_info.file_info.file_handle.tell()} " + "failed." + ) + raise + + if file_info.file_position == 0: + # make filename windows-agnostic + filename = file_info.filename.replace("\\", "/") + if "/raylet" in filename: + file_info.worker_pid = "raylet" + elif "/gcs_server" in filename: + file_info.worker_pid = "gcs_server" + elif "/monitor" in filename or "event_AUTOSCALER" in filename: + file_info.worker_pid = "autoscaler" + elif "/runtime_env" in filename: + file_info.worker_pid = "runtime_env" + + # Record the current position in the file. + file_info.file_position = file_info.file_handle.tell() + flush() + + return anything_published + + def should_update_filenames(self, last_file_updated_time: float) -> bool: + """Return true if filenames should be updated. + + This method is used to apply the backpressure on file updates because + that requires heavy glob operations which use lots of CPUs. + + Args: + last_file_updated_time: The last time filenames are updated. + + Returns: + True if filenames should be updated. False otherwise. + """ + elapsed_seconds = float(time.time() - last_file_updated_time) + return ( + len(self.log_filenames) < RAY_LOG_MONITOR_MANY_FILES_THRESHOLD + or elapsed_seconds > LOG_NAME_UPDATE_INTERVAL_S + ) + + def run(self): + """Run the log monitor. + + This will scan the file system once every LOG_NAME_UPDATE_INTERVAL_S to + check if there are new log files to monitor. It will also publish new + log lines. + """ + last_updated = time.time() + while True: + if self.should_update_filenames(last_updated): + self.update_log_filenames() + last_updated = time.time() + + self.open_closed_files() + anything_published = self.check_log_files_and_publish_updates() + # If nothing was published, then wait a little bit before checking + # for logs to avoid using too much CPU. + if not anything_published: + time.sleep(0.1) + + +def is_proc_alive(pid): + # Import locally to make sure the bundled version is used if needed + import psutil + + try: + return psutil.Process(pid).is_running() + except psutil.NoSuchProcess: + # The process does not exist. + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=("Parse GCS server address for the log monitor to connect to.") + ) + parser.add_argument( + "--gcs-address", required=False, type=str, help="The address (ip:port) of GCS." + ) + parser.add_argument( + "--logging-level", + required=False, + type=str, + default=ray_constants.LOGGER_LEVEL, + choices=ray_constants.LOGGER_LEVEL_CHOICES, + help=ray_constants.LOGGER_LEVEL_HELP, + ) + parser.add_argument( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP, + ) + parser.add_argument( + "--logging-filename", + required=False, + type=str, + default=ray_constants.LOG_MONITOR_LOG_FILE_NAME, + help="Specify the name of log file, " + "log to stdout if set empty, default is " + f'"{ray_constants.LOG_MONITOR_LOG_FILE_NAME}"', + ) + parser.add_argument( + "--session-dir", + required=True, + type=str, + help="Specify the path of the session directory used by Ray processes.", + ) + parser.add_argument( + "--logs-dir", + required=True, + type=str, + help="Specify the path of the log directory used by Ray processes.", + ) + parser.add_argument( + "--logging-rotate-bytes", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BYTES, + help="Specify the max bytes for rotating " + "log file, default is " + f"{ray_constants.LOGGING_ROTATE_BYTES} bytes.", + ) + parser.add_argument( + "--logging-rotate-backup-count", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, + help="Specify the backup count of rotated log file, default is " + f"{ray_constants.LOGGING_ROTATE_BACKUP_COUNT}.", + ) + args = parser.parse_args() + setup_component_logger( + logging_level=args.logging_level, + logging_format=args.logging_format, + log_dir=args.logs_dir, + filename=args.logging_filename, + max_bytes=args.logging_rotate_bytes, + backup_count=args.logging_rotate_backup_count, + ) + + node_ip = services.get_cached_node_ip_address(args.session_dir) + log_monitor = LogMonitor( + node_ip, + args.logs_dir, + ray._raylet.GcsPublisher(address=args.gcs_address), + is_proc_alive, + gcs_address=args.gcs_address, + ) + + try: + log_monitor.run() + except Exception as e: + # Something went wrong, so push an error to all drivers. + gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address) + traceback_str = ray._private.utils.format_error_message(traceback.format_exc()) + message = ( + f"The log monitor on node {platform.node()} " + f"failed with the following error:\n{traceback_str}" + ) + ray._private.utils.publish_error_to_driver( + ray_constants.LOG_MONITOR_DIED_ERROR, + message, + gcs_publisher=gcs_publisher, + ) + logger.error(message) + raise e diff --git a/.venv/lib/python3.11/site-packages/ray/_private/logging_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20bf2159acd0357b503f0d3dccaa053b86071765 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/logging_utils.py @@ -0,0 +1,29 @@ +from ray.core.generated.logging_pb2 import LogBatch + + +def log_batch_dict_to_proto(log_json: dict) -> LogBatch: + """Converts a dict containing a batch of logs to a LogBatch proto.""" + return LogBatch( + ip=log_json.get("ip"), + # Cast to support string pid like "gcs". + pid=str(log_json.get("pid")) if log_json.get("pid") else None, + # Job ID as a hex string. + job_id=log_json.get("job"), + is_error=bool(log_json.get("is_err")), + lines=log_json.get("lines"), + actor_name=log_json.get("actor_name"), + task_name=log_json.get("task_name"), + ) + + +def log_batch_proto_to_dict(log_batch: LogBatch) -> dict: + """Converts a LogBatch proto to a dict containing a batch of logs.""" + return { + "ip": log_batch.ip, + "pid": log_batch.pid, + "job": log_batch.job_id, + "is_err": log_batch.is_error, + "lines": log_batch.lines, + "actor_name": log_batch.actor_name, + "task_name": log_batch.task_name, + } diff --git a/.venv/lib/python3.11/site-packages/ray/_private/memory_monitor.py b/.venv/lib/python3.11/site-packages/ray/_private/memory_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..b09cb9893c86eebc5bbc48bfdd0aa83a1138d75d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/memory_monitor.py @@ -0,0 +1,162 @@ +import logging +import os +import platform +import sys +import time + +# Import ray before psutil will make sure we use psutil's bundled version +import ray # noqa F401 +import psutil # noqa E402 + +logger = logging.getLogger(__name__) + + +def get_rss(memory_info): + """Get the estimated non-shared memory usage from psutil memory_info.""" + mem = memory_info.rss + # OSX doesn't have the shared attribute + if hasattr(memory_info, "shared"): + mem -= memory_info.shared + return mem + + +def get_shared(virtual_memory): + """Get the estimated shared memory usage from psutil virtual mem info.""" + # OSX doesn't have the shared attribute + if hasattr(virtual_memory, "shared"): + return virtual_memory.shared + else: + return 0 + + +def get_top_n_memory_usage(n: int = 10): + """Get the top n memory usage of the process + + Params: + n: Number of top n process memory usage to return. + Returns: + (str) The formatted string of top n process memory usage. + """ + pids = psutil.pids() + proc_stats = [] + for pid in pids: + try: + proc = psutil.Process(pid) + proc_stats.append((get_rss(proc.memory_info()), pid, proc.cmdline())) + except psutil.NoSuchProcess: + # We should skip the process that has exited. Refer this + # issue for more detail: + # https://github.com/ray-project/ray/issues/14929 + continue + except psutil.AccessDenied: + # On MacOS, the proc_pidinfo call (used to get per-process + # memory info) fails with a permission denied error when used + # on a process that isn’t owned by the same user. For now, we + # drop the memory info of any such process, assuming that + # processes owned by other users (e.g. root) aren't Ray + # processes and will be of less interest when an OOM happens + # on a Ray node. + # See issue for more detail: + # https://github.com/ray-project/ray/issues/11845#issuecomment-849904019 # noqa: E501 + continue + proc_str = "PID\tMEM\tCOMMAND" + for rss, pid, cmdline in sorted(proc_stats, reverse=True)[:n]: + proc_str += "\n{}\t{}GiB\t{}".format( + pid, round(rss / (1024**3), 2), " ".join(cmdline)[:100].strip() + ) + return proc_str + + +class RayOutOfMemoryError(Exception): + def __init__(self, msg): + Exception.__init__(self, msg) + + @staticmethod + def get_message(used_gb, total_gb, threshold): + proc_str = get_top_n_memory_usage(n=10) + return ( + "More than {}% of the memory on ".format(int(100 * threshold)) + + "node {} is used ({} / {} GB). ".format( + platform.node(), round(used_gb, 2), round(total_gb, 2) + ) + + f"The top 10 memory consumers are:\n\n{proc_str}" + + "\n\nIn addition, up to {} GiB of shared memory is ".format( + round(get_shared(psutil.virtual_memory()) / (1024**3), 2) + ) + + "currently being used by the Ray object store.\n---\n" + "--- Tip: Use the `ray memory` command to list active " + "objects in the cluster.\n" + "--- To disable OOM exceptions, set " + "RAY_DISABLE_MEMORY_MONITOR=1.\n---\n" + ) + + +class MemoryMonitor: + """Helper class for raising errors on low memory. + + This presents a much cleaner error message to users than what would happen + if we actually ran out of memory. + + The monitor tries to use the cgroup memory limit and usage if it is set + and available so that it is more reasonable inside containers. Otherwise, + it uses `psutil` to check the memory usage. + + The environment variable `RAY_MEMORY_MONITOR_ERROR_THRESHOLD` can be used + to overwrite the default error_threshold setting. + + Used by test only. For production code use memory_monitor.cc + """ + + def __init__(self, error_threshold=0.95, check_interval=1): + # Note: it takes ~50us to check the memory usage through psutil, so + # throttle this check at most once a second or so. + self.check_interval = check_interval + self.last_checked = 0 + try: + self.error_threshold = float( + os.getenv("RAY_MEMORY_MONITOR_ERROR_THRESHOLD") + ) + except (ValueError, TypeError): + self.error_threshold = error_threshold + # Try to read the cgroup memory limit if it is available. + try: + with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "rb") as f: + self.cgroup_memory_limit_gb = int(f.read()) / (1024**3) + except IOError: + self.cgroup_memory_limit_gb = sys.maxsize / (1024**3) + if not psutil: + logger.warn( + "WARNING: Not monitoring node memory since `psutil` " + "is not installed. Install this with " + "`pip install psutil` to enable " + "debugging of memory-related crashes." + ) + self.disabled = ( + "RAY_DEBUG_DISABLE_MEMORY_MONITOR" in os.environ + or "RAY_DISABLE_MEMORY_MONITOR" in os.environ + ) + + def get_memory_usage(self): + from ray._private.utils import get_system_memory, get_used_memory + + total_gb = get_system_memory() / (1024**3) + used_gb = get_used_memory() / (1024**3) + + return used_gb, total_gb + + def raise_if_low_memory(self): + if self.disabled: + return + + if time.time() - self.last_checked > self.check_interval: + self.last_checked = time.time() + used_gb, total_gb = self.get_memory_usage() + + if used_gb > total_gb * self.error_threshold: + raise RayOutOfMemoryError( + RayOutOfMemoryError.get_message( + used_gb, total_gb, self.error_threshold + ) + ) + else: + logger.debug(f"Memory usage is {used_gb} / {total_gb}") diff --git a/.venv/lib/python3.11/site-packages/ray/_private/metrics_agent.py b/.venv/lib/python3.11/site-packages/ray/_private/metrics_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee89ca72a9ea300569f24d73285766c9e877815 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/metrics_agent.py @@ -0,0 +1,675 @@ +import json +import logging +import os +import re +import threading +import time +import traceback +from collections import namedtuple +from typing import List, Tuple, Any, Dict, Set + +from prometheus_client.core import ( + CounterMetricFamily, + GaugeMetricFamily, + HistogramMetricFamily, +) +from opencensus.metrics.export.value import ValueDouble +from opencensus.metrics.export.metric_descriptor import MetricDescriptorType +from opencensus.stats import aggregation +from opencensus.stats import measure as measure_module +from opencensus.stats.view_manager import ViewManager +from opencensus.stats.stats_recorder import StatsRecorder +from opencensus.stats.base_exporter import StatsExporter +from prometheus_client.core import Metric as PrometheusMetric +from opencensus.stats.aggregation_data import ( + CountAggregationData, + DistributionAggregationData, + LastValueAggregationData, + SumAggregationData, +) +from opencensus.stats.view import View +from opencensus.tags import tag_key as tag_key_module +from opencensus.tags import tag_map as tag_map_module +from opencensus.tags import tag_value as tag_value_module + +import ray +from ray._raylet import GcsClient + +from ray.core.generated.metrics_pb2 import Metric +from ray._private.ray_constants import env_bool + +logger = logging.getLogger(__name__) + +# Env var key to decide worker timeout. +# If the worker doesn't report for more than +# this time, we treat workers as dead. +RAY_WORKER_TIMEOUT_S = "RAY_WORKER_TIMEOUT_S" +GLOBAL_COMPONENT_KEY = "CORE" +RE_NON_ALPHANUMS = re.compile(r"[^a-zA-Z0-9]") + + +class Gauge(View): + """Gauge representation of opencensus view. + + This class is used to collect process metrics from the reporter agent. + Cpp metrics should be collected in a different way. + """ + + def __init__(self, name, description, unit, tags: List[str]): + self._measure = measure_module.MeasureInt(name, description, unit) + tags = [tag_key_module.TagKey(tag) for tag in tags] + self._view = View( + name, description, tags, self.measure, aggregation.LastValueAggregation() + ) + + @property + def measure(self): + return self._measure + + @property + def view(self): + return self._view + + @property + def name(self): + return self.measure.name + + +Record = namedtuple("Record", ["gauge", "value", "tags"]) + + +def fix_grpc_metric(metric: Metric): + """ + Fix the inbound `opencensus.proto.metrics.v1.Metric` protos to make it acceptable + by opencensus.stats.DistributionAggregationData. + + - metric name: gRPC OpenCensus metrics have names with slashes and dots, e.g. + `grpc.io/client/server_latency`[1]. However Prometheus metric names only take + alphanums,underscores and colons[2]. We santinize the name by replacing non-alphanum + chars to underscore, like the official opencensus prometheus exporter[3]. + - distribution bucket bounds: The Metric proto asks distribution bucket bounds to + be > 0 [4]. However, gRPC OpenCensus metrics have their first bucket bound == 0 [1]. + This makes the `DistributionAggregationData` constructor to raise Exceptions. This + applies to all bytes and milliseconds (latencies). The fix: we update the initial 0 + bounds to be 0.000_000_1. This will not affect the precision of the metrics, since + we don't expect any less-than-1 bytes, or less-than-1-nanosecond times. + + [1] https://github.com/census-instrumentation/opencensus-specs/blob/master/stats/gRPC.md#units # noqa: E501 + [2] https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels + [3] https://github.com/census-instrumentation/opencensus-cpp/blob/50eb5de762e5f87e206c011a4f930adb1a1775b1/opencensus/exporters/stats/prometheus/internal/prometheus_utils.cc#L39 # noqa: E501 + [4] https://github.com/census-instrumentation/opencensus-proto/blob/master/src/opencensus/proto/metrics/v1/metrics.proto#L218 # noqa: E501 + """ + + if not metric.metric_descriptor.name.startswith("grpc.io/"): + return + + metric.metric_descriptor.name = RE_NON_ALPHANUMS.sub( + "_", metric.metric_descriptor.name + ) + + for series in metric.timeseries: + for point in series.points: + if point.HasField("distribution_value"): + dist_value = point.distribution_value + bucket_bounds = dist_value.bucket_options.explicit.bounds + if len(bucket_bounds) > 0 and bucket_bounds[0] == 0: + bucket_bounds[0] = 0.000_000_1 + + +class OpencensusProxyMetric: + def __init__(self, name: str, desc: str, unit: str, label_keys: List[str]): + """Represents the OpenCensus metrics that will be proxy exported.""" + self._name = name + self._desc = desc + self._unit = unit + # -- The label keys of the metric -- + self._label_keys = label_keys + # -- The data that needs to be proxy exported -- + # tuple of label values -> data (OpenCesnsus Aggregation data) + self._data = {} + + @property + def name(self): + return self._name + + @property + def desc(self): + return self._desc + + @property + def unit(self): + return self._unit + + @property + def label_keys(self): + return self._label_keys + + @property + def data(self): + return self._data + + def record(self, metric: Metric): + """Parse the Opencensus Protobuf and store the data. + + The data can be accessed via `data` API once recorded. + """ + timeseries = metric.timeseries + + if len(timeseries) == 0: + return + + # Create the aggregation and fill it in the our stats + for series in timeseries: + labels = tuple(val.value for val in series.label_values) + + # Aggregate points. + for point in series.points: + if ( + metric.metric_descriptor.type + == MetricDescriptorType.CUMULATIVE_INT64 + ): + data = CountAggregationData(point.int64_value) + elif ( + metric.metric_descriptor.type + == MetricDescriptorType.CUMULATIVE_DOUBLE + ): + data = SumAggregationData(ValueDouble, point.double_value) + elif metric.metric_descriptor.type == MetricDescriptorType.GAUGE_DOUBLE: + data = LastValueAggregationData(ValueDouble, point.double_value) + elif ( + metric.metric_descriptor.type + == MetricDescriptorType.CUMULATIVE_DISTRIBUTION + ): + dist_value = point.distribution_value + counts_per_bucket = [bucket.count for bucket in dist_value.buckets] + bucket_bounds = dist_value.bucket_options.explicit.bounds + data = DistributionAggregationData( + dist_value.sum / dist_value.count, + dist_value.count, + dist_value.sum_of_squared_deviation, + counts_per_bucket, + bucket_bounds, + ) + else: + raise ValueError("Summary is not supported") + self._data[labels] = data + + +class Component: + def __init__(self, id: str): + """Represent a component that requests to proxy export metrics + + Args: + id: Id of this component. + """ + self.id = id + # -- The time this component reported its metrics last time -- + # It is used to figure out if this component is stale. + self._last_reported_time = time.monotonic() + # -- Metrics requested to proxy export from this component -- + # metrics_name (str) -> metric (OpencensusProxyMetric) + self._metrics = {} + + @property + def metrics(self) -> Dict[str, OpencensusProxyMetric]: + """Return the metrics requested to proxy export from this component.""" + return self._metrics + + @property + def last_reported_time(self): + return self._last_reported_time + + def record(self, metrics: List[Metric]): + """Parse the Opencensus protobuf and store metrics. + + Metrics can be accessed via `metrics` API for proxy export. + + Args: + metrics: A list of Opencensus protobuf for proxy export. + """ + self._last_reported_time = time.monotonic() + for metric in metrics: + fix_grpc_metric(metric) + descriptor = metric.metric_descriptor + name = descriptor.name + label_keys = [label_key.key for label_key in descriptor.label_keys] + + if name not in self._metrics: + self._metrics[name] = OpencensusProxyMetric( + name, descriptor.description, descriptor.unit, label_keys + ) + self._metrics[name].record(metric) + + +class OpenCensusProxyCollector: + def __init__(self, namespace: str, component_timeout_s: int = 60): + """Prometheus collector implementation for opencensus proxy export. + + Prometheus collector requires to implement `collect` which is + invoked whenever Prometheus queries the endpoint. + + The class is thread-safe. + + Args: + namespace: Prometheus namespace. + """ + # -- Protect `self._components` -- + self._components_lock = threading.Lock() + # -- Timeout until the component is marked as stale -- + # Once the component is considered as stale, + # the metrics from that worker won't be exported. + self._component_timeout_s = component_timeout_s + # -- Prometheus namespace -- + self._namespace = namespace + # -- Component that requests to proxy export metrics -- + # Component means core worker, raylet, and GCS. + # component_id -> Components + # For workers, they contain worker ids. + # For other components (raylet, GCS), + # they contain the global key `GLOBAL_COMPONENT_KEY`. + self._components = {} + # Whether we want to export counter as gauge. + # This is for bug compatibility. + # See https://github.com/ray-project/ray/pull/43795. + self._export_counter_as_gauge = env_bool("RAY_EXPORT_COUNTER_AS_GAUGE", True) + + def record(self, metrics: List[Metric], worker_id_hex: str = None): + """Record the metrics reported from the component that reports it. + + Args: + metrics: A list of opencensus protobuf to proxy export metrics. + worker_id_hex: A worker id that reports these metrics. + If None, it means they are reported from Raylet or GCS. + """ + key = GLOBAL_COMPONENT_KEY if not worker_id_hex else worker_id_hex + with self._components_lock: + if key not in self._components: + self._components[key] = Component(key) + self._components[key].record(metrics) + + def clean_stale_components(self): + """Clean up stale components. + + Stale means the component is dead or unresponsive. + + Stale components won't be reported to Prometheus anymore. + """ + with self._components_lock: + stale_components = [] + stale_component_ids = [] + for id, component in self._components.items(): + elapsed = time.monotonic() - component.last_reported_time + if elapsed > self._component_timeout_s: + stale_component_ids.append(id) + logger.info( + "Metrics from a worker ({}) is cleaned up due to " + "timeout. Time since last report {}s".format(id, elapsed) + ) + for id in stale_component_ids: + stale_components.append(self._components.pop(id)) + return stale_components + + # TODO(sang): add start and end timestamp + def to_metrics( + self, + metric_name: str, + metric_description: str, + label_keys: List[str], + metric_units: str, + label_values: Tuple[tag_value_module.TagValue], + agg_data: Any, + metrics_map: Dict[str, List[PrometheusMetric]], + ): + """to_metric translate the data that OpenCensus create + to Prometheus format, using Prometheus Metric object. + + This method is from Opencensus Prometheus Exporter. + + Args: + metric_name: Name of the metric. + metric_description: Description of the metric. + label_keys: The fixed label keys of the metric. + metric_units: Units of the metric. + label_values: The values of `label_keys`. + agg_data: `opencensus.stats.aggregation_data.AggregationData` object. + Aggregated data that needs to be converted as Prometheus samples + metrics_map: The converted metric is added to this map. + + """ + assert self._components_lock.locked() + metric_name = f"{self._namespace}_{metric_name}" + assert len(label_values) == len(label_keys), (label_values, label_keys) + # Prometheus requires that all tag values be strings hence + # the need to cast none to the empty string before exporting. See + # https://github.com/census-instrumentation/opencensus-python/issues/480 + label_values = [tv if tv else "" for tv in label_values] + + if isinstance(agg_data, CountAggregationData): + metrics = metrics_map.get(metric_name) + if not metrics: + metric = CounterMetricFamily( + name=metric_name, + documentation=metric_description, + unit=metric_units, + labels=label_keys, + ) + metrics = [metric] + metrics_map[metric_name] = metrics + metrics[0].add_metric(labels=label_values, value=agg_data.count_data) + return + + if isinstance(agg_data, SumAggregationData): + # This should be emitted as prometheus counter + # but we used to emit it as prometheus gauge. + # To keep the backward compatibility + # (changing from counter to gauge changes the metric name + # since prometheus client will add "_total" suffix to counter + # per OpenMetrics specification), + # we now emit both counter and gauge and in the + # next major Ray release (3.0) we can stop emitting gauge. + # This leaves people enough time to migrate their dashboards. + # See https://github.com/ray-project/ray/pull/43795. + metrics = metrics_map.get(metric_name) + if not metrics: + metric = CounterMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics = [metric] + metrics_map[metric_name] = metrics + metrics[0].add_metric(labels=label_values, value=agg_data.sum_data) + + if not self._export_counter_as_gauge: + pass + elif metric_name.endswith("_total"): + # In this case, we only need to emit prometheus counter + # since for metric name already ends with _total suffix + # prometheus client won't change it + # so there is no backward compatibility issue. + # See https://prometheus.github.io/client_python/instrumenting/counter/ + pass + else: + if len(metrics) == 1: + metric = GaugeMetricFamily( + name=metric_name, + documentation=( + f"(DEPRECATED, use {metric_name}_total metric instead) " + f"{metric_description}" + ), + labels=label_keys, + ) + metrics.append(metric) + assert len(metrics) == 2 + metrics[1].add_metric(labels=label_values, value=agg_data.sum_data) + return + + elif isinstance(agg_data, DistributionAggregationData): + + assert agg_data.bounds == sorted(agg_data.bounds) + # buckets are a list of buckets. Each bucket is another list with + # a pair of bucket name and value, or a triple of bucket name, + # value, and exemplar. buckets need to be in order. + buckets = [] + cum_count = 0 # Prometheus buckets expect cumulative count. + for ii, bound in enumerate(agg_data.bounds): + cum_count += agg_data.counts_per_bucket[ii] + bucket = [str(bound), cum_count] + buckets.append(bucket) + # Prometheus requires buckets to be sorted, and +Inf present. + # In OpenCensus we don't have +Inf in the bucket bonds so need to + # append it here. + buckets.append(["+Inf", agg_data.count_data]) + metrics = metrics_map.get(metric_name) + if not metrics: + metric = HistogramMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics = [metric] + metrics_map[metric_name] = metrics + metrics[0].add_metric( + labels=label_values, + buckets=buckets, + sum_value=agg_data.sum, + ) + return + + elif isinstance(agg_data, LastValueAggregationData): + metrics = metrics_map.get(metric_name) + if not metrics: + metric = GaugeMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics = [metric] + metrics_map[metric_name] = metrics + metrics[0].add_metric(labels=label_values, value=agg_data.value) + return + + else: + raise ValueError(f"unsupported aggregation type {type(agg_data)}") + + def collect(self): # pragma: NO COVER + """Collect fetches the statistics from OpenCensus + and delivers them as Prometheus Metrics. + Collect is invoked every time a prometheus.Gatherer is run + for example when the HTTP endpoint is invoked by Prometheus. + + This method is required as a Prometheus Collector. + """ + with self._components_lock: + metrics_map = {} + for component in self._components.values(): + for metric in component.metrics.values(): + for label_values, data in metric.data.items(): + self.to_metrics( + metric.name, + metric.desc, + metric.label_keys, + metric.unit, + label_values, + data, + metrics_map, + ) + + for metrics in metrics_map.values(): + for metric in metrics: + yield metric + + +class MetricsAgent: + def __init__( + self, + view_manager: ViewManager, + stats_recorder: StatsRecorder, + stats_exporter: StatsExporter = None, + ): + """A class to record and export metrics. + + The class exports metrics in 2 different ways. + - Directly record and export metrics using OpenCensus. + - Proxy metrics from other core components + (e.g., raylet, GCS, core workers). + + This class is thread-safe. + """ + # Lock required because gRPC server uses + # multiple threads to process requests. + self._lock = threading.Lock() + + # + # Opencensus components to record metrics. + # + + # Managing views to export metrics + # If the stats_exporter is None, we disable all metrics export. + self.view_manager = view_manager + # A class that's used to record metrics + # emitted from the current process. + self.stats_recorder = stats_recorder + # A class to export metrics. + self.stats_exporter = stats_exporter + # -- A Prometheus custom collector to proxy export metrics -- + # `None` if the prometheus server is not started. + self.proxy_exporter_collector = None + + if self.stats_exporter is None: + # If the exporter is not given, + # we disable metrics collection. + self.view_manager = None + else: + self.view_manager.register_exporter(stats_exporter) + self.proxy_exporter_collector = OpenCensusProxyCollector( + self.stats_exporter.options.namespace, + component_timeout_s=int(os.getenv(RAY_WORKER_TIMEOUT_S, 120)), + ) + + # Registered view names. + self._registered_views: Set[str] = set() + + def record_and_export(self, records: List[Record], global_tags=None): + """Directly record and export stats from the same process.""" + global_tags = global_tags or {} + with self._lock: + if not self.view_manager: + return + + for record in records: + gauge = record.gauge + value = record.value + tags = record.tags + self._record_gauge(gauge, value, {**tags, **global_tags}) + + def _record_gauge(self, gauge: Gauge, value: float, tags: dict): + if gauge.name not in self._registered_views: + self.view_manager.register_view(gauge.view) + self._registered_views.add(gauge.name) + measurement_map = self.stats_recorder.new_measurement_map() + tag_map = tag_map_module.TagMap() + for key, tag_val in tags.items(): + tag_key = tag_key_module.TagKey(key) + tag_value = tag_value_module.TagValue(tag_val) + tag_map.insert(tag_key, tag_value) + measurement_map.measure_float_put(gauge.measure, value) + # NOTE: When we record this metric, timestamp will be renewed. + measurement_map.record(tag_map) + + def proxy_export_metrics(self, metrics: List[Metric], worker_id_hex: str = None): + """Proxy export metrics specified by a Opencensus Protobuf. + + This API is used to export metrics emitted from + core components. + + Args: + metrics: A list of protobuf Metric defined from OpenCensus. + worker_id_hex: The worker ID it proxies metrics export. None + if the metric is not from a worker (i.e., raylet, GCS). + """ + with self._lock: + if not self.view_manager: + return + + self._proxy_export_metrics(metrics, worker_id_hex) + + def _proxy_export_metrics(self, metrics: List[Metric], worker_id_hex: str = None): + self.proxy_exporter_collector.record(metrics, worker_id_hex) + + def clean_all_dead_worker_metrics(self): + """Clean dead worker's metrics. + + Worker metrics are cleaned up and won't be exported once + it is considered as dead. + + This method has to be periodically called by a caller. + """ + with self._lock: + if not self.view_manager: + return + + self.proxy_exporter_collector.clean_stale_components() + + +class PrometheusServiceDiscoveryWriter(threading.Thread): + """A class to support Prometheus service discovery. + + It supports file-based service discovery. Checkout + https://prometheus.io/docs/guides/file-sd/ for more details. + + Args: + gcs_address: Gcs address for this cluster. + temp_dir: Temporary directory used by + Ray to store logs and metadata. + """ + + def __init__(self, gcs_address, temp_dir): + gcs_client_options = ray._raylet.GcsClientOptions.create( + gcs_address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False + ) + self.gcs_address = gcs_address + + ray._private.state.state._initialize_global_state(gcs_client_options) + self.temp_dir = temp_dir + self.default_service_discovery_flush_period = 5 + super().__init__() + + def get_file_discovery_content(self): + """Return the content for Prometheus service discovery.""" + nodes = ray.nodes() + metrics_export_addresses = [ + "{}:{}".format(node["NodeManagerAddress"], node["MetricsExportPort"]) + for node in nodes + if node["alive"] is True + ] + gcs_client = GcsClient(address=self.gcs_address) + autoscaler_addr = gcs_client.internal_kv_get(b"AutoscalerMetricsAddress", None) + if autoscaler_addr: + metrics_export_addresses.append(autoscaler_addr.decode("utf-8")) + dashboard_addr = gcs_client.internal_kv_get(b"DashboardMetricsAddress", None) + if dashboard_addr: + metrics_export_addresses.append(dashboard_addr.decode("utf-8")) + return json.dumps( + [{"labels": {"job": "ray"}, "targets": metrics_export_addresses}] + ) + + def write(self): + # Write a file based on https://prometheus.io/docs/guides/file-sd/ + # Write should be atomic. Otherwise, Prometheus raises an error that + # json file format is invalid because it reads a file when + # file is re-written. Note that Prometheus still works although we + # have this error. + temp_file_name = self.get_temp_file_name() + with open(temp_file_name, "w") as json_file: + json_file.write(self.get_file_discovery_content()) + # NOTE: os.replace is atomic on both Linux and Windows, so we won't + # have race condition reading this file. + os.replace(temp_file_name, self.get_target_file_name()) + + def get_target_file_name(self): + return os.path.join( + self.temp_dir, ray._private.ray_constants.PROMETHEUS_SERVICE_DISCOVERY_FILE + ) + + def get_temp_file_name(self): + return os.path.join( + self.temp_dir, + "{}_{}".format( + "tmp", ray._private.ray_constants.PROMETHEUS_SERVICE_DISCOVERY_FILE + ), + ) + + def run(self): + while True: + # This thread won't be broken by exceptions. + try: + self.write() + except Exception as e: + logger.warning( + "Writing a service discovery file, {}," + "failed.".format(self.get_target_file_name()) + ) + logger.warning(traceback.format_exc()) + logger.warning(f"Error message: {e}") + time.sleep(self.default_service_discovery_flush_period) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/node.py b/.venv/lib/python3.11/site-packages/ray/_private/node.py new file mode 100644 index 0000000000000000000000000000000000000000..b9097eae54ea46b1992c371a4544a5883f449ce3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/node.py @@ -0,0 +1,1862 @@ +import atexit +import collections +import datetime +import errno +import json +import logging +import os +import random +import signal +import socket +import subprocess +import sys +import tempfile +import threading +import time +import traceback +from collections import defaultdict +from typing import Dict, Optional, Tuple, IO, AnyStr + +from filelock import FileLock + +import ray +import ray._private.ray_constants as ray_constants +import ray._private.services +from ray._private import storage +from ray._raylet import GcsClient, get_session_key_from_storage +from ray._private.resource_spec import ResourceSpec +from ray._private.services import serialize_config, get_address +from ray._private.utils import open_log, try_to_create_directory, try_to_symlink + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray configures it by default automatically +# using logging.basicConfig in its entry/init points. +logger = logging.getLogger(__name__) + + +class Node: + """An encapsulation of the Ray processes on a single node. + + This class is responsible for starting Ray processes and killing them, + and it also controls the temp file policy. + + Attributes: + all_processes: A mapping from process type (str) to a list of + ProcessInfo objects. All lists have length one except for the Redis + server list, which has multiple. + """ + + def __init__( + self, + ray_params, + head: bool = False, + shutdown_at_exit: bool = True, + spawn_reaper: bool = True, + connect_only: bool = False, + default_worker: bool = False, + ray_init_cluster: bool = False, + ): + """Start a node. + + Args: + ray_params: The RayParams to use to configure the node. + head: True if this is the head node, which means it will + start additional processes like the Redis servers, monitor + processes, and web UI. + shutdown_at_exit: If true, spawned processes will be cleaned + up if this process exits normally. + spawn_reaper: If true, spawns a process that will clean up + other spawned processes if this process dies unexpectedly. + connect_only: If true, connect to the node without starting + new processes. + default_worker: Whether it's running from a ray worker or not + ray_init_cluster: Whether it's a cluster created by ray.init() + """ + if shutdown_at_exit: + if connect_only: + raise ValueError( + "'shutdown_at_exit' and 'connect_only' cannot both be true." + ) + self._register_shutdown_hooks() + self._default_worker = default_worker + self.head = head + self.kernel_fate_share = bool( + spawn_reaper and ray._private.utils.detect_fate_sharing_support() + ) + self.all_processes: dict = {} + self.removal_lock = threading.Lock() + + self.ray_init_cluster = ray_init_cluster + if ray_init_cluster: + assert head, "ray.init() created cluster only has the head node" + + # Set up external Redis when `RAY_REDIS_ADDRESS` is specified. + redis_address_env = os.environ.get("RAY_REDIS_ADDRESS") + if ray_params.external_addresses is None and redis_address_env is not None: + external_redis = redis_address_env.split(",") + + # Reuse primary Redis as Redis shard when there's only one + # instance provided. + if len(external_redis) == 1: + external_redis.append(external_redis[0]) + [primary_redis_ip, port] = external_redis[0].rsplit(":", 1) + ray_params.external_addresses = external_redis + ray_params.num_redis_shards = len(external_redis) - 1 + + if ( + ray_params._system_config + and len(ray_params._system_config) > 0 + and (not head and not connect_only) + ): + raise ValueError( + "System config parameters can only be set on the head node." + ) + + ray_params.update_if_absent( + include_log_monitor=True, + resources={}, + worker_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "workers", + "default_worker.py", + ), + setup_worker_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "workers", + ray_constants.SETUP_WORKER_FILENAME, + ), + ) + + self._resource_spec = None + self._localhost = socket.gethostbyname("localhost") + self._ray_params = ray_params + self._config = ray_params._system_config or {} + + self._dashboard_agent_listen_port = ray_params.dashboard_agent_listen_port + self._dashboard_grpc_port = ray_params.dashboard_grpc_port + + # Configure log rotation parameters. + self.max_bytes = int( + os.getenv("RAY_ROTATION_MAX_BYTES", ray_constants.LOGGING_ROTATE_BYTES) + ) + self.backup_count = int( + os.getenv( + "RAY_ROTATION_BACKUP_COUNT", ray_constants.LOGGING_ROTATE_BACKUP_COUNT + ) + ) + + assert self.max_bytes >= 0 + assert self.backup_count >= 0 + + self._redis_address = ray_params.redis_address + if head: + ray_params.update_if_absent(num_redis_shards=1) + self._gcs_address = ray_params.gcs_address + self._gcs_client = None + + if not self.head: + self.validate_ip_port(self.address) + self._init_gcs_client() + + # Register the temp dir. + self._session_name = ray_params.session_name + if self._session_name is None: + if head: + # We expect this the first time we initialize a cluster, but not during + # subsequent restarts of the head node. + maybe_key = self.check_persisted_session_name() + if maybe_key is None: + # date including microsecond + date_str = datetime.datetime.today().strftime( + "%Y-%m-%d_%H-%M-%S_%f" + ) + self._session_name = f"session_{date_str}_{os.getpid()}" + else: + self._session_name = ray._private.utils.decode(maybe_key) + else: + assert not self._default_worker + session_name = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "session_name", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._session_name = ray._private.utils.decode(session_name) + + # Initialize webui url + if head: + self._webui_url = None + else: + if ray_params.webui is None: + assert not self._default_worker + self._webui_url = ray._private.services.get_webui_url_from_internal_kv() + else: + self._webui_url = ( + f"{ray_params.dashboard_host}:{ray_params.dashboard_port}" + ) + + # It creates a session_dir. + self._init_temp() + + node_ip_address = ray_params.node_ip_address + if node_ip_address is None: + if connect_only: + node_ip_address = self._wait_and_get_for_node_address() + else: + node_ip_address = ray.util.get_node_ip_address() + + assert node_ip_address is not None + ray_params.update_if_absent( + node_ip_address=node_ip_address, raylet_ip_address=node_ip_address + ) + self._node_ip_address = node_ip_address + if not connect_only: + ray._private.services.write_node_ip_address( + self.get_session_dir_path(), node_ip_address + ) + + if ray_params.raylet_ip_address: + raylet_ip_address = ray_params.raylet_ip_address + else: + raylet_ip_address = node_ip_address + + if raylet_ip_address != node_ip_address and (not connect_only or head): + raise ValueError( + "The raylet IP address should only be different than the node " + "IP address when connecting to an existing raylet; i.e., when " + "head=False and connect_only=True." + ) + self._raylet_ip_address = raylet_ip_address + + # Validate and initialize the persistent storage API. + if head: + storage._init_storage(ray_params.storage, is_head=True) + else: + if not self._default_worker: + storage_uri = ray._private.services.get_storage_uri_from_internal_kv() + else: + storage_uri = ray_params.storage + storage._init_storage(storage_uri, is_head=False) + + # If it is a head node, try validating if + # external storage is configurable. + if head: + self.validate_external_storage() + + if connect_only: + # Get socket names from the configuration. + self._plasma_store_socket_name = ray_params.plasma_store_socket_name + self._raylet_socket_name = ray_params.raylet_socket_name + self._node_id = ray_params.node_id + + # If user does not provide the socket name, get it from Redis. + if ( + self._plasma_store_socket_name is None + or self._raylet_socket_name is None + or self._ray_params.node_manager_port is None + or self._node_id is None + ): + # Get the address info of the processes to connect to + # from Redis or GCS. + node_info = ray._private.services.get_node_to_connect_for_driver( + self.gcs_address, + self._raylet_ip_address, + ) + self._plasma_store_socket_name = node_info["object_store_socket_name"] + self._raylet_socket_name = node_info["raylet_socket_name"] + self._ray_params.node_manager_port = node_info["node_manager_port"] + self._node_id = node_info["node_id"] + else: + # If the user specified a socket name, use it. + self._plasma_store_socket_name = self._prepare_socket_file( + self._ray_params.plasma_store_socket_name, default_prefix="plasma_store" + ) + self._raylet_socket_name = self._prepare_socket_file( + self._ray_params.raylet_socket_name, default_prefix="raylet" + ) + if ( + self._ray_params.env_vars is not None + and "RAY_OVERRIDE_NODE_ID_FOR_TESTING" in self._ray_params.env_vars + ): + node_id = self._ray_params.env_vars["RAY_OVERRIDE_NODE_ID_FOR_TESTING"] + logger.debug( + f"Setting node ID to {node_id} " + "based on ray_params.env_vars override" + ) + self._node_id = node_id + elif os.environ.get("RAY_OVERRIDE_NODE_ID_FOR_TESTING"): + node_id = os.environ["RAY_OVERRIDE_NODE_ID_FOR_TESTING"] + logger.debug(f"Setting node ID to {node_id} based on env override") + self._node_id = node_id + else: + node_id = ray.NodeID.from_random().hex() + logger.debug(f"Setting node ID to {node_id}") + self._node_id = node_id + + # The dashboard agent port is assigned first to avoid + # other processes accidentally taking its default port + self._dashboard_agent_listen_port = self._get_cached_port( + "dashboard_agent_listen_port", + default_port=ray_params.dashboard_agent_listen_port, + ) + + self.metrics_agent_port = self._get_cached_port( + "metrics_agent_port", default_port=ray_params.metrics_agent_port + ) + self._metrics_export_port = self._get_cached_port( + "metrics_export_port", default_port=ray_params.metrics_export_port + ) + self._runtime_env_agent_port = self._get_cached_port( + "runtime_env_agent_port", + default_port=ray_params.runtime_env_agent_port, + ) + + ray_params.update_if_absent( + metrics_agent_port=self.metrics_agent_port, + metrics_export_port=self._metrics_export_port, + dashboard_agent_listen_port=self._dashboard_agent_listen_port, + runtime_env_agent_port=self._runtime_env_agent_port, + ) + + # Pick a GCS server port. + if head: + gcs_server_port = os.getenv(ray_constants.GCS_PORT_ENVIRONMENT_VARIABLE) + if gcs_server_port: + ray_params.update_if_absent(gcs_server_port=int(gcs_server_port)) + if ray_params.gcs_server_port is None or ray_params.gcs_server_port == 0: + ray_params.gcs_server_port = self._get_cached_port("gcs_server_port") + + if not connect_only and spawn_reaper and not self.kernel_fate_share: + self.start_reaper_process() + if not connect_only: + self._ray_params.update_pre_selected_port() + + # Start processes. + if head: + self.start_head_processes() + + if not connect_only: + self.start_ray_processes() + # we should update the address info after the node has been started + try: + ray._private.services.wait_for_node( + self.gcs_address, + self._plasma_store_socket_name, + ) + except TimeoutError as te: + raise Exception( + "The current node timed out during startup. This " + "could happen because some of the Ray processes " + "failed to startup." + ) from te + node_info = ray._private.services.get_node( + self.gcs_address, + self._node_id, + ) + if self._ray_params.node_manager_port == 0: + self._ray_params.node_manager_port = node_info["node_manager_port"] + + # Makes sure the Node object has valid addresses after setup. + self.validate_ip_port(self.address) + self.validate_ip_port(self.gcs_address) + + if not connect_only: + self._record_stats() + + def check_persisted_session_name(self): + if self._ray_params.external_addresses is None: + return None + self._redis_address = self._ray_params.external_addresses[0] + redis_ip_address, redis_port, enable_redis_ssl = get_address( + self._redis_address, + ) + # Address is ip:port or redis://ip:port + if int(redis_port) < 0: + raise ValueError( + f"Invalid Redis port provided: {redis_port}." + "The port must be a non-negative integer." + ) + + return get_session_key_from_storage( + redis_ip_address, + int(redis_port), + self._ray_params.redis_username, + self._ray_params.redis_password, + enable_redis_ssl, + serialize_config(self._config), + b"session_name", + ) + + @staticmethod + def validate_ip_port(ip_port): + """Validates the address is in the ip:port format""" + _, _, port = ip_port.rpartition(":") + if port == ip_port: + raise ValueError(f"Port is not specified for address {ip_port}") + try: + _ = int(port) + except ValueError: + raise ValueError( + f"Unable to parse port number from {port} (full address = {ip_port})" + ) + + def check_version_info(self): + """Check if the Python and Ray version of this process matches that in GCS. + + This will be used to detect if workers or drivers are started using + different versions of Python, or Ray. + + Raises: + Exception: An exception is raised if there is a version mismatch. + """ + import ray._private.usage.usage_lib as ray_usage_lib + + cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client()) + if cluster_metadata is None: + cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client()) + + if not cluster_metadata: + return + node_ip_address = ray._private.services.get_node_ip_address() + ray._private.utils.check_version_info( + cluster_metadata, f"node {node_ip_address}" + ) + + def _register_shutdown_hooks(self): + # Register the atexit handler. In this case, we shouldn't call sys.exit + # as we're already in the exit procedure. + def atexit_handler(*args): + self.kill_all_processes(check_alive=False, allow_graceful=True) + + atexit.register(atexit_handler) + + # Register the handler to be called if we get a SIGTERM. + # In this case, we want to exit with an error code (1) after + # cleaning up child processes. + def sigterm_handler(signum, frame): + self.kill_all_processes(check_alive=False, allow_graceful=True) + sys.exit(1) + + ray._private.utils.set_sigterm_handler(sigterm_handler) + + def _init_temp(self): + # Create a dictionary to store temp file index. + self._incremental_dict = collections.defaultdict(lambda: 0) + + if self.head: + self._ray_params.update_if_absent( + temp_dir=ray._private.utils.get_ray_temp_dir() + ) + self._temp_dir = self._ray_params.temp_dir + else: + if self._ray_params.temp_dir is None: + assert not self._default_worker + temp_dir = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "temp_dir", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._temp_dir = ray._private.utils.decode(temp_dir) + else: + self._temp_dir = self._ray_params.temp_dir + + try_to_create_directory(self._temp_dir) + + if self.head: + self._session_dir = os.path.join(self._temp_dir, self._session_name) + else: + if self._temp_dir is None or self._session_name is None: + assert not self._default_worker + session_dir = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "session_dir", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._session_dir = ray._private.utils.decode(session_dir) + else: + self._session_dir = os.path.join(self._temp_dir, self._session_name) + session_symlink = os.path.join(self._temp_dir, ray_constants.SESSION_LATEST) + + # Send a warning message if the session exists. + try_to_create_directory(self._session_dir) + try_to_symlink(session_symlink, self._session_dir) + # Create a directory to be used for socket files. + self._sockets_dir = os.path.join(self._session_dir, "sockets") + try_to_create_directory(self._sockets_dir) + # Create a directory to be used for process log files. + self._logs_dir = os.path.join(self._session_dir, "logs") + try_to_create_directory(self._logs_dir) + old_logs_dir = os.path.join(self._logs_dir, "old") + try_to_create_directory(old_logs_dir) + # Create a directory to be used for runtime environment. + self._runtime_env_dir = os.path.join( + self._session_dir, self._ray_params.runtime_env_dir_name + ) + try_to_create_directory(self._runtime_env_dir) + + def _get_node_labels(self): + def merge_labels(env_override_labels, params_labels): + """Merges two dictionaries, picking from the + first in the event of a conflict. Also emit a warning on every + conflict. + """ + + result = params_labels.copy() + result.update(env_override_labels) + + for key in set(env_override_labels.keys()).intersection( + set(params_labels.keys()) + ): + if params_labels[key] != env_override_labels[key]: + logger.warning( + "Autoscaler is overriding your label:" + f"{key}: {params_labels[key]} to " + f"{key}: {env_override_labels[key]}." + ) + return result + + env_override_labels = {} + env_override_labels_string = os.getenv( + ray_constants.LABELS_ENVIRONMENT_VARIABLE + ) + if env_override_labels_string: + try: + env_override_labels = json.loads(env_override_labels_string) + except Exception: + logger.exception(f"Failed to load {env_override_labels_string}") + raise + logger.info(f"Autoscaler overriding labels: {env_override_labels}.") + + return merge_labels(env_override_labels, self._ray_params.labels or {}) + + def get_resource_spec(self): + """Resolve and return the current resource spec for the node.""" + + def merge_resources(env_dict, params_dict): + """Separates special case params and merges two dictionaries, picking from the + first in the event of a conflict. Also emit a warning on every + conflict. + """ + num_cpus = env_dict.pop("CPU", None) + num_gpus = env_dict.pop("GPU", None) + memory = env_dict.pop("memory", None) + object_store_memory = env_dict.pop("object_store_memory", None) + + result = params_dict.copy() + result.update(env_dict) + + for key in set(env_dict.keys()).intersection(set(params_dict.keys())): + if params_dict[key] != env_dict[key]: + logger.warning( + "Autoscaler is overriding your resource:" + f"{key}: {params_dict[key]} with {env_dict[key]}." + ) + return num_cpus, num_gpus, memory, object_store_memory, result + + if not self._resource_spec: + env_resources = {} + env_string = os.getenv(ray_constants.RESOURCES_ENVIRONMENT_VARIABLE) + if env_string: + try: + env_resources = json.loads(env_string) + except Exception: + logger.exception(f"Failed to load {env_string}") + raise + logger.debug(f"Autoscaler overriding resources: {env_resources}.") + ( + num_cpus, + num_gpus, + memory, + object_store_memory, + resources, + ) = merge_resources(env_resources, self._ray_params.resources) + self._resource_spec = ResourceSpec( + self._ray_params.num_cpus if num_cpus is None else num_cpus, + self._ray_params.num_gpus if num_gpus is None else num_gpus, + self._ray_params.memory if memory is None else memory, + ( + self._ray_params.object_store_memory + if object_store_memory is None + else object_store_memory + ), + resources, + self._ray_params.redis_max_memory, + ).resolve(is_head=self.head, node_ip_address=self.node_ip_address) + return self._resource_spec + + @property + def node_id(self): + """Get the node ID.""" + return self._node_id + + @property + def session_name(self): + """Get the session name (cluster ID).""" + return self._session_name + + @property + def node_ip_address(self): + """Get the IP address of this node.""" + return self._node_ip_address + + @property + def raylet_ip_address(self): + """Get the IP address of the raylet that this node connects to.""" + return self._raylet_ip_address + + @property + def address(self): + """Get the address for bootstrapping, e.g. the address to pass to + `ray start` or `ray.init()` to start worker nodes, that has been + converted to ip:port format. + """ + return self._gcs_address + + @property + def gcs_address(self): + """Get the gcs address.""" + assert self._gcs_address is not None, "Gcs address is not set" + return self._gcs_address + + @property + def redis_address(self): + """Get the cluster Redis address.""" + return self._redis_address + + @property + def redis_username(self): + """Get the cluster Redis username.""" + return self._ray_params.redis_username + + @property + def redis_password(self): + """Get the cluster Redis password.""" + return self._ray_params.redis_password + + @property + def object_ref_seed(self): + """Get the seed for deterministic generation of object refs""" + return self._ray_params.object_ref_seed + + @property + def plasma_store_socket_name(self): + """Get the node's plasma store socket name.""" + return self._plasma_store_socket_name + + @property + def unique_id(self): + """Get a unique identifier for this node.""" + return f"{self.node_ip_address}:{self._plasma_store_socket_name}" + + @property + def webui_url(self): + """Get the cluster's web UI url.""" + return self._webui_url + + @property + def raylet_socket_name(self): + """Get the node's raylet socket name.""" + return self._raylet_socket_name + + @property + def node_manager_port(self): + """Get the node manager's port.""" + return self._ray_params.node_manager_port + + @property + def metrics_export_port(self): + """Get the port that exposes metrics""" + return self._metrics_export_port + + @property + def runtime_env_agent_port(self): + """Get the port that exposes runtime env agent as http""" + return self._runtime_env_agent_port + + @property + def runtime_env_agent_address(self): + """Get the address that exposes runtime env agent as http""" + return f"http://{self._raylet_ip_address}:{self._runtime_env_agent_port}" + + @property + def dashboard_agent_listen_port(self): + """Get the dashboard agent's listen port""" + return self._dashboard_agent_listen_port + + @property + def dashboard_grpc_port(self): + """Get the dashboard head grpc port""" + return self._dashboard_grpc_port + + @property + def logging_config(self): + """Get the logging config of the current node.""" + return { + "log_rotation_max_bytes": self.max_bytes, + "log_rotation_backup_count": self.backup_count, + } + + @property + def address_info(self): + """Get a dictionary of addresses.""" + return { + "node_ip_address": self._node_ip_address, + "raylet_ip_address": self._raylet_ip_address, + "redis_address": self.redis_address, + "object_store_address": self._plasma_store_socket_name, + "raylet_socket_name": self._raylet_socket_name, + "webui_url": self._webui_url, + "session_dir": self._session_dir, + "metrics_export_port": self._metrics_export_port, + "gcs_address": self.gcs_address, + "address": self.address, + "dashboard_agent_listen_port": self.dashboard_agent_listen_port, + } + + def is_head(self): + return self.head + + def get_gcs_client(self): + if self._gcs_client is None: + self._init_gcs_client() + return self._gcs_client + + def _init_gcs_client(self): + if self.head: + gcs_process = self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][ + 0 + ].process + else: + gcs_process = None + + # TODO(ryw) instead of create a new GcsClient, wrap the one from + # CoreWorkerProcess to save a grpc channel. + for _ in range(ray_constants.NUM_REDIS_GET_RETRIES): + gcs_address = None + last_ex = None + try: + gcs_address = self.gcs_address + client = GcsClient( + address=gcs_address, + cluster_id=self._ray_params.cluster_id, # Hex string + ) + self.cluster_id = client.cluster_id + if self.head: + # Send a simple request to make sure GCS is alive + # if it's a head node. + client.internal_kv_get(b"dummy", None) + self._gcs_client = client + break + except Exception: + if gcs_process is not None and gcs_process.poll() is not None: + # GCS has exited. + break + last_ex = traceback.format_exc() + logger.debug(f"Connecting to GCS: {last_ex}") + time.sleep(1) + + if self._gcs_client is None: + if hasattr(self, "_logs_dir"): + with open(os.path.join(self._logs_dir, "gcs_server.err")) as err: + # Use " C " or " E " to exclude the stacktrace. + # This should work for most cases, especitally + # it's when GCS is starting. Only display last 10 lines of logs. + errors = [e for e in err.readlines() if " C " in e or " E " in e][ + -10: + ] + error_msg = "\n" + "".join(errors) + "\n" + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS. " + f" Last {len(errors)} lines of error files:" + f"{error_msg}." + f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}" + f" for details. Last connection error: {last_ex}" + ) + else: + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS. Last " + f"connection error: {last_ex}" + ) + + ray.experimental.internal_kv._initialize_internal_kv(self._gcs_client) + + def get_temp_dir_path(self): + """Get the path of the temporary directory.""" + return self._temp_dir + + def get_runtime_env_dir_path(self): + """Get the path of the runtime env.""" + return self._runtime_env_dir + + def get_session_dir_path(self): + """Get the path of the session directory.""" + return self._session_dir + + def get_logs_dir_path(self): + """Get the path of the log files directory.""" + return self._logs_dir + + def get_sockets_dir_path(self): + """Get the path of the sockets directory.""" + return self._sockets_dir + + def _make_inc_temp( + self, suffix: str = "", prefix: str = "", directory_name: Optional[str] = None + ): + """Return an incremental temporary file name. The file is not created. + + Args: + suffix: The suffix of the temp file. + prefix: The prefix of the temp file. + directory_name (str) : The base directory of the temp file. + + Returns: + A string of file name. If there existing a file having + the same name, the returned name will look like + "{directory_name}/{prefix}.{unique_index}{suffix}" + """ + if directory_name is None: + directory_name = ray._private.utils.get_ray_temp_dir() + directory_name = os.path.expanduser(directory_name) + index = self._incremental_dict[suffix, prefix, directory_name] + # `tempfile.TMP_MAX` could be extremely large, + # so using `range` in Python2.x should be avoided. + while index < tempfile.TMP_MAX: + if index == 0: + filename = os.path.join(directory_name, prefix + suffix) + else: + filename = os.path.join( + directory_name, prefix + "." + str(index) + suffix + ) + index += 1 + if not os.path.exists(filename): + # Save the index. + self._incremental_dict[suffix, prefix, directory_name] = index + return filename + + raise FileExistsError(errno.EEXIST, "No usable temporary filename found") + + def should_redirect_logs(self): + redirect_output = self._ray_params.redirect_output + if redirect_output is None: + # Fall back to stderr redirect environment variable. + redirect_output = ( + os.environ.get( + ray_constants.LOGGING_REDIRECT_STDERR_ENVIRONMENT_VARIABLE + ) + != "1" + ) + return redirect_output + + def get_log_file_names( + self, + name: str, + unique: bool = False, + create_out: bool = True, + create_err: bool = True, + ) -> Tuple[Optional[str], Optional[str]]: + """Get filename to dump logs for stdout and stderr, with no files opened. + If output redirection has been disabled, no files will + be opened and `(None, None)` will be returned. + + Args: + name: descriptive string for this log file. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + create_out: if True, create a .out file. + create_err: if True, create a .err file. + + Returns: + A tuple of two file handles for redirecting optional (stdout, stderr), + or `(None, None)` if output redirection is disabled. + """ + if not self.should_redirect_logs(): + return None, None + + log_stdout = None + log_stderr = None + + if create_out: + log_stdout = self._get_log_file_name(name, "out", unique=unique) + if create_err: + log_stderr = self._get_log_file_name(name, "err", unique=unique) + return log_stdout, log_stderr + + def get_log_file_handles( + self, + name: str, + unique: bool = False, + create_out: bool = True, + create_err: bool = True, + ) -> Tuple[Optional[IO[AnyStr]], Optional[IO[AnyStr]]]: + """Open log files with partially randomized filenames, returning the + file handles. If output redirection has been disabled, no files will + be opened and `(None, None)` will be returned. + + Args: + name: descriptive string for this log file. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + create_out: if True, create a .out file. + create_err: if True, create a .err file. + + Returns: + A tuple of two file handles for redirecting optional (stdout, stderr), + or `(None, None)` if output redirection is disabled. + """ + log_stdout_fname, log_stderr_fname = self.get_log_file_names( + name, unique=unique, create_out=create_out, create_err=create_err + ) + log_stdout = None if log_stdout_fname is None else open_log(log_stdout_fname) + log_stderr = None if log_stderr_fname is None else open_log(log_stderr_fname) + return log_stdout, log_stderr + + def _get_log_file_name( + self, + name: str, + suffix: str, + unique: bool = False, + ) -> str: + """Generate partially randomized filenames for log files. + + Args: + name: descriptive string for this log file. + suffix: suffix of the file. Usually it is .out of .err. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + + Returns: + A tuple of two file names for redirecting (stdout, stderr). + """ + # strip if the suffix is something like .out. + suffix = suffix.strip(".") + + if unique: + filename = self._make_inc_temp( + suffix=f".{suffix}", prefix=name, directory_name=self._logs_dir + ) + else: + filename = os.path.join(self._logs_dir, f"{name}.{suffix}") + return filename + + def _get_unused_port(self, allocated_ports=None): + if allocated_ports is None: + allocated_ports = set() + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + + # Try to generate a port that is far above the 'next available' one. + # This solves issue #8254 where GRPC fails because the port assigned + # from this method has been used by a different process. + for _ in range(ray_constants.NUM_PORT_RETRIES): + new_port = random.randint(port, 65535) + if new_port in allocated_ports: + # This port is allocated for other usage already, + # so we shouldn't use it even if it's not in use right now. + continue + new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + new_s.bind(("", new_port)) + except OSError: + new_s.close() + continue + s.close() + new_s.close() + return new_port + logger.error("Unable to succeed in selecting a random port.") + s.close() + return port + + def _prepare_socket_file(self, socket_path: str, default_prefix: str): + """Prepare the socket file for raylet and plasma. + + This method helps to prepare a socket file. + 1. Make the directory if the directory does not exist. + 2. If the socket file exists, do nothing (this just means we aren't the + first worker on the node). + + Args: + socket_path: the socket file to prepare. + """ + result = socket_path + is_mac = sys.platform.startswith("darwin") + if sys.platform == "win32": + if socket_path is None: + result = f"tcp://{self._localhost}" f":{self._get_unused_port()}" + else: + if socket_path is None: + result = self._make_inc_temp( + prefix=default_prefix, directory_name=self._sockets_dir + ) + else: + try_to_create_directory(os.path.dirname(socket_path)) + + # Check socket path length to make sure it's short enough + maxlen = (104 if is_mac else 108) - 1 # sockaddr_un->sun_path + if len(result.split("://", 1)[-1].encode("utf-8")) > maxlen: + raise OSError( + f"AF_UNIX path length cannot exceed {maxlen} bytes: {result!r}" + ) + return result + + def _get_cached_port( + self, port_name: str, default_port: Optional[int] = None + ) -> int: + """Get a port number from a cache on this node. + + Different driver processes on a node should use the same ports for + some purposes, e.g. exporting metrics. This method returns a port + number for the given port name and caches it in a file. If the + port isn't already cached, an unused port is generated and cached. + + Args: + port_name: the name of the port, e.g. metrics_export_port + default_port (Optional[int]): The port to return and cache if no + port has already been cached for the given port_name. If None, an + unused port is generated and cached. + Returns: + port: the port number. + """ + file_path = os.path.join(self.get_session_dir_path(), "ports_by_node.json") + + # Make sure only the ports in RAY_CACHED_PORTS are cached. + assert port_name in ray_constants.RAY_ALLOWED_CACHED_PORTS + + # Maps a Node.unique_id to a dict that maps port names to port numbers. + ports_by_node: Dict[str, Dict[str, int]] = defaultdict(dict) + + with FileLock(file_path + ".lock"): + if not os.path.exists(file_path): + with open(file_path, "w") as f: + json.dump({}, f) + + with open(file_path, "r") as f: + ports_by_node.update(json.load(f)) + + if ( + self.unique_id in ports_by_node + and port_name in ports_by_node[self.unique_id] + ): + # The port has already been cached at this node, so use it. + port = int(ports_by_node[self.unique_id][port_name]) + else: + # Pick a new port to use and cache it at this node. + allocated_ports = set(ports_by_node[self.unique_id].values()) + + if default_port is not None and default_port in allocated_ports: + # The default port is already in use, so don't use it. + default_port = None + + port = default_port or self._get_unused_port(allocated_ports) + + ports_by_node[self.unique_id][port_name] = port + with open(file_path, "w") as f: + json.dump(ports_by_node, f) + + return port + + def _wait_and_get_for_node_address(self, timeout_s: int = 60) -> str: + """Wait until the RAY_NODE_IP_FILENAME file is avialable. + + RAY_NODE_IP_FILENAME is created when a ray instance is started. + + Args: + timeout_s: If the ip address is not found within this + timeout, it will raise ValueError. + Returns: + The node_ip_address of the current session if it finds it + within timeout_s. + """ + for i in range(timeout_s): + node_ip_address = ray._private.services.get_cached_node_ip_address( + self.get_session_dir_path() + ) + + if node_ip_address is not None: + return node_ip_address + + time.sleep(1) + if i % 10 == 0: + logger.info( + f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` " + f"file from {self.get_session_dir_path()}. " + "Have you started Ray instance using " + "`ray start` or `ray.init`?" + ) + + raise ValueError( + f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` " + f"file from {self.get_session_dir_path()}. " + f"for {timeout_s} seconds. " + "A ray instance hasn't started. " + "Did you do `ray start` or `ray.init` on this host?" + ) + + def start_reaper_process(self): + """ + Start the reaper process. + + This must be the first process spawned and should only be called when + ray processes should be cleaned up if this process dies. + """ + assert ( + not self.kernel_fate_share + ), "a reaper should not be used with kernel fate-sharing" + process_info = ray._private.services.start_reaper(fate_share=False) + assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes + if process_info is not None: + self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ + process_info, + ] + + def start_log_monitor(self): + """Start the log monitor.""" + # Only redirect logs to .err. .err file is only useful when the + # component has an unexpected output to stdout/stderr. + _, stderr_file = self.get_log_file_handles( + "log_monitor", unique=True, create_out=False + ) + process_info = ray._private.services.start_log_monitor( + self.get_session_dir_path(), + self._logs_dir, + self.gcs_address, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + redirect_logging=self.should_redirect_logs(), + stdout_file=stderr_file, + stderr_file=stderr_file, + ) + assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ + process_info, + ] + + def start_api_server( + self, *, include_dashboard: Optional[bool], raise_on_failure: bool + ): + """Start the dashboard. + + Args: + include_dashboard: If true, this will load all dashboard-related modules + when starting the API server. Otherwise, it will only + start the modules that are not relevant to the dashboard. + raise_on_failure: If true, this will raise an exception + if we fail to start the API server. Otherwise it will print + a warning if we fail to start the API server. + """ + # Only redirect logs to .err. .err file is only useful when the + # component has an unexpected output to stdout/stderr. + _, stderr_file = self.get_log_file_handles( + "dashboard", unique=True, create_out=False + ) + self._webui_url, process_info = ray._private.services.start_api_server( + include_dashboard, + raise_on_failure, + self._ray_params.dashboard_host, + self.gcs_address, + self.cluster_id.hex(), + self._node_ip_address, + self._temp_dir, + self._logs_dir, + self._session_dir, + port=self._ray_params.dashboard_port, + dashboard_grpc_port=self._ray_params.dashboard_grpc_port, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + redirect_logging=self.should_redirect_logs(), + stdout_file=stderr_file, + stderr_file=stderr_file, + ) + assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes + if process_info is not None: + self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ + process_info, + ] + self.get_gcs_client().internal_kv_put( + b"webui:url", + self._webui_url.encode(), + True, + ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + def start_gcs_server(self): + """Start the gcs server.""" + gcs_server_port = self._ray_params.gcs_server_port + assert gcs_server_port > 0 + assert self._gcs_address is None, "GCS server is already running." + assert self._gcs_client is None, "GCS client is already connected." + + # TODO(hjiang): Update stderr to pass filename and get spdlog to handle + # logging as well. + stdout_log_fname, _ = self.get_log_file_names( + "gcs_server", unique=True, create_out=True, create_err=False + ) + _, stderr_file = self.get_log_file_handles( + "gcs_server", unique=True, create_out=False, create_err=True + ) + process_info = ray._private.services.start_gcs_server( + self.redis_address, + log_dir=self._logs_dir, + ray_log_filepath=stdout_log_fname, + stderr_file=stderr_file, + session_name=self.session_name, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + config=self._config, + fate_share=self.kernel_fate_share, + gcs_server_port=gcs_server_port, + metrics_agent_port=self._ray_params.metrics_agent_port, + node_ip_address=self._node_ip_address, + ) + assert ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [ + process_info, + ] + # Connecting via non-localhost address may be blocked by firewall rule, + # e.g. https://github.com/ray-project/ray/issues/15780 + # TODO(mwtian): figure out a way to use 127.0.0.1 for local connection + # when possible. + self._gcs_address = f"{self._node_ip_address}:" f"{gcs_server_port}" + + def start_raylet( + self, + plasma_directory: str, + object_store_memory: int, + use_valgrind: bool = False, + use_profiler: bool = False, + enable_physical_mode: bool = False, + ): + """Start the raylet. + + Args: + use_valgrind: True if we should start the process in + valgrind. + use_profiler: True if we should start the process in the + valgrind profiler. + """ + stdout_log_fname, _ = self.get_log_file_names( + "raylet", unique=True, create_out=True, create_err=False + ) + _, stderr_file = self.get_log_file_handles( + "raylet", unique=True, create_out=False, create_err=True + ) + process_info = ray._private.services.start_raylet( + self.redis_address, + self.gcs_address, + self._node_id, + self._node_ip_address, + self._ray_params.node_manager_port, + self._raylet_socket_name, + self._plasma_store_socket_name, + self.cluster_id.hex(), + self._ray_params.worker_path, + self._ray_params.setup_worker_path, + self._ray_params.storage, + self._temp_dir, + self._session_dir, + self._runtime_env_dir, + self._logs_dir, + self.get_resource_spec(), + plasma_directory, + object_store_memory, + self.session_name, + is_head_node=self.is_head(), + min_worker_port=self._ray_params.min_worker_port, + max_worker_port=self._ray_params.max_worker_port, + worker_port_list=self._ray_params.worker_port_list, + object_manager_port=self._ray_params.object_manager_port, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + metrics_agent_port=self._ray_params.metrics_agent_port, + runtime_env_agent_port=self._ray_params.runtime_env_agent_port, + metrics_export_port=self._metrics_export_port, + dashboard_agent_listen_port=self._ray_params.dashboard_agent_listen_port, + use_valgrind=use_valgrind, + use_profiler=use_profiler, + ray_log_filepath=stdout_log_fname, + stderr_file=stderr_file, + huge_pages=self._ray_params.huge_pages, + fate_share=self.kernel_fate_share, + socket_to_use=None, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + ray_debugger_external=self._ray_params.ray_debugger_external, + env_updates=self._ray_params.env_vars, + node_name=self._ray_params.node_name, + webui=self._webui_url, + labels=self._get_node_labels(), + enable_physical_mode=enable_physical_mode, + ) + assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] + + def start_worker(self): + """Start a worker process.""" + raise NotImplementedError + + def start_monitor(self): + """Start the monitor. + + Autoscaling output goes to these monitor.err/out files, and + any modification to these files may break existing + cluster launching commands. + """ + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + stdout_file, stderr_file = self.get_log_file_handles("monitor", unique=True) + process_info = ray._private.services.start_monitor( + self.gcs_address, + self._logs_dir, + stdout_file=stdout_file, + stderr_file=stderr_file, + autoscaling_config=self._ray_params.autoscaling_config, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + monitor_ip=self._node_ip_address, + autoscaler_v2=is_autoscaler_v2(fetch_from_server=True), + ) + assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] + + def start_ray_client_server(self): + """Start the ray client server process.""" + stdout_file, stderr_file = self.get_log_file_handles( + "ray_client_server", unique=True + ) + process_info = ray._private.services.start_ray_client_server( + self.address, + self._node_ip_address, + self._ray_params.ray_client_server_port, + stdout_file=stdout_file, + stderr_file=stderr_file, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share, + runtime_env_agent_address=self.runtime_env_agent_address, + ) + assert ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER] = [ + process_info + ] + + def _write_cluster_info_to_kv(self): + """Write the cluster metadata to GCS. + Cluster metadata is always recorded, but they are + not reported unless usage report is enabled. + Check `usage_stats_head.py` for more details. + """ + # Make sure the cluster metadata wasn't reported before. + import ray._private.usage.usage_lib as ray_usage_lib + + ray_usage_lib.put_cluster_metadata( + self.get_gcs_client(), ray_init_cluster=self.ray_init_cluster + ) + # Make sure GCS is up. + added = self.get_gcs_client().internal_kv_put( + b"session_name", + self._session_name.encode(), + False, + ray_constants.KV_NAMESPACE_SESSION, + ) + if not added: + curr_val = self.get_gcs_client().internal_kv_get( + b"session_name", ray_constants.KV_NAMESPACE_SESSION + ) + assert curr_val == self._session_name.encode("utf-8"), ( + f"Session name {self._session_name} does not match " + f"persisted value {curr_val}. Perhaps there was an " + f"error connecting to Redis." + ) + + self.get_gcs_client().internal_kv_put( + b"session_dir", + self._session_dir.encode(), + True, + ray_constants.KV_NAMESPACE_SESSION, + ) + self.get_gcs_client().internal_kv_put( + b"temp_dir", + self._temp_dir.encode(), + True, + ray_constants.KV_NAMESPACE_SESSION, + ) + if self._ray_params.storage is not None: + self.get_gcs_client().internal_kv_put( + b"storage", + self._ray_params.storage.encode(), + True, + ray_constants.KV_NAMESPACE_SESSION, + ) + # Add tracing_startup_hook to redis / internal kv manually + # since internal kv is not yet initialized. + if self._ray_params.tracing_startup_hook: + self.get_gcs_client().internal_kv_put( + b"tracing_startup_hook", + self._ray_params.tracing_startup_hook.encode(), + True, + ray_constants.KV_NAMESPACE_TRACING, + ) + + def start_head_processes(self): + """Start head processes on the node.""" + logger.debug( + f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}." + ) + assert self._gcs_address is None + assert self._gcs_client is None + + self.start_gcs_server() + assert self.get_gcs_client() is not None + self._write_cluster_info_to_kv() + + if not self._ray_params.no_monitor: + self.start_monitor() + + if self._ray_params.ray_client_server_port: + self.start_ray_client_server() + + if self._ray_params.include_dashboard is None: + # Default + raise_on_api_server_failure = False + else: + raise_on_api_server_failure = self._ray_params.include_dashboard + + self.start_api_server( + include_dashboard=self._ray_params.include_dashboard, + raise_on_failure=raise_on_api_server_failure, + ) + + def start_ray_processes(self): + """Start all of the processes on the node.""" + logger.debug( + f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}." + ) + + if not self.head: + # Get the system config from GCS first if this is a non-head node. + gcs_options = ray._raylet.GcsClientOptions.create( + self.gcs_address, + self.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + global_state = ray._private.state.GlobalState() + global_state._initialize_global_state(gcs_options) + new_config = global_state.get_system_config() + assert self._config.items() <= new_config.items(), ( + "The system config from GCS is not a superset of the local" + " system config. There might be a configuration inconsistency" + " issue between the head node and non-head nodes." + f" Local system config: {self._config}," + f" GCS system config: {new_config}" + ) + self._config = new_config + + # Make sure we don't call `determine_plasma_store_config` multiple + # times to avoid printing multiple warnings. + resource_spec = self.get_resource_spec() + ( + plasma_directory, + object_store_memory, + ) = ray._private.services.determine_plasma_store_config( + resource_spec.object_store_memory, + plasma_directory=self._ray_params.plasma_directory, + huge_pages=self._ray_params.huge_pages, + ) + self.start_raylet(plasma_directory, object_store_memory) + if self._ray_params.include_log_monitor: + self.start_log_monitor() + + def _kill_process_type( + self, + process_type, + allow_graceful: bool = False, + check_alive: bool = True, + wait: bool = False, + ): + """Kill a process of a given type. + + If the process type is PROCESS_TYPE_REDIS_SERVER, then we will kill all + of the Redis servers. + + If the process was started in valgrind, then we will raise an exception + if the process has a non-zero exit code. + + Args: + process_type: The type of the process to kill. + allow_graceful: Send a SIGTERM first and give the process + time to exit gracefully. If that doesn't work, then use + SIGKILL. We usually want to do this outside of tests. + check_alive: If true, then we expect the process to be alive + and will raise an exception if the process is already dead. + wait: If true, then this method will not return until the + process in question has exited. + + Raises: + This process raises an exception in the following cases: + 1. The process had already died and check_alive is true. + 2. The process had been started in valgrind and had a non-zero + exit code. + """ + + # Ensure thread safety + with self.removal_lock: + self._kill_process_impl( + process_type, + allow_graceful=allow_graceful, + check_alive=check_alive, + wait=wait, + ) + + def _kill_process_impl( + self, process_type, allow_graceful=False, check_alive=True, wait=False + ): + """See `_kill_process_type`.""" + if process_type not in self.all_processes: + return + process_infos = self.all_processes[process_type] + if process_type != ray_constants.PROCESS_TYPE_REDIS_SERVER: + assert len(process_infos) == 1 + for process_info in process_infos: + process = process_info.process + # Handle the case where the process has already exited. + if process.poll() is not None: + if check_alive: + raise RuntimeError( + "Attempting to kill a process of type " + f"'{process_type}', but this process is already dead." + ) + else: + continue + + if process_info.use_valgrind: + process.terminate() + process.wait() + if process.returncode != 0: + message = ( + "Valgrind detected some errors in process of " + f"type {process_type}. Error code {process.returncode}." + ) + if process_info.stdout_file is not None: + with open(process_info.stdout_file, "r") as f: + message += "\nPROCESS STDOUT:\n" + f.read() + if process_info.stderr_file is not None: + with open(process_info.stderr_file, "r") as f: + message += "\nPROCESS STDERR:\n" + f.read() + raise RuntimeError(message) + continue + + if process_info.use_valgrind_profiler: + # Give process signal to write profiler data. + os.kill(process.pid, signal.SIGINT) + # Wait for profiling data to be written. + time.sleep(0.1) + + if allow_graceful: + process.terminate() + # Allow the process one second to exit gracefully. + timeout_seconds = 1 + try: + process.wait(timeout_seconds) + except subprocess.TimeoutExpired: + pass + + # If the process did not exit, force kill it. + if process.poll() is None: + process.kill() + # The reason we usually don't call process.wait() here is that + # there's some chance we'd end up waiting a really long time. + if wait: + process.wait() + + del self.all_processes[process_type] + + def kill_redis(self, check_alive: bool = True): + """Kill the Redis servers. + + Args: + check_alive: Raise an exception if any of the processes + were already dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REDIS_SERVER, check_alive=check_alive + ) + + def kill_raylet(self, check_alive: bool = True): + """Kill the raylet. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_RAYLET, check_alive=check_alive + ) + + def kill_log_monitor(self, check_alive: bool = True): + """Kill the log monitor. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_LOG_MONITOR, check_alive=check_alive + ) + + def kill_reporter(self, check_alive: bool = True): + """Kill the reporter. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REPORTER, check_alive=check_alive + ) + + def kill_dashboard(self, check_alive: bool = True): + """Kill the dashboard. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_DASHBOARD, check_alive=check_alive + ) + + def kill_monitor(self, check_alive: bool = True): + """Kill the monitor. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_MONITOR, check_alive=check_alive + ) + + def kill_gcs_server(self, check_alive: bool = True): + """Kill the gcs server. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_GCS_SERVER, check_alive=check_alive, wait=True + ) + # Clear GCS client and address to indicate no GCS server is running. + self._gcs_address = None + self._gcs_client = None + + def kill_reaper(self, check_alive: bool = True): + """Kill the reaper process. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REAPER, check_alive=check_alive + ) + + def kill_all_processes(self, check_alive=True, allow_graceful=False, wait=False): + """Kill all of the processes. + + Note that This is slower than necessary because it calls kill, wait, + kill, wait, ... instead of kill, kill, ..., wait, wait, ... + + Args: + check_alive: Raise an exception if any of the processes were + already dead. + wait: If true, then this method will not return until the + process in question has exited. + """ + # Kill the raylet first. This is important for suppressing errors at + # shutdown because we give the raylet a chance to exit gracefully and + # clean up its child worker processes. If we were to kill the plasma + # store (or Redis) first, that could cause the raylet to exit + # ungracefully, leading to more verbose output from the workers. + if ray_constants.PROCESS_TYPE_RAYLET in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_RAYLET, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + if ray_constants.PROCESS_TYPE_GCS_SERVER in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_GCS_SERVER, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + # We call "list" to copy the keys because we are modifying the + # dictionary while iterating over it. + for process_type in list(self.all_processes.keys()): + # Need to kill the reaper process last in case we die unexpectedly + # while cleaning up. + if process_type != ray_constants.PROCESS_TYPE_REAPER: + self._kill_process_type( + process_type, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + if ray_constants.PROCESS_TYPE_REAPER in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_REAPER, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + def live_processes(self): + """Return a list of the live processes. + + Returns: + A list of the live processes. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is None: + result.append((process_type, process_info.process)) + return result + + def dead_processes(self): + """Return a list of the dead processes. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + A list of the dead processes ignoring the ones that have been + explicitly killed. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is not None: + result.append((process_type, process_info.process)) + return result + + def any_processes_alive(self): + """Return true if any processes are still alive. + + Returns: + True if any process is still alive. + """ + return any(self.live_processes()) + + def remaining_processes_alive(self): + """Return true if all remaining processes are still alive. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if any process that wasn't explicitly killed is still alive. + """ + return not any(self.dead_processes()) + + def destroy_external_storage(self): + object_spilling_config = self._config.get("object_spilling_config", {}) + if object_spilling_config: + object_spilling_config = json.loads(object_spilling_config) + from ray._private import external_storage + + storage = external_storage.setup_external_storage( + object_spilling_config, self._node_id, self._session_name + ) + storage.destroy_external_storage() + + def validate_external_storage(self): + """Make sure we can setup the object spilling external storage. + This will also fill up the default setting for object spilling + if not specified. + """ + object_spilling_config = self._config.get("object_spilling_config", {}) + automatic_spilling_enabled = self._config.get( + "automatic_object_spilling_enabled", True + ) + if not automatic_spilling_enabled: + return + + if not object_spilling_config: + object_spilling_config = os.environ.get("RAY_object_spilling_config", "") + + # If the config is not specified, we fill up the default. + if not object_spilling_config: + object_spilling_config = json.dumps( + {"type": "filesystem", "params": {"directory_path": self._session_dir}} + ) + + # Try setting up the storage. + # Configure the proper system config. + # We need to set both ray param's system config and self._config + # because they could've been diverged at this point. + deserialized_config = json.loads(object_spilling_config) + self._ray_params._system_config[ + "object_spilling_config" + ] = object_spilling_config + self._config["object_spilling_config"] = object_spilling_config + + is_external_storage_type_fs = deserialized_config["type"] == "filesystem" + self._ray_params._system_config[ + "is_external_storage_type_fs" + ] = is_external_storage_type_fs + self._config["is_external_storage_type_fs"] = is_external_storage_type_fs + + # Validate external storage usage. + from ray._private import external_storage + + # Node ID is available only after GCS is connected. However, + # validate_external_storage() needs to be called before it to + # be able to validate the configs early. Therefore, we use a + # dummy node ID here and make sure external storage can be set + # up based on the provided config. This storage is destroyed + # right after the validation. + dummy_node_id = ray.NodeID.from_random().hex() + storage = external_storage.setup_external_storage( + deserialized_config, dummy_node_id, self._session_name + ) + storage.destroy_external_storage() + external_storage.reset_external_storage() + + def _record_stats(self): + # This is only called when a new node is started. + # Initialize the internal kv so that the metrics can be put + from ray._private.usage.usage_lib import ( + TagKey, + record_extra_usage_tag, + record_hardware_usage, + ) + + if not ray.experimental.internal_kv._internal_kv_initialized(): + ray.experimental.internal_kv._initialize_internal_kv(self.get_gcs_client()) + assert ray.experimental.internal_kv._internal_kv_initialized() + if self.head: + # record head node stats + gcs_storage_type = ( + "redis" if os.environ.get("RAY_REDIS_ADDRESS") is not None else "memory" + ) + record_extra_usage_tag(TagKey.GCS_STORAGE, gcs_storage_type) + cpu_model_name = ray._private.utils.get_current_node_cpu_model_name() + if cpu_model_name: + # CPU model name can be an arbitrary long string + # so we truncate it to the first 50 characters + # to avoid any issues. + record_hardware_usage(cpu_model_name[:50]) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/parameter.py b/.venv/lib/python3.11/site-packages/ray/_private/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..1185df149f766851788d9b57e55040d3ca2d531f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/parameter.py @@ -0,0 +1,483 @@ +import logging +import os +from typing import Dict, List, Optional + +import ray._private.ray_constants as ray_constants +from ray._private.utils import ( + validate_node_labels, + check_ray_client_dependencies_installed, +) + + +logger = logging.getLogger(__name__) + + +class RayParams: + """A class used to store the parameters used by Ray. + + Attributes: + redis_address: The address of the Redis server to connect to. If + this address is not provided, then this command will start Redis, a + raylet, a plasma store, a plasma manager, and some workers. + It will also kill these processes when Python exits. + redis_port: The port that the primary Redis shard should listen + to. If None, then it will fall back to + ray._private.ray_constants.DEFAULT_PORT, or a random port if the default is + not available. + redis_shard_ports: A list of the ports to use for the non-primary Redis + shards. If None, then it will fall back to the ports right after + redis_port, or random ports if those are not available. + num_cpus: Number of CPUs to configure the raylet with. + num_gpus: Number of GPUs to configure the raylet with. + resources: A dictionary mapping the name of a resource to the quantity + of that resource available. + labels: The key-value labels of the node. + memory: Total available memory for workers requesting memory. + object_store_memory: The amount of memory (in bytes) to start the + object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + object_manager_port int: The port to use for the object manager. + node_manager_port: The port to use for the node manager. + gcs_server_port: The port to use for the GCS server. + node_ip_address: The IP address of the node that we are on. + raylet_ip_address: The IP address of the raylet that this node + connects to. + min_worker_port: The lowest port number that workers will bind + on. If not set or set to 0, random ports will be chosen. + max_worker_port: The highest port number that workers will bind + on. If set, min_worker_port must also be set. + worker_port_list: An explicit list of ports to be used for + workers (comma-separated). Overrides min_worker_port and + max_worker_port. + ray_client_server_port: The port number the ray client server + will bind on. If not set, the ray client server will not + be started. + object_ref_seed: Used to seed the deterministic generation of + object refs. The same value can be used across multiple runs of the + same job in order to generate the object refs in a consistent + manner. However, the same ID should not be used for different jobs. + redirect_output: True if stdout and stderr for non-worker + processes should be redirected to files and false otherwise. + external_addresses: The address of external Redis server to + connect to, in format of "ip1:port1,ip2:port2,...". If this + address is provided, then ray won't start Redis instances in the + head node but use external Redis server(s) instead. + num_redis_shards: The number of Redis shards to start in addition to + the primary Redis shard. + redis_max_clients: If provided, attempt to configure Redis with this + maxclients number. + redis_username: Prevents external clients without the username + from connecting to Redis if provided. + redis_password: Prevents external clients without the password + from connecting to Redis if provided. + plasma_directory: A directory where the Plasma memory mapped files will + be created. + worker_path: The path of the source code that will be run by the + worker. + setup_worker_path: The path of the Python file that will set up + the environment for the worker process. + huge_pages: Boolean flag indicating whether to start the Object + Store with hugetlbfs support. Requires plasma_directory. + include_dashboard: Boolean flag indicating whether to start the web + UI, which displays the status of the Ray cluster. If this value is + None, then the UI will be started if the relevant dependencies are + present. + dashboard_host: The host to bind the web UI server to. Can either be + localhost (127.0.0.1) or 0.0.0.0 (available from all interfaces). + By default, this is set to localhost to prevent access from + external machines. + dashboard_port: The port to bind the dashboard server to. + Defaults to 8265. + dashboard_agent_listen_port: The port for dashboard agents to listen on + for HTTP requests. + Defaults to 52365. + dashboard_grpc_port: The port for the dashboard head process to listen + for gRPC on. + Defaults to random available port. + runtime_env_agent_port: The port at which the runtime env agent + listens to for HTTP. + Defaults to random available port. + plasma_store_socket_name: If provided, it specifies the socket + name used by the plasma store. + raylet_socket_name: If provided, it specifies the socket path + used by the raylet process. + temp_dir: If provided, it will specify the root temporary + directory for the Ray process. Must be an absolute path. + storage: Specify a URI for persistent cluster-wide storage. This storage path + must be accessible by all nodes of the cluster, otherwise an error will be + raised. + runtime_env_dir_name: If provided, specifies the directory that + will be created in the session dir to hold runtime_env files. + include_log_monitor: If True, then start a log monitor to + monitor the log files for all processes on this node and push their + contents to Redis. + autoscaling_config: path to autoscaling config file. + metrics_agent_port: The port to bind metrics agent. + metrics_export_port: The port at which metrics are exposed + through a Prometheus endpoint. + no_monitor: If True, the ray autoscaler monitor for this cluster + will not be started. + _system_config: Configuration for overriding RayConfig + defaults. Used to set system configuration and for experimental Ray + core feature flags. + enable_object_reconstruction: Enable plasma reconstruction on + failure. + ray_debugger_external: If true, make the Ray debugger for a + worker available externally to the node it is running on. This will + bind on 0.0.0.0 instead of localhost. + env_vars: Override environment variables for the raylet. + session_name: The name of the session of the ray cluster. + webui: The url of the UI. + cluster_id: The cluster ID in hex string. + enable_physical_mode: Whether physical mode is enabled, which applies + constraint to tasks' resource consumption. As of now, only memory resource + is supported. + """ + + def __init__( + self, + redis_address: Optional[str] = None, + gcs_address: Optional[str] = None, + num_cpus: Optional[int] = None, + num_gpus: Optional[int] = None, + resources: Optional[Dict[str, float]] = None, + labels: Optional[Dict[str, str]] = None, + memory: Optional[float] = None, + object_store_memory: Optional[float] = None, + redis_max_memory: Optional[float] = None, + redis_port: Optional[int] = None, + redis_shard_ports: Optional[List[int]] = None, + object_manager_port: Optional[int] = None, + node_manager_port: int = 0, + gcs_server_port: Optional[int] = None, + node_ip_address: Optional[str] = None, + node_name: Optional[str] = None, + raylet_ip_address: Optional[str] = None, + min_worker_port: Optional[int] = None, + max_worker_port: Optional[int] = None, + worker_port_list: Optional[List[int]] = None, + ray_client_server_port: Optional[int] = None, + object_ref_seed: Optional[int] = None, + driver_mode=None, + redirect_output: Optional[bool] = None, + external_addresses: Optional[List[str]] = None, + num_redis_shards: Optional[int] = None, + redis_max_clients: Optional[int] = None, + redis_username: Optional[str] = ray_constants.REDIS_DEFAULT_USERNAME, + redis_password: Optional[str] = ray_constants.REDIS_DEFAULT_PASSWORD, + plasma_directory: Optional[str] = None, + worker_path: Optional[str] = None, + setup_worker_path: Optional[str] = None, + huge_pages: Optional[bool] = False, + include_dashboard: Optional[bool] = None, + dashboard_host: Optional[str] = ray_constants.DEFAULT_DASHBOARD_IP, + dashboard_port: Optional[bool] = ray_constants.DEFAULT_DASHBOARD_PORT, + dashboard_agent_listen_port: Optional[ + int + ] = ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT, + runtime_env_agent_port: Optional[int] = None, + dashboard_grpc_port: Optional[int] = None, + plasma_store_socket_name: Optional[str] = None, + raylet_socket_name: Optional[str] = None, + temp_dir: Optional[str] = None, + storage: Optional[str] = None, + runtime_env_dir_name: Optional[str] = None, + include_log_monitor: Optional[str] = None, + autoscaling_config: Optional[str] = None, + ray_debugger_external: bool = False, + _system_config: Optional[Dict[str, str]] = None, + enable_object_reconstruction: Optional[bool] = False, + metrics_agent_port: Optional[int] = None, + metrics_export_port: Optional[int] = None, + tracing_startup_hook=None, + no_monitor: Optional[bool] = False, + env_vars: Optional[Dict[str, str]] = None, + session_name: Optional[str] = None, + webui: Optional[str] = None, + cluster_id: Optional[str] = None, + node_id: Optional[str] = None, + enable_physical_mode: bool = False, + ): + self.redis_address = redis_address + self.gcs_address = gcs_address + self.num_cpus = num_cpus + self.num_gpus = num_gpus + self.memory = memory + self.object_store_memory = object_store_memory + self.resources = resources + self.redis_max_memory = redis_max_memory + self.redis_port = redis_port + self.redis_shard_ports = redis_shard_ports + self.object_manager_port = object_manager_port + self.node_manager_port = node_manager_port + self.gcs_server_port = gcs_server_port + self.node_ip_address = node_ip_address + self.node_name = node_name + self.raylet_ip_address = raylet_ip_address + self.min_worker_port = min_worker_port + self.max_worker_port = max_worker_port + self.worker_port_list = worker_port_list + self.ray_client_server_port = ray_client_server_port + self.driver_mode = driver_mode + self.redirect_output = redirect_output + self.external_addresses = external_addresses + self.num_redis_shards = num_redis_shards + self.redis_max_clients = redis_max_clients + self.redis_username = redis_username + self.redis_password = redis_password + self.plasma_directory = plasma_directory + self.worker_path = worker_path + self.setup_worker_path = setup_worker_path + self.huge_pages = huge_pages + self.include_dashboard = include_dashboard + self.dashboard_host = dashboard_host + self.dashboard_port = dashboard_port + self.dashboard_agent_listen_port = dashboard_agent_listen_port + self.dashboard_grpc_port = dashboard_grpc_port + self.runtime_env_agent_port = runtime_env_agent_port + self.plasma_store_socket_name = plasma_store_socket_name + self.raylet_socket_name = raylet_socket_name + self.temp_dir = temp_dir + self.storage = storage or os.environ.get( + ray_constants.RAY_STORAGE_ENVIRONMENT_VARIABLE + ) + self.runtime_env_dir_name = ( + runtime_env_dir_name or ray_constants.DEFAULT_RUNTIME_ENV_DIR_NAME + ) + self.include_log_monitor = include_log_monitor + self.autoscaling_config = autoscaling_config + self.metrics_agent_port = metrics_agent_port + self.metrics_export_port = metrics_export_port + self.tracing_startup_hook = tracing_startup_hook + self.no_monitor = no_monitor + self.object_ref_seed = object_ref_seed + self.ray_debugger_external = ray_debugger_external + self.env_vars = env_vars + self.session_name = session_name + self.webui = webui + self._system_config = _system_config or {} + self._enable_object_reconstruction = enable_object_reconstruction + self.labels = labels + self._check_usage() + self.cluster_id = cluster_id + self.node_id = node_id + self.enable_physical_mode = enable_physical_mode + + # Set the internal config options for object reconstruction. + if enable_object_reconstruction: + # Turn off object pinning. + if self._system_config is None: + self._system_config = dict() + print(self._system_config) + self._system_config["lineage_pinning_enabled"] = True + + def update(self, **kwargs): + """Update the settings according to the keyword arguments. + + Args: + kwargs: The keyword arguments to set corresponding fields. + """ + for arg in kwargs: + if hasattr(self, arg): + setattr(self, arg, kwargs[arg]) + else: + raise ValueError(f"Invalid RayParams parameter in update: {arg}") + + self._check_usage() + + def update_if_absent(self, **kwargs): + """Update the settings when the target fields are None. + + Args: + kwargs: The keyword arguments to set corresponding fields. + """ + for arg in kwargs: + if hasattr(self, arg): + if getattr(self, arg) is None: + setattr(self, arg, kwargs[arg]) + else: + raise ValueError( + f"Invalid RayParams parameter in update_if_absent: {arg}" + ) + + self._check_usage() + + def update_pre_selected_port(self): + """Update the pre-selected port information + + Returns: + The dictionary mapping of component -> ports. + """ + + def wrap_port(port): + # 0 port means select a random port for the grpc server. + if port is None or port == 0: + return [] + else: + return [port] + + # Create a dictionary of the component -> port mapping. + pre_selected_ports = { + "gcs": wrap_port(self.redis_port), + "object_manager": wrap_port(self.object_manager_port), + "node_manager": wrap_port(self.node_manager_port), + "gcs_server": wrap_port(self.gcs_server_port), + "client_server": wrap_port(self.ray_client_server_port), + "dashboard": wrap_port(self.dashboard_port), + "dashboard_agent_grpc": wrap_port(self.metrics_agent_port), + "dashboard_agent_http": wrap_port(self.dashboard_agent_listen_port), + "dashboard_grpc": wrap_port(self.dashboard_grpc_port), + "runtime_env_agent": wrap_port(self.runtime_env_agent_port), + "metrics_export": wrap_port(self.metrics_export_port), + } + redis_shard_ports = self.redis_shard_ports + if redis_shard_ports is None: + redis_shard_ports = [] + pre_selected_ports["redis_shards"] = redis_shard_ports + if self.worker_port_list is None: + if self.min_worker_port is not None and self.max_worker_port is not None: + pre_selected_ports["worker_ports"] = list( + range(self.min_worker_port, self.max_worker_port + 1) + ) + else: + # The dict is not updated when it requires random ports. + pre_selected_ports["worker_ports"] = [] + else: + pre_selected_ports["worker_ports"] = [ + int(port) for port in self.worker_port_list.split(",") + ] + + # Update the pre selected port set. + self.reserved_ports = set() + for comp, port_list in pre_selected_ports.items(): + for port in port_list: + if port in self.reserved_ports: + raise ValueError( + f"Ray component {comp} is trying to use " + f"a port number {port} that is used by other components.\n" + f"Port information: {self._format_ports(pre_selected_ports)}\n" + "If you allocate ports, please make sure the same port " + "is not used by multiple components." + ) + self.reserved_ports.add(port) + + def _check_usage(self): + if self.worker_port_list is not None: + for port_str in self.worker_port_list.split(","): + try: + port = int(port_str) + except ValueError as e: + raise ValueError( + "worker_port_list must be a comma-separated " + f"list of integers: {e}" + ) from None + + if port < 1024 or port > 65535: + raise ValueError( + "Ports in worker_port_list must be " + f"between 1024 and 65535. Got: {port}" + ) + + # Used primarily for testing. + if os.environ.get("RAY_USE_RANDOM_PORTS", False): + if self.min_worker_port is None and self.max_worker_port is None: + self.min_worker_port = 0 + self.max_worker_port = 0 + + if self.min_worker_port is not None: + if self.min_worker_port != 0 and ( + self.min_worker_port < 1024 or self.min_worker_port > 65535 + ): + raise ValueError( + "min_worker_port must be 0 or an integer between 1024 and 65535." + ) + + if self.max_worker_port is not None: + if self.min_worker_port is None: + raise ValueError( + "If max_worker_port is set, min_worker_port must also be set." + ) + elif self.max_worker_port != 0: + if self.max_worker_port < 1024 or self.max_worker_port > 65535: + raise ValueError( + "max_worker_port must be 0 or an integer between " + "1024 and 65535." + ) + elif self.max_worker_port <= self.min_worker_port: + raise ValueError( + "max_worker_port must be higher than min_worker_port." + ) + if self.ray_client_server_port is not None: + if not check_ray_client_dependencies_installed(): + raise ValueError( + "Ray Client requires pip package `ray[client]`. " + "If you installed the minimal Ray (e.g. `pip install ray`), " + "please reinstall by executing `pip install ray[client]`." + ) + if ( + self.ray_client_server_port < 1024 + or self.ray_client_server_port > 65535 + ): + raise ValueError( + "ray_client_server_port must be an integer " + "between 1024 and 65535." + ) + if self.runtime_env_agent_port is not None: + if ( + self.runtime_env_agent_port < 1024 + or self.runtime_env_agent_port > 65535 + ): + raise ValueError( + "runtime_env_agent_port must be an integer " + "between 1024 and 65535." + ) + + if self.resources is not None: + + def build_error(resource, alternative): + return ( + f"{self.resources} -> `{resource}` cannot be a " + "custom resource because it is one of the default resources " + f"({ray_constants.DEFAULT_RESOURCES}). " + f"Use `{alternative}` instead. For example, use `ray start " + f"--{alternative.replace('_', '-')}=1` instead of " + f"`ray start --resources={{'{resource}': 1}}`" + ) + + assert "CPU" not in self.resources, build_error("CPU", "num_cpus") + assert "GPU" not in self.resources, build_error("GPU", "num_gpus") + assert "memory" not in self.resources, build_error("memory", "memory") + assert "object_store_memory" not in self.resources, build_error( + "object_store_memory", "object_store_memory" + ) + + if self.redirect_output is not None: + raise DeprecationWarning("The redirect_output argument is deprecated.") + + if self.temp_dir is not None and not os.path.isabs(self.temp_dir): + raise ValueError("temp_dir must be absolute path or None.") + + validate_node_labels(self.labels) + + def _format_ports(self, pre_selected_ports): + """Format the pre-selected ports information to be more human-readable.""" + ports = pre_selected_ports.copy() + + for comp, port_list in ports.items(): + if len(port_list) == 1: + ports[comp] = port_list[0] + elif len(port_list) == 0: + # Nothing is selected, meaning it will be randomly selected. + ports[comp] = "random" + elif comp == "worker_ports": + min_port = port_list[0] + max_port = port_list[len(port_list) - 1] + if len(port_list) < 50: + port_range_str = str(port_list) + else: + port_range_str = f"from {min_port} to {max_port}" + ports[comp] = f"{len(port_list)} ports {port_range_str}" + return ports diff --git a/.venv/lib/python3.11/site-packages/ray/_private/process_watcher.py b/.venv/lib/python3.11/site-packages/ray/_private/process_watcher.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ff3dd2a3aac43f9f290b5eb620445f9a3c5dba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/process_watcher.py @@ -0,0 +1,198 @@ +import asyncio +import io +import logging +import sys +import os + +from concurrent.futures import ThreadPoolExecutor + +import ray +from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD +import ray.dashboard.consts as dashboard_consts +import ray._private.ray_constants as ray_constants +from ray._private.utils import run_background_task + +# Import psutil after ray so the packaged version is used. +import psutil + + +logger = logging.getLogger(__name__) + +# TODO: move all consts from dashboard_consts to ray_constants and rename to remove +# DASHBOARD_ prefixes. + +# Publishes at most this number of lines of Raylet logs, when the Raylet dies +# unexpectedly. +_RAYLET_LOG_MAX_PUBLISH_LINES = 20 + +# Reads at most this amount of Raylet logs from the tail, for publishing and +# checking if the Raylet was terminated gracefully. +_RAYLET_LOG_MAX_TAIL_SIZE = 1 * 1024**2 + +try: + create_task = asyncio.create_task +except AttributeError: + create_task = asyncio.ensure_future + + +def get_raylet_pid(): + # TODO(edoakes): RAY_RAYLET_PID isn't properly set on Windows. This is + # only used for fate-sharing with the raylet and we need a different + # fate-sharing mechanism for Windows anyways. + if sys.platform in ["win32", "cygwin"]: + return None + raylet_pid = int(os.environ["RAY_RAYLET_PID"]) + assert raylet_pid > 0 + logger.info("raylet pid is %s", raylet_pid) + return raylet_pid + + +def create_check_raylet_task(log_dir, gcs_address, parent_dead_callback, loop): + """ + Creates an asyncio task to periodically check if the raylet process is still + running. If raylet is dead for _PARENT_DEATH_THREASHOLD (5) times, prepare to exit + as follows: + + - Write logs about whether the raylet exit is graceful, by looking into the raylet + log and search for term "SIGTERM", + - Flush the logs via GcsPublisher, + - Exit. + """ + if sys.platform in ["win32", "cygwin"]: + raise RuntimeError("can't check raylet process in Windows.") + raylet_pid = get_raylet_pid() + + if dashboard_consts.PARENT_HEALTH_CHECK_BY_PIPE: + logger.info("check_parent_via_pipe") + check_parent_task = _check_parent_via_pipe( + log_dir, gcs_address, loop, parent_dead_callback + ) + else: + logger.info("_check_parent") + check_parent_task = _check_parent( + raylet_pid, log_dir, gcs_address, parent_dead_callback + ) + + return run_background_task(check_parent_task) + + +def report_raylet_error_logs(log_dir: str, gcs_address: str): + log_path = os.path.join(log_dir, "raylet.out") + error = False + msg = "Raylet is terminated. " + try: + with open(log_path, "r", encoding="utf-8") as f: + # Seek to _RAYLET_LOG_MAX_TAIL_SIZE from the end if the + # file is larger than that. + f.seek(0, io.SEEK_END) + pos = max(0, f.tell() - _RAYLET_LOG_MAX_TAIL_SIZE) + f.seek(pos, io.SEEK_SET) + # Read remaining logs by lines. + raylet_logs = f.readlines() + # Assume the SIGTERM message must exist within the last + # _RAYLET_LOG_MAX_TAIL_SIZE of the log file. + if any("Raylet received SIGTERM" in line for line in raylet_logs): + msg += "Termination is graceful." + logger.info(msg) + else: + msg += ( + "Termination is unexpected. Possible reasons " + "include: (1) SIGKILL by the user or system " + "OOM killer, (2) Invalid memory access from " + "Raylet causing SIGSEGV or SIGBUS, " + "(3) Other termination signals. " + f"Last {_RAYLET_LOG_MAX_PUBLISH_LINES} lines " + "of the Raylet logs:\n" + ) + msg += " " + " ".join( + raylet_logs[-_RAYLET_LOG_MAX_PUBLISH_LINES:] + ) + error = True + except Exception as e: + msg += f"Failed to read Raylet logs at {log_path}: {e}!" + logger.exception(msg) + error = True + if error: + logger.error(msg) + # TODO: switch to async if necessary. + ray._private.utils.publish_error_to_driver( + ray_constants.RAYLET_DIED_ERROR, + msg, + gcs_publisher=ray._raylet.GcsPublisher(address=gcs_address), + ) + else: + logger.info(msg) + + +async def _check_parent_via_pipe( + log_dir: str, gcs_address: str, loop, parent_dead_callback +): + while True: + try: + # Read input asynchronously. + # The parent (raylet) should have redirected its pipe + # to stdin. If we read 0 bytes from stdin, it means + # the process is dead. + with ThreadPoolExecutor(max_workers=1) as executor: + input_data = await loop.run_in_executor( + executor, lambda: sys.stdin.readline() + ) + if len(input_data) == 0: + # cannot read bytes from parent == parent is dead. + parent_dead_callback("_check_parent_via_pipe: The parent is dead.") + report_raylet_error_logs(log_dir, gcs_address) + sys.exit(0) + except Exception as e: + logger.exception( + "raylet health checking is failed. " + f"The agent process may leak. Exception: {e}" + ) + + +async def _check_parent(raylet_pid, log_dir, gcs_address, parent_dead_callback): + """Check if raylet is dead and fate-share if it is.""" + try: + curr_proc = psutil.Process() + parent_death_cnt = 0 + while True: + parent = curr_proc.parent() + # If the parent is dead, it is None. + parent_gone = parent is None + init_assigned_for_parent = False + parent_changed = False + + if parent: + # Sometimes, the parent is changed to the `init` process. + # In this case, the parent.pid is 1. + init_assigned_for_parent = parent.pid == 1 + # Sometimes, the parent is dead, and the pid is reused + # by other processes. In this case, this condition is triggered. + parent_changed = raylet_pid != parent.pid + + if parent_gone or init_assigned_for_parent or parent_changed: + parent_death_cnt += 1 + logger.warning( + f"Raylet is considered dead {parent_death_cnt} X. " + f"If it reaches to {_PARENT_DEATH_THREASHOLD}, the agent " + f"will kill itself. Parent: {parent}, " + f"parent_gone: {parent_gone}, " + f"init_assigned_for_parent: {init_assigned_for_parent}, " + f"parent_changed: {parent_changed}." + ) + if parent_death_cnt < _PARENT_DEATH_THREASHOLD: + await asyncio.sleep( + dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S + ) + continue + + parent_dead_callback("_check_parent: The parent is dead.") + report_raylet_error_logs(log_dir, gcs_address) + sys.exit(0) + else: + parent_death_cnt = 0 + await asyncio.sleep( + dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S + ) + except Exception: + logger.exception("Failed to check parent PID, exiting.") + sys.exit(1) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/profiling.py b/.venv/lib/python3.11/site-packages/ray/_private/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..0784fb323314345b9115cd2afd382f089885f4a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/profiling.py @@ -0,0 +1,240 @@ +import os +import json +from collections import defaultdict +from dataclasses import dataclass, asdict +from typing import List, Dict, Union + +import ray + + +class _NullLogSpan: + """A log span context manager that does nothing""" + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + pass + + +PROFILING_ENABLED = "RAY_PROFILING" in os.environ +NULL_LOG_SPAN = _NullLogSpan() + +# Colors are specified at +# https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 +_default_color_mapping = defaultdict( + lambda: "generic_work", + { + "worker_idle": "cq_build_abandoned", + "task": "rail_response", + "task:deserialize_arguments": "rail_load", + "task:execute": "rail_animation", + "task:store_outputs": "rail_idle", + "wait_for_function": "detailed_memory_dump", + "ray.get": "good", + "ray.put": "terrible", + "ray.wait": "vsync_highlight_color", + "submit_task": "background_memory_dump", + "fetch_and_run_function": "detailed_memory_dump", + "register_remote_function": "detailed_memory_dump", + }, +) + + +@dataclass(init=True) +class ChromeTracingCompleteEvent: + # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.lpfof2aylapb # noqa + # The event categories. This is a comma separated list of categories + # for the event. The categories can be used to hide events in + # the Trace Viewer UI. + cat: str + # The string displayed on the event. + name: str + # The identifier for the group of rows that the event + # appears in. + pid: int + # The identifier for the row that the event appears in. + tid: int + # The start time in microseconds. + ts: int + # The duration in microseconds. + dur: int + # This is the name of the color to display the box in. + cname: str + # The extra user-defined data. + args: Dict[str, Union[str, int]] + # The event type (X means the complete event). + ph: str = "X" + + +@dataclass(init=True) +class ChromeTracingMetadataEvent: + # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#bookmark=id.iycbnb4z7i9g # noqa + name: str + # Metadata arguments. E.g., name: + args: Dict[str, str] + # The process id of this event. In Ray, pid indicates the node. + pid: int + # The thread id of this event. In Ray, tid indicates each worker. + tid: int = None + # M means the metadata event. + ph: str = "M" + + +def profile(event_type, extra_data=None): + """Profile a span of time so that it appears in the timeline visualization. + + Note that this only works in the raylet code path. + + This function can be used as follows (both on the driver or within a task). + + .. testcode:: + import ray._private.profiling as profiling + + with profiling.profile("custom event", extra_data={'key': 'val'}): + # Do some computation here. + x = 1 * 2 + + Optionally, a dictionary can be passed as the "extra_data" argument, and + it can have keys "name" and "cname" if you want to override the default + timeline display text and box color. Other values will appear at the bottom + of the chrome tracing GUI when you click on the box corresponding to this + profile span. + + Args: + event_type: A string describing the type of the event. + extra_data: This must be a dictionary mapping strings to strings. This + data will be added to the json objects that are used to populate + the timeline, so if you want to set a particular color, you can + simply set the "cname" attribute to an appropriate color. + Similarly, if you set the "name" attribute, then that will set the + text displayed on the box in the timeline. + + Returns: + An object that can profile a span of time via a "with" statement. + """ + if not PROFILING_ENABLED: + return NULL_LOG_SPAN + worker = ray._private.worker.global_worker + if worker.mode == ray._private.worker.LOCAL_MODE: + return NULL_LOG_SPAN + return worker.core_worker.profile_event(event_type.encode("ascii"), extra_data) + + +def chrome_tracing_dump( + tasks: List[dict], +) -> str: + """Generate a chrome/perfetto tracing dump using task events. + + Args: + tasks: List of tasks generated by a state API list_tasks(detail=True). + + Returns: + Json serialized dump to create a chrome/perfetto tracing. + """ + # All events from given tasks. + all_events = [] + + # Chrome tracing doesn't have a concept of "node". Instead, we use + # chrome tracing's pid == ray's node. + # chrome tracing's tid == ray's process. + # Note that pid or tid is usually integer, but ray's node/process has + # ids in string. + # Unfortunately, perfetto doesn't allow to have string as a value of pid/tid. + # To workaround it, we use Metadata event from chrome tracing schema + # (https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.xqopa5m0e28f) # noqa + # which allows pid/tid -> name mapping. In order to use this schema + # we build node_ip/(node_ip, worker_id) -> arbitrary index mapping. + + # node ip address -> node idx. + node_to_index = {} + # Arbitrary index mapped to the ip address. + node_idx = 0 + # (node index, worker id) -> worker idx + worker_to_index = {} + # Arbitrary index mapped to the (node index, worker id). + worker_idx = 0 + + for task in tasks: + profiling_data = task.get("profiling_data", []) + if profiling_data: + node_ip_address = profiling_data["node_ip_address"] + component_events = profiling_data["events"] + component_type = profiling_data["component_type"] + component_id = component_type + ":" + profiling_data["component_id"] + + if component_type not in ["worker", "driver"]: + continue + + for event in component_events: + extra_data = event["extra_data"] + # Propagate extra data. + extra_data["task_id"] = task["task_id"] + extra_data["job_id"] = task["job_id"] + extra_data["attempt_number"] = task["attempt_number"] + extra_data["func_or_class_name"] = task["func_or_class_name"] + extra_data["actor_id"] = task["actor_id"] + event_name = event["event_name"] + + # build a id -> arbitrary index mapping + if node_ip_address not in node_to_index: + node_to_index[node_ip_address] = node_idx + # Whenever new node ip is introduced, we increment the index. + node_idx += 1 + + if ( + node_to_index[node_ip_address], + component_id, + ) not in worker_to_index: # noqa + worker_to_index[ + (node_to_index[node_ip_address], component_id) + ] = worker_idx # noqa + worker_idx += 1 + + # Modify the name with the additional user-defined extra data. + cname = _default_color_mapping[event["event_name"]] + name = event_name + + if "cname" in extra_data: + cname = _default_color_mapping[event["extra_data"]["cname"]] + if "name" in extra_data: + name = extra_data["name"] + + new_event = ChromeTracingCompleteEvent( + cat=event_name, + name=name, + pid=node_to_index[node_ip_address], + tid=worker_to_index[(node_to_index[node_ip_address], component_id)], + ts=event["start_time"] * 1e3, + dur=(event["end_time"] * 1e3) - (event["start_time"] * 1e3), + cname=cname, + args=extra_data, + ) + all_events.append(asdict(new_event)) + + for node, i in node_to_index.items(): + all_events.append( + asdict( + ChromeTracingMetadataEvent( + name="process_name", + pid=i, + args={"name": f"Node {node}"}, + ) + ) + ) + + for worker, i in worker_to_index.items(): + all_events.append( + asdict( + ChromeTracingMetadataEvent( + name="thread_name", + ph="M", + tid=i, + pid=worker[0], + args={"name": worker[1]}, + ) + ) + ) + + # Handle task event disabled. + return json.dumps(all_events) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/prometheus_exporter.py b/.venv/lib/python3.11/site-packages/ray/_private/prometheus_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..28d09861bee5ae5c275e94dbfdfe8fad9b788c5f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/prometheus_exporter.py @@ -0,0 +1,365 @@ +# NOTE: This file has been copied from OpenCensus Python exporter. +# It is because OpenCensus Prometheus exporter hasn't released for a while +# and the latest version has a compatibility issue with the latest OpenCensus +# library. + +import re + +from prometheus_client import start_http_server +from prometheus_client.core import ( + REGISTRY, + CounterMetricFamily, + GaugeMetricFamily, + HistogramMetricFamily, + UnknownMetricFamily, +) + +from opencensus.common.transports import sync +from opencensus.stats import aggregation_data as aggregation_data_module +from opencensus.stats import base_exporter +import logging + +logger = logging.getLogger(__name__) + + +class Options(object): + """Options contains options for configuring the exporter. + The address can be empty as the prometheus client will + assume it's localhost + :type namespace: str + :param namespace: The prometheus namespace to be used. Defaults to ''. + :type port: int + :param port: The Prometheus port to be used. Defaults to 8000. + :type address: str + :param address: The Prometheus address to be used. Defaults to ''. + :type registry: registry + :param registry: The Prometheus address to be used. Defaults to ''. + :type registry: :class:`~prometheus_client.core.CollectorRegistry` + :param registry: A Prometheus collector registry instance. + """ + + def __init__(self, namespace="", port=8000, address="", registry=REGISTRY): + self._namespace = namespace + self._registry = registry + self._port = int(port) + self._address = address + + @property + def registry(self): + """Prometheus Collector Registry instance""" + return self._registry + + @property + def namespace(self): + """Prefix to be used with view name""" + return self._namespace + + @property + def port(self): + """Port number to listen""" + return self._port + + @property + def address(self): + """Endpoint address (default is localhost)""" + return self._address + + +class Collector(object): + """Collector represents the Prometheus Collector object""" + + def __init__(self, options=Options(), view_name_to_data_map=None): + if view_name_to_data_map is None: + view_name_to_data_map = {} + self._options = options + self._registry = options.registry + self._view_name_to_data_map = view_name_to_data_map + self._registered_views = {} + + @property + def options(self): + """Options to be used to configure the exporter""" + return self._options + + @property + def registry(self): + """Prometheus Collector Registry instance""" + return self._registry + + @property + def view_name_to_data_map(self): + """Map with all view data objects + that will be sent to Prometheus + """ + return self._view_name_to_data_map + + @property + def registered_views(self): + """Map with all registered views""" + return self._registered_views + + def register_view(self, view): + """register_view will create the needed structure + in order to be able to sent all data to Prometheus + """ + v_name = get_view_name(self.options.namespace, view) + + if v_name not in self.registered_views: + desc = { + "name": v_name, + "documentation": view.description, + "labels": list(map(sanitize, view.columns)), + "units": view.measure.unit, + } + self.registered_views[v_name] = desc + + def add_view_data(self, view_data): + """Add view data object to be sent to server""" + self.register_view(view_data.view) + v_name = get_view_name(self.options.namespace, view_data.view) + self.view_name_to_data_map[v_name] = view_data + + # TODO: add start and end timestamp + def to_metric(self, desc, tag_values, agg_data, metrics_map): + """to_metric translate the data that OpenCensus create + to Prometheus format, using Prometheus Metric object + :type desc: dict + :param desc: The map that describes view definition + :type tag_values: tuple of :class: + `~opencensus.tags.tag_value.TagValue` + :param object of opencensus.tags.tag_value.TagValue: + TagValue object used as label values + :type agg_data: object of :class: + `~opencensus.stats.aggregation_data.AggregationData` + :param object of opencensus.stats.aggregation_data.AggregationData: + Aggregated data that needs to be converted as Prometheus samples + :rtype: :class:`~prometheus_client.core.CounterMetricFamily` or + :class:`~prometheus_client.core.HistogramMetricFamily` or + :class:`~prometheus_client.core.UnknownMetricFamily` or + :class:`~prometheus_client.core.GaugeMetricFamily` + """ + metric_name = desc["name"] + metric_description = desc["documentation"] + label_keys = desc["labels"] + metric_units = desc["units"] + assert len(tag_values) == len(label_keys), (tag_values, label_keys) + # Prometheus requires that all tag values be strings hence + # the need to cast none to the empty string before exporting. See + # https://github.com/census-instrumentation/opencensus-python/issues/480 + tag_values = [tv if tv else "" for tv in tag_values] + + if isinstance(agg_data, aggregation_data_module.CountAggregationData): + metric = metrics_map.get(metric_name) + if not metric: + metric = CounterMetricFamily( + name=metric_name, + documentation=metric_description, + unit=metric_units, + labels=label_keys, + ) + metrics_map[metric_name] = metric + metric.add_metric(labels=tag_values, value=agg_data.count_data) + return + + elif isinstance(agg_data, aggregation_data_module.DistributionAggregationData): + + assert agg_data.bounds == sorted(agg_data.bounds) + # buckets are a list of buckets. Each bucket is another list with + # a pair of bucket name and value, or a triple of bucket name, + # value, and exemplar. buckets need to be in order. + buckets = [] + cum_count = 0 # Prometheus buckets expect cumulative count. + for ii, bound in enumerate(agg_data.bounds): + cum_count += agg_data.counts_per_bucket[ii] + bucket = [str(bound), cum_count] + buckets.append(bucket) + # Prometheus requires buckets to be sorted, and +Inf present. + # In OpenCensus we don't have +Inf in the bucket bonds so need to + # append it here. + buckets.append(["+Inf", agg_data.count_data]) + metric = metrics_map.get(metric_name) + if not metric: + metric = HistogramMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics_map[metric_name] = metric + metric.add_metric( + labels=tag_values, + buckets=buckets, + sum_value=agg_data.sum, + ) + return + + elif isinstance(agg_data, aggregation_data_module.SumAggregationData): + metric = metrics_map.get(metric_name) + if not metric: + metric = UnknownMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics_map[metric_name] = metric + metric.add_metric(labels=tag_values, value=agg_data.sum_data) + return + + elif isinstance(agg_data, aggregation_data_module.LastValueAggregationData): + metric = metrics_map.get(metric_name) + if not metric: + metric = GaugeMetricFamily( + name=metric_name, + documentation=metric_description, + labels=label_keys, + ) + metrics_map[metric_name] = metric + metric.add_metric(labels=tag_values, value=agg_data.value) + return + + else: + raise ValueError(f"unsupported aggregation type {type(agg_data)}") + + def collect(self): # pragma: NO COVER + """Collect fetches the statistics from OpenCensus + and delivers them as Prometheus Metrics. + Collect is invoked every time a prometheus.Gatherer is run + for example when the HTTP endpoint is invoked by Prometheus. + """ + # Make a shallow copy of self._view_name_to_data_map, to avoid seeing + # concurrent modifications when iterating through the dictionary. + metrics_map = {} + for v_name, view_data in self._view_name_to_data_map.copy().items(): + if v_name not in self.registered_views: + continue + desc = self.registered_views[v_name] + for tag_values in view_data.tag_value_aggregation_data_map: + agg_data = view_data.tag_value_aggregation_data_map[tag_values] + self.to_metric(desc, tag_values, agg_data, metrics_map) + + for metric in metrics_map.values(): + yield metric + + +class PrometheusStatsExporter(base_exporter.StatsExporter): + """Exporter exports stats to Prometheus, users need + to register the exporter as an HTTP Handler to be + able to export. + :type options: + :class:`~opencensus.ext.prometheus.stats_exporter.Options` + :param options: An options object with the parameters to instantiate the + prometheus exporter. + :type gatherer: :class:`~prometheus_client.core.CollectorRegistry` + :param gatherer: A Prometheus collector registry instance. + :type transport: + :class:`opencensus.common.transports.sync.SyncTransport` or + :class:`opencensus.common.transports.async_.AsyncTransport` + :param transport: An instance of a Transpor to send data with. + :type collector: + :class:`~opencensus.ext.prometheus.stats_exporter.Collector` + :param collector: An instance of the Prometheus Collector object. + """ + + def __init__( + self, options, gatherer, transport=sync.SyncTransport, collector=Collector() + ): + self._options = options + self._gatherer = gatherer + self._collector = collector + self._transport = transport(self) + self.serve_http() + REGISTRY.register(self._collector) + + @property + def transport(self): + """The transport way to be sent data to server + (default is sync). + """ + return self._transport + + @property + def collector(self): + """Collector class instance to be used + to communicate with Prometheus + """ + return self._collector + + @property + def gatherer(self): + """Prometheus Collector Registry instance""" + return self._gatherer + + @property + def options(self): + """Options to be used to configure the exporter""" + return self._options + + def export(self, view_data): + """export send the data to the transport class + in order to be sent to Prometheus in a sync or async way. + """ + if view_data is not None: # pragma: NO COVER + self.transport.export(view_data) + + def on_register_view(self, view): + return NotImplementedError("Not supported by Prometheus") + + def emit(self, view_data): # pragma: NO COVER + """Emit exports to the Prometheus if view data has one or more rows. + Each OpenCensus AggregationData will be converted to + corresponding Prometheus Metric: SumData will be converted + to Untyped Metric, CountData will be a Counter Metric + DistributionData will be a Histogram Metric. + """ + + for v_data in view_data: + if v_data.tag_value_aggregation_data_map is None: + v_data.tag_value_aggregation_data_map = {} + + self.collector.add_view_data(v_data) + + def serve_http(self): + """serve_http serves the Prometheus endpoint.""" + address = str(self.options.address) + kwargs = {"addr": address} if address else {} + start_http_server(port=self.options.port, **kwargs) + + +def new_stats_exporter(option): + """new_stats_exporter returns an exporter + that exports stats to Prometheus. + """ + if option.namespace == "": + raise ValueError("Namespace can not be empty string.") + + collector = new_collector(option) + + exporter = PrometheusStatsExporter( + options=option, gatherer=option.registry, collector=collector + ) + return exporter + + +def new_collector(options): + """new_collector should be used + to create instance of Collector class in order to + prevent the usage of constructor directly + """ + return Collector(options=options) + + +def get_view_name(namespace, view): + """create the name for the view""" + name = "" + if namespace != "": + name = namespace + "_" + return sanitize(name + view.name) + + +_NON_LETTERS_NOR_DIGITS_RE = re.compile(r"[^\w]", re.UNICODE | re.IGNORECASE) + + +def sanitize(key): + """sanitize the given metric name or label according to Prometheus rule. + Replace all characters other than [A-Za-z0-9_] with '_'. + """ + return _NON_LETTERS_NOR_DIGITS_RE.sub("_", key) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/protobuf_compat.py b/.venv/lib/python3.11/site-packages/ray/_private/protobuf_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..d394f3d0191251d3bc0836c18aacf6ed4313dc66 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/protobuf_compat.py @@ -0,0 +1,46 @@ +from google.protobuf.json_format import MessageToDict +import inspect + +""" +This module provides a compatibility layer for different versions of the protobuf +library. +""" + +_protobuf_has_old_arg_name_cached = None + + +def _protobuf_has_old_arg_name(): + """Cache the inspect result to avoid doing it for every single message.""" + global _protobuf_has_old_arg_name_cached + if _protobuf_has_old_arg_name_cached is None: + params = inspect.signature(MessageToDict).parameters + _protobuf_has_old_arg_name_cached = "including_default_value_fields" in params + return _protobuf_has_old_arg_name_cached + + +def rename_always_print_fields_with_no_presence(kwargs): + """ + Protobuf version 5.26.0rc2 renamed argument for `MessageToDict`: + `including_default_value_fields` -> `always_print_fields_with_no_presence`. + See https://github.com/protocolbuffers/protobuf/commit/06e7caba58ede0220b110b89d08f329e5f8a7537#diff-8de817c14d6a087981503c9aea38730b1b3e98f4e306db5ff9d525c7c304f234L129 # noqa: E501 + + We choose to always use the new argument name. If user used the old arg, we raise an + error. + + If protobuf does not have the new arg name but have the old arg name, we rename our + arg to the old one. + """ + old_arg_name = "including_default_value_fields" + new_arg_name = "always_print_fields_with_no_presence" + if old_arg_name in kwargs: + raise ValueError(f"{old_arg_name} is deprecated, please use {new_arg_name}") + + if new_arg_name in kwargs and _protobuf_has_old_arg_name(): + kwargs[old_arg_name] = kwargs.pop(new_arg_name) + + return kwargs + + +def message_to_dict(*args, **kwargs): + kwargs = rename_always_print_fields_with_no_presence(kwargs) + return MessageToDict(*args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/pydantic_compat.py b/.venv/lib/python3.11/site-packages/ray/_private/pydantic_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..b405e64ffa8ff77aa3eeb694961bbe141e3a6dd3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/pydantic_compat.py @@ -0,0 +1,108 @@ +# ruff: noqa +import packaging.version + +# Pydantic is a dependency of `ray["default"]` but not the minimal installation, +# so handle the case where it isn't installed. +try: + import pydantic + + PYDANTIC_INSTALLED = True +except ImportError: + pydantic = None + PYDANTIC_INSTALLED = False + + +if not PYDANTIC_INSTALLED: + IS_PYDANTIC_2 = False + BaseModel = None + Extra = None + Field = None + NonNegativeFloat = None + NonNegativeInt = None + PositiveFloat = None + PositiveInt = None + PrivateAttr = None + StrictInt = None + ValidationError = None + root_validator = None + validator = None + is_subclass_of_base_model = lambda obj: False +# In pydantic <1.9.0, __version__ attribute is missing, issue ref: +# https://github.com/pydantic/pydantic/issues/2572, so we need to check +# the existence prior to comparison. +elif not hasattr(pydantic, "__version__") or packaging.version.parse( + pydantic.__version__ +) < packaging.version.parse("2.0"): + IS_PYDANTIC_2 = False + from pydantic import ( + BaseModel, + Extra, + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + PrivateAttr, + StrictInt, + ValidationError, + root_validator, + validator, + ) + + def is_subclass_of_base_model(obj): + return issubclass(obj, BaseModel) + +else: + IS_PYDANTIC_2 = True + from pydantic.v1 import ( + BaseModel, + Extra, + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + PrivateAttr, + StrictInt, + ValidationError, + root_validator, + validator, + ) + + def is_subclass_of_base_model(obj): + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + return issubclass(obj, BaseModelV1) or issubclass(obj, BaseModelV2) + + +def register_pydantic_serializers(serialization_context): + if not PYDANTIC_INSTALLED: + return + + if IS_PYDANTIC_2: + # TODO(edoakes): compare against the version that has the fixes. + from pydantic.v1.fields import ModelField + else: + from pydantic.fields import ModelField + + # Pydantic's Cython validators are not serializable. + # https://github.com/cloudpipe/cloudpickle/issues/408 + serialization_context._register_cloudpickle_serializer( + ModelField, + custom_serializer=lambda o: { + "name": o.name, + # outer_type_ is the original type for ModelFields, + # while type_ can be updated later with the nested type + # like int for List[int]. + "type_": o.outer_type_, + "class_validators": o.class_validators, + "model_config": o.model_config, + "default": o.default, + "default_factory": o.default_factory, + "required": o.required, + "alias": o.alias, + "field_info": o.field_info, + }, + custom_deserializer=lambda kwargs: ModelField(**kwargs), + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_client_microbenchmark.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_client_microbenchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f414d8cfee32de932afd7d6c561149f5034fbfbb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_client_microbenchmark.py @@ -0,0 +1,117 @@ +import inspect +import logging +import numpy as np +import sys + +from ray.util.client.ray_client_helpers import ray_start_client_server + +from ray._private.ray_microbenchmark_helpers import timeit + + +def benchmark_get_calls(ray, results): + value = ray.put(0) + + def get_small(): + ray.get(value) + + results += timeit("client: get calls", get_small) + + +def benchmark_tasks_and_get_batch(ray, results): + @ray.remote + def small_value(): + return b"ok" + + def small_value_batch(): + submitted = [small_value.remote() for _ in range(1000)] + ray.get(submitted) + return 0 + + results += timeit("client: tasks and get batch", small_value_batch) + + +def benchmark_put_calls(ray, results): + def put_small(): + ray.put(0) + + results += timeit("client: put calls", put_small) + + +def benchmark_remote_put_calls(ray, results): + @ray.remote + def do_put_small(): + for _ in range(100): + ray.put(0) + + def put_multi_small(): + ray.get([do_put_small.remote() for _ in range(10)]) + + results += timeit("client: tasks and put batch", put_multi_small, 1000) + + +def benchmark_put_large(ray, results): + arr = np.zeros(100 * 1024 * 1024, dtype=np.int64) + + def put_large(): + ray.put(arr) + + results += timeit("client: put gigabytes", put_large, 8 * 0.1) + + +def benchmark_simple_actor(ray, results): + @ray.remote(num_cpus=0) + class Actor: + def small_value(self): + return b"ok" + + def small_value_arg(self, x): + return b"ok" + + def small_value_batch(self, n): + ray.get([self.small_value.remote() for _ in range(n)]) + + a = Actor.remote() + + def actor_sync(): + ray.get(a.small_value.remote()) + + results += timeit("client: 1:1 actor calls sync", actor_sync) + + def actor_async(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + results += timeit("client: 1:1 actor calls async", actor_async, 1000) + + a = Actor.options(max_concurrency=16).remote() + + def actor_concurrent(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + results += timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000) + + +def main(results=None): + results = results or [] + + ray_config = {"logging_level": logging.WARNING} + + def ray_connect_handler(job_config=None, **ray_init_kwargs): + from ray._private.client_mode_hook import disable_client_hook + + with disable_client_hook(): + import ray as real_ray + + if not real_ray.is_initialized(): + real_ray.init(**ray_config) + + for name, obj in inspect.getmembers(sys.modules[__name__]): + if not name.startswith("benchmark_"): + continue + with ray_start_client_server(ray_connect_handler=ray_connect_handler) as ray: + obj(ray, results) + + return results + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_cluster_perf.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_cluster_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..82abcb4bbeadabc76dbb033d0eba29d3d155b0bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_cluster_perf.py @@ -0,0 +1,50 @@ +"""This is the script for `ray clusterbenchmark`.""" + +import time +import numpy as np +import ray + +from ray.cluster_utils import Cluster + + +def main(): + cluster = Cluster( + initialize_head=True, + connect=True, + head_node_args={"object_store_memory": 20 * 1024 * 1024 * 1024, "num_cpus": 16}, + ) + cluster.add_node( + object_store_memory=20 * 1024 * 1024 * 1024, num_gpus=1, num_cpus=16 + ) + + object_ref_list = [] + for i in range(0, 10): + object_ref = ray.put(np.random.rand(1024 * 128, 1024)) + object_ref_list.append(object_ref) + + @ray.remote(num_gpus=1) + def f(object_ref_list): + diffs = [] + for object_ref in object_ref_list: + before = time.time() + ray.get(object_ref) + after = time.time() + diffs.append(after - before) + time.sleep(1) + return np.mean(diffs), np.std(diffs) + + time_diff, time_diff_std = ray.get(f.remote(object_ref_list)) + + print( + "latency to get an 1G object over network", + round(time_diff, 2), + "+-", + round(time_diff_std, 2), + ) + + ray.shutdown() + cluster.shutdown() + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_constants.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..166fc42791b77d2bc23907b8054f88804627e143 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_constants.py @@ -0,0 +1,554 @@ +"""Ray constants used in the Python code.""" + +import logging +import os +import sys +import json + +logger = logging.getLogger(__name__) + + +def env_integer(key, default): + if key in os.environ: + value = os.environ[key] + if value.isdigit(): + return int(os.environ[key]) + + logger.debug( + f"Found {key} in environment, but value must " + f"be an integer. Got: {value}. Returning " + f"provided default {default}." + ) + return default + return default + + +def env_float(key, default): + if key in os.environ: + value = os.environ[key] + try: + return float(value) + except ValueError: + logger.debug( + f"Found {key} in environment, but value must " + f"be a float. Got: {value}. Returning " + f"provided default {default}." + ) + return default + return default + + +def env_bool(key, default): + if key in os.environ: + return ( + True + if os.environ[key].lower() == "true" or os.environ[key] == "1" + else False + ) + return default + + +def env_set_by_user(key): + return key in os.environ + + +# Whether event logging to driver is enabled. Set to 0 to disable. +AUTOSCALER_EVENTS = env_integer("RAY_SCHEDULER_EVENTS", 1) + +RAY_LOG_TO_DRIVER = env_bool("RAY_LOG_TO_DRIVER", True) + +# Filter level under which events will be filtered out, i.e. not printing to driver +RAY_LOG_TO_DRIVER_EVENT_LEVEL = os.environ.get("RAY_LOG_TO_DRIVER_EVENT_LEVEL", "INFO") + +# Internal kv keys for storing monitor debug status. +DEBUG_AUTOSCALING_ERROR = "__autoscaling_error" +DEBUG_AUTOSCALING_STATUS = "__autoscaling_status" +DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy" + +ID_SIZE = 28 + +# The default maximum number of bytes to allocate to the object store unless +# overridden by the user. +DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = env_integer( + "RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES", 200 * 10**9 # 200 GB +) +# The default proportion of available memory allocated to the object store +DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = env_float( + "RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION", 0.3 +) +# The smallest cap on the memory used by the object store that we allow. +# This must be greater than MEMORY_RESOURCE_UNIT_BYTES +OBJECT_STORE_MINIMUM_MEMORY_BYTES = 75 * 1024 * 1024 +# Each ObjectRef currently uses about 3KB of caller memory. +CALLER_MEMORY_USAGE_PER_OBJECT_REF = 3000 +# Match max_direct_call_object_size in +# src/ray/common/ray_config_def.h. +# TODO(swang): Ideally this should be pulled directly from the +# config in case the user overrides it. +DEFAULT_MAX_DIRECT_CALL_OBJECT_SIZE = 100 * 1024 +# The default maximum number of bytes that the non-primary Redis shards are +# allowed to use unless overridden by the user. +DEFAULT_REDIS_MAX_MEMORY_BYTES = 10**10 +# The smallest cap on the memory used by Redis that we allow. +REDIS_MINIMUM_MEMORY_BYTES = 10**7 +# Above this number of bytes, raise an error by default unless the user sets +# RAY_ALLOW_SLOW_STORAGE=1. This avoids swapping with large object stores. +REQUIRE_SHM_SIZE_THRESHOLD = 10**10 +# Mac with 16GB memory has degraded performance when the object store size is +# greater than 2GB. +# (see https://github.com/ray-project/ray/issues/20388 for details) +# The workaround here is to limit capacity to 2GB for Mac by default, +# and raise error if the capacity is overwritten by user. +MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT = 2 * 2**30 +# If a user does not specify a port for the primary Ray service, +# we attempt to start the service running at this port. +DEFAULT_PORT = 6379 + +RAY_ADDRESS_ENVIRONMENT_VARIABLE = "RAY_ADDRESS" +RAY_NAMESPACE_ENVIRONMENT_VARIABLE = "RAY_NAMESPACE" +RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE = "RAY_RUNTIME_ENV" +RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR = ( + "RAY_RUNTIME_ENV_TEMPORARY_REFERENCE_EXPIRATION_S" +) +# Ray populates this env var to the working dir in the creation of a runtime env. +# For example, `pip` and `conda` users can use this environment variable to locate the +# `requirements.txt` file. +RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR = "RAY_RUNTIME_ENV_CREATE_WORKING_DIR" +# Defaults to 10 minutes. This should be longer than the total time it takes for +# the local working_dir and py_modules to be uploaded, or these files might get +# garbage collected before the job starts. +RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT = 10 * 60 +# If set to 1, then `.gitignore` files will not be parsed and loaded into "excludes" +# when using a local working_dir or py_modules. +RAY_RUNTIME_ENV_IGNORE_GITIGNORE = "RAY_RUNTIME_ENV_IGNORE_GITIGNORE" +RAY_STORAGE_ENVIRONMENT_VARIABLE = "RAY_STORAGE" +# Hook for running a user-specified runtime-env hook. This hook will be called +# unconditionally given the runtime_env dict passed for ray.init. It must return +# a rewritten runtime_env dict. Example: "your.module.runtime_env_hook". +RAY_RUNTIME_ENV_HOOK = "RAY_RUNTIME_ENV_HOOK" +# Hook that is invoked on `ray start`. It will be given the cluster parameters and +# whether we are the head node as arguments. The function can modify the params class, +# but otherwise returns void. Example: "your.module.ray_start_hook". +RAY_START_HOOK = "RAY_START_HOOK" +# Hook that is invoked on `ray job submit`. It will be given all the same args as the +# job.cli.submit() function gets, passed as kwargs to this function. +RAY_JOB_SUBMIT_HOOK = "RAY_JOB_SUBMIT_HOOK" +# Headers to pass when using the Job CLI. It will be given to +# instantiate a Job SubmissionClient. +RAY_JOB_HEADERS = "RAY_JOB_HEADERS" + +DEFAULT_DASHBOARD_IP = "127.0.0.1" +DEFAULT_DASHBOARD_PORT = 8265 +DASHBOARD_ADDRESS = "dashboard" +PROMETHEUS_SERVICE_DISCOVERY_FILE = "prom_metrics_service_discovery.json" +DEFAULT_DASHBOARD_AGENT_LISTEN_PORT = 52365 +# Default resource requirements for actors when no resource requirements are +# specified. +DEFAULT_ACTOR_METHOD_CPU_SIMPLE = 1 +DEFAULT_ACTOR_CREATION_CPU_SIMPLE = 0 +# Default resource requirements for actors when some resource requirements are +# specified in . +DEFAULT_ACTOR_METHOD_CPU_SPECIFIED = 0 +DEFAULT_ACTOR_CREATION_CPU_SPECIFIED = 1 +# Default number of return values for each actor method. +DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 + +# Wait 30 seconds for client to reconnect after unexpected disconnection +DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD = 30 + +# If a remote function or actor (or some other export) has serialized size +# greater than this quantity, print an warning. +FUNCTION_SIZE_WARN_THRESHOLD = 10**7 +FUNCTION_SIZE_ERROR_THRESHOLD = env_integer("FUNCTION_SIZE_ERROR_THRESHOLD", (10**8)) + +# If remote functions with the same source are imported this many times, then +# print a warning. +DUPLICATE_REMOTE_FUNCTION_THRESHOLD = 100 + +# The maximum resource quantity that is allowed. TODO(rkn): This could be +# relaxed, but the current implementation of the node manager will be slower +# for large resource quantities due to bookkeeping of specific resource IDs. +MAX_RESOURCE_QUANTITY = 100e12 + +# Number of units 1 resource can be subdivided into. +MIN_RESOURCE_GRANULARITY = 0.0001 + +# Set this environment variable to populate the dashboard URL with +# an external hosted Ray dashboard URL (e.g. because the +# dashboard is behind a proxy or load balancer). This only overrides +# the dashboard URL when returning or printing to a user through a public +# API, but not in the internal KV store. +RAY_OVERRIDE_DASHBOARD_URL = "RAY_OVERRIDE_DASHBOARD_URL" + + +# Different types of Ray errors that can be pushed to the driver. +# TODO(rkn): These should be defined in flatbuffers and must be synced with +# the existing C++ definitions. +PICKLING_LARGE_OBJECT_PUSH_ERROR = "pickling_large_object" +WAIT_FOR_FUNCTION_PUSH_ERROR = "wait_for_function" +VERSION_MISMATCH_PUSH_ERROR = "version_mismatch" +WORKER_CRASH_PUSH_ERROR = "worker_crash" +WORKER_DIED_PUSH_ERROR = "worker_died" +WORKER_POOL_LARGE_ERROR = "worker_pool_large" +PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction" +RESOURCE_DEADLOCK_ERROR = "resource_deadlock" +REMOVED_NODE_ERROR = "node_removed" +MONITOR_DIED_ERROR = "monitor_died" +LOG_MONITOR_DIED_ERROR = "log_monitor_died" +DASHBOARD_AGENT_DIED_ERROR = "dashboard_agent_died" +DASHBOARD_DIED_ERROR = "dashboard_died" +RAYLET_DIED_ERROR = "raylet_died" +DETACHED_ACTOR_ANONYMOUS_NAMESPACE_ERROR = "detached_actor_anonymous_namespace" +EXCESS_QUEUEING_WARNING = "excess_queueing_warning" + +# Used in gpu detection +RESOURCE_CONSTRAINT_PREFIX = "accelerator_type:" + +# Used by autoscaler to set the node custom resources and labels +# from cluster.yaml. +RESOURCES_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_RESOURCES" +LABELS_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_LABELS" + +# Temporary flag to disable log processing in the dashboard. This is useful +# if the dashboard is overloaded by logs and failing to process other +# dashboard API requests (e.g. Job Submission). +DISABLE_DASHBOARD_LOG_INFO = env_integer("RAY_DISABLE_DASHBOARD_LOG_INFO", 0) + +LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" +LOGGER_FORMAT_ESCAPE = json.dumps(LOGGER_FORMAT.replace("%", "%%")) +LOGGER_FORMAT_HELP = f"The logging format. default={LOGGER_FORMAT_ESCAPE}" +# Configure the default logging levels for various Ray components. +# TODO (kevin85421): Currently, I don't encourage Ray users to configure +# `RAY_LOGGER_LEVEL` until its scope and expected behavior are clear and +# easy to understand. Now, only Ray developers should use it. +LOGGER_LEVEL = os.environ.get("RAY_LOGGER_LEVEL", "info") +LOGGER_LEVEL_CHOICES = ["debug", "info", "warning", "error", "critical"] +LOGGER_LEVEL_HELP = ( + "The logging level threshold, choices=['debug', 'info'," + " 'warning', 'error', 'critical'], default='info'" +) + +LOGGING_ROTATE_BYTES = 512 * 1024 * 1024 # 512MB. +LOGGING_ROTATE_BACKUP_COUNT = 5 # 5 Backup files at max. + +LOGGING_REDIRECT_STDERR_ENVIRONMENT_VARIABLE = "RAY_LOG_TO_STDERR" +# Logging format when logging stderr. This should be formatted with the +# component before setting the formatter, e.g. via +# format = LOGGER_FORMAT_STDERR.format(component="dashboard") +# handler.setFormatter(logging.Formatter(format)) +LOGGER_FORMAT_STDERR = ( + "%(asctime)s\t%(levelname)s ({component}) %(filename)s:%(lineno)s -- %(message)s" +) + +# Constants used to define the different process types. +PROCESS_TYPE_REAPER = "reaper" +PROCESS_TYPE_MONITOR = "monitor" +PROCESS_TYPE_RAY_CLIENT_SERVER = "ray_client_server" +PROCESS_TYPE_LOG_MONITOR = "log_monitor" +# TODO(sang): Delete it. +PROCESS_TYPE_REPORTER = "reporter" +PROCESS_TYPE_DASHBOARD = "dashboard" +PROCESS_TYPE_DASHBOARD_AGENT = "dashboard_agent" +PROCESS_TYPE_RUNTIME_ENV_AGENT = "runtime_env_agent" +PROCESS_TYPE_WORKER = "worker" +PROCESS_TYPE_RAYLET = "raylet" +PROCESS_TYPE_REDIS_SERVER = "redis_server" +PROCESS_TYPE_WEB_UI = "web_ui" +PROCESS_TYPE_GCS_SERVER = "gcs_server" +PROCESS_TYPE_PYTHON_CORE_WORKER_DRIVER = "python-core-driver" +PROCESS_TYPE_PYTHON_CORE_WORKER = "python-core-worker" + +# Log file names +MONITOR_LOG_FILE_NAME = f"{PROCESS_TYPE_MONITOR}.log" +LOG_MONITOR_LOG_FILE_NAME = f"{PROCESS_TYPE_LOG_MONITOR}.log" + +# Enable log deduplication. +RAY_DEDUP_LOGS = env_bool("RAY_DEDUP_LOGS", True) + +# How many seconds of messages to buffer for log deduplication. +RAY_DEDUP_LOGS_AGG_WINDOW_S = env_integer("RAY_DEDUP_LOGS_AGG_WINDOW_S", 5) + +# Regex for log messages to never deduplicate, or None. This takes precedence over +# the skip regex below. A default pattern is set for testing. +TESTING_NEVER_DEDUP_TOKEN = "__ray_testing_never_deduplicate__" +RAY_DEDUP_LOGS_ALLOW_REGEX = os.environ.get( + "RAY_DEDUP_LOGS_ALLOW_REGEX", TESTING_NEVER_DEDUP_TOKEN +) + +# Regex for log messages to always skip / suppress, or None. +RAY_DEDUP_LOGS_SKIP_REGEX = os.environ.get("RAY_DEDUP_LOGS_SKIP_REGEX") + +WORKER_PROCESS_TYPE_IDLE_WORKER = "ray::IDLE" +WORKER_PROCESS_TYPE_SPILL_WORKER_NAME = "SpillWorker" +WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME = "RestoreWorker" +WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE = ( + f"ray::IDLE_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}" +) +WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE = ( + f"ray::IDLE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}" +) +WORKER_PROCESS_TYPE_SPILL_WORKER = f"ray::SPILL_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}" +WORKER_PROCESS_TYPE_RESTORE_WORKER = ( + f"ray::RESTORE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}" +) +WORKER_PROCESS_TYPE_SPILL_WORKER_DELETE = ( + f"ray::DELETE_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}" +) +WORKER_PROCESS_TYPE_RESTORE_WORKER_DELETE = ( + f"ray::DELETE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}" +) + +# The number of files the log monitor will open. If more files exist, they will +# be ignored. +LOG_MONITOR_MAX_OPEN_FILES = int( + os.environ.get("RAY_LOG_MONITOR_MAX_OPEN_FILES", "200") +) + +# The maximum batch of lines to be read in a single iteration. We _always_ try +# to read this number of lines even if there aren't any new lines. +LOG_MONITOR_NUM_LINES_TO_READ = int( + os.environ.get("RAY_LOG_MONITOR_NUM_LINES_TO_READ", "1000") +) + +# Autoscaler events are denoted by the ":event_summary:" magic token. +LOG_PREFIX_EVENT_SUMMARY = ":event_summary:" +# Cluster-level info events are denoted by the ":info_message:" magic token. These may +# be emitted in the stderr of Ray components. +LOG_PREFIX_INFO_MESSAGE = ":info_message:" +# Actor names are recorded in the logs with this magic token as a prefix. +LOG_PREFIX_ACTOR_NAME = ":actor_name:" +# Task names are recorded in the logs with this magic token as a prefix. +LOG_PREFIX_TASK_NAME = ":task_name:" +# Job ids are recorded in the logs with this magic token as a prefix. +LOG_PREFIX_JOB_ID = ":job_id:" + +# The object metadata field uses the following format: It is a comma +# separated list of fields. The first field is mandatory and is the +# type of the object (see types below) or an integer, which is interpreted +# as an error value. The second part is optional and if present has the +# form DEBUG:, it is used for implementing the debugger. + +# A constant used as object metadata to indicate the object is cross language. +OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG" +# A constant used as object metadata to indicate the object is python specific. +OBJECT_METADATA_TYPE_PYTHON = b"PYTHON" +# A constant used as object metadata to indicate the object is raw bytes. +OBJECT_METADATA_TYPE_RAW = b"RAW" + +# A constant used as object metadata to indicate the object is an actor handle. +# This value should be synchronized with the Java definition in +# ObjectSerializer.java +# TODO(fyrestone): Serialize the ActorHandle via the custom type feature +# of XLANG. +OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE" + +# A constant indicating the debugging part of the metadata (see above). +OBJECT_METADATA_DEBUG_PREFIX = b"DEBUG:" + +AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request" + +REDIS_DEFAULT_USERNAME = "" + +REDIS_DEFAULT_PASSWORD = "" + +# The default ip address to bind to. +NODE_DEFAULT_IP = "127.0.0.1" + +# The Mach kernel page size in bytes. +MACH_PAGE_SIZE_BYTES = 4096 + +# The max number of bytes for task execution error message. +MAX_APPLICATION_ERROR_LEN = 500 + +# Max 64 bit integer value, which is needed to ensure against overflow +# in C++ when passing integer values cross-language. +MAX_INT64_VALUE = 9223372036854775807 + +# Object Spilling related constants +DEFAULT_OBJECT_PREFIX = "ray_spilled_objects" + +GCS_PORT_ENVIRONMENT_VARIABLE = "RAY_GCS_SERVER_PORT" + +HEALTHCHECK_EXPIRATION_S = os.environ.get("RAY_HEALTHCHECK_EXPIRATION_S", 10) + +# Filename of "shim process" that sets up Python worker environment. +# Should be kept in sync with kSetupWorkerFilename in +# src/ray/common/constants.h. +SETUP_WORKER_FILENAME = "setup_worker.py" + +# Directory name where runtime_env resources will be created & cached. +DEFAULT_RUNTIME_ENV_DIR_NAME = "runtime_resources" + +# The timeout seconds for the creation of runtime env, +# dafault timeout is 10 minutes +DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS = 600 + +# Used to separate lines when formatting the call stack where an ObjectRef was +# created. +CALL_STACK_LINE_DELIMITER = " | " + +# The default gRPC max message size is 4 MiB, we use a larger number of 250 MiB +# NOTE: This is equal to the C++ limit of (RAY_CONFIG::max_grpc_message_size) +GRPC_CPP_MAX_MESSAGE_SIZE = 250 * 1024 * 1024 + +# The gRPC send & receive max length for "dashboard agent" server. +# NOTE: This is equal to the C++ limit of RayConfig::max_grpc_message_size +# and HAVE TO STAY IN SYNC with it (ie, meaning that both of these values +# have to be set at the same time) +AGENT_GRPC_MAX_MESSAGE_LENGTH = env_integer( + "AGENT_GRPC_MAX_MESSAGE_LENGTH", 20 * 1024 * 1024 # 20MB +) + + +# GRPC options +GRPC_ENABLE_HTTP_PROXY = ( + 1 + if os.environ.get("RAY_grpc_enable_http_proxy", "0").lower() in ("1", "true") + else 0 +) +GLOBAL_GRPC_OPTIONS = (("grpc.enable_http_proxy", GRPC_ENABLE_HTTP_PROXY),) + +# Internal kv namespaces +KV_NAMESPACE_DASHBOARD = b"dashboard" +KV_NAMESPACE_SESSION = b"session" +KV_NAMESPACE_TRACING = b"tracing" +KV_NAMESPACE_PDB = b"ray_pdb" +KV_NAMESPACE_HEALTHCHECK = b"healthcheck" +KV_NAMESPACE_JOB = b"job" +KV_NAMESPACE_CLUSTER = b"cluster" +KV_HEAD_NODE_ID_KEY = b"head_node_id" +# TODO: Set package for runtime env +# We need to update ray client for this since runtime env use ray client +# This might introduce some compatibility issues so leave it here for now. +KV_NAMESPACE_PACKAGE = None +KV_NAMESPACE_SERVE = b"serve" +KV_NAMESPACE_FUNCTION_TABLE = b"fun" + +LANGUAGE_WORKER_TYPES = ["python", "java", "cpp"] + +# Accelerator constants +NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" + +CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES" +ROCR_VISIBLE_DEVICES_ENV_VAR = "ROCR_VISIBLE_DEVICES" +NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES" +TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS" +NPU_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES" + +NEURON_CORES = "neuron_cores" +GPU = "GPU" +TPU = "TPU" +NPU = "NPU" +HPU = "HPU" + + +RAY_WORKER_NICENESS = "RAY_worker_niceness" + +# Default max_retries option in @ray.remote for non-actor +# tasks. +DEFAULT_TASK_MAX_RETRIES = 3 + +# Default max_concurrency option in @ray.remote for threaded actors. +DEFAULT_MAX_CONCURRENCY_THREADED = 1 + +# Default max_concurrency option in @ray.remote for async actors. +DEFAULT_MAX_CONCURRENCY_ASYNC = 1000 + +# Prefix for namespaces which are used internally by ray. +# Jobs within these namespaces should be hidden from users +# and should not be considered user activity. +# Please keep this in sync with the definition kRayInternalNamespacePrefix +# in /src/ray/gcs/gcs_server/gcs_job_manager.h. +RAY_INTERNAL_NAMESPACE_PREFIX = "_ray_internal_" +RAY_INTERNAL_DASHBOARD_NAMESPACE = f"{RAY_INTERNAL_NAMESPACE_PREFIX}dashboard" + +# Ray internal flags. These flags should not be set by users, and we strip them on job +# submission. +# This should be consistent with src/ray/common/ray_internal_flag_def.h +RAY_INTERNAL_FLAGS = [ + "RAY_JOB_ID", + "RAY_RAYLET_PID", + "RAY_OVERRIDE_NODE_ID_FOR_TESTING", +] + + +def gcs_actor_scheduling_enabled(): + return os.environ.get("RAY_gcs_actor_scheduling_enabled") == "true" + + +DEFAULT_RESOURCES = {"CPU", "GPU", "memory", "object_store_memory"} + +# Supported Python versions for runtime env's "conda" field. Ray downloads +# Ray wheels into the conda environment, so the Ray wheels for these Python +# versions must be available online. +RUNTIME_ENV_CONDA_PY_VERSIONS = [(3, 9), (3, 10), (3, 11), (3, 12)] + +# Whether to enable Ray clusters (in addition to local Ray). +# Ray clusters are not explicitly supported for Windows and OSX. +IS_WINDOWS_OR_OSX = sys.platform == "darwin" or sys.platform == "win32" +ENABLE_RAY_CLUSTERS_ENV_VAR = "RAY_ENABLE_WINDOWS_OR_OSX_CLUSTER" +ENABLE_RAY_CLUSTER = env_bool( + ENABLE_RAY_CLUSTERS_ENV_VAR, + not IS_WINDOWS_OR_OSX, +) + +SESSION_LATEST = "session_latest" +NUM_PORT_RETRIES = 40 +NUM_REDIS_GET_RETRIES = int(os.environ.get("RAY_NUM_REDIS_GET_RETRIES", "20")) + +# The allowed cached ports in Ray. Refer to Port configuration for more details: +# https://docs.ray.io/en/latest/ray-core/configure.html#ports-configurations +RAY_ALLOWED_CACHED_PORTS = { + "metrics_agent_port", + "metrics_export_port", + "dashboard_agent_listen_port", + "runtime_env_agent_port", + "gcs_server_port", # the `port` option for gcs port. +} + +# Turn this on if actor task log's offsets are expected to be recorded. +# With this enabled, actor tasks' log could be queried with task id. +RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING = env_bool( + "RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING", False +) + +# RuntimeEnv env var to indicate it exports a function +WORKER_PROCESS_SETUP_HOOK_ENV_VAR = "__RAY_WORKER_PROCESS_SETUP_HOOK_ENV_VAR" +RAY_WORKER_PROCESS_SETUP_HOOK_LOAD_TIMEOUT_ENV_VAR = ( + "RAY_WORKER_PROCESS_SETUP_HOOK_LOAD_TIMEOUT" # noqa +) + +RAY_DEFAULT_LABEL_KEYS_PREFIX = "ray.io/" + +RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR = "RAY_TPU_MAX_CONCURRENT_ACTIVE_CONNECTIONS" + +RAY_NODE_IP_FILENAME = "node_ip_address.json" + +PLACEMENT_GROUP_BUNDLE_RESOURCE_NAME = "bundle" + +RAY_LOGGING_CONFIG_ENCODING = os.environ.get("RAY_LOGGING_CONFIG_ENCODING") + +RAY_BACKEND_LOG_JSON_ENV_VAR = "RAY_BACKEND_LOG_JSON" + +# Write export API event of all resource types to file if enabled. +# RAY_enable_export_api_write_config will not be considered if +# this is enabled. +RAY_ENABLE_EXPORT_API_WRITE = env_bool("RAY_enable_export_api_write", False) + +# Comma separated string containing individual resource +# to write export API events for. This configuration is only used if +# RAY_enable_export_api_write is not enabled. Full list of valid +# resource types in ExportEvent.SourceType enum in +# src/ray/protobuf/export_api/export_event.proto +# Example config: +# `export RAY_enable_export_api_write_config='EXPORT_SUBMISSION_JOB,EXPORT_ACTOR'` +RAY_ENABLE_EXPORT_API_WRITE_CONFIG_STR = os.environ.get( + "RAY_enable_export_api_write_config", "" +) +RAY_ENABLE_EXPORT_API_WRITE_CONFIG = RAY_ENABLE_EXPORT_API_WRITE_CONFIG_STR.split(",") + +RAY_EXPORT_EVENT_MAX_FILE_SIZE_BYTES = env_bool( + "RAY_EXPORT_EVENT_MAX_FILE_SIZE_BYTES", 100 * 1e6 +) + +RAY_EXPORT_EVENT_MAX_BACKUP_COUNT = env_bool("RAY_EXPORT_EVENT_MAX_BACKUP_COUNT", 20) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_experimental_perf.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_experimental_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..b46408c2abe17ce25413279b63b4b662f7895f1a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_experimental_perf.py @@ -0,0 +1,337 @@ +"""This is the script for `ray microbenchmark`.""" + +import asyncio +import logging +from ray._private.ray_microbenchmark_helpers import timeit, asyncio_timeit +import multiprocessing +import ray +from ray.dag.compiled_dag_node import CompiledDAG +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +import ray.experimental.channel as ray_channel +from ray.dag import InputNode, MultiOutputNode +from ray._private.utils import ( + get_or_create_event_loop, +) +from ray._private.test_utils import get_actor_node_id + +logger = logging.getLogger(__name__) + + +@ray.remote +class DAGActor: + def echo(self, x): + return x + + def echo_multiple(self, *x): + return x + + +def check_optimized_build(): + if not ray._raylet.OPTIMIZED: + msg = ( + "WARNING: Unoptimized build! " + "To benchmark an optimized build, try:\n" + "\tbazel build -c opt //:ray_pkg\n" + "You can also make this permanent by adding\n" + "\tbuild --compilation_mode=opt\n" + "to your user-wide ~/.bazelrc file. " + "(Do not add this to the project-level .bazelrc file.)" + ) + logger.warning(msg) + + +def create_driver_actor(): + return CompiledDAG.DAGDriverProxyActor.options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ) + ).remote() + + +def main(results=None): + results = results or [] + loop = get_or_create_event_loop() + + check_optimized_build() + + print("Tip: set TESTS_TO_RUN='pattern' to run a subset of benchmarks") + + ################################################# + # Perf tests for channels, used in compiled DAGs. + ################################################# + ray.init() + + def put_channel_small(chans, do_get=False): + for chan in chans: + chan.write(b"0") + if do_get: + chan.read() + + @ray.remote + class ChannelReader: + def ready(self): + return + + def read(self, chans): + while True: + for chan in chans: + chan.read() + + driver_actor = create_driver_actor() + driver_node = get_actor_node_id(driver_actor) + chans = [ray_channel.Channel(None, [(driver_actor, driver_node)], 1000)] + results += timeit( + "[unstable] local put:local get, single channel calls", + lambda: put_channel_small(chans, do_get=True), + ) + + reader = ChannelReader.remote() + reader_node = get_actor_node_id(reader) + chans = [ray_channel.Channel(None, [(reader, reader_node)], 1000)] + ray.get(reader.ready.remote()) + reader.read.remote(chans) + results += timeit( + "[unstable] local put:1 remote get, single channel calls", + lambda: put_channel_small(chans), + ) + ray.kill(reader) + + n_cpu = multiprocessing.cpu_count() // 2 + print(f"Testing multiple readers/channels, n={n_cpu}") + + reader_and_node_list = [] + for _ in range(n_cpu): + reader = ChannelReader.remote() + reader_node = get_actor_node_id(reader) + reader_and_node_list.append((reader, reader_node)) + chans = [ray_channel.Channel(None, reader_and_node_list, 1000)] + ray.get([reader.ready.remote() for reader, _ in reader_and_node_list]) + for reader, _ in reader_and_node_list: + reader.read.remote(chans) + results += timeit( + "[unstable] local put:n remote get, single channel calls", + lambda: put_channel_small(chans), + ) + for reader, _ in reader_and_node_list: + ray.kill(reader) + + reader = ChannelReader.remote() + reader_node = get_actor_node_id(reader) + chans = [ + ray_channel.Channel(None, [(reader, reader_node)], 1000) for _ in range(n_cpu) + ] + ray.get(reader.ready.remote()) + reader.read.remote(chans) + results += timeit( + "[unstable] local put:1 remote get, n channels calls", + lambda: put_channel_small(chans), + ) + ray.kill(reader) + + reader_and_node_list = [] + for _ in range(n_cpu): + reader = ChannelReader.remote() + reader_node = get_actor_node_id(reader) + reader_and_node_list.append((reader, reader_node)) + chans = [ + ray_channel.Channel(None, [reader_and_node_list[i]], 1000) for i in range(n_cpu) + ] + ray.get([reader.ready.remote() for reader, _ in reader_and_node_list]) + for chan, reader_node_tuple in zip(chans, reader_and_node_list): + reader = reader_node_tuple[0] + reader.read.remote([chan]) + results += timeit( + "[unstable] local put:n remote get, n channels calls", + lambda: put_channel_small(chans), + ) + for reader, _ in reader_and_node_list: + ray.kill(reader) + + # Tests for compiled DAGs. + + def _exec(dag, num_args=1, payload_size=1): + output_ref = dag.execute(*[b"x" * payload_size for _ in range(num_args)]) + ray.get(output_ref) + + async def exec_async(tag): + async def _exec_async(): + fut = await compiled_dag.execute_async(b"x") + if not isinstance(fut, list): + await fut + else: + await asyncio.gather(*fut) + + return await asyncio_timeit( + tag, + _exec_async, + ) + + # Single-actor DAG calls + + a = DAGActor.remote() + with InputNode() as inp: + dag = a.echo.bind(inp) + + results += timeit( + "[unstable] single-actor DAG calls", lambda: ray.get(dag.execute(b"x")) + ) + compiled_dag = dag.experimental_compile() + results += timeit( + "[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag) + ) + del a + + # Single-actor asyncio DAG calls + + a = DAGActor.remote() + with InputNode() as inp: + dag = a.echo.bind(inp) + compiled_dag = dag.experimental_compile(enable_asyncio=True) + results += loop.run_until_complete( + exec_async( + "[unstable] compiled single-actor asyncio DAG calls", + ) + ) + del a + + # Scatter-gather DAG calls + + n_cpu = multiprocessing.cpu_count() // 2 + actors = [DAGActor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = MultiOutputNode([a.echo.bind(inp) for a in actors]) + results += timeit( + f"[unstable] scatter-gather DAG calls, n={n_cpu} actors", + lambda: ray.get(dag.execute(b"x")), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors", + lambda: _exec(compiled_dag), + ) + + # Scatter-gather asyncio DAG calls + + actors = [DAGActor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = MultiOutputNode([a.echo.bind(inp) for a in actors]) + compiled_dag = dag.experimental_compile(enable_asyncio=True) + results += loop.run_until_complete( + exec_async( + f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors", + ) + ) + + # Chain DAG calls + + actors = [DAGActor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = inp + for a in actors: + dag = a.echo.bind(dag) + results += timeit( + f"[unstable] chain DAG calls, n={n_cpu} actors", + lambda: ray.get(dag.execute(b"x")), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + f"[unstable] compiled chain DAG calls, n={n_cpu} actors", + lambda: _exec(compiled_dag), + ) + + # Chain asyncio DAG calls + + actors = [DAGActor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = inp + for a in actors: + dag = a.echo.bind(dag) + compiled_dag = dag.experimental_compile(enable_asyncio=True) + results += loop.run_until_complete( + exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors") + ) + + # Multiple args with small payloads + + n_actors = 8 + assert ( + n_cpu > n_actors + ), f"n_cpu ({n_cpu}) must be greater than n_actors ({n_actors})" + + actors = [DAGActor.remote() for _ in range(n_actors)] + with InputNode() as inp: + dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)]) + payload_size = 1 + results += timeit( + f"[unstable] multiple args with small payloads DAG calls, n={n_actors} actors", + lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + f"[unstable] compiled multiple args with small payloads DAG calls, " + f"n={n_actors} actors", + lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), + ) + + # Multiple args with medium payloads + + actors = [DAGActor.remote() for _ in range(n_actors)] + with InputNode() as inp: + dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)]) + payload_size = 1024 * 1024 + results += timeit( + f"[unstable] multiple args with medium payloads DAG calls, n={n_actors} actors", + lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + "[unstable] compiled multiple args with medium payloads DAG calls, " + f"n={n_actors} actors", + lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), + ) + + # Multiple args with large payloads + + actors = [DAGActor.remote() for _ in range(n_actors)] + with InputNode() as inp: + dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)]) + payload_size = 10 * 1024 * 1024 + results += timeit( + f"[unstable] multiple args with large payloads DAG calls, n={n_actors} actors", + lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + "[unstable] compiled multiple args with large payloads DAG calls, " + f"n={n_actors} actors", + lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), + ) + + # Worst case for multiple arguments: a single actor takes all the arguments + # with small payloads. + + actor = DAGActor.remote() + n_args = 8 + with InputNode() as inp: + dag = actor.echo_multiple.bind(*[inp[i] for i in range(n_args)]) + payload_size = 1 + results += timeit( + "[unstable] single-actor with all args with small payloads DAG calls, " + "n=1 actors", + lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_args)])), + ) + compiled_dag = dag.experimental_compile() + results += timeit( + "[unstable] single-actor with all args with small payloads DAG calls, " + "n=1 actors", + lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size), + ) + + ray.shutdown() + + return results + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_microbenchmark_helpers.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_microbenchmark_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..17a4a20c94e418eece5ce445ea5fe71bf3ada231 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_microbenchmark_helpers.py @@ -0,0 +1,91 @@ +import time +from typing import List, Optional, Tuple +import os +import ray +import numpy as np + +from contextlib import contextmanager + +# Only run tests matching this filter pattern. + +filter_pattern = os.environ.get("TESTS_TO_RUN", "") +skip_pattern = os.environ.get("TESTS_TO_SKIP", "") + + +def timeit( + name, fn, multiplier=1, warmup_time_sec=10 +) -> List[Optional[Tuple[str, float, float]]]: + if filter_pattern and filter_pattern not in name: + return [None] + if skip_pattern and skip_pattern in name: + return [None] + # sleep for a while to avoid noisy neigbhors. + # related issue: https://github.com/ray-project/ray/issues/22045 + time.sleep(warmup_time_sec) + # warmup + start = time.perf_counter() + count = 0 + while time.perf_counter() - start < 1: + fn() + count += 1 + # real run + step = count // 10 + 1 + stats = [] + for _ in range(4): + start = time.perf_counter() + count = 0 + while time.perf_counter() - start < 2: + for _ in range(step): + fn() + count += step + end = time.perf_counter() + stats.append(multiplier * count / (end - start)) + + mean = np.mean(stats) + sd = np.std(stats) + print(name, "per second", round(mean, 2), "+-", round(sd, 2)) + return [(name, mean, sd)] + + +async def asyncio_timeit( + name, async_fn, multiplier=1, warmup_time_sec=10 +) -> List[Optional[Tuple[str, float, float]]]: + if filter_pattern and filter_pattern not in name: + return [None] + if skip_pattern and skip_pattern in name: + return [None] + # sleep for a while to avoid noisy neigbhors. + # related issue: https://github.com/ray-project/ray/issues/22045 + time.sleep(warmup_time_sec) + # warmup + start = time.perf_counter() + count = 0 + while time.perf_counter() - start < 1: + await async_fn() + count += 1 + # real run + step = count // 10 + 1 + stats = [] + for _ in range(4): + start = time.perf_counter() + count = 0 + while time.perf_counter() - start < 2: + for _ in range(step): + await async_fn() + count += step + end = time.perf_counter() + stats.append(multiplier * count / (end - start)) + + mean = np.mean(stats) + sd = np.std(stats) + print(name, "per second", round(mean, 2), "+-", round(sd, 2)) + return [(name, mean, sd)] + + +@contextmanager +def ray_setup_and_teardown(**init_args): + ray.init(**init_args) + try: + yield None + finally: + ray.shutdown() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_option_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_option_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61c898aff8c4c92d09f78e28570f80991acf09a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_option_utils.py @@ -0,0 +1,387 @@ +"""Manage, parse and validate options for Ray tasks, actors and actor methods.""" +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import ray +from ray._private import ray_constants +from ray._private.utils import get_ray_doc_version +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import ( + NodeAffinitySchedulingStrategy, + PlacementGroupSchedulingStrategy, + NodeLabelSchedulingStrategy, +) + + +@dataclass +class Option: + # Type constraint of an option. + type_constraint: Optional[Union[type, Tuple[type]]] = None + # Value constraint of an option. + # The callable should return None if there is no error. + # Otherwise, return the error message. + value_constraint: Optional[Callable[[Any], Optional[str]]] = None + # Default value. + default_value: Any = None + + def validate(self, keyword: str, value: Any): + """Validate the option.""" + if self.type_constraint is not None: + if not isinstance(value, self.type_constraint): + raise TypeError( + f"The type of keyword '{keyword}' must be {self.type_constraint}, " + f"but received type {type(value)}" + ) + if self.value_constraint is not None: + possible_error_message = self.value_constraint(value) + if possible_error_message: + raise ValueError(possible_error_message) + + +def _counting_option(name: str, infinite: bool = True, default_value: Any = None): + """This is used for positive and discrete options. + + Args: + name: The name of the option keyword. + infinite: If True, user could use -1 to represent infinity. + default_value: The default value for this option. + """ + if infinite: + return Option( + (int, type(None)), + lambda x: None + if (x is None or x >= -1) + else f"The keyword '{name}' only accepts None, 0, -1" + " or a positive integer, where -1 represents infinity.", + default_value=default_value, + ) + return Option( + (int, type(None)), + lambda x: None + if (x is None or x >= 0) + else f"The keyword '{name}' only accepts None, 0 or a positive integer.", + default_value=default_value, + ) + + +def _validate_resource_quantity(name, quantity): + if quantity < 0: + return f"The quantity of resource {name} cannot be negative" + if ( + isinstance(quantity, float) + and quantity != 0.0 + and int(quantity * ray._raylet.RESOURCE_UNIT_SCALING) == 0 + ): + return ( + f"The precision of the fractional quantity of resource {name}" + " cannot go beyond 0.0001" + ) + resource_name = "GPU" if name == "num_gpus" else name + if resource_name in ray._private.accelerators.get_all_accelerator_resource_names(): + ( + valid, + error_message, + ) = ray._private.accelerators.get_accelerator_manager_for_resource( + resource_name + ).validate_resource_request_quantity( + quantity + ) + if not valid: + return error_message + return None + + +def _resource_option(name: str, default_value: Any = None): + """This is used for resource related options.""" + return Option( + (float, int, type(None)), + lambda x: None if (x is None) else _validate_resource_quantity(name, x), + default_value=default_value, + ) + + +def _validate_resources(resources: Optional[Dict[str, float]]) -> Optional[str]: + if resources is None: + return None + + if "CPU" in resources or "GPU" in resources: + return ( + "Use the 'num_cpus' and 'num_gpus' keyword instead of 'CPU' and 'GPU' " + "in 'resources' keyword" + ) + + for name, quantity in resources.items(): + possible_error_message = _validate_resource_quantity(name, quantity) + if possible_error_message: + return possible_error_message + + return None + + +_common_options = { + "accelerator_type": Option((str, type(None))), + "memory": _resource_option("memory"), + "name": Option((str, type(None))), + "num_cpus": _resource_option("num_cpus"), + "num_gpus": _resource_option("num_gpus"), + "object_store_memory": _counting_option("object_store_memory", False), + # TODO(suquark): "placement_group", "placement_group_bundle_index" + # and "placement_group_capture_child_tasks" are deprecated, + # use "scheduling_strategy" instead. + "placement_group": Option( + (type(None), str, PlacementGroup), default_value="default" + ), + "placement_group_bundle_index": Option(int, default_value=-1), + "placement_group_capture_child_tasks": Option((bool, type(None))), + "resources": Option((dict, type(None)), lambda x: _validate_resources(x)), + "runtime_env": Option((dict, type(None))), + "scheduling_strategy": Option( + ( + type(None), + str, + PlacementGroupSchedulingStrategy, + NodeAffinitySchedulingStrategy, + NodeLabelSchedulingStrategy, + ) + ), + "_metadata": Option((dict, type(None))), + "enable_task_events": Option(bool, default_value=True), + "_labels": Option((dict, type(None))), +} + + +def issubclass_safe(obj: Any, cls_: type) -> bool: + try: + return issubclass(obj, cls_) + except TypeError: + return False + + +_task_only_options = { + "max_calls": _counting_option("max_calls", False, default_value=0), + # Normal tasks may be retried on failure this many times. + # TODO(swang): Allow this to be set globally for an application. + "max_retries": _counting_option( + "max_retries", default_value=ray_constants.DEFAULT_TASK_MAX_RETRIES + ), + # override "_common_options" + "num_cpus": _resource_option("num_cpus", default_value=1), + "num_returns": Option( + (int, str, type(None)), + lambda x: None + if (x is None or x == "dynamic" or x == "streaming" or x >= 0) + else "Default None. When None is passed, " + "The default value is 1 for a task and actor task, and " + "'streaming' for generator tasks and generator actor tasks. " + "The keyword 'num_returns' only accepts None, " + "a non-negative integer, " + "'streaming' (for generators), or 'dynamic'. 'dynamic' flag " + "will be deprecated in the future, and it is recommended to use " + "'streaming' instead.", + default_value=None, + ), + "object_store_memory": Option( # override "_common_options" + (int, type(None)), + lambda x: None + if (x is None) + else "Setting 'object_store_memory' is not implemented for tasks", + ), + "retry_exceptions": Option( + (bool, list, tuple), + lambda x: None + if ( + isinstance(x, bool) + or ( + isinstance(x, (list, tuple)) + and all(issubclass_safe(x_, Exception) for x_ in x) + ) + ) + else "retry_exceptions must be either a boolean or a list of exceptions", + default_value=False, + ), + "_generator_backpressure_num_objects": Option( + (int, type(None)), + lambda x: None + if x != 0 + else ( + "_generator_backpressure_num_objects=0 is not allowed. " + "Use a value > 0. If the value is equal to 1, the behavior " + "is identical to Python generator (generator 1 object " + "whenever `next` is called). Use -1 to disable this feature. " + ), + ), +} + +_actor_only_options = { + "concurrency_groups": Option((list, dict, type(None))), + "lifetime": Option( + (str, type(None)), + lambda x: None + if x in (None, "detached", "non_detached") + else "actor `lifetime` argument must be one of 'detached', " + "'non_detached' and 'None'.", + ), + "max_concurrency": _counting_option("max_concurrency", False), + "max_restarts": _counting_option("max_restarts", default_value=0), + "max_task_retries": _counting_option("max_task_retries", default_value=0), + "max_pending_calls": _counting_option("max_pending_calls", default_value=-1), + "namespace": Option((str, type(None))), + "get_if_exists": Option(bool, default_value=False), +} + +# Priority is important here because during dictionary update, same key with higher +# priority overrides the same key with lower priority. We make use of priority +# to set the correct default value for tasks / actors. + +# priority: _common_options > _actor_only_options > _task_only_options +valid_options: Dict[str, Option] = { + **_task_only_options, + **_actor_only_options, + **_common_options, +} +# priority: _task_only_options > _common_options +task_options: Dict[str, Option] = {**_common_options, **_task_only_options} +# priority: _actor_only_options > _common_options +actor_options: Dict[str, Option] = {**_common_options, **_actor_only_options} + +remote_args_error_string = ( + "The @ray.remote decorator must be applied either with no arguments and no " + "parentheses, for example '@ray.remote', or it must be applied using some of " + f"the arguments in the list {list(valid_options.keys())}, for example " + "'@ray.remote(num_returns=2, resources={\"CustomResource\": 1})'." +) + + +def _check_deprecate_placement_group(options: Dict[str, Any]): + """Check if deprecated placement group option exists.""" + placement_group = options.get("placement_group", "default") + scheduling_strategy = options.get("scheduling_strategy") + # TODO(suquark): @ray.remote(placement_group=None) is used in + # "python/ray.data._internal/remote_fn.py" and many other places, + # while "ray.data.read_api.read_datasource" set "scheduling_strategy=SPREAD". + # This might be a bug, but it is also ok to allow them co-exist. + if (placement_group not in ("default", None)) and (scheduling_strategy is not None): + raise ValueError( + "Placement groups should be specified via the " + "scheduling_strategy option. " + "The placement_group option is deprecated." + ) + + +def _warn_if_using_deprecated_placement_group( + options: Dict[str, Any], caller_stacklevel: int +): + placement_group = options["placement_group"] + placement_group_bundle_index = options["placement_group_bundle_index"] + placement_group_capture_child_tasks = options["placement_group_capture_child_tasks"] + if placement_group != "default": + warnings.warn( + "placement_group parameter is deprecated. Use " + "scheduling_strategy=PlacementGroupSchedulingStrategy(...) " + "instead, see the usage at " + f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501 + DeprecationWarning, + stacklevel=caller_stacklevel + 1, + ) + if placement_group_bundle_index != -1: + warnings.warn( + "placement_group_bundle_index parameter is deprecated. Use " + "scheduling_strategy=PlacementGroupSchedulingStrategy(...) " + "instead, see the usage at " + f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501 + DeprecationWarning, + stacklevel=caller_stacklevel + 1, + ) + if placement_group_capture_child_tasks: + warnings.warn( + "placement_group_capture_child_tasks parameter is deprecated. Use " + "scheduling_strategy=PlacementGroupSchedulingStrategy(...) " + "instead, see the usage at " + f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501 + DeprecationWarning, + stacklevel=caller_stacklevel + 1, + ) + + +def validate_task_options(options: Dict[str, Any], in_options: bool): + """Options check for Ray tasks. + + Args: + options: Options for Ray tasks. + in_options: If True, we are checking the options under the context of + ".options()". + """ + for k, v in options.items(): + if k not in task_options: + raise ValueError( + f"Invalid option keyword {k} for remote functions. " + f"Valid ones are {list(task_options)}." + ) + task_options[k].validate(k, v) + if in_options and "max_calls" in options: + raise ValueError("Setting 'max_calls' is not supported in '.options()'.") + _check_deprecate_placement_group(options) + + +def validate_actor_options(options: Dict[str, Any], in_options: bool): + """Options check for Ray actors. + + Args: + options: Options for Ray actors. + in_options: If True, we are checking the options under the context of + ".options()". + """ + for k, v in options.items(): + if k not in actor_options: + raise ValueError( + f"Invalid option keyword {k} for actors. " + f"Valid ones are {list(actor_options)}." + ) + actor_options[k].validate(k, v) + + if in_options and "concurrency_groups" in options: + raise ValueError( + "Setting 'concurrency_groups' is not supported in '.options()'." + ) + + if options.get("get_if_exists") and not options.get("name"): + raise ValueError("The actor name must be specified to use `get_if_exists`.") + + if "object_store_memory" in options: + warnings.warn( + "Setting 'object_store_memory'" + " for actors is deprecated since it doesn't actually" + " reserve the required object store memory." + f" Use object spilling that's enabled by default (https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/objects/object-spilling.html) " # noqa: E501 + "instead to bypass the object store memory size limitation.", + DeprecationWarning, + stacklevel=1, + ) + + _check_deprecate_placement_group(options) + + +def update_options( + original_options: Dict[str, Any], new_options: Dict[str, Any] +) -> Dict[str, Any]: + """Update original options with new options and return. + The returned updated options contain shallow copy of original options. + """ + + updated_options = {**original_options, **new_options} + # Ensure we update each namespace in "_metadata" independently. + # "_metadata" is a dict like {namespace1: config1, namespace2: config2} + if ( + original_options.get("_metadata") is not None + and new_options.get("_metadata") is not None + ): + # make a shallow copy to avoid messing up the metadata dict in + # the original options. + metadata = original_options["_metadata"].copy() + for namespace, config in new_options["_metadata"].items(): + metadata[namespace] = {**metadata.get(namespace, {}), **config} + + updated_options["_metadata"] = metadata + + return updated_options diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_perf.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..5001cd9d070851f8ae17dbf9bdef840d8090ff35 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_perf.py @@ -0,0 +1,328 @@ +"""This is the script for `ray microbenchmark`.""" + +import asyncio +import logging +from ray._private.ray_microbenchmark_helpers import timeit +from ray._private.ray_client_microbenchmark import main as client_microbenchmark_main +import numpy as np +import multiprocessing +import ray + +logger = logging.getLogger(__name__) + + +@ray.remote(num_cpus=0) +class Actor: + def small_value(self): + return b"ok" + + def small_value_arg(self, x): + return b"ok" + + def small_value_batch(self, n): + ray.get([small_value.remote() for _ in range(n)]) + + +@ray.remote +class AsyncActor: + async def small_value(self): + return b"ok" + + async def small_value_with_arg(self, x): + return b"ok" + + async def small_value_batch(self, n): + await asyncio.wait([small_value.remote() for _ in range(n)]) + + +@ray.remote(num_cpus=0) +class Client: + def __init__(self, servers): + if not isinstance(servers, list): + servers = [servers] + self.servers = servers + + def small_value_batch(self, n): + results = [] + for s in self.servers: + results.extend([s.small_value.remote() for _ in range(n)]) + ray.get(results) + + def small_value_batch_arg(self, n): + x = ray.put(0) + results = [] + for s in self.servers: + results.extend([s.small_value_arg.remote(x) for _ in range(n)]) + ray.get(results) + + +@ray.remote +def small_value(): + return b"ok" + + +@ray.remote +def small_value_batch(n): + submitted = [small_value.remote() for _ in range(n)] + ray.get(submitted) + return 0 + + +@ray.remote +def create_object_containing_ref(): + obj_refs = [] + for _ in range(10000): + obj_refs.append(ray.put(1)) + return obj_refs + + +def check_optimized_build(): + if not ray._raylet.OPTIMIZED: + msg = ( + "WARNING: Unoptimized build! " + "To benchmark an optimized build, try:\n" + "\tbazel build -c opt //:ray_pkg\n" + "You can also make this permanent by adding\n" + "\tbuild --compilation_mode=opt\n" + "to your user-wide ~/.bazelrc file. " + "(Do not add this to the project-level .bazelrc file.)" + ) + logger.warning(msg) + + +def main(results=None): + results = results or [] + + check_optimized_build() + + print("Tip: set TESTS_TO_RUN='pattern' to run a subset of benchmarks") + + ray.init() + + value = ray.put(0) + + def get_small(): + ray.get(value) + + def put_small(): + ray.put(0) + + @ray.remote + def do_put_small(): + for _ in range(100): + ray.put(0) + + def put_multi_small(): + ray.get([do_put_small.remote() for _ in range(10)]) + + arr = np.zeros(100 * 1024 * 1024, dtype=np.int64) + + results += timeit("single client get calls (Plasma Store)", get_small) + + results += timeit("single client put calls (Plasma Store)", put_small) + + results += timeit("multi client put calls (Plasma Store)", put_multi_small, 1000) + + def put_large(): + ray.put(arr) + + results += timeit("single client put gigabytes", put_large, 8 * 0.1) + + def small_value_batch(): + submitted = [small_value.remote() for _ in range(1000)] + ray.get(submitted) + return 0 + + results += timeit("single client tasks and get batch", small_value_batch) + + @ray.remote + def do_put(): + for _ in range(10): + ray.put(np.zeros(10 * 1024 * 1024, dtype=np.int64)) + + def put_multi(): + ray.get([do_put.remote() for _ in range(10)]) + + results += timeit("multi client put gigabytes", put_multi, 10 * 8 * 0.1) + + obj_containing_ref = create_object_containing_ref.remote() + + def get_containing_object_ref(): + ray.get(obj_containing_ref) + + results += timeit( + "single client get object containing 10k refs", get_containing_object_ref + ) + + def wait_multiple_refs(): + num_objs = 1000 + not_ready = [small_value.remote() for _ in range(num_objs)] + # We only need to trigger the fetch_local once for each object, + # raylet will persist these fetch requests even after ray.wait returns. + # See https://github.com/ray-project/ray/issues/30375. + fetch_local = True + for _ in range(num_objs): + _ready, not_ready = ray.wait(not_ready, fetch_local=fetch_local) + if fetch_local: + fetch_local = False + + results += timeit("single client wait 1k refs", wait_multiple_refs) + + def small_task(): + ray.get(small_value.remote()) + + results += timeit("single client tasks sync", small_task) + + def small_task_async(): + ray.get([small_value.remote() for _ in range(1000)]) + + results += timeit("single client tasks async", small_task_async, 1000) + + n = 10000 + m = 4 + actors = [Actor.remote() for _ in range(m)] + + def multi_task(): + submitted = [a.small_value_batch.remote(n) for a in actors] + ray.get(submitted) + + results += timeit("multi client tasks async", multi_task, n * m) + + a = Actor.remote() + + def actor_sync(): + ray.get(a.small_value.remote()) + + results += timeit("1:1 actor calls sync", actor_sync) + + a = Actor.remote() + + def actor_async(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + results += timeit("1:1 actor calls async", actor_async, 1000) + + a = Actor.options(max_concurrency=16).remote() + + def actor_concurrent(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + results += timeit("1:1 actor calls concurrent", actor_concurrent, 1000) + + n = 5000 + n_cpu = multiprocessing.cpu_count() // 2 + actors = [Actor._remote() for _ in range(n_cpu)] + client = Client.remote(actors) + + def actor_async_direct(): + ray.get(client.small_value_batch.remote(n)) + + results += timeit("1:n actor calls async", actor_async_direct, n * len(actors)) + + n_cpu = multiprocessing.cpu_count() // 2 + a = [Actor.remote() for _ in range(n_cpu)] + + @ray.remote + def work(actors): + ray.get([actors[i % n_cpu].small_value.remote() for i in range(n)]) + + def actor_multi2(): + ray.get([work.remote(a) for _ in range(m)]) + + results += timeit("n:n actor calls async", actor_multi2, m * n) + + n = 1000 + actors = [Actor._remote() for _ in range(n_cpu)] + clients = [Client.remote(a) for a in actors] + + def actor_multi2_direct_arg(): + ray.get([c.small_value_batch_arg.remote(n) for c in clients]) + + results += timeit( + "n:n actor calls with arg async", actor_multi2_direct_arg, n * len(clients) + ) + + a = AsyncActor.remote() + + def actor_sync(): + ray.get(a.small_value.remote()) + + results += timeit("1:1 async-actor calls sync", actor_sync) + + a = AsyncActor.remote() + + def async_actor(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + results += timeit("1:1 async-actor calls async", async_actor, 1000) + + a = AsyncActor.remote() + + def async_actor(): + ray.get([a.small_value_with_arg.remote(i) for i in range(1000)]) + + results += timeit("1:1 async-actor calls with args async", async_actor, 1000) + + n = 5000 + n_cpu = multiprocessing.cpu_count() // 2 + actors = [AsyncActor.remote() for _ in range(n_cpu)] + client = Client.remote(actors) + + def async_actor_async(): + ray.get(client.small_value_batch.remote(n)) + + results += timeit("1:n async-actor calls async", async_actor_async, n * len(actors)) + + n = 5000 + m = 4 + n_cpu = multiprocessing.cpu_count() // 2 + a = [AsyncActor.remote() for _ in range(n_cpu)] + + @ray.remote + def async_actor_work(actors): + ray.get([actors[i % n_cpu].small_value.remote() for i in range(n)]) + + def async_actor_multi(): + ray.get([async_actor_work.remote(a) for _ in range(m)]) + + results += timeit("n:n async-actor calls async", async_actor_multi, m * n) + ray.shutdown() + + ############################ + # End of channel perf tests. + ############################ + + NUM_PGS = 100 + NUM_BUNDLES = 1 + ray.init(resources={"custom": 100}) + + def placement_group_create_removal(num_pgs): + pgs = [ + ray.util.placement_group( + bundles=[{"custom": 0.001} for _ in range(NUM_BUNDLES)] + ) + for _ in range(num_pgs) + ] + [pg.wait(timeout_seconds=30) for pg in pgs] + # Include placement group removal here to clean up. + # If we don't clean up placement groups, the whole performance + # gets slower as it runs more. + # Since timeit function runs multiple times without + # the cleaning logic, we should have this method here. + for pg in pgs: + ray.util.remove_placement_group(pg) + + results += timeit( + "placement group create/removal", + lambda: placement_group_create_removal(NUM_PGS), + NUM_PGS, + ) + ray.shutdown() + + client_microbenchmark_main(results) + + return results + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/ray_process_reaper.py b/.venv/lib/python3.11/site-packages/ray/_private/ray_process_reaper.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfbb37fffeb4efb0fb2c43fc86a583e63fa9291 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/ray_process_reaper.py @@ -0,0 +1,60 @@ +import atexit +import os +import signal +import sys +import time + +""" +This is a lightweight "reaper" process used to ensure that ray processes are +cleaned up properly when the main ray process dies unexpectedly (e.g., +segfaults or gets SIGKILLed). Note that processes may not be cleaned up +properly if this process is SIGTERMed or SIGKILLed. + +It detects that its parent has died by reading from stdin, which must be +inherited from the parent process so that the OS will deliver an EOF if the +parent dies. When this happens, the reaper process kills the rest of its +process group (first attempting graceful shutdown with SIGTERM, then escalating +to SIGKILL). +""" + +SIGTERM_GRACE_PERIOD_SECONDS = 1 + + +def reap_process_group(*args): + def sigterm_handler(*args): + # Give a one-second grace period for other processes to clean up. + time.sleep(SIGTERM_GRACE_PERIOD_SECONDS) + # SIGKILL the pgroup (including ourselves) as a last-resort. + if sys.platform == "win32": + atexit.unregister(sigterm_handler) + os.kill(0, signal.CTRL_BREAK_EVENT) + else: + os.killpg(0, signal.SIGKILL) + + # Set a SIGTERM handler to handle SIGTERMing ourselves with the group. + if sys.platform == "win32": + atexit.register(sigterm_handler) + else: + signal.signal(signal.SIGTERM, sigterm_handler) + + # Our parent must have died, SIGTERM the group (including ourselves). + if sys.platform == "win32": + os.kill(0, signal.CTRL_C_EVENT) + else: + os.killpg(0, signal.SIGTERM) + + +def main(): + # Read from stdout forever. Because stdout is a file descriptor + # inherited from our parent process, we will get an EOF if the parent + # dies, which is signaled by an empty return from read(). + # We intentionally don't set any signal handlers here, so a SIGTERM from + # the parent can be used to kill this process gracefully without it killing + # the rest of the process group. + while len(sys.stdin.read()) != 0: + pass + reap_process_group() + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/resource_spec.py b/.venv/lib/python3.11/site-packages/ray/_private/resource_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf13721d20cc9b2de914615dbcb9d5234d9945e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/resource_spec.py @@ -0,0 +1,317 @@ +import logging +import sys +from collections import namedtuple +from typing import Optional + +import ray +import ray._private.ray_constants as ray_constants + + +logger = logging.getLogger(__name__) + +# Prefix for the node id resource that is automatically added to each node. +# For example, a node may have id `node:172.23.42.1`. +NODE_ID_PREFIX = "node:" +# The system resource that head node has. +HEAD_NODE_RESOURCE_NAME = NODE_ID_PREFIX + "__internal_head__" + + +class ResourceSpec( + namedtuple( + "ResourceSpec", + [ + "num_cpus", + "num_gpus", + "memory", + "object_store_memory", + "resources", + "redis_max_memory", + ], + ) +): + """Represents the resource configuration passed to a raylet. + + All fields can be None. Before starting services, resolve() should be + called to return a ResourceSpec with unknown values filled in with + defaults based on the local machine specifications. + + Attributes: + num_cpus: The CPUs allocated for this raylet. + num_gpus: The GPUs allocated for this raylet. + memory: The memory allocated for this raylet. + object_store_memory: The object store memory allocated for this raylet. + Note that when calling to_resource_dict(), this will be scaled down + by 30% to account for the global plasma LRU reserve. + resources: The custom resources allocated for this raylet. + redis_max_memory: The max amount of memory (in bytes) to allow each + redis shard to use. Once the limit is exceeded, redis will start + LRU eviction of entries. This only applies to the sharded redis + tables (task, object, and profile tables). By default, this is + capped at 10GB but can be set higher. + """ + + def __new__( + cls, + num_cpus=None, + num_gpus=None, + memory=None, + object_store_memory=None, + resources=None, + redis_max_memory=None, + ): + return super(ResourceSpec, cls).__new__( + cls, + num_cpus, + num_gpus, + memory, + object_store_memory, + resources, + redis_max_memory, + ) + + def resolved(self): + """Returns if this ResourceSpec has default values filled out.""" + for v in self._asdict().values(): + if v is None: + return False + return True + + def to_resource_dict(self): + """Returns a dict suitable to pass to raylet initialization. + + This renames num_cpus / num_gpus to "CPU" / "GPU", + translates memory from bytes into 100MB memory units, and checks types. + """ + assert self.resolved() + + resources = dict( + self.resources, + CPU=self.num_cpus, + GPU=self.num_gpus, + memory=int(self.memory), + object_store_memory=int(self.object_store_memory), + ) + + resources = { + resource_label: resource_quantity + for resource_label, resource_quantity in resources.items() + if resource_quantity != 0 + } + + # Check types. + for resource_label, resource_quantity in resources.items(): + assert isinstance(resource_quantity, int) or isinstance( + resource_quantity, float + ), ( + f"{resource_label} ({type(resource_quantity)}): " f"{resource_quantity}" + ) + if ( + isinstance(resource_quantity, float) + and not resource_quantity.is_integer() + ): + raise ValueError( + "Resource quantities must all be whole numbers. " + "Violated by resource '{}' in {}.".format(resource_label, resources) + ) + if resource_quantity < 0: + raise ValueError( + "Resource quantities must be nonnegative. " + "Violated by resource '{}' in {}.".format(resource_label, resources) + ) + if resource_quantity > ray_constants.MAX_RESOURCE_QUANTITY: + raise ValueError( + "Resource quantities must be at most {}. " + "Violated by resource '{}' in {}.".format( + ray_constants.MAX_RESOURCE_QUANTITY, resource_label, resources + ) + ) + + return resources + + def resolve(self, is_head: bool, node_ip_address: Optional[str] = None): + """Returns a copy with values filled out with system defaults. + + Args: + is_head: Whether this is the head node. + node_ip_address: The IP address of the node that we are on. + This is used to automatically create a node id resource. + """ + + resources = (self.resources or {}).copy() + assert "CPU" not in resources, resources + assert "GPU" not in resources, resources + assert "memory" not in resources, resources + assert "object_store_memory" not in resources, resources + + if node_ip_address is None: + node_ip_address = ray.util.get_node_ip_address() + + # Automatically create a node id resource on each node. This is + # queryable with ray._private.state.node_ids() and + # ray._private.state.current_node_id(). + resources[NODE_ID_PREFIX + node_ip_address] = 1.0 + + # Automatically create a head node resource. + if HEAD_NODE_RESOURCE_NAME in resources: + raise ValueError( + f"{HEAD_NODE_RESOURCE_NAME}" + " is a reserved resource name, use another name instead." + ) + if is_head: + resources[HEAD_NODE_RESOURCE_NAME] = 1.0 + + num_cpus = self.num_cpus + if num_cpus is None: + num_cpus = ray._private.utils.get_num_cpus() + + num_gpus = 0 + for ( + accelerator_resource_name + ) in ray._private.accelerators.get_all_accelerator_resource_names(): + accelerator_manager = ( + ray._private.accelerators.get_accelerator_manager_for_resource( + accelerator_resource_name + ) + ) + num_accelerators = None + if accelerator_resource_name == "GPU": + num_accelerators = self.num_gpus + else: + num_accelerators = resources.get(accelerator_resource_name, None) + visible_accelerator_ids = ( + accelerator_manager.get_current_process_visible_accelerator_ids() + ) + # Check that the number of accelerators that the raylet wants doesn't + # exceed the amount allowed by visible accelerator ids. + if ( + num_accelerators is not None + and visible_accelerator_ids is not None + and num_accelerators > len(visible_accelerator_ids) + ): + raise ValueError( + f"Attempting to start raylet with {num_accelerators} " + f"{accelerator_resource_name}, " + f"but {accelerator_manager.get_visible_accelerator_ids_env_var()} " + f"contains {visible_accelerator_ids}." + ) + if num_accelerators is None: + # Try to automatically detect the number of accelerators. + num_accelerators = ( + accelerator_manager.get_current_node_num_accelerators() + ) + # Don't use more accelerators than allowed by visible accelerator ids. + if visible_accelerator_ids is not None: + num_accelerators = min( + num_accelerators, len(visible_accelerator_ids) + ) + + if num_accelerators: + if accelerator_resource_name == "GPU": + num_gpus = num_accelerators + else: + resources[accelerator_resource_name] = num_accelerators + + accelerator_type = ( + accelerator_manager.get_current_node_accelerator_type() + ) + if accelerator_type: + resources[ + f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}{accelerator_type}" + ] = 1 + + from ray._private.usage import usage_lib + + usage_lib.record_hardware_usage(accelerator_type) + additional_resources = ( + accelerator_manager.get_current_node_additional_resources() + ) + if additional_resources: + resources.update(additional_resources) + # Choose a default object store size. + system_memory = ray._private.utils.get_system_memory() + avail_memory = ray._private.utils.estimate_available_memory() + object_store_memory = self.object_store_memory + if object_store_memory is None: + object_store_memory = int( + avail_memory * ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION + ) + + # Set the object_store_memory size to 2GB on Mac + # to avoid degraded performance. + # (https://github.com/ray-project/ray/issues/20388) + if sys.platform == "darwin": + object_store_memory = min( + object_store_memory, ray_constants.MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT + ) + + object_store_memory_cap = ( + ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES + ) + + # Cap by shm size by default to avoid low performance, but don't + # go lower than REQUIRE_SHM_SIZE_THRESHOLD. + if sys.platform == "linux" or sys.platform == "linux2": + # Multiple by 0.95 to give a bit of wiggle-room. + # https://github.com/ray-project/ray/pull/23034/files + shm_avail = ray._private.utils.get_shared_memory_bytes() * 0.95 + shm_cap = max(ray_constants.REQUIRE_SHM_SIZE_THRESHOLD, shm_avail) + + object_store_memory_cap = min(object_store_memory_cap, shm_cap) + + # Cap memory to avoid memory waste and perf issues on large nodes + if ( + object_store_memory_cap + and object_store_memory > object_store_memory_cap + ): + logger.debug( + "Warning: Capping object memory store to {}GB. ".format( + object_store_memory_cap // 1e9 + ) + + "To increase this further, specify `object_store_memory` " + "when calling ray.init() or ray start." + ) + object_store_memory = object_store_memory_cap + + redis_max_memory = self.redis_max_memory + if redis_max_memory is None: + redis_max_memory = min( + ray_constants.DEFAULT_REDIS_MAX_MEMORY_BYTES, + max(int(avail_memory * 0.1), ray_constants.REDIS_MINIMUM_MEMORY_BYTES), + ) + if redis_max_memory < ray_constants.REDIS_MINIMUM_MEMORY_BYTES: + raise ValueError( + "Attempting to cap Redis memory usage at {} bytes, " + "but the minimum allowed is {} bytes.".format( + redis_max_memory, ray_constants.REDIS_MINIMUM_MEMORY_BYTES + ) + ) + + memory = self.memory + if memory is None: + memory = ( + avail_memory + - object_store_memory + - (redis_max_memory if is_head else 0) + ) + if memory < 100e6 and memory < 0.05 * system_memory: + raise ValueError( + "After taking into account object store and redis memory " + "usage, the amount of memory on this node available for " + "tasks and actors ({} GB) is less than {}% of total. " + "You can adjust these settings with " + "ray.init(memory=, " + "object_store_memory=).".format( + round(memory / 1e9, 2), int(100 * (memory / system_memory)) + ) + ) + + spec = ResourceSpec( + num_cpus, + num_gpus, + memory, + object_store_memory, + resources, + redis_max_memory, + ) + assert spec.resolved() + return spec diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/__init__.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d40c33d67ca8374373a3258a0daa4445480efb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/__init__.py @@ -0,0 +1,3 @@ +# List of files to exclude from the Ray directory when using runtime_env for +# Ray development. These are not necessary in the Ray workers. +RAY_WORKER_DEV_EXCLUDES = ["raylet", "gcs_server", "cpp/", "tests/", "core/src"] diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/_clonevirtualenv.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/_clonevirtualenv.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2eab3d1040352a0a673323838004f2f4cc7f2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/_clonevirtualenv.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python + +from __future__ import with_statement + +import logging +import optparse +import os +import os.path +import re +import shutil +import subprocess +import sys +import itertools + +__version__ = "0.5.7" + + +logger = logging.getLogger() + + +env_bin_dir = "bin" +if sys.platform == "win32": + env_bin_dir = "Scripts" + _WIN32 = True +else: + _WIN32 = False + + +class UserError(Exception): + pass + + +def _dirmatch(path, matchwith): + """Check if path is within matchwith's tree. + >>> _dirmatch('/home/foo/bar', '/home/foo/bar') + True + >>> _dirmatch('/home/foo/bar/', '/home/foo/bar') + True + >>> _dirmatch('/home/foo/bar/etc', '/home/foo/bar') + True + >>> _dirmatch('/home/foo/bar2', '/home/foo/bar') + False + >>> _dirmatch('/home/foo/bar2/etc', '/home/foo/bar') + False + """ + matchlen = len(matchwith) + if path.startswith(matchwith) and path[matchlen : matchlen + 1] in [os.sep, ""]: + return True + return False + + +def _virtualenv_sys(venv_path): + """obtain version and path info from a virtualenv.""" + executable = os.path.join(venv_path, env_bin_dir, "python") + if _WIN32: + env = os.environ.copy() + else: + env = {} + # Must use "executable" as the first argument rather than as the + # keyword argument "executable" to get correct value from sys.path + p = subprocess.Popen( + [ + executable, + "-c", + "import sys;" + 'print ("%d.%d" % (sys.version_info.major, sys.version_info.minor));' + 'print ("\\n".join(sys.path));', + ], + env=env, + stdout=subprocess.PIPE, + ) + stdout, err = p.communicate() + assert not p.returncode and stdout + lines = stdout.decode("utf-8").splitlines() + return lines[0], list(filter(bool, lines[1:])) + + +def clone_virtualenv(src_dir, dst_dir): + if not os.path.exists(src_dir): + raise UserError("src dir %r does not exist" % src_dir) + if os.path.exists(dst_dir): + raise UserError("dest dir %r exists" % dst_dir) + # sys_path = _virtualenv_syspath(src_dir) + logger.info("cloning virtualenv '%s' => '%s'..." % (src_dir, dst_dir)) + shutil.copytree( + src_dir, dst_dir, symlinks=True, ignore=shutil.ignore_patterns("*.pyc") + ) + version, sys_path = _virtualenv_sys(dst_dir) + logger.info("fixing scripts in bin...") + fixup_scripts(src_dir, dst_dir, version) + + has_old = lambda s: any(i for i in s if _dirmatch(i, src_dir)) # noqa: E731 + + if has_old(sys_path): + # only need to fix stuff in sys.path if we have old + # paths in the sys.path of new python env. right? + logger.info("fixing paths in sys.path...") + fixup_syspath_items(sys_path, src_dir, dst_dir) + v_sys = _virtualenv_sys(dst_dir) + remaining = has_old(v_sys[1]) + assert not remaining, v_sys + fix_symlink_if_necessary(src_dir, dst_dir) + + +def fix_symlink_if_necessary(src_dir, dst_dir): + # sometimes the source virtual environment has symlinks that point to itself + # one example is $OLD_VIRTUAL_ENV/local/lib points to $OLD_VIRTUAL_ENV/lib + # this function makes sure + # $NEW_VIRTUAL_ENV/local/lib will point to $NEW_VIRTUAL_ENV/lib + # usually this goes unnoticed unless one tries to upgrade a package though pip, + # so this bug is hard to find. + logger.info("scanning for internal symlinks that point to the original virtual env") + for dirpath, dirnames, filenames in os.walk(dst_dir): + for a_file in itertools.chain(filenames, dirnames): + full_file_path = os.path.join(dirpath, a_file) + if os.path.islink(full_file_path): + target = os.path.realpath(full_file_path) + if target.startswith(src_dir): + new_target = target.replace(src_dir, dst_dir) + logger.debug("fixing symlink in %s" % (full_file_path,)) + os.remove(full_file_path) + os.symlink(new_target, full_file_path) + + +def fixup_scripts(old_dir, new_dir, version, rewrite_env_python=False): + bin_dir = os.path.join(new_dir, env_bin_dir) + root, dirs, files = next(os.walk(bin_dir)) + pybinre = re.compile(r"pythonw?([0-9]+(\.[0-9]+(\.[0-9]+)?)?)?$") + for file_ in files: + filename = os.path.join(root, file_) + if file_ in ["python", "python%s" % version, "activate_this.py"]: + continue + elif file_.startswith("python") and pybinre.match(file_): + # ignore other possible python binaries + continue + elif file_.endswith(".pyc"): + # ignore compiled files + continue + elif file_ == "activate" or file_.startswith("activate."): + fixup_activate(os.path.join(root, file_), old_dir, new_dir) + elif os.path.islink(filename): + fixup_link(filename, old_dir, new_dir) + elif os.path.isfile(filename): + fixup_script_( + root, + file_, + old_dir, + new_dir, + version, + rewrite_env_python=rewrite_env_python, + ) + + +def fixup_script_(root, file_, old_dir, new_dir, version, rewrite_env_python=False): + old_shebang = "#!%s/bin/python" % os.path.normcase(os.path.abspath(old_dir)) + new_shebang = "#!%s/bin/python" % os.path.normcase(os.path.abspath(new_dir)) + env_shebang = "#!/usr/bin/env python" + + filename = os.path.join(root, file_) + with open(filename, "rb") as f: + if f.read(2) != b"#!": + # no shebang + return + f.seek(0) + lines = f.readlines() + + if not lines: + # warn: empty script + return + + def rewrite_shebang(version=None): + logger.debug("fixing %s" % filename) + shebang = new_shebang + if version: + shebang = shebang + version + shebang = (shebang + "\n").encode("utf-8") + with open(filename, "wb") as f: + f.write(shebang) + f.writelines(lines[1:]) + + try: + bang = lines[0].decode("utf-8").strip() + except UnicodeDecodeError: + # binary file + return + + # This takes care of the scheme in which shebang is of type + # '#!/venv/bin/python3' while the version of system python + # is of type 3.x e.g. 3.5. + short_version = bang[len(old_shebang) :] + + if not bang.startswith("#!"): + return + elif bang == old_shebang: + rewrite_shebang() + elif bang.startswith(old_shebang) and bang[len(old_shebang) :] == version: + rewrite_shebang(version) + elif ( + bang.startswith(old_shebang) + and short_version + and bang[len(old_shebang) :] == short_version + ): + rewrite_shebang(short_version) + elif rewrite_env_python and bang.startswith(env_shebang): + if bang == env_shebang: + rewrite_shebang() + elif bang[len(env_shebang) :] == version: + rewrite_shebang(version) + else: + # can't do anything + return + + +def fixup_activate(filename, old_dir, new_dir): + logger.debug("fixing %s" % filename) + with open(filename, "rb") as f: + data = f.read().decode("utf-8") + + data = data.replace(old_dir, new_dir) + with open(filename, "wb") as f: + f.write(data.encode("utf-8")) + + +def fixup_link(filename, old_dir, new_dir, target=None): + logger.debug("fixing %s" % filename) + if target is None: + target = os.readlink(filename) + + origdir = os.path.dirname(os.path.abspath(filename)).replace(new_dir, old_dir) + if not os.path.isabs(target): + target = os.path.abspath(os.path.join(origdir, target)) + rellink = True + else: + rellink = False + + if _dirmatch(target, old_dir): + if rellink: + # keep relative links, but don't keep original in case it + # traversed up out of, then back into the venv. + # so, recreate a relative link from absolute. + target = target[len(origdir) :].lstrip(os.sep) + else: + target = target.replace(old_dir, new_dir, 1) + + # else: links outside the venv, replaced with absolute path to target. + _replace_symlink(filename, target) + + +def _replace_symlink(filename, newtarget): + tmpfn = "%s.new" % filename + os.symlink(newtarget, tmpfn) + os.rename(tmpfn, filename) + + +def fixup_syspath_items(syspath, old_dir, new_dir): + for path in syspath: + if not os.path.isdir(path): + continue + path = os.path.normcase(os.path.abspath(path)) + if _dirmatch(path, old_dir): + path = path.replace(old_dir, new_dir, 1) + if not os.path.exists(path): + continue + elif not _dirmatch(path, new_dir): + continue + root, dirs, files = next(os.walk(path)) + for file_ in files: + filename = os.path.join(root, file_) + if filename.endswith(".pth"): + fixup_pth_file(filename, old_dir, new_dir) + elif filename.endswith(".egg-link"): + fixup_egglink_file(filename, old_dir, new_dir) + + +def fixup_pth_file(filename, old_dir, new_dir): + logger.debug("fixup_pth_file %s" % filename) + + with open(filename, "r") as f: + lines = f.readlines() + + has_change = False + + for num, line in enumerate(lines): + line = (line.decode("utf-8") if hasattr(line, "decode") else line).strip() + + if not line or line.startswith("#") or line.startswith("import "): + continue + elif _dirmatch(line, old_dir): + lines[num] = line.replace(old_dir, new_dir, 1) + has_change = True + + if has_change: + with open(filename, "w") as f: + payload = os.linesep.join([line.strip() for line in lines]) + os.linesep + f.write(payload) + + +def fixup_egglink_file(filename, old_dir, new_dir): + logger.debug("fixing %s" % filename) + with open(filename, "rb") as f: + link = f.read().decode("utf-8").strip() + if _dirmatch(link, old_dir): + link = link.replace(old_dir, new_dir, 1) + with open(filename, "wb") as f: + link = (link + "\n").encode("utf-8") + f.write(link) + + +def main(): + parser = optparse.OptionParser( + "usage: %prog [options] /path/to/existing/venv /path/to/cloned/venv" + ) + parser.add_option( + "-v", action="count", dest="verbose", default=False, help="verbosity" + ) + options, args = parser.parse_args() + try: + old_dir, new_dir = args + except ValueError: + print("virtualenv-clone %s" % (__version__,)) + parser.error("not enough arguments given.") + old_dir = os.path.realpath(old_dir) + new_dir = os.path.realpath(new_dir) + loglevel = (logging.WARNING, logging.INFO, logging.DEBUG)[min(2, options.verbose)] + logging.basicConfig(level=loglevel, format="%(message)s") + try: + clone_virtualenv(old_dir, new_dir) + except UserError: + e = sys.exc_info()[1] + parser.error(str(e)) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda.py new file mode 100644 index 0000000000000000000000000000000000000000..12022b14c858464b97297808a0024b2e1cd3cae6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda.py @@ -0,0 +1,407 @@ +import hashlib +import json +import logging +import os +import platform +import runpy +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml +from filelock import FileLock + +import ray +from ray._private.runtime_env.conda_utils import ( + create_conda_env_if_needed, + delete_conda_env, + get_conda_activate_commands, + get_conda_info_json, + get_conda_envs, +) +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.packaging import Protocol, parse_uri +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.runtime_env.validation import parse_and_validate_conda +from ray._private.utils import ( + get_directory_size_bytes, + get_master_wheel_url, + get_or_create_event_loop, + get_release_wheel_url, + get_wheel_filename, + try_to_create_directory, +) + +default_logger = logging.getLogger(__name__) + +_WIN32 = os.name == "nt" + + +def _resolve_current_ray_path() -> str: + # When ray is built from source with pip install -e, + # ray.__file__ returns .../python/ray/__init__.py and this function returns + # ".../python". + # When ray is installed from a prebuilt binary, ray.__file__ returns + # .../site-packages/ray/__init__.py and this function returns + # ".../site-packages". + return os.path.split(os.path.split(ray.__file__)[0])[0] + + +def _get_ray_setup_spec(): + """Find the Ray setup_spec from the currently running Ray. + + This function works even when Ray is built from source with pip install -e. + """ + ray_source_python_path = _resolve_current_ray_path() + setup_py_path = os.path.join(ray_source_python_path, "setup.py") + return runpy.run_path(setup_py_path)["setup_spec"] + + +def _resolve_install_from_source_ray_dependencies(): + """Find the Ray dependencies when Ray is installed from source.""" + deps = ( + _get_ray_setup_spec().install_requires + _get_ray_setup_spec().extras["default"] + ) + # Remove duplicates + return list(set(deps)) + + +def _inject_ray_to_conda_site( + conda_path, logger: Optional[logging.Logger] = default_logger +): + """Write the current Ray site package directory to a new site""" + if _WIN32: + python_binary = os.path.join(conda_path, "python") + else: + python_binary = os.path.join(conda_path, "bin/python") + site_packages_path = ( + subprocess.check_output( + [ + python_binary, + "-c", + "import sysconfig; print(sysconfig.get_paths()['purelib'])", + ] + ) + .decode() + .strip() + ) + + ray_path = _resolve_current_ray_path() + logger.warning( + f"Injecting {ray_path} to environment site-packages {site_packages_path} " + "because _inject_current_ray flag is on." + ) + + maybe_ray_dir = os.path.join(site_packages_path, "ray") + if os.path.isdir(maybe_ray_dir): + logger.warning(f"Replacing existing ray installation with {ray_path}") + shutil.rmtree(maybe_ray_dir) + + # See usage of *.pth file at + # https://docs.python.org/3/library/site.html + with open(os.path.join(site_packages_path, "ray_shared.pth"), "w") as f: + f.write(ray_path) + + +def _current_py_version(): + return ".".join(map(str, sys.version_info[:3])) # like 3.6.10 + + +def _is_m1_mac(): + return sys.platform == "darwin" and platform.machine() == "arm64" + + +def current_ray_pip_specifier( + logger: Optional[logging.Logger] = default_logger, +) -> Optional[str]: + """The pip requirement specifier for the running version of Ray. + + Returns: + A string which can be passed to `pip install` to install the + currently running Ray version, or None if running on a version + built from source locally (likely if you are developing Ray). + + Examples: + Returns "https://s3-us-west-2.amazonaws.com/ray-wheels/[..].whl" + if running a stable release, a nightly or a specific commit + """ + if os.environ.get("RAY_CI_POST_WHEEL_TESTS"): + # Running in Buildkite CI after the wheel has been built. + # Wheels are at in the ray/.whl directory, but use relative path to + # allow for testing locally if needed. + return os.path.join( + Path(ray.__file__).resolve().parents[2], ".whl", get_wheel_filename() + ) + elif ray.__commit__ == "{{RAY_COMMIT_SHA}}": + # Running on a version built from source locally. + if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE") != "1": + logger.warning( + "Current Ray version could not be detected, most likely " + "because you have manually built Ray from source. To use " + "runtime_env in this case, set the environment variable " + "RAY_RUNTIME_ENV_LOCAL_DEV_MODE=1." + ) + return None + elif "dev" in ray.__version__: + # Running on a nightly wheel. + if _is_m1_mac(): + raise ValueError("Nightly wheels are not available for M1 Macs.") + return get_master_wheel_url() + else: + if _is_m1_mac(): + # M1 Mac release wheels are currently not uploaded to AWS S3; they + # are only available on PyPI. So unfortunately, this codepath is + # not end-to-end testable prior to the release going live on PyPI. + return f"ray=={ray.__version__}" + else: + return get_release_wheel_url() + + +def inject_dependencies( + conda_dict: Dict[Any, Any], + py_version: str, + pip_dependencies: Optional[List[str]] = None, +) -> Dict[Any, Any]: + """Add Ray, Python and (optionally) extra pip dependencies to a conda dict. + + Args: + conda_dict: A dict representing the JSON-serialized conda + environment YAML file. This dict will be modified and returned. + py_version: A string representing a Python version to inject + into the conda dependencies, e.g. "3.7.7" + pip_dependencies (List[str]): A list of pip dependencies that + will be prepended to the list of pip dependencies in + the conda dict. If the conda dict does not already have a "pip" + field, one will be created. + Returns: + The modified dict. (Note: the input argument conda_dict is modified + and returned.) + """ + if pip_dependencies is None: + pip_dependencies = [] + if conda_dict.get("dependencies") is None: + conda_dict["dependencies"] = [] + + # Inject Python dependency. + deps = conda_dict["dependencies"] + + # Add current python dependency. If the user has already included a + # python version dependency, conda will raise a readable error if the two + # are incompatible, e.g: + # ResolvePackageNotFound: - python[version='3.5.*,>=3.6'] + deps.append(f"python={py_version}") + + if "pip" not in deps: + deps.append("pip") + + # Insert pip dependencies. + found_pip_dict = False + for dep in deps: + if isinstance(dep, dict) and dep.get("pip") and isinstance(dep["pip"], list): + dep["pip"] = pip_dependencies + dep["pip"] + found_pip_dict = True + break + if not found_pip_dict: + deps.append({"pip": pip_dependencies}) + + return conda_dict + + +def _get_conda_env_hash(conda_dict: Dict) -> str: + # Set `sort_keys=True` so that different orderings yield the same hash. + serialized_conda_spec = json.dumps(conda_dict, sort_keys=True) + hash = hashlib.sha1(serialized_conda_spec.encode("utf-8")).hexdigest() + return hash + + +def get_uri(runtime_env: Dict) -> Optional[str]: + """Return `"conda://"`, or None if no GC required.""" + conda = runtime_env.get("conda") + if conda is not None: + if isinstance(conda, str): + # User-preinstalled conda env. We don't garbage collect these, so + # we don't track them with URIs. + uri = None + elif isinstance(conda, dict): + uri = f"conda://{_get_conda_env_hash(conda_dict=conda)}" + else: + raise TypeError( + "conda field received by RuntimeEnvAgent must be " + f"str or dict, not {type(conda).__name__}." + ) + else: + uri = None + return uri + + +def _get_conda_dict_with_ray_inserted( + runtime_env: "RuntimeEnv", # noqa: F821 + logger: Optional[logging.Logger] = default_logger, +) -> Dict[str, Any]: + """Returns the conda spec with the Ray and `python` dependency inserted.""" + conda_dict = json.loads(runtime_env.conda_config()) + assert conda_dict is not None + + ray_pip = current_ray_pip_specifier(logger=logger) + if ray_pip: + extra_pip_dependencies = [ray_pip, "ray[default]"] + elif runtime_env.get_extension("_inject_current_ray"): + extra_pip_dependencies = _resolve_install_from_source_ray_dependencies() + else: + extra_pip_dependencies = [] + conda_dict = inject_dependencies( + conda_dict, _current_py_version(), extra_pip_dependencies + ) + return conda_dict + + +class CondaPlugin(RuntimeEnvPlugin): + + name = "conda" + + def __init__(self, resources_dir: str): + self._resources_dir = os.path.join(resources_dir, "conda") + try_to_create_directory(self._resources_dir) + + # It is not safe for multiple processes to install conda envs + # concurrently, even if the envs are different, so use a global + # lock for all conda installs and deletions. + # See https://github.com/ray-project/ray/issues/17086 + self._installs_and_deletions_file_lock = os.path.join( + self._resources_dir, "ray-conda-installs-and-deletions.lock" + ) + # A set of named conda environments (instead of yaml or dict) + # that are validated to exist. + # NOTE: It has to be only used within the same thread, which + # is an event loop. + # Also, we don't need to GC this field because it is pretty small. + self._validated_named_conda_env = set() + + def _get_path_from_hash(self, hash: str) -> str: + """Generate a path from the hash of a conda or pip spec. + + The output path also functions as the name of the conda environment + when using the `--prefix` option to `conda create` and `conda remove`. + + Example output: + /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources + /conda/ray-9a7972c3a75f55e976e620484f58410c920db091 + """ + return os.path.join(self._resources_dir, hash) + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + """Return the conda URI from the RuntimeEnv if it exists, else return [].""" + conda_uri = runtime_env.conda_uri() + if conda_uri: + return [conda_uri] + return [] + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + logger.info(f"Got request to delete URI {uri}") + protocol, hash = parse_uri(uri) + if protocol != Protocol.CONDA: + raise ValueError( + "CondaPlugin can only delete URIs with protocol " + f"conda. Received protocol {protocol}, URI {uri}" + ) + + conda_env_path = self._get_path_from_hash(hash) + local_dir_size = get_directory_size_bytes(conda_env_path) + + with FileLock(self._installs_and_deletions_file_lock): + successful = delete_conda_env(prefix=conda_env_path, logger=logger) + if not successful: + logger.warning(f"Error when deleting conda env {conda_env_path}. ") + return 0 + + return local_dir_size + + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, + ) -> int: + if not runtime_env.has_conda(): + return 0 + + def _create(): + result = parse_and_validate_conda(runtime_env.get("conda")) + + if isinstance(result, str): + # The conda env name is given. + # In this case, we only verify if the given + # conda env exists. + + # If the env is already validated, do nothing. + if result in self._validated_named_conda_env: + return 0 + + conda_info = get_conda_info_json() + envs = get_conda_envs(conda_info) + + # We accept `result` as a conda name or full path. + if not any(result == env[0] or result == env[1] for env in envs): + raise ValueError( + f"The given conda environment '{result}' " + f"from the runtime env {runtime_env} doesn't " + "exist from the output of `conda info --json`. " + "You can only specify an env that already exists. " + f"Please make sure to create an env {result} " + ) + self._validated_named_conda_env.add(result) + return 0 + + logger.debug( + "Setting up conda for runtime_env: " f"{runtime_env.serialize()}" + ) + protocol, hash = parse_uri(uri) + conda_env_name = self._get_path_from_hash(hash) + + conda_dict = _get_conda_dict_with_ray_inserted(runtime_env, logger=logger) + + logger.info(f"Setting up conda environment with {runtime_env}") + with FileLock(self._installs_and_deletions_file_lock): + try: + conda_yaml_file = os.path.join( + self._resources_dir, "environment.yml" + ) + with open(conda_yaml_file, "w") as file: + yaml.dump(conda_dict, file) + create_conda_env_if_needed( + conda_yaml_file, prefix=conda_env_name, logger=logger + ) + finally: + os.remove(conda_yaml_file) + + if runtime_env.get_extension("_inject_current_ray"): + _inject_ray_to_conda_site(conda_path=conda_env_name, logger=logger) + logger.info(f"Finished creating conda environment at {conda_env_name}") + return get_directory_size_bytes(conda_env_name) + + loop = get_or_create_event_loop() + return await loop.run_in_executor(None, _create) + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + if not runtime_env.has_conda(): + return + + if runtime_env.conda_env_name(): + conda_env_name = runtime_env.conda_env_name() + else: + protocol, hash = parse_uri(runtime_env.conda_uri()) + conda_env_name = self._get_path_from_hash(hash) + context.py_executable = "python" + context.command_prefix += get_conda_activate_commands(conda_env_name) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b298ae4b058808779e9ce134c963bc3589a75a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda_utils.py @@ -0,0 +1,278 @@ +import logging +import os +import shutil +import subprocess +import hashlib +import json +from typing import Optional, List, Union, Tuple + +"""Utilities for conda. Adapted from https://github.com/mlflow/mlflow.""" + +# Name of environment variable indicating a path to a conda installation. Ray +# will default to running "conda" if unset. +RAY_CONDA_HOME = "RAY_CONDA_HOME" + +_WIN32 = os.name == "nt" + + +def get_conda_activate_commands(conda_env_name: str) -> List[str]: + """ + Get a list of commands to run to silently activate the given conda env. + """ + # Checking for newer conda versions + if not _WIN32 and ("CONDA_EXE" in os.environ or RAY_CONDA_HOME in os.environ): + conda_path = get_conda_bin_executable("conda") + activate_conda_env = [ + ".", + f"{os.path.dirname(conda_path)}/../etc/profile.d/conda.sh", + "&&", + ] + activate_conda_env += ["conda", "activate", conda_env_name] + + else: + activate_path = get_conda_bin_executable("activate") + if not _WIN32: + # Use bash command syntax + activate_conda_env = ["source", activate_path, conda_env_name] + else: + activate_conda_env = ["conda", "activate", conda_env_name] + return activate_conda_env + ["1>&2", "&&"] + + +def get_conda_bin_executable(executable_name: str) -> str: + """ + Return path to the specified executable, assumed to be discoverable within + a conda installation. + + The conda home directory (expected to contain a 'bin' subdirectory on + linux) is configurable via the ``RAY_CONDA_HOME`` environment variable. If + ``RAY_CONDA_HOME`` is unspecified, try the ``CONDA_EXE`` environment + variable set by activating conda. If neither is specified, this method + returns `executable_name`. + """ + conda_home = os.environ.get(RAY_CONDA_HOME) + if conda_home: + if _WIN32: + candidate = os.path.join(conda_home, "%s.exe" % executable_name) + if os.path.exists(candidate): + return candidate + candidate = os.path.join(conda_home, "%s.bat" % executable_name) + if os.path.exists(candidate): + return candidate + else: + return os.path.join(conda_home, "bin/%s" % executable_name) + else: + conda_home = "." + # Use CONDA_EXE as per https://github.com/conda/conda/issues/7126 + if "CONDA_EXE" in os.environ: + conda_bin_dir = os.path.dirname(os.environ["CONDA_EXE"]) + if _WIN32: + candidate = os.path.join(conda_home, "%s.exe" % executable_name) + if os.path.exists(candidate): + return candidate + candidate = os.path.join(conda_home, "%s.bat" % executable_name) + if os.path.exists(candidate): + return candidate + else: + return os.path.join(conda_bin_dir, executable_name) + if _WIN32: + return executable_name + ".bat" + return executable_name + + +def _get_conda_env_name(conda_env_path: str) -> str: + conda_env_contents = open(conda_env_path).read() + return "ray-%s" % hashlib.sha1(conda_env_contents.encode("utf-8")).hexdigest() + + +def create_conda_env_if_needed( + conda_yaml_file: str, prefix: str, logger: Optional[logging.Logger] = None +) -> None: + """ + Given a conda YAML, creates a conda environment containing the required + dependencies if such a conda environment doesn't already exist. + Args: + conda_yaml_file: The path to a conda `environment.yml` file. + prefix: Directory to install the environment into via + the `--prefix` option to conda create. This also becomes the name + of the conda env; i.e. it can be passed into `conda activate` and + `conda remove` + """ + if logger is None: + logger = logging.getLogger(__name__) + + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except (EnvironmentError, FileNotFoundError): + raise ValueError( + f"Could not find Conda executable at '{conda_path}'. " + "Ensure Conda is installed as per the instructions at " + "https://conda.io/projects/conda/en/latest/" + "user-guide/install/index.html. " + "You can also configure Ray to look for a specific " + f"Conda executable by setting the {RAY_CONDA_HOME} " + "environment variable to the path of the Conda executable." + ) + + _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"]) + envs = json.loads(stdout[stdout.index("{") :])["envs"] + + if prefix in envs: + logger.info(f"Conda environment {prefix} already exists.") + return + + create_cmd = [ + conda_path, + "env", + "create", + "--file", + conda_yaml_file, + "--prefix", + prefix, + ] + + logger.info(f"Creating conda environment {prefix}") + exit_code, output = exec_cmd_stream_to_logger(create_cmd, logger) + if exit_code != 0: + if os.path.exists(prefix): + shutil.rmtree(prefix) + raise RuntimeError( + f"Failed to install conda environment {prefix}:\nOutput:\n{output}" + ) + + +def delete_conda_env(prefix: str, logger: Optional[logging.Logger] = None) -> bool: + if logger is None: + logger = logging.getLogger(__name__) + + logger.info(f"Deleting conda environment {prefix}") + + conda_path = get_conda_bin_executable("conda") + delete_cmd = [conda_path, "remove", "-p", prefix, "--all", "-y"] + exit_code, output = exec_cmd_stream_to_logger(delete_cmd, logger) + + if exit_code != 0: + logger.debug(f"Failed to delete conda environment {prefix}:\n{output}") + return False + + return True + + +def get_conda_env_list() -> list: + """ + Get conda env list in full paths. + """ + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except EnvironmentError: + raise ValueError(f"Could not find Conda executable at {conda_path}.") + _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"]) + envs = json.loads(stdout)["envs"] + return envs + + +def get_conda_info_json() -> dict: + """ + Get `conda info --json` output. + + Returns dict of conda info. See [1] for more details. We mostly care about these + keys: + + - `conda_prefix`: str The path to the conda installation. + - `envs`: List[str] absolute paths to conda environments. + + [1] https://github.com/conda/conda/blob/main/conda/cli/main_info.py + """ + conda_path = get_conda_bin_executable("conda") + try: + exec_cmd([conda_path, "--help"], throw_on_error=False) + except EnvironmentError: + raise ValueError(f"Could not find Conda executable at {conda_path}.") + _, stdout, _ = exec_cmd([conda_path, "info", "--json"]) + return json.loads(stdout) + + +def get_conda_envs(conda_info: dict) -> List[Tuple[str, str]]: + """ + Gets the conda environments, as a list of (name, path) tuples. + """ + prefix = conda_info["conda_prefix"] + ret = [] + for env in conda_info["envs"]: + if env == prefix: + ret.append(("base", env)) + else: + ret.append((os.path.basename(env), env)) + return ret + + +class ShellCommandException(Exception): + pass + + +def exec_cmd( + cmd: List[str], throw_on_error: bool = True, logger: Optional[logging.Logger] = None +) -> Union[int, Tuple[int, str, str]]: + """ + Runs a command as a child process. + + A convenience wrapper for running a command from a Python script. + + Note on the return value: A tuple of the exit code, + standard output and standard error is returned. + + Args: + cmd: the command to run, as a list of strings + throw_on_error: if true, raises an Exception if the exit code of the + program is nonzero + """ + child = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + (stdout, stderr) = child.communicate() + exit_code = child.wait() + if throw_on_error and exit_code != 0: + raise ShellCommandException( + "Non-zero exit code: %s\n\nSTDOUT:\n%s\n\nSTDERR:%s" + % (exit_code, stdout, stderr) + ) + return exit_code, stdout, stderr + + +def exec_cmd_stream_to_logger( + cmd: List[str], logger: logging.Logger, n_lines: int = 50, **kwargs +) -> Tuple[int, str]: + """Runs a command as a child process, streaming output to the logger. + + The last n_lines lines of output are also returned (stdout and stderr). + """ + if "env" in kwargs and _WIN32 and "PATH" not in [x.upper() for x in kwargs.keys]: + raise ValueError("On windows, Popen requires 'PATH' in 'env'") + child = subprocess.Popen( + cmd, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **kwargs, + ) + last_n_lines = [] + with child.stdout: + for line in iter(child.stdout.readline, b""): + exit_code = child.poll() + if exit_code is not None: + break + line = line.strip() + if not line: + continue + last_n_lines.append(line.strip()) + last_n_lines = last_n_lines[-n_lines:] + logger.info(line.strip()) + + exit_code = child.wait() + return exit_code, "\n".join(last_n_lines) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/constants.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6096b5993ef73836771ade7966947681bbab89 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/constants.py @@ -0,0 +1,28 @@ +# Env var set by job manager to pass runtime env and metadata to subprocess +RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR" + +# The plugin config which should be loaded when ray cluster starts. +# It is a json formatted config, +# e.g. [{"class": "xxx.xxx.xxx_plugin", "priority": 10}]. +RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS" + +# The field name of plugin class in the plugin config. +RAY_RUNTIME_ENV_CLASS_FIELD_NAME = "class" + +# The field name of priority in the plugin config. +RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME = "priority" + +# The default priority of runtime env plugin. +RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY = 10 + +# The minimum priority of runtime env plugin. +RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY = 0 + +# The maximum priority of runtime env plugin. +RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY = 100 + +# The schema files or directories of plugins which should be loaded in workers. +RAY_RUNTIME_ENV_PLUGIN_SCHEMAS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGIN_SCHEMAS" + +# The file suffix of runtime env plugin schemas. +RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX = ".json" diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/context.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/context.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e76b4c7112e61035a596d66b222d35c28d5fbe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/context.py @@ -0,0 +1,108 @@ +import json +import logging +import os +import subprocess +import shlex +import sys +from typing import Dict, List, Optional + +from ray.util.annotations import DeveloperAPI +from ray.core.generated.common_pb2 import Language +from ray._private.services import get_ray_jars_dir +from ray._private.utils import update_envs + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class RuntimeEnvContext: + """A context used to describe the created runtime env.""" + + def __init__( + self, + command_prefix: List[str] = None, + env_vars: Dict[str, str] = None, + py_executable: Optional[str] = None, + override_worker_entrypoint: Optional[str] = None, + java_jars: List[str] = None, + ): + self.command_prefix = command_prefix or [] + self.env_vars = env_vars or {} + self.py_executable = py_executable or sys.executable + self.override_worker_entrypoint: Optional[str] = override_worker_entrypoint + self.java_jars = java_jars or [] + + def serialize(self) -> str: + return json.dumps(self.__dict__) + + @staticmethod + def deserialize(json_string): + return RuntimeEnvContext(**json.loads(json_string)) + + def exec_worker(self, passthrough_args: List[str], language: Language): + update_envs(self.env_vars) + + if language == Language.PYTHON and sys.platform == "win32": + executable = [self.py_executable] + elif language == Language.PYTHON: + executable = ["exec", self.py_executable] + elif language == Language.JAVA: + executable = ["java"] + ray_jars = os.path.join(get_ray_jars_dir(), "*") + + local_java_jars = [] + for java_jar in self.java_jars: + local_java_jars.append(f"{java_jar}/*") + local_java_jars.append(java_jar) + + class_path_args = ["-cp", ray_jars + ":" + str(":".join(local_java_jars))] + passthrough_args = class_path_args + passthrough_args + elif sys.platform == "win32": + executable = [] + else: + executable = ["exec"] + + # By default, raylet uses the path to default_worker.py on host. + # However, the path to default_worker.py inside the container + # can be different. We need the user to specify the path to + # default_worker.py inside the container. + if self.override_worker_entrypoint: + logger.debug( + f"Changing the worker entrypoint from {passthrough_args[0]} to " + f"{self.override_worker_entrypoint}." + ) + passthrough_args[0] = self.override_worker_entrypoint + + if sys.platform == "win32": + + def quote(s): + s = s.replace("&", "%26") + return s + + passthrough_args = [quote(s) for s in passthrough_args] + + cmd = [*self.command_prefix, *executable, *passthrough_args] + logger.debug(f"Exec'ing worker with command: {cmd}") + subprocess.Popen(cmd, shell=True).wait() + else: + # We use shlex to do the necessary shell escape + # of special characters in passthrough_args. + passthrough_args = [shlex.quote(s) for s in passthrough_args] + cmd = [*self.command_prefix, *executable, *passthrough_args] + # TODO(SongGuyang): We add this env to command for macOS because it doesn't + # work for the C++ process of `os.execvp`. We should find a better way to + # fix it. + MACOS_LIBRARY_PATH_ENV_NAME = "DYLD_LIBRARY_PATH" + if MACOS_LIBRARY_PATH_ENV_NAME in os.environ: + cmd.insert( + 0, + f"{MACOS_LIBRARY_PATH_ENV_NAME}=" + f"{os.environ[MACOS_LIBRARY_PATH_ENV_NAME]}", + ) + logger.debug(f"Exec'ing worker with command: {cmd}") + # PyCharm will monkey patch the os.execvp at + # .pycharm_helpers/pydev/_pydev_bundle/pydev_monkey.py + # The monkey patched os.execvp function has a different + # signature. So, we use os.execvp("executable", args=[]) + # instead of os.execvp(file="executable", args=[]) + os.execvp("bash", args=["bash", "-c", " ".join(cmd)]) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/default_impl.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/default_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..331dc7fce01e2b096e4db2cf0d4cc32bd99301f4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/default_impl.py @@ -0,0 +1,11 @@ +from ray._private.runtime_env.image_uri import ImageURIPlugin + + +def get_image_uri_plugin_cls(): + return ImageURIPlugin + + +def get_protocols_provider(): + from ray._private.runtime_env.protocol import ProtocolsProvider + + return ProtocolsProvider diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/dependency_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/dependency_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..baac4a3d2ee40a6e7efd130749b131041d64e36c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/dependency_utils.py @@ -0,0 +1,113 @@ +"""Util functions to manage dependency requirements.""" + +from typing import List, Tuple, Optional +import os +import tempfile +import logging +from contextlib import asynccontextmanager +from ray._private.runtime_env import virtualenv_utils +from ray._private.runtime_env.utils import check_output_cmd + +INTERNAL_PIP_FILENAME = "ray_runtime_env_internal_pip_requirements.txt" +MAX_INTERNAL_PIP_FILENAME_TRIES = 100 + + +def gen_requirements_txt(requirements_file: str, pip_packages: List[str]): + """Dump [pip_packages] to the given [requirements_file] for later env setup.""" + with open(requirements_file, "w") as file: + for line in pip_packages: + file.write(line + "\n") + + +@asynccontextmanager +async def check_ray(python: str, cwd: str, logger: logging.Logger): + """A context manager to check ray is not overwritten. + + Currently, we only check ray version and path. It works for virtualenv, + - ray is in Python's site-packages. + - ray is overwritten during yield. + - ray is in virtualenv's site-packages. + """ + + async def _get_ray_version_and_path() -> Tuple[str, str]: + with tempfile.TemporaryDirectory( + prefix="check_ray_version_tempfile" + ) as tmp_dir: + ray_version_path = os.path.join(tmp_dir, "ray_version.txt") + check_ray_cmd = [ + python, + "-c", + """ +import ray +with open(r"{ray_version_path}", "wt") as f: + f.write(ray.__version__) + f.write(" ") + f.write(ray.__path__[0]) + """.format( + ray_version_path=ray_version_path + ), + ] + if virtualenv_utils._WIN32: + env = os.environ.copy() + else: + env = {} + output = await check_output_cmd( + check_ray_cmd, logger=logger, cwd=cwd, env=env + ) + logger.info(f"try to write ray version information in: {ray_version_path}") + with open(ray_version_path, "rt") as f: + output = f.read() + # print after import ray may have  endings, so we strip them by *_ + ray_version, ray_path, *_ = [s.strip() for s in output.split()] + return ray_version, ray_path + + version, path = await _get_ray_version_and_path() + yield + actual_version, actual_path = await _get_ray_version_and_path() + if actual_version != version or actual_path != path: + raise RuntimeError( + "Changing the ray version is not allowed: \n" + f" current version: {actual_version}, " + f"current path: {actual_path}\n" + f" expect version: {version}, " + f"expect path: {path}\n" + "Please ensure the dependencies in the runtime_env pip field " + "do not install a different version of Ray." + ) + + +def get_requirements_file(target_dir: str, pip_list: Optional[List[str]]) -> str: + """Returns the path to the requirements file to use for this runtime env. + + If pip_list is not None, we will check if the internal pip filename is in any of + the entries of pip_list. If so, we will append numbers to the end of the + filename until we find one that doesn't conflict. This prevents infinite + recursion if the user specifies the internal pip filename in their pip list. + + Args: + target_dir: The directory to store the requirements file in. + pip_list: A list of pip requirements specified by the user. + + Returns: + The path to the requirements file to use for this runtime env. + """ + + def filename_in_pip_list(filename: str) -> bool: + for pip_entry in pip_list: + if filename in pip_entry: + return True + return False + + filename = INTERNAL_PIP_FILENAME + if pip_list is not None: + i = 1 + while filename_in_pip_list(filename) and i < MAX_INTERNAL_PIP_FILENAME_TRIES: + filename = f"{INTERNAL_PIP_FILENAME}.{i}" + i += 1 + if i == MAX_INTERNAL_PIP_FILENAME_TRIES: + raise RuntimeError( + "Could not find a valid filename for the internal " + "pip requirements file. Please specify a different " + "pip list in your runtime env." + ) + return os.path.join(target_dir, filename) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/image_uri.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/image_uri.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2b39907271df7143cc9fd95795add64a1318a6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/image_uri.py @@ -0,0 +1,195 @@ +import logging +import os +from typing import List, Optional + +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.runtime_env.utils import check_output_cmd + +default_logger = logging.getLogger(__name__) + + +async def _create_impl(image_uri: str, logger: logging.Logger): + # Pull image if it doesn't exist + # Also get path to `default_worker.py` inside the image. + pull_image_cmd = [ + "podman", + "run", + "--rm", + image_uri, + "python", + "-c", + ( + "import ray._private.workers.default_worker as default_worker; " + "print(default_worker.__file__)" + ), + ] + logger.info("Pulling image %s", image_uri) + worker_path = await check_output_cmd(pull_image_cmd, logger=logger) + return worker_path.strip() + + +def _modify_context_impl( + image_uri: str, + worker_path: str, + run_options: Optional[List[str]], + context: RuntimeEnvContext, + logger: logging.Logger, + ray_tmp_dir: str, +): + context.override_worker_entrypoint = worker_path + + container_driver = "podman" + container_command = [ + container_driver, + "run", + "-v", + ray_tmp_dir + ":" + ray_tmp_dir, + "--cgroup-manager=cgroupfs", + "--network=host", + "--pid=host", + "--ipc=host", + # NOTE(zcin): Mounted volumes in rootless containers are + # owned by the user `root`. The user on host (which will + # usually be `ray` if this is being run in a ray docker + # image) who started the container is mapped using user + # namespaces to the user `root` in a rootless container. In + # order for the Ray Python worker to access the mounted ray + # tmp dir, we need to use keep-id mode which maps the user + # as itself (instead of as `root`) into the container. + # https://www.redhat.com/sysadmin/rootless-podman-user-namespace-modes + "--userns=keep-id", + ] + + # Environment variables to set in container + env_vars = dict() + + # Propagate all host environment variables that have the prefix "RAY_" + # This should include RAY_RAYLET_PID + for env_var_name, env_var_value in os.environ.items(): + if env_var_name.startswith("RAY_"): + env_vars[env_var_name] = env_var_value + + # Support for runtime_env['env_vars'] + env_vars.update(context.env_vars) + + # Set environment variables + for env_var_name, env_var_value in env_vars.items(): + container_command.append("--env") + container_command.append(f"{env_var_name}='{env_var_value}'") + + # The RAY_JOB_ID environment variable is needed for the default worker. + # It won't be set at the time setup() is called, but it will be set + # when worker command is executed, so we use RAY_JOB_ID=$RAY_JOB_ID + # for the container start command + container_command.append("--env") + container_command.append("RAY_JOB_ID=$RAY_JOB_ID") + + if run_options: + container_command.extend(run_options) + # TODO(chenk008): add resource limit + container_command.append("--entrypoint") + container_command.append("python") + container_command.append(image_uri) + + # Example: + # podman run -v /tmp/ray:/tmp/ray + # --cgroup-manager=cgroupfs --network=host --pid=host --ipc=host + # --userns=keep-id --env RAY_RAYLET_PID=23478 --env RAY_JOB_ID=$RAY_JOB_ID + # --entrypoint python rayproject/ray:nightly-py39 + container_command_str = " ".join(container_command) + logger.info(f"Starting worker in container with prefix {container_command_str}") + + context.py_executable = container_command_str + + +class ImageURIPlugin(RuntimeEnvPlugin): + """Starts worker in a container of a custom image.""" + + name = "image_uri" + + @staticmethod + def get_compatible_keys(): + return {"image_uri", "config", "env_vars"} + + def __init__(self, ray_tmp_dir: str): + self._ray_tmp_dir = ray_tmp_dir + + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: + if not runtime_env.image_uri(): + return + + self.worker_path = await _create_impl(runtime_env.image_uri(), logger) + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + if not runtime_env.image_uri(): + return + + _modify_context_impl( + runtime_env.image_uri(), + self.worker_path, + [], + context, + logger, + self._ray_tmp_dir, + ) + + +class ContainerPlugin(RuntimeEnvPlugin): + """Starts worker in container.""" + + name = "container" + + def __init__(self, ray_tmp_dir: str): + self._ray_tmp_dir = ray_tmp_dir + + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: + if not runtime_env.has_py_container() or not runtime_env.py_container_image(): + return + + self.worker_path = await _create_impl(runtime_env.py_container_image(), logger) + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + if not runtime_env.has_py_container() or not runtime_env.py_container_image(): + return + + if runtime_env.py_container_worker_path(): + logger.warning( + "You are using `container.worker_path`, but the path to " + "`default_worker.py` is now automatically detected from the image. " + "`container.worker_path` is deprecated and will be removed in future " + "versions." + ) + + _modify_context_impl( + runtime_env.py_container_image(), + runtime_env.py_container_worker_path() or self.worker_path, + runtime_env.py_container_run_options(), + context, + logger, + self._ray_tmp_dir, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/java_jars.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/java_jars.py new file mode 100644 index 0000000000000000000000000000000000000000..4312e8521f78c35257a25248dd70ace7cdaf6737 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/java_jars.py @@ -0,0 +1,103 @@ +import logging +import os +from typing import Dict, List, Optional + +from ray._private.gcs_utils import GcsAioClient +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.packaging import ( + delete_package, + download_and_unpack_package, + get_local_dir_from_uri, + is_jar_uri, +) +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.utils import get_directory_size_bytes, try_to_create_directory +from ray.exceptions import RuntimeEnvSetupError + +default_logger = logging.getLogger(__name__) + + +class JavaJarsPlugin(RuntimeEnvPlugin): + + name = "java_jars" + + def __init__(self, resources_dir: str, gcs_aio_client: GcsAioClient): + self._resources_dir = os.path.join(resources_dir, "java_jars_files") + self._gcs_aio_client = gcs_aio_client + try_to_create_directory(self._resources_dir) + + def _get_local_dir_from_uri(self, uri: str): + return get_local_dir_from_uri(uri, self._resources_dir) + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + local_dir = get_local_dir_from_uri(uri, self._resources_dir) + local_dir_size = get_directory_size_bytes(local_dir) + + deleted = delete_package(uri, self._resources_dir) + if not deleted: + logger.warning(f"Tried to delete nonexistent URI: {uri}.") + return 0 + + return local_dir_size + + def get_uris(self, runtime_env: dict) -> List[str]: + return runtime_env.java_jars() + + async def _download_jars( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ): + """Download a jar URI.""" + try: + jar_file = await download_and_unpack_package( + uri, self._resources_dir, self._gcs_aio_client, logger=logger + ) + except Exception as e: + raise RuntimeEnvSetupError( + "Failed to download jar file: {}".format(e) + ) from e + module_dir = self._get_local_dir_from_uri(uri) + logger.debug(f"Succeeded to download jar file {jar_file} .") + return module_dir + + async def create( + self, + uri: str, + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ) -> int: + if not uri: + return 0 + if is_jar_uri(uri): + module_dir = await self._download_jars(uri=uri, logger=logger) + else: + try: + module_dir = await download_and_unpack_package( + uri, self._resources_dir, self._gcs_aio_client, logger=logger + ) + except Exception as e: + raise RuntimeEnvSetupError( + "Failed to download jar file: {}".format(e) + ) from e + + return get_directory_size_bytes(module_dir) + + def modify_context( + self, + uris: List[str], + runtime_env_dict: Dict, + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + for uri in uris: + module_dir = self._get_local_dir_from_uri(uri) + if not module_dir.exists(): + raise ValueError( + f"Local directory {module_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "downloading, unpacking or installing the java jar files." + ) + context.java_jars.append(str(module_dir)) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a8c1d26152842a8851a9c4f15311318507c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi.py @@ -0,0 +1,114 @@ +import logging +import os +from typing import List, Optional +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +import subprocess + +default_logger = logging.getLogger(__name__) + + +def mpi_init(): + """Initialize the MPI cluster. When using MPI cluster, this must be called first.""" + + if hasattr(mpi_init, "inited"): + assert mpi_init.inited is True + return + + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + if rank == 0: + from ray._private.accelerators import get_all_accelerator_managers + + device_vars = [ + m.get_visible_accelerator_ids_env_var() + for m in get_all_accelerator_managers() + ] + visible_devices = { + n: os.environ.get(n) for n in device_vars if os.environ.get(n) + } + comm.bcast(visible_devices) + with open(f"/tmp/{os.getpid()}.{rank}", "w") as f: + f.write(str(visible_devices)) + else: + visible_devices = comm.bcast(None) + os.environ.update(visible_devices) + mpi_init.inited = True + + +class MPIPlugin(RuntimeEnvPlugin): + """This plugin enable a MPI cluster to run on top of ray. + + To use this, "mpi" need to be added to the runtime env like following + + @ray.remote( + runtime_env={ + "mpi": { + "args": ["-n", "4"], + "worker_entry": worker_entry, + } + } + ) + def calc_pi(): + ... + + Here worker_entry should be function for the MPI worker to run. + For example, it should be `'py_module.worker_func'`. The module should be able to + be imported in the runtime. + + In the mpi worker with rank==0, it'll be the normal ray function or actor. + For the worker with rank > 0, it'll just run `worker_func`. + + ray.runtime_env.mpi_init must be called in the ray actors/tasks before any MPI + communication. + """ + + priority = 90 + name = "mpi" + + def modify_context( + self, + uris: List[str], # noqa: ARG002 + runtime_env: "RuntimeEnv", # noqa: F821 ARG002 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, # noqa: ARG002 + ) -> None: + mpi_config = runtime_env.mpi() + if mpi_config is None: + return + try: + proc = subprocess.run( + ["mpirun", "--version"], capture_output=True, check=True + ) + except subprocess.CalledProcessError: + logger.exception( + "Failed to run mpi run. Please make sure mpi has been installed" + ) + # The worker will fail to run and exception will be thrown in runtime + # env agent. + raise + + logger.info(f"Running MPI plugin\n {proc.stdout.decode()}") + + # worker_entry should be a file either in the working dir + # or visible inside the cluster. + worker_entry = mpi_config.get("worker_entry") + + assert ( + worker_entry is not None + ), "`worker_entry` must be setup in the runtime env." + + cmds = ( + ["mpirun"] + + mpi_config.get("args", []) + + [ + context.py_executable, + "-m", + "ray._private.runtime_env.mpi_runner", + worker_entry, + ] + ) + # Construct the start cmd + context.py_executable = " ".join(cmds) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi_runner.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..bc419d01b5c4d4affc655e923e40cdcb2cfa0485 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi_runner.py @@ -0,0 +1,32 @@ +import sys +import argparse +import importlib +from mpi4py import MPI + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Setup MPI worker") + parser.add_argument("worker_entry") + parser.add_argument("main_entry") + + args, remaining_args = parser.parse_known_args() + + comm = MPI.COMM_WORLD + + rank = comm.Get_rank() + + if rank == 0: + entry_file = args.main_entry + + sys.argv[1:] = remaining_args + spec = importlib.util.spec_from_file_location("__main__", entry_file) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + else: + from ray.runtime_env import mpi_init + + mpi_init() + module, func = args.worker_entry.rsplit(".", 1) + m = importlib.import_module(module) + f = getattr(m, func) + f() diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/nsight.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/nsight.py new file mode 100644 index 0000000000000000000000000000000000000000..c5770e109478e3d389de7e996e1120b7caa86bae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/nsight.py @@ -0,0 +1,149 @@ +import os +import sys +import logging +import asyncio +import subprocess +import copy +from pathlib import Path +from typing import Tuple, List, Dict, Optional + +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.utils import ( + try_to_create_directory, +) +from ray.exceptions import RuntimeEnvSetupError + +default_logger = logging.getLogger(__name__) + +# Nsight options used when runtime_env={"_nsight": "default"} +NSIGHT_DEFAULT_CONFIG = { + "t": "cuda,cudnn,cublas,nvtx", + "o": "'worker_process_%p'", + "stop-on-exit": "true", +} + + +def parse_nsight_config(nsight_config: Dict[str, str]) -> List[str]: + """ + Function to convert dictionary of nsight options into + nsight command line + + The function returns: + - List[str]: nsys profile cmd line split into list of str + """ + nsight_cmd = ["nsys", "profile"] + for option, option_val in nsight_config.items(): + # option standard based on + # https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html + if len(option) > 1: + nsight_cmd.append(f"--{option}={option_val}") + else: + nsight_cmd += [f"-{option}", option_val] + return nsight_cmd + + +class NsightPlugin(RuntimeEnvPlugin): + name = "_nsight" + + def __init__(self, resources_dir: str): + self.nsight_cmd = [] + + # replace this with better way to get logs dir + session_dir, runtime_dir = os.path.split(resources_dir) + self._nsight_dir = Path(session_dir) / "logs" / "nsight" + try_to_create_directory(self._nsight_dir) + + async def _check_nsight_script( + self, nsight_config: Dict[str, str] + ) -> Tuple[bool, str]: + """ + Function to validate if nsight_config is a valid nsight profile options + Args: + nsight_config: dictionary mapping nsight option to it's value + Returns: + a tuple consists of a boolean indicating if the nsight_config + is valid option and an error message if the nsight_config is invalid + """ + + # use empty as nsight report test filename + nsight_config_copy = copy.deepcopy(nsight_config) + nsight_config_copy["o"] = str(Path(self._nsight_dir) / "empty") + nsight_cmd = parse_nsight_config(nsight_config_copy) + try: + nsight_cmd = nsight_cmd + ["python", "-c", '""'] + process = await asyncio.create_subprocess_exec( + *nsight_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + error_msg = stderr.strip() if stderr.strip() != "" else stdout.strip() + + # cleanup test.nsys-rep file + clean_up_cmd = ["rm", f"{nsight_config_copy['o']}.nsys-rep"] + cleanup_process = await asyncio.create_subprocess_exec( + *clean_up_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, _ = await cleanup_process.communicate() + if process.returncode == 0: + return True, None + else: + return False, error_msg + except FileNotFoundError: + return False, ("nsight is not installed") + + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, + ) -> int: + nsight_config = runtime_env.nsight() + if not nsight_config: + return 0 + + if nsight_config and sys.platform != "linux": + raise RuntimeEnvSetupError( + "Nsight CLI is only available in Linux.\n" + "More information can be found in " + "https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html" + ) + + if isinstance(nsight_config, str): + if nsight_config == "default": + nsight_config = NSIGHT_DEFAULT_CONFIG + else: + raise RuntimeEnvSetupError( + f"Unsupported nsight config: {nsight_config}. " + "The supported config is 'default' or " + "Dictionary of nsight options" + ) + + is_valid_nsight_cmd, error_msg = await self._check_nsight_script(nsight_config) + if not is_valid_nsight_cmd: + logger.warning(error_msg) + raise RuntimeEnvSetupError( + "nsight profile failed to run with the following " + f"error message:\n {error_msg}" + ) + # add set output path to logs dir + nsight_config["o"] = str( + Path(self._nsight_dir) / nsight_config.get("o", NSIGHT_DEFAULT_CONFIG["o"]) + ) + + self.nsight_cmd = parse_nsight_config(nsight_config) + return 0 + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + logger.info("Running nsight profiler") + context.py_executable = " ".join(self.nsight_cmd) + " python" diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/packaging.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/packaging.py new file mode 100644 index 0000000000000000000000000000000000000000..87412e06abf13d3bd26fbf7fd9f543dfad8609e1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/packaging.py @@ -0,0 +1,952 @@ +import time +import asyncio +import hashlib +import logging +import os +import shutil +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, List, Optional, Tuple +from urllib.parse import urlparse +from zipfile import ZipFile + +from filelock import FileLock +from ray.util.annotations import DeveloperAPI + +from ray._private.ray_constants import ( + RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT, + RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR, + RAY_RUNTIME_ENV_IGNORE_GITIGNORE, +) +from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger +from ray._private.runtime_env.protocol import Protocol +from ray._private.thirdparty.pathspec import PathSpec +from ray.experimental.internal_kv import ( + _internal_kv_exists, + _internal_kv_put, + _pin_runtime_env_uri, +) + +default_logger = logging.getLogger(__name__) + +# If an individual file is beyond this size, print a warning. +FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MiB +# The size is bounded by the max gRPC message size. +# Keep in sync with max_grpc_message_size in ray_config_def.h. +GCS_STORAGE_MAX_SIZE = int( + os.environ.get("RAY_max_grpc_message_size", 500 * 1024 * 1024) +) +RAY_PKG_PREFIX = "_ray_pkg_" + +RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR = ( + "RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING" +) +RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR = ( + "RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING" +) + +# The name of the hidden top-level directory that appears when files are +# zipped on MacOS. +MAC_OS_ZIP_HIDDEN_DIR_NAME = "__MACOSX" + + +def _mib_string(num_bytes: float) -> str: + size_mib = float(num_bytes / 1024**2) + return f"{size_mib:.2f}MiB" + + +class _AsyncFileLock: + """Asyncio version used to prevent blocking event loop.""" + + def __init__(self, lock_file: str): + self.file = FileLock(lock_file) + + async def __aenter__(self): + while True: + try: + self.file.acquire(timeout=0) + return + except TimeoutError: + await asyncio.sleep(0.1) + + async def __aexit__(self, exc_type, exc, tb): + self.file.release() + + +def _xor_bytes(left: bytes, right: bytes) -> bytes: + if left and right: + return bytes(a ^ b for (a, b) in zip(left, right)) + return left or right + + +def _dir_travel( + path: Path, + excludes: List[Callable], + handler: Callable, + logger: Optional[logging.Logger] = default_logger, +): + """Travels the path recursively, calling the handler on each subpath. + + Respects excludes, which will be called to check if this path is skipped. + """ + e = _get_gitignore(path) + + if e is not None: + excludes.append(e) + + skip = any(e(path) for e in excludes) + if not skip: + try: + handler(path) + except Exception as e: + logger.error(f"Issue with path: {path}") + raise e + if path.is_dir(): + for sub_path in path.iterdir(): + _dir_travel(sub_path, excludes, handler, logger=logger) + + if e is not None: + excludes.pop() + + +def _hash_file_content_or_directory_name( + filepath: Path, + relative_path: Path, + logger: Optional[logging.Logger] = default_logger, +) -> bytes: + """Helper function to create hash of a single file or directory. + + This function hashes the path of the file or directory, + and if it's a file, then it hashes its content too. + """ + + BUF_SIZE = 4096 * 1024 + + sha1 = hashlib.sha1() + sha1.update(str(filepath.relative_to(relative_path)).encode()) + if not filepath.is_dir(): + try: + f = filepath.open("rb") + except Exception as e: + logger.debug( + f"Skipping contents of file {filepath} when calculating package hash " + f"because the file couldn't be opened: {e}" + ) + else: + try: + data = f.read(BUF_SIZE) + while len(data) != 0: + sha1.update(data) + data = f.read(BUF_SIZE) + finally: + f.close() + + return sha1.digest() + + +def _hash_file( + filepath: Path, + relative_path: Path, + logger: Optional[logging.Logger] = default_logger, +) -> bytes: + """Helper function to create hash of a single file. + + It hashes the path of the file and its content to create a hash value. + """ + file_hash = _hash_file_content_or_directory_name( + filepath, relative_path, logger=logger + ) + return _xor_bytes(file_hash, b"0" * 8) + + +def _hash_directory( + root: Path, + relative_path: Path, + excludes: Optional[Callable], + logger: Optional[logging.Logger] = default_logger, +) -> bytes: + """Helper function to create hash of a directory. + + It'll go through all the files in the directory and xor + hash(file_name, file_content) to create a hash value. + """ + hash_val = b"0" * 8 + + def handler(path: Path): + file_hash = _hash_file_content_or_directory_name( + path, relative_path, logger=logger + ) + nonlocal hash_val + hash_val = _xor_bytes(hash_val, file_hash) + + excludes = [] if excludes is None else [excludes] + _dir_travel(root, excludes, handler, logger=logger) + return hash_val + + +def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: + """ + Parse package uri into protocol and package name based on its format. + Note that the output of this function is not for handling actual IO, it's + only for setting up local directory folders by using package name as path. + + >>> parse_uri("https://test.com/file.zip") + (, 'https_test_com_file.zip') + + >>> parse_uri("https://test.com/file.whl") + (, 'file.whl') + + """ + uri = urlparse(pkg_uri) + try: + protocol = Protocol(uri.scheme) + except ValueError as e: + raise ValueError( + f'Invalid protocol for runtime_env URI "{pkg_uri}". ' + f"Supported protocols: {Protocol._member_names_}. Original error: {e}" + ) + + if protocol in Protocol.remote_protocols(): + if uri.path.endswith(".whl"): + # Don't modify the .whl filename. See + # https://peps.python.org/pep-0427/#file-name-convention + # for more information. + package_name = uri.path.split("/")[-1] + else: + package_name = f"{protocol.value}_{uri.netloc}{uri.path}" + + disallowed_chars = ["/", ":", "@", "+", " "] + for disallowed_char in disallowed_chars: + package_name = package_name.replace(disallowed_char, "_") + + # Remove all periods except the last, which is part of the + # file extension + package_name = package_name.replace(".", "_", package_name.count(".") - 1) + else: + package_name = uri.netloc + + return (protocol, package_name) + + +def is_zip_uri(uri: str) -> bool: + try: + protocol, path = parse_uri(uri) + except ValueError: + return False + + return Path(path).suffix == ".zip" + + +def is_whl_uri(uri: str) -> bool: + try: + _, path = parse_uri(uri) + except ValueError: + return False + + return Path(path).suffix == ".whl" + + +def is_jar_uri(uri: str) -> bool: + try: + _, path = parse_uri(uri) + except ValueError: + return False + + return Path(path).suffix == ".jar" + + +def _get_excludes(path: Path, excludes: List[str]) -> Callable: + path = path.absolute() + pathspec = PathSpec.from_lines("gitwildmatch", excludes) + + def match(p: Path): + path_str = str(p.absolute().relative_to(path)) + return pathspec.match_file(path_str) + + return match + + +def _get_gitignore(path: Path) -> Optional[Callable]: + """Returns a function that returns True if the path should be excluded. + + Returns None if there is no .gitignore file in the path, or if the + RAY_RUNTIME_ENV_IGNORE_GITIGNORE environment variable is set to 1. + + Args: + path: The path to the directory to check for a .gitignore file. + + Returns: + A function that returns True if the path should be excluded. + """ + ignore_gitignore = os.environ.get(RAY_RUNTIME_ENV_IGNORE_GITIGNORE, "0") == "1" + if ignore_gitignore: + return None + + path = path.absolute() + ignore_file = path / ".gitignore" + if ignore_file.is_file(): + with ignore_file.open("r") as f: + pathspec = PathSpec.from_lines("gitwildmatch", f.readlines()) + + def match(p: Path): + path_str = str(p.absolute().relative_to(path)) + return pathspec.match_file(path_str) + + return match + else: + return None + + +def pin_runtime_env_uri(uri: str, *, expiration_s: Optional[int] = None) -> None: + """Pin a reference to a runtime_env URI in the GCS on a timeout. + + This is used to avoid premature eviction in edge conditions for job + reference counting. See https://github.com/ray-project/ray/pull/24719. + + Packages are uploaded to GCS in order to be downloaded by a runtime env plugin + (e.g. working_dir, py_modules) after the job starts. + + This function adds a temporary reference to the package in the GCS to prevent + it from being deleted before the job starts. (See #23423 for the bug where + this happened.) + + If this reference didn't have an expiration, then if the script exited + (e.g. via Ctrl-C) before the job started, the reference would never be + removed, so the package would never be deleted. + """ + + if expiration_s is None: + expiration_s = int( + os.environ.get( + RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR, + RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT, + ) + ) + elif not isinstance(expiration_s, int): + raise ValueError(f"expiration_s must be an int, got {type(expiration_s)}.") + + if expiration_s < 0: + raise ValueError(f"expiration_s must be >= 0, got {expiration_s}.") + elif expiration_s > 0: + _pin_runtime_env_uri(uri, expiration_s=expiration_s) + + +def _store_package_in_gcs( + pkg_uri: str, + data: bytes, + logger: Optional[logging.Logger] = default_logger, +) -> int: + """Stores package data in the Global Control Store (GCS). + + Args: + pkg_uri: The GCS key to store the data in. + data: The serialized package's bytes to store in the GCS. + logger (Optional[logging.Logger]): The logger used by this function. + + Return: + int: Size of data + + Raises: + RuntimeError: If the upload to the GCS fails. + ValueError: If the data's size exceeds GCS_STORAGE_MAX_SIZE. + """ + + file_size = len(data) + size_str = _mib_string(file_size) + if len(data) >= GCS_STORAGE_MAX_SIZE: + raise ValueError( + f"Package size ({size_str}) exceeds the maximum size of " + f"{_mib_string(GCS_STORAGE_MAX_SIZE)}. You can exclude large " + "files using the 'excludes' option to the runtime_env or provide " + "a remote URI of a zip file using protocols such as 's3://', " + "'https://' and so on, refer to " + "https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#api-reference." # noqa + ) + + logger.info(f"Pushing file package '{pkg_uri}' ({size_str}) to Ray cluster...") + try: + if os.environ.get(RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR): + raise RuntimeError( + "Simulating failure to upload package for testing purposes." + ) + _internal_kv_put(pkg_uri, data) + except Exception as e: + raise RuntimeError( + "Failed to store package in the GCS.\n" + f" - GCS URI: {pkg_uri}\n" + f" - Package data ({size_str}): {data[:15]}...\n" + ) from e + logger.info(f"Successfully pushed file package '{pkg_uri}'.") + return len(data) + + +def _get_local_path(base_directory: str, pkg_uri: str) -> str: + _, pkg_name = parse_uri(pkg_uri) + return os.path.join(base_directory, pkg_name) + + +def _zip_files( + path_str: str, + excludes: List[str], + output_path: str, + include_parent_dir: bool = False, + logger: Optional[logging.Logger] = default_logger, +) -> None: + """Zip the target file or directory and write it to the output_path. + + path_str: The file or directory to zip. + excludes (List(str)): The directories or file to be excluded. + output_path: The output path for the zip file. + include_parent_dir: If true, includes the top-level directory as a + directory inside the zip file. + """ + pkg_file = Path(output_path).absolute() + with ZipFile(pkg_file, "w", strict_timestamps=False) as zip_handler: + # Put all files in the directory into the zip file. + file_path = Path(path_str).absolute() + dir_path = file_path + if file_path.is_file(): + dir_path = file_path.parent + + def handler(path: Path): + # Pack this path if it's an empty directory or it's a file. + if path.is_dir() and next(path.iterdir(), None) is None or path.is_file(): + file_size = path.stat().st_size + if file_size >= FILE_SIZE_WARNING: + logger.warning( + f"File {path} is very large " + f"({_mib_string(file_size)}). Consider adding this " + "file to the 'excludes' list to skip uploading it: " + "`ray.init(..., " + f"runtime_env={{'excludes': ['{path}']}})`" + ) + to_path = path.relative_to(dir_path) + if include_parent_dir: + to_path = dir_path.name / to_path + zip_handler.write(path, to_path) + + excludes = [_get_excludes(file_path, excludes)] + _dir_travel(file_path, excludes, handler, logger=logger) + + +def package_exists(pkg_uri: str) -> bool: + """Check whether the package with given URI exists or not. + + Args: + pkg_uri: The uri of the package + + Return: + True for package existing and False for not. + """ + protocol, pkg_name = parse_uri(pkg_uri) + if protocol == Protocol.GCS: + return _internal_kv_exists(pkg_uri) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + +def get_uri_for_package(package: Path) -> str: + """Get a content-addressable URI from a package's contents.""" + + if package.suffix == ".whl": + # Wheel file names include the Python package name, version + # and tags, so it is already effectively content-addressed. + return "{protocol}://{whl_filename}".format( + protocol=Protocol.GCS.value, whl_filename=package.name + ) + else: + hash_val = hashlib.sha1(package.read_bytes()).hexdigest() + return "{protocol}://{pkg_name}.zip".format( + protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val + ) + + +def get_uri_for_file(file: str) -> str: + """Get a content-addressable URI from a file's content. + + This function generates the name of the package by the file. + The final package name is _ray_pkg_.zip of this package, + where HASH_VAL is the hash value of the file. + For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip + + Examples: + + >>> get_uri_for_file("/my_file.py") # doctest: +SKIP + _ray_pkg_af2734982a741.zip + + Args: + file: The file. + + Returns: + URI (str) + + Raises: + ValueError if the file doesn't exist. + """ + filepath = Path(file).absolute() + if not filepath.exists() or not filepath.is_file(): + raise ValueError(f"File {filepath} must be an existing file") + + hash_val = _hash_file(filepath, filepath.parent) + + return "{protocol}://{pkg_name}.zip".format( + protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex() + ) + + +def get_uri_for_directory(directory: str, excludes: Optional[List[str]] = None) -> str: + """Get a content-addressable URI from a directory's contents. + + This function generates the name of the package by the directory. + It'll go through all the files in the directory and hash the contents + of the files to get the hash value of the package. + The final package name is _ray_pkg_.zip of this package. + For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip + + Examples: + + >>> get_uri_for_directory("/my_directory") # doctest: +SKIP + _ray_pkg_af2734982a741.zip + + Args: + directory: The directory. + excludes (list[str]): The dir or files that should be excluded. + + Returns: + URI (str) + + Raises: + ValueError if the directory doesn't exist. + """ + if excludes is None: + excludes = [] + + directory = Path(directory).absolute() + if not directory.exists() or not directory.is_dir(): + raise ValueError(f"directory {directory} must be an existing directory") + + hash_val = _hash_directory(directory, directory, _get_excludes(directory, excludes)) + + return "{protocol}://{pkg_name}.zip".format( + protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex() + ) + + +def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes) -> None: + """Upload a local package to GCS. + + Args: + pkg_uri: The URI of the package, e.g. gcs://my_package.zip + pkg_bytes: The data to be uploaded. + + Raises: + RuntimeError: If the upload fails. + ValueError: If the pkg_uri is a remote path or if the data's + size exceeds GCS_STORAGE_MAX_SIZE. + NotImplementedError: If the protocol of the URI is not supported. + + """ + protocol, pkg_name = parse_uri(pkg_uri) + if protocol == Protocol.GCS: + _store_package_in_gcs(pkg_uri, pkg_bytes) + elif protocol in Protocol.remote_protocols(): + raise ValueError( + "upload_package_to_gcs should not be called with a remote path." + ) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + +def create_package( + module_path: str, + target_path: Path, + include_parent_dir: bool = False, + excludes: Optional[List[str]] = None, + logger: Optional[logging.Logger] = default_logger, +): + if excludes is None: + excludes = [] + + if logger is None: + logger = default_logger + + if not target_path.exists(): + logger.info(f"Creating a file package for local module '{module_path}'.") + _zip_files( + module_path, + excludes, + str(target_path), + include_parent_dir=include_parent_dir, + logger=logger, + ) + + +def upload_package_if_needed( + pkg_uri: str, + base_directory: str, + module_path: str, + include_parent_dir: bool = False, + excludes: Optional[List[str]] = None, + logger: Optional[logging.Logger] = default_logger, +) -> bool: + """Upload the contents of the directory under the given URI. + + This will first create a temporary zip file under the passed + base_directory. + + If the package already exists in storage, this is a no-op. + + Args: + pkg_uri: URI of the package to upload. + base_directory: Directory where package files are stored. + module_path: The module to be uploaded, either a single .py file or a directory. + include_parent_dir: If true, includes the top-level directory as a + directory inside the zip file. + excludes: List specifying files to exclude. + + Raises: + RuntimeError: If the upload fails. + ValueError: If the pkg_uri is a remote path or if the data's + size exceeds GCS_STORAGE_MAX_SIZE. + NotImplementedError: If the protocol of the URI is not supported. + """ + if excludes is None: + excludes = [] + + if logger is None: + logger = default_logger + + pin_runtime_env_uri(pkg_uri) + + if package_exists(pkg_uri): + return False + + package_file = Path(_get_local_path(base_directory, pkg_uri)) + # Make the temporary zip file name unique so that it doesn't conflict with + # concurrent upload_package_if_needed calls with the same pkg_uri. + # See https://github.com/ray-project/ray/issues/47471. + package_file = package_file.with_name( + f"{time.time_ns()}_{os.getpid()}_{package_file.name}" + ) + + create_package( + module_path, + package_file, + include_parent_dir=include_parent_dir, + excludes=excludes, + ) + package_file_bytes = package_file.read_bytes() + # Remove the local file to avoid accumulating temporary zip files. + package_file.unlink() + + upload_package_to_gcs(pkg_uri, package_file_bytes) + + return True + + +def get_local_dir_from_uri(uri: str, base_directory: str) -> Path: + """Return the local directory corresponding to this URI.""" + pkg_file = Path(_get_local_path(base_directory, uri)) + local_dir = pkg_file.with_suffix("") + return local_dir + + +@DeveloperAPI +async def download_and_unpack_package( + pkg_uri: str, + base_directory: str, + gcs_aio_client: Optional["GcsAioClient"] = None, # noqa: F821 + logger: Optional[logging.Logger] = default_logger, + overwrite: bool = False, +) -> str: + """Download the package corresponding to this URI and unpack it if zipped. + + Will be written to a file or directory named {base_directory}/{uri}. + Returns the path to this file or directory. + + Args: + pkg_uri: URI of the package to download. + base_directory: Directory to use as the parent directory of the target + directory for the unpacked files. + gcs_aio_client: Client to use for downloading from the GCS. + logger: The logger to use. + overwrite: If True, overwrite the existing package. + + Returns: + Path to the local directory containing the unpacked package files. + + Raises: + IOError: If the download fails. + ImportError: If smart_open is not installed and a remote URI is used. + NotImplementedError: If the protocol of the URI is not supported. + ValueError: If the GCS client is not provided when downloading from GCS, + or if package URI is invalid. + + """ + pkg_file = Path(_get_local_path(base_directory, pkg_uri)) + if pkg_file.suffix == "": + raise ValueError( + f"Invalid package URI: {pkg_uri}." + "URI must have a file extension and the URI must be valid." + ) + + async with _AsyncFileLock(str(pkg_file) + ".lock"): + if logger is None: + logger = default_logger + + logger.debug(f"Fetching package for URI: {pkg_uri}") + + local_dir = get_local_dir_from_uri(pkg_uri, base_directory) + assert local_dir != pkg_file, "Invalid pkg_file!" + + download_package: bool = True + if local_dir.exists() and not overwrite: + download_package = False + assert local_dir.is_dir(), f"{local_dir} is not a directory" + elif local_dir.exists(): + logger.info(f"Removing {local_dir} with pkg_file {pkg_file}") + shutil.rmtree(local_dir) + + if download_package: + protocol, _ = parse_uri(pkg_uri) + logger.info( + f"Downloading package from {pkg_uri} to {pkg_file} " + f"with protocol {protocol}" + ) + if protocol == Protocol.GCS: + if gcs_aio_client is None: + raise ValueError( + "GCS client must be provided to download from GCS." + ) + + # Download package from the GCS. + code = await gcs_aio_client.internal_kv_get( + pkg_uri.encode(), namespace=None, timeout=None + ) + if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR): + code = None + if code is None: + raise IOError( + f"Failed to download runtime_env file package {pkg_uri} " + "from the GCS to the Ray worker node. The package may " + "have prematurely been deleted from the GCS due to a " + "long upload time or a problem with Ray. Try setting the " + "environment variable " + f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} " + " to a value larger than the upload time in seconds " + "(the default is " + f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). " + "If this fails, try re-running " + "after making any change to a file in the file package." + ) + code = code or b"" + pkg_file.write_bytes(code) + + if is_zip_uri(pkg_uri): + unzip_package( + package_path=pkg_file, + target_dir=local_dir, + remove_top_level_directory=False, + unlink_zip=True, + logger=logger, + ) + else: + return str(pkg_file) + elif protocol in Protocol.remote_protocols(): + protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file) + + if pkg_file.suffix in [".zip", ".jar"]: + unzip_package( + package_path=pkg_file, + target_dir=local_dir, + remove_top_level_directory=True, + unlink_zip=True, + logger=logger, + ) + elif pkg_file.suffix == ".whl": + return str(pkg_file) + else: + raise NotImplementedError( + f"Package format {pkg_file.suffix} is ", + "not supported for remote protocols", + ) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + return str(local_dir) + + +def get_top_level_dir_from_compressed_package(package_path: str): + """ + If compressed package at package_path contains a single top-level + directory, returns the name of the top-level directory. Otherwise, + returns None. + + Ignores a second top-level directory if it is named __MACOSX. + """ + + package_zip = ZipFile(package_path, "r") + top_level_directory = None + + def is_top_level_file(file_name): + return "/" not in file_name + + def base_dir_name(file_name): + return file_name.split("/")[0] + + for file_name in package_zip.namelist(): + if top_level_directory is None: + # Cache the top_level_directory name when checking + # the first file in the zipped package + if is_top_level_file(file_name): + return None + else: + # Top-level directory, or non-top-level file or directory + dir_name = base_dir_name(file_name) + if dir_name == MAC_OS_ZIP_HIDDEN_DIR_NAME: + continue + top_level_directory = dir_name + else: + # Confirm that all other files + # belong to the same top_level_directory + if is_top_level_file(file_name) or base_dir_name(file_name) not in [ + top_level_directory, + MAC_OS_ZIP_HIDDEN_DIR_NAME, + ]: + return None + + return top_level_directory + + +def remove_dir_from_filepaths(base_dir: str, rdir: str): + """ + base_dir: String path of the directory containing rdir + rdir: String path of directory relative to base_dir whose contents should + be moved to its base_dir, its parent directory + + Removes rdir from the filepaths of all files and directories inside it. + In other words, moves all the files inside rdir to the directory that + contains rdir. Assumes base_dir's contents and rdir's contents have no + name conflicts. + """ + + # Move rdir to a temporary directory, so its contents can be moved to + # base_dir without any name conflicts + with TemporaryDirectory() as tmp_dir: + # shutil.move() is used instead of os.rename() in case rdir and tmp_dir + # are located on separate file systems + shutil.move(os.path.join(base_dir, rdir), os.path.join(tmp_dir, rdir)) + + # Shift children out of rdir and into base_dir + rdir_children = os.listdir(os.path.join(tmp_dir, rdir)) + for child in rdir_children: + shutil.move( + os.path.join(tmp_dir, rdir, child), os.path.join(base_dir, child) + ) + + +def unzip_package( + package_path: str, + target_dir: str, + remove_top_level_directory: bool, + unlink_zip: bool, + logger: Optional[logging.Logger] = default_logger, +) -> None: + """ + Unzip the compressed package contained at package_path to target_dir. + + If remove_top_level_directory is True and the top level consists of a + a single directory (or possibly also a second hidden directory named + __MACOSX at the top level arising from macOS's zip command), the function + will automatically remove the top-level directory and store the contents + directly in target_dir. + + Otherwise, if remove_top_level_directory is False or if the top level + consists of multiple files or directories (not counting __MACOS), + the zip contents will be stored in target_dir. + + Args: + package_path: String path of the compressed package to unzip. + target_dir: String path of the directory to store the unzipped contents. + remove_top_level_directory: Whether to remove the top-level directory + from the zip contents. + unlink_zip: Whether to unlink the zip file stored at package_path. + logger: Optional logger to use for logging. + + """ + try: + os.mkdir(target_dir) + except FileExistsError: + logger.info(f"Directory at {target_dir} already exists") + + logger.debug(f"Unpacking {package_path} to {target_dir}") + + with ZipFile(str(package_path), "r") as zip_ref: + zip_ref.extractall(target_dir) + if remove_top_level_directory: + top_level_directory = get_top_level_dir_from_compressed_package(package_path) + if top_level_directory is not None: + # Remove __MACOSX directory if it exists + macos_dir = os.path.join(target_dir, MAC_OS_ZIP_HIDDEN_DIR_NAME) + if os.path.isdir(macos_dir): + shutil.rmtree(macos_dir) + + remove_dir_from_filepaths(target_dir, top_level_directory) + + if unlink_zip: + Path(package_path).unlink() + + +def delete_package(pkg_uri: str, base_directory: str) -> Tuple[bool, int]: + """Deletes a specific URI from the local filesystem. + + Args: + pkg_uri: URI to delete. + + Returns: + bool: True if the URI was successfully deleted, else False. + """ + + deleted = False + path = Path(_get_local_path(base_directory, pkg_uri)) + with FileLock(str(path) + ".lock"): + path = path.with_suffix("") + if path.exists(): + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(str(path)) + else: + path.unlink() + deleted = True + + return deleted + + +async def install_wheel_package( + wheel_uri: str, + target_dir: str, + logger: Optional[logging.Logger] = default_logger, +) -> None: + """Install packages in the wheel URI, and then delete the local wheel file.""" + + pip_install_cmd = [ + "pip", + "install", + wheel_uri, + f"--target={target_dir}", + ] + + logger.info("Running py_modules wheel install command: %s", str(pip_install_cmd)) + try: + # TODO(architkulkarni): Use `await check_output_cmd` or similar. + exit_code, output = exec_cmd_stream_to_logger(pip_install_cmd, logger) + finally: + if Path(wheel_uri).exists(): + Path(wheel_uri).unlink() + + if exit_code != 0: + if Path(target_dir).exists(): + Path(target_dir).unlink() + raise RuntimeError( + f"Failed to install py_modules wheel {wheel_uri}" + f"to {target_dir}:\n{output}" + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/pip.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/pip.py new file mode 100644 index 0000000000000000000000000000000000000000..e3559721239aa159ea6820996d494fe68c8d6b15 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/pip.py @@ -0,0 +1,338 @@ +import asyncio +import hashlib +import json +import logging +import os +import shutil +import sys +from typing import Dict, List, Optional +from asyncio import create_task, get_running_loop + +from ray._private.runtime_env import virtualenv_utils +from ray._private.runtime_env import dependency_utils +from ray._private.runtime_env.packaging import Protocol, parse_uri +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.runtime_env.utils import check_output_cmd +from ray._private.utils import get_directory_size_bytes, try_to_create_directory + +default_logger = logging.getLogger(__name__) + + +def _get_pip_hash(pip_dict: Dict) -> str: + serialized_pip_spec = json.dumps(pip_dict, sort_keys=True) + hash_val = hashlib.sha1(serialized_pip_spec.encode("utf-8")).hexdigest() + return hash_val + + +def get_uri(runtime_env: Dict) -> Optional[str]: + """Return `"pip://"`, or None if no GC required.""" + pip = runtime_env.get("pip") + if pip is not None: + if isinstance(pip, dict): + uri = "pip://" + _get_pip_hash(pip_dict=pip) + elif isinstance(pip, list): + uri = "pip://" + _get_pip_hash(pip_dict=dict(packages=pip)) + else: + raise TypeError( + "pip field received by RuntimeEnvAgent must be " + f"list or dict, not {type(pip).__name__}." + ) + else: + uri = None + return uri + + +class PipProcessor: + def __init__( + self, + target_dir: str, + runtime_env: "RuntimeEnv", # noqa: F821 + logger: Optional[logging.Logger] = default_logger, + ): + try: + import virtualenv # noqa: F401 ensure virtualenv exists. + except ImportError: + raise RuntimeError( + f"Please install virtualenv " + f"`{sys.executable} -m pip install virtualenv`" + f"to enable pip runtime env." + ) + logger.debug("Setting up pip for runtime_env: %s", runtime_env) + self._target_dir = target_dir + self._runtime_env = runtime_env + self._logger = logger + + self._pip_config = self._runtime_env.pip_config() + self._pip_env = os.environ.copy() + self._pip_env.update(self._runtime_env.env_vars()) + + @classmethod + async def _ensure_pip_version( + cls, + path: str, + pip_version: Optional[str], + cwd: str, + pip_env: Dict, + logger: logging.Logger, + ): + """Run the pip command to reinstall pip to the specified version.""" + if not pip_version: + return + + python = virtualenv_utils.get_virtualenv_python(path) + # Ensure pip version. + pip_reinstall_cmd = [ + python, + "-m", + "pip", + "install", + "--disable-pip-version-check", + f"pip{pip_version}", + ] + logger.info("Installing pip with version %s", pip_version) + + await check_output_cmd(pip_reinstall_cmd, logger=logger, cwd=cwd, env=pip_env) + + async def _pip_check( + self, + path: str, + pip_check: bool, + cwd: str, + pip_env: Dict, + logger: logging.Logger, + ): + """Run the pip check command to check python dependency conflicts. + If exists conflicts, the exit code of pip check command will be non-zero. + """ + if not pip_check: + logger.info("Skip pip check.") + return + python = virtualenv_utils.get_virtualenv_python(path) + + await check_output_cmd( + [python, "-m", "pip", "check", "--disable-pip-version-check"], + logger=logger, + cwd=cwd, + env=pip_env, + ) + + logger.info("Pip check on %s successfully.", path) + + @classmethod + async def _install_pip_packages( + cls, + path: str, + pip_packages: List[str], + cwd: str, + pip_env: Dict, + logger: logging.Logger, + ): + virtualenv_path = virtualenv_utils.get_virtualenv_path(path) + python = virtualenv_utils.get_virtualenv_python(path) + # TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ... + pip_requirements_file = dependency_utils.get_requirements_file( + path, pip_packages + ) + + # Avoid blocking the event loop. + loop = get_running_loop() + await loop.run_in_executor( + None, + dependency_utils.gen_requirements_txt, + pip_requirements_file, + pip_packages, + ) + + # pip options + # + # --disable-pip-version-check + # Don't periodically check PyPI to determine whether a new version + # of pip is available for download. + # + # --no-cache-dir + # Disable the cache, the pip runtime env is a one-time installation, + # and we don't need to handle the pip cache broken. + pip_install_cmd = [ + python, + "-m", + "pip", + "install", + "--disable-pip-version-check", + "--no-cache-dir", + "-r", + pip_requirements_file, + ] + logger.info("Installing python requirements to %s", virtualenv_path) + + await check_output_cmd(pip_install_cmd, logger=logger, cwd=cwd, env=pip_env) + + async def _run(self): + path = self._target_dir + logger = self._logger + pip_packages = self._pip_config["packages"] + # We create an empty directory for exec cmd so that the cmd will + # run more stable. e.g. if cwd has ray, then checking ray will + # look up ray in cwd instead of site packages. + exec_cwd = os.path.join(path, "exec_cwd") + os.makedirs(exec_cwd, exist_ok=True) + try: + await virtualenv_utils.create_or_get_virtualenv(path, exec_cwd, logger) + python = virtualenv_utils.get_virtualenv_python(path) + async with dependency_utils.check_ray(python, exec_cwd, logger): + # Ensure pip version. + await self._ensure_pip_version( + path, + self._pip_config.get("pip_version", None), + exec_cwd, + self._pip_env, + logger, + ) + # Install pip packages. + await self._install_pip_packages( + path, + pip_packages, + exec_cwd, + self._pip_env, + logger, + ) + # Check python environment for conflicts. + await self._pip_check( + path, + self._pip_config.get("pip_check", False), + exec_cwd, + self._pip_env, + logger, + ) + except Exception: + logger.info("Delete incomplete virtualenv: %s", path) + shutil.rmtree(path, ignore_errors=True) + logger.exception("Failed to install pip packages.") + raise + + def __await__(self): + return self._run().__await__() + + +class PipPlugin(RuntimeEnvPlugin): + name = "pip" + + def __init__(self, resources_dir: str): + self._pip_resources_dir = os.path.join(resources_dir, "pip") + self._creating_task = {} + # Maps a URI to a lock that is used to prevent multiple concurrent + # installs of the same virtualenv, see #24513 + self._create_locks: Dict[str, asyncio.Lock] = {} + # Key: created hashes. Value: size of the pip dir. + self._created_hash_bytes: Dict[str, int] = {} + try_to_create_directory(self._pip_resources_dir) + + def _get_path_from_hash(self, hash_val: str) -> str: + """Generate a path from the hash of a pip spec. + + Example output: + /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources + /pip/ray-9a7972c3a75f55e976e620484f58410c920db091 + """ + return os.path.join(self._pip_resources_dir, hash_val) + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + """Return the pip URI from the RuntimeEnv if it exists, else return [].""" + pip_uri = runtime_env.pip_uri() + if pip_uri: + return [pip_uri] + return [] + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + logger.info("Got request to delete pip URI %s", uri) + protocol, hash_val = parse_uri(uri) + if protocol != Protocol.PIP: + raise ValueError( + "PipPlugin can only delete URIs with protocol " + f"pip. Received protocol {protocol}, URI {uri}" + ) + + # Cancel running create task. + task = self._creating_task.pop(hash_val, None) + if task is not None: + task.cancel() + + del self._created_hash_bytes[hash_val] + + pip_env_path = self._get_path_from_hash(hash_val) + local_dir_size = get_directory_size_bytes(pip_env_path) + del self._create_locks[uri] + try: + shutil.rmtree(pip_env_path) + except OSError as e: + logger.warning(f"Error when deleting pip env {pip_env_path}: {str(e)}") + return 0 + + return local_dir_size + + async def create( + self, + uri: str, + runtime_env: "RuntimeEnv", # noqa: F821 + context: "RuntimeEnvContext", # noqa: F821 + logger: Optional[logging.Logger] = default_logger, + ) -> int: + if not runtime_env.has_pip(): + return 0 + + protocol, hash_val = parse_uri(uri) + target_dir = self._get_path_from_hash(hash_val) + + async def _create_for_hash(): + await PipProcessor( + target_dir, + runtime_env, + logger, + ) + + loop = get_running_loop() + return await loop.run_in_executor( + None, get_directory_size_bytes, target_dir + ) + + if uri not in self._create_locks: + # async lock to prevent the same virtualenv being concurrently installed + self._create_locks[uri] = asyncio.Lock() + + async with self._create_locks[uri]: + if hash_val in self._created_hash_bytes: + return self._created_hash_bytes[hash_val] + self._creating_task[hash_val] = task = create_task(_create_for_hash()) + task.add_done_callback(lambda _: self._creating_task.pop(hash_val, None)) + pip_dir_bytes = await task + self._created_hash_bytes[hash_val] = pip_dir_bytes + return pip_dir_bytes + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: "RuntimeEnvContext", # noqa: F821 + logger: logging.Logger = default_logger, + ): + if not runtime_env.has_pip(): + return + # PipPlugin only uses a single URI. + uri = uris[0] + # Update py_executable. + protocol, hash_val = parse_uri(uri) + target_dir = self._get_path_from_hash(hash_val) + virtualenv_python = virtualenv_utils.get_virtualenv_python(target_dir) + + if not os.path.exists(virtualenv_python): + raise ValueError( + f"Local directory {target_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "installing the runtime_env `pip` packages." + ) + context.py_executable = virtualenv_python + context.command_prefix += virtualenv_utils.get_virtualenv_activate_command( + target_dir + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e03a507b596429ba56d38657c1c3339cb002b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin.py @@ -0,0 +1,264 @@ +import logging +import os +import json +from abc import ABC +from typing import List, Dict, Optional, Any, Type + +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.uri_cache import URICache +from ray._private.runtime_env.constants import ( + RAY_RUNTIME_ENV_PLUGINS_ENV_VAR, + RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY, + RAY_RUNTIME_ENV_CLASS_FIELD_NAME, + RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME, + RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY, + RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY, +) +from ray.util.annotations import DeveloperAPI +from ray._private.utils import import_attr + +default_logger = logging.getLogger(__name__) + + +@DeveloperAPI +class RuntimeEnvPlugin(ABC): + """Abstract base class for runtime environment plugins.""" + + name: str = None + priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY + + @staticmethod + def validate(runtime_env_dict: dict) -> None: + """Validate user entry for this plugin. + + The method is invoked upon installation of runtime env. + + Args: + runtime_env_dict: the user-supplied runtime environment dict. + + Raises: + ValueError: if the validation fails. + """ + pass + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + return [] + + async def create( + self, + uri: Optional[str], + runtime_env, + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: + """Create and install the runtime environment. + + Gets called in the runtime env agent at install time. The URI can be + used as a caching mechanism. + + Args: + uri: A URI uniquely describing this resource. + runtime_env: The RuntimeEnv object. + context: auxiliary information supplied by Ray. + logger: A logger to log messages during the context modification. + + Returns: + the disk space taken up by this plugin installation for this + environment. e.g. for working_dir, this downloads the files to the + local node. + """ + return 0 + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> None: + """Modify context to change worker startup behavior. + + For example, you can use this to preprend "cd " command to worker + startup, or add new environment variables. + + Args: + uris: The URIs used by this resource. + runtime_env: The RuntimeEnv object. + context: Auxiliary information supplied by Ray. + logger: A logger to log messages during the context modification. + """ + return + + def delete_uri(self, uri: str, logger: logging.Logger) -> float: + """Delete the the runtime environment given uri. + + Args: + uri: a URI uniquely describing this resource. + + Returns: + the amount of space reclaimed by the deletion. + """ + return 0 + + +class PluginSetupContext: + def __init__( + self, + name: str, + class_instance: RuntimeEnvPlugin, + priority: int, + uri_cache: URICache, + ): + self.name = name + self.class_instance = class_instance + self.priority = priority + self.uri_cache = uri_cache + + +class RuntimeEnvPluginManager: + """This manager is used to load plugins in runtime env agent.""" + + def __init__(self): + self.plugins: Dict[str, PluginSetupContext] = {} + plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR) + if plugin_config_str: + plugin_configs = json.loads(plugin_config_str) + self.load_plugins(plugin_configs) + + def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None: + if not issubclass(plugin_class, RuntimeEnvPlugin): + raise RuntimeError( + f"Invalid runtime env plugin class {plugin_class}. " + "The plugin class must inherit " + "ray._private.runtime_env.plugin.RuntimeEnvPlugin." + ) + if not plugin_class.name: + raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.") + if plugin_class.name in self.plugins: + raise RuntimeError( + f"The name of runtime env plugin {plugin_class} conflicts " + f"with {self.plugins[plugin_class.name]}.", + ) + + def validate_priority(self, priority: Any) -> None: + if ( + not isinstance(priority, int) + or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY + or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY + ): + raise RuntimeError( + f"Invalid runtime env priority {priority}, " + "it should be an integer between " + f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} " + f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}." + ) + + def load_plugins(self, plugin_configs: List[Dict]) -> None: + """Load runtime env plugins and create URI caches for them.""" + for plugin_config in plugin_configs: + if ( + not isinstance(plugin_config, dict) + or RAY_RUNTIME_ENV_CLASS_FIELD_NAME not in plugin_config + ): + raise RuntimeError( + f"Invalid runtime env plugin config {plugin_config}, " + "it should be a object which contains the " + f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field." + ) + plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME]) + self.validate_plugin_class(plugin_class) + + # The priority should be an integer between 0 and 100. + # The default priority is 10. A smaller number indicates a + # higher priority and the plugin will be set up first. + if RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME in plugin_config: + priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME] + else: + priority = plugin_class.priority + self.validate_priority(priority) + + class_instance = plugin_class() + self.plugins[plugin_class.name] = PluginSetupContext( + plugin_class.name, + class_instance, + priority, + self.create_uri_cache_for_plugin(class_instance), + ) + + def add_plugin(self, plugin: RuntimeEnvPlugin) -> None: + """Add a plugin to the manager and create a URI cache for it. + + Args: + plugin: The class instance of the plugin. + """ + plugin_class = type(plugin) + self.validate_plugin_class(plugin_class) + self.validate_priority(plugin_class.priority) + self.plugins[plugin_class.name] = PluginSetupContext( + plugin_class.name, + plugin, + plugin_class.priority, + self.create_uri_cache_for_plugin(plugin), + ) + + def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache: + """Create a URI cache for a plugin. + + Args: + plugin_name: The name of the plugin. + + Returns: + The created URI cache for the plugin. + """ + # Set the max size for the cache. Defaults to 10 GB. + cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper() + cache_size_bytes = int( + (1024**3) * float(os.environ.get(cache_size_env_var, 10)) + ) + return URICache(plugin.delete_uri, cache_size_bytes) + + def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]: + """Get the sorted plugin setup contexts, sorted by increasing priority. + + Returns: + The sorted plugin setup contexts. + """ + return sorted(self.plugins.values(), key=lambda x: x.priority) + + +async def create_for_plugin_if_needed( + runtime_env: "RuntimeEnv", # noqa: F821 + plugin: RuntimeEnvPlugin, + uri_cache: URICache, + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, +): + """Set up the environment using the plugin if not already set up and cached.""" + if plugin.name not in runtime_env or runtime_env[plugin.name] is None: + return + + plugin.validate(runtime_env) + + uris = plugin.get_uris(runtime_env) + + if not uris: + logger.debug( + f"No URIs for runtime env plugin {plugin.name}; " + "create always without checking the cache." + ) + await plugin.create(None, runtime_env, context, logger=logger) + + for uri in uris: + if uri not in uri_cache: + logger.debug(f"Cache miss for URI {uri}.") + size_bytes = await plugin.create(uri, runtime_env, context, logger=logger) + uri_cache.add(uri, size_bytes, logger=logger) + else: + logger.info( + f"Runtime env {plugin.name} {uri} is already installed " + "and will be reused. Search " + "all runtime_env_setup-*.log to find the corresponding setup log." + ) + uri_cache.mark_used(uri, logger=logger) + + plugin.modify_context(uris, runtime_env, context, logger) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin_schema_manager.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin_schema_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4cb3e90fe052c1c3500df0c7e337c244d6d106 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/plugin_schema_manager.py @@ -0,0 +1,95 @@ +import os +import logging +import jsonschema +from typing import List +import json +from ray._private.runtime_env.constants import ( + RAY_RUNTIME_ENV_PLUGIN_SCHEMAS_ENV_VAR, + RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX, +) + +logger = logging.getLogger(__name__) + + +class RuntimeEnvPluginSchemaManager: + """This manager is used to load plugin json schemas.""" + + default_schema_path = os.path.join( + os.path.dirname(__file__), "../../runtime_env/schemas" + ) + schemas = {} + loaded = False + + @classmethod + def _load_schemas(cls, schema_paths: List[str]): + for schema_path in schema_paths: + try: + with open(schema_path) as f: + schema = json.load(f) + except json.decoder.JSONDecodeError: + logger.error("Invalid runtime env schema %s, skip it.", schema_path) + continue + except OSError: + logger.error("Cannot open runtime env schema %s, skip it.", schema_path) + continue + if "title" not in schema: + logger.error( + "No valid title in runtime env schema %s, skip it.", schema_path + ) + continue + if schema["title"] in cls.schemas: + logger.error( + "The 'title' of runtime env schema %s conflicts with %s, skip it.", + schema_path, + cls.schemas[schema["title"]], + ) + continue + cls.schemas[schema["title"]] = schema + + @classmethod + def _load_default_schemas(cls): + schema_json_files = list() + for root, _, files in os.walk(cls.default_schema_path): + for f in files: + if f.endswith(RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX): + schema_json_files.append(os.path.join(root, f)) + logger.debug( + f"Loading the default runtime env schemas: {schema_json_files}." + ) + cls._load_schemas(schema_json_files) + + @classmethod + def _load_schemas_from_env_var(cls): + # The format of env var: + # "/path/to/env_1_schema.json,/path/to/env_2_schema.json,/path/to/schemas_dir/" + schema_paths = os.environ.get(RAY_RUNTIME_ENV_PLUGIN_SCHEMAS_ENV_VAR) + if schema_paths: + schema_json_files = list() + for path in schema_paths.split(","): + if path.endswith(RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX): + schema_json_files.append(path) + elif os.path.isdir(path): + for root, _, files in os.walk(path): + for f in files: + if f.endswith(RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX): + schema_json_files.append(os.path.join(root, f)) + logger.info( + f"Loading the runtime env schemas from env var: {schema_json_files}." + ) + cls._load_schemas(schema_json_files) + + @classmethod + def validate(cls, name, instance): + if not cls.loaded: + # Load the schemas lazily. + cls._load_default_schemas() + cls._load_schemas_from_env_var() + cls.loaded = True + # if no schema matches, skip the validation. + if name in cls.schemas: + jsonschema.validate(instance=instance, schema=cls.schemas[name]) + + @classmethod + def clear(cls): + cls.schemas.clear() + cls.loaded = False diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/protocol.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8e3e3440c5bfc8a1f40b72780898d9a224dfd9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/protocol.py @@ -0,0 +1,109 @@ +import enum +from ray._private.runtime_env.default_impl import get_protocols_provider + + +class ProtocolsProvider: + _MISSING_DEPENDENCIES_WARNING = ( + "Note that these must be preinstalled " + "on all nodes in the Ray cluster; it is not " + "sufficient to install them in the runtime_env." + ) + + @classmethod + def get_protocols(cls): + return { + # For packages dynamically uploaded and managed by the GCS. + "gcs", + # For conda environments installed locally on each node. + "conda", + # For pip environments installed locally on each node. + "pip", + # For uv environments install locally on each node. + "uv", + # Remote https path, assumes everything packed in one zip file. + "https", + # Remote s3 path, assumes everything packed in one zip file. + "s3", + # Remote google storage path, assumes everything packed in one zip file. + "gs", + # File storage path, assumes everything packed in one zip file. + "file", + } + + @classmethod + def get_remote_protocols(cls): + return {"https", "s3", "gs", "file"} + + @classmethod + def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str): + """Download file from remote URI to dest file""" + assert protocol in cls.get_remote_protocols() + + tp = None + + if protocol == "file": + source_uri = source_uri[len("file://") :] + + def open_file(uri, mode, *, transport_params=None): + return open(uri, mode) + + elif protocol == "s3": + try: + import boto3 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open[s3]` " + "to fetch URIs in s3 bucket. " + cls._MISSING_DEPENDENCIES_WARNING + ) + tp = {"client": boto3.client("s3")} + elif protocol == "gs": + try: + from google.cloud import storage # noqa: F401 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open[gcs]` " + "to fetch URIs in Google Cloud Storage bucket." + + cls._MISSING_DEPENDENCIES_WARNING + ) + else: + try: + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` " + f"to fetch {protocol.upper()} URIs. " + + cls._MISSING_DEPENDENCIES_WARNING + ) + + with open_file(source_uri, "rb", transport_params=tp) as fin: + with open_file(dest_file, "wb") as fout: + fout.write(fin.read()) + + +_protocols_provider = get_protocols_provider() + +Protocol = enum.Enum( + "Protocol", + {protocol.upper(): protocol for protocol in _protocols_provider.get_protocols()}, +) + + +@classmethod +def _remote_protocols(cls): + # Returns a list of protocols that support remote storage + # These protocols should only be used with paths that end in ".zip" or ".whl" + return [ + cls[protocol.upper()] for protocol in _protocols_provider.get_remote_protocols() + ] + + +Protocol.remote_protocols = _remote_protocols + + +def _download_remote_uri(self, source_uri, dest_file): + return _protocols_provider.download_remote_uri(self.value, source_uri, dest_file) + + +Protocol.download_remote_uri = _download_remote_uri diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/py_modules.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/py_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1066cbe6126dfa146a2fc7aa4ec24cd6fec62bc5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/py_modules.py @@ -0,0 +1,233 @@ +import logging +import os +from pathlib import Path +from types import ModuleType +from typing import Any, Dict, List, Optional + +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.packaging import ( + Protocol, + delete_package, + download_and_unpack_package, + get_local_dir_from_uri, + get_uri_for_file, + get_uri_for_directory, + get_uri_for_package, + install_wheel_package, + is_whl_uri, + package_exists, + parse_uri, + upload_package_if_needed, + upload_package_to_gcs, +) +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.runtime_env.working_dir import set_pythonpath_in_context +from ray._private.utils import get_directory_size_bytes, try_to_create_directory +from ray.exceptions import RuntimeEnvSetupError + +default_logger = logging.getLogger(__name__) + + +def _check_is_uri(s: str) -> bool: + try: + protocol, path = parse_uri(s) + except ValueError: + protocol, path = None, None + + if ( + protocol in Protocol.remote_protocols() + and not path.endswith(".zip") + and not path.endswith(".whl") + ): + raise ValueError("Only .zip or .whl files supported for remote URIs.") + + return protocol is not None + + +def upload_py_modules_if_needed( + runtime_env: Dict[str, Any], + scratch_dir: Optional[str] = os.getcwd(), + logger: Optional[logging.Logger] = default_logger, + upload_fn=None, +) -> Dict[str, Any]: + """Uploads the entries in py_modules and replaces them with a list of URIs. + + For each entry that is already a URI, this is a no-op. + """ + py_modules = runtime_env.get("py_modules") + if py_modules is None: + return runtime_env + + if not isinstance(py_modules, list): + raise TypeError( + "py_modules must be a List of local paths, imported modules, or " + f"URIs, got {type(py_modules)}." + ) + + py_modules_uris = [] + for module in py_modules: + if isinstance(module, str): + # module_path is a local path or a URI. + module_path = module + elif isinstance(module, Path): + module_path = str(module) + elif isinstance(module, ModuleType): + if not hasattr(module, "__path__"): + # This is a single-file module. + module_path = module.__file__ + else: + # NOTE(edoakes): Python allows some installed Python packages to + # be split into multiple directories. We could probably handle + # this, but it seems tricky & uncommon. If it's a problem for + # users, we can add this support on demand. + if len(module.__path__) > 1: + raise ValueError( + "py_modules only supports modules whose __path__" + " has length 1 or those who are single-file." + ) + [module_path] = module.__path__ + else: + raise TypeError( + "py_modules must be a list of file paths, URIs, " + f"or imported modules, got {type(module)}." + ) + + if _check_is_uri(module_path): + module_uri = module_path + else: + # module_path is a local path. + if Path(module_path).is_dir() or Path(module_path).suffix == ".py": + is_dir = Path(module_path).is_dir() + excludes = runtime_env.get("excludes", None) + if is_dir: + module_uri = get_uri_for_directory(module_path, excludes=excludes) + else: + module_uri = get_uri_for_file(module_path) + if upload_fn is None: + try: + upload_package_if_needed( + module_uri, + scratch_dir, + module_path, + excludes=excludes, + include_parent_dir=is_dir, + logger=logger, + ) + except Exception as e: + from ray.util.spark.utils import is_in_databricks_runtime + + if is_in_databricks_runtime(): + raise RuntimeEnvSetupError( + f"Failed to upload module {module_path} to the Ray " + f"cluster, please ensure there are only files under " + f"the module path, notebooks under the path are " + f"not allowed, original exception: {e}" + ) from e + raise RuntimeEnvSetupError( + f"Failed to upload module {module_path} to the Ray " + f"cluster: {e}" + ) from e + else: + upload_fn(module_path, excludes=excludes) + elif Path(module_path).suffix == ".whl": + module_uri = get_uri_for_package(Path(module_path)) + if upload_fn is None: + if not package_exists(module_uri): + try: + upload_package_to_gcs( + module_uri, Path(module_path).read_bytes() + ) + except Exception as e: + raise RuntimeEnvSetupError( + f"Failed to upload {module_path} to the Ray " + f"cluster: {e}" + ) from e + else: + upload_fn(module_path, excludes=None, is_file=True) + else: + raise ValueError( + "py_modules entry must be a .py file, " + "a directory, or a .whl file; " + f"got {module_path}" + ) + + py_modules_uris.append(module_uri) + + # TODO(architkulkarni): Expose a single URI for py_modules. This plugin + # should internally handle the "sub-URIs", the individual modules. + + runtime_env["py_modules"] = py_modules_uris + return runtime_env + + +class PyModulesPlugin(RuntimeEnvPlugin): + + name = "py_modules" + + def __init__( + self, resources_dir: str, gcs_aio_client: "GcsAioClient" # noqa: F821 + ): + self._resources_dir = os.path.join(resources_dir, "py_modules_files") + self._gcs_aio_client = gcs_aio_client + try_to_create_directory(self._resources_dir) + + def _get_local_dir_from_uri(self, uri: str): + return get_local_dir_from_uri(uri, self._resources_dir) + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + logger.info("Got request to delete pymodule URI %s", uri) + local_dir = get_local_dir_from_uri(uri, self._resources_dir) + local_dir_size = get_directory_size_bytes(local_dir) + + deleted = delete_package(uri, self._resources_dir) + if not deleted: + logger.warning(f"Tried to delete nonexistent URI: {uri}.") + return 0 + + return local_dir_size + + def get_uris(self, runtime_env) -> List[str]: + return runtime_env.py_modules() + + async def create( + self, + uri: str, + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ) -> int: + + module_dir = await download_and_unpack_package( + uri, self._resources_dir, self._gcs_aio_client, logger=logger + ) + + if is_whl_uri(uri): + wheel_uri = module_dir + module_dir = self._get_local_dir_from_uri(uri) + await install_wheel_package( + wheel_uri=wheel_uri, target_dir=module_dir, logger=logger + ) + + return get_directory_size_bytes(module_dir) + + def modify_context( + self, + uris: List[str], + runtime_env_dict: Dict, + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + module_dirs = [] + for uri in uris: + module_dir = self._get_local_dir_from_uri(uri) + if not module_dir.exists(): + raise ValueError( + f"Local directory {module_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "downloading, unpacking or installing the py_modules files." + ) + module_dirs.append(str(module_dir)) + set_pythonpath_in_context(os.pathsep.join(module_dirs), context) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/setup_hook.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/setup_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc45d8f4235ace9d72fbac360cac4fc272db367 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/setup_hook.py @@ -0,0 +1,194 @@ +import traceback +import logging +import base64 +import os + +from typing import Dict, Any, Callable, Union, Optional + +import ray +import ray._private.ray_constants as ray_constants +from ray._private.storage import _load_class +import ray.cloudpickle as pickle +from ray.runtime_env import RuntimeEnv + +logger = logging.getLogger(__name__) + +RUNTIME_ENV_FUNC_IDENTIFIER = "ray_runtime_env_func::" + + +def get_import_export_timeout(): + return int( + os.environ.get( + ray_constants.RAY_WORKER_PROCESS_SETUP_HOOK_LOAD_TIMEOUT_ENV_VAR, "60" + ) + ) + + +def _decode_function_key(key: bytes) -> str: + # b64encode only includes A-Z, a-z, 0-9, + and / characters + return RUNTIME_ENV_FUNC_IDENTIFIER + base64.b64encode(key).decode() + + +def _encode_function_key(key: str) -> bytes: + assert key.startswith(RUNTIME_ENV_FUNC_IDENTIFIER) + return base64.b64decode(key[len(RUNTIME_ENV_FUNC_IDENTIFIER) :]) + + +def export_setup_func_callable( + runtime_env: Union[Dict[str, Any], RuntimeEnv], + setup_func: Callable, + worker: "ray.Worker", +) -> Union[Dict[str, Any], RuntimeEnv]: + assert isinstance(setup_func, Callable) + try: + key = worker.function_actor_manager.export_setup_func( + setup_func, timeout=get_import_export_timeout() + ) + except Exception as e: + raise ray.exceptions.RuntimeEnvSetupError( + "Failed to export the setup function." + ) from e + env_vars = runtime_env.get("env_vars", {}) + assert ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR not in env_vars, ( + f"The env var, {ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR}, " + "is not permitted because it is reserved for the internal use." + ) + env_vars[ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR] = _decode_function_key( + key + ) + runtime_env["env_vars"] = env_vars + # Note: This field is no-op. We don't have a plugin for the setup hook + # because we can implement it simply using an env var. + # This field is just for the observability purpose, so we store + # the name of the method. + runtime_env["worker_process_setup_hook"] = setup_func.__name__ + return runtime_env + + +def export_setup_func_module( + runtime_env: Union[Dict[str, Any], RuntimeEnv], + setup_func_module: str, +) -> Union[Dict[str, Any], RuntimeEnv]: + assert isinstance(setup_func_module, str) + env_vars = runtime_env.get("env_vars", {}) + assert ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR not in env_vars, ( + f"The env var, {ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR}, " + "is not permitted because it is reserved for the internal use." + ) + env_vars[ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR] = setup_func_module + runtime_env["env_vars"] = env_vars + return runtime_env + + +def upload_worker_process_setup_hook_if_needed( + runtime_env: Union[Dict[str, Any], RuntimeEnv], + worker: "ray.Worker", +) -> Union[Dict[str, Any], RuntimeEnv]: + """Uploads the worker_process_setup_hook to GCS with a key. + + runtime_env["worker_process_setup_hook"] is converted to a decoded key + that can load the worker setup hook function from GCS. + i.e., you can use internalKV.Get(runtime_env["worker_process_setup_hook]) + to access the worker setup hook from GCS. + + Args: + runtime_env: The runtime_env. The value will be modified + when returned. + worker: ray.worker instance. + decoder: GCS requires the function key to be bytes. However, + we cannot json serialize (which is required to serialize + runtime env) the bytes. So the key should be decoded to + a string. The given decoder is used to decode the function + key. + """ + setup_func = runtime_env.get("worker_process_setup_hook") + + if setup_func is None: + return runtime_env + + if isinstance(setup_func, Callable): + return export_setup_func_callable(runtime_env, setup_func, worker) + elif isinstance(setup_func, str): + return export_setup_func_module(runtime_env, setup_func) + else: + raise TypeError( + "worker_process_setup_hook must be a function, " f"got {type(setup_func)}." + ) + + +def load_and_execute_setup_hook( + worker_process_setup_hook_key: str, +) -> Optional[str]: + """Load the setup hook from a given key and execute. + + Args: + worker_process_setup_hook_key: The key to import the setup hook + from GCS. + Returns: + An error message if it fails. None if it succeeds. + """ + assert worker_process_setup_hook_key is not None + if not worker_process_setup_hook_key.startswith(RUNTIME_ENV_FUNC_IDENTIFIER): + return load_and_execute_setup_hook_module(worker_process_setup_hook_key) + else: + return load_and_execute_setup_hook_func(worker_process_setup_hook_key) + + +def load_and_execute_setup_hook_module( + worker_process_setup_hook_key: str, +) -> Optional[str]: + try: + setup_func = _load_class(worker_process_setup_hook_key) + setup_func() + return None + except Exception: + error_message = ( + "Failed to execute the setup hook method, " + f"{worker_process_setup_hook_key} " + "from ``ray.init(runtime_env=" + f"{{'worker_process_setup_hook': {worker_process_setup_hook_key}}})``. " + "Please make sure the given module exists and is available " + "from ray workers. For more details, see the error trace below.\n" + f"{traceback.format_exc()}" + ) + return error_message + + +def load_and_execute_setup_hook_func( + worker_process_setup_hook_key: str, +) -> Optional[str]: + worker = ray._private.worker.global_worker + assert worker.connected + func_manager = worker.function_actor_manager + try: + worker_setup_func_info = func_manager.fetch_registered_method( + _encode_function_key(worker_process_setup_hook_key), + timeout=get_import_export_timeout(), + ) + except Exception: + error_message = ( + "Failed to import setup hook within " + f"{get_import_export_timeout()} seconds.\n" + f"{traceback.format_exc()}" + ) + return error_message + + try: + setup_func = pickle.loads(worker_setup_func_info.function) + except Exception: + error_message = ( + "Failed to deserialize the setup hook method.\n" f"{traceback.format_exc()}" + ) + return error_message + + try: + setup_func() + except Exception: + error_message = ( + f"Failed to execute the setup hook method. Function name:" + f"{worker_setup_func_info.function_name}\n" + f"{traceback.format_exc()}" + ) + return error_message + + return None diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uri_cache.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uri_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..c749cc63c041548200161263823c3d83d944af36 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uri_cache.py @@ -0,0 +1,117 @@ +import logging +from typing import Set, Callable, Optional + +default_logger = logging.getLogger(__name__) + +DEFAULT_MAX_URI_CACHE_SIZE_BYTES = (1024**3) * 10 # 10 GB + + +class URICache: + """ + Caches URIs up to a specified total size limit. + + URIs are represented by strings. Each URI has an associated size on disk. + + When a URI is added to the URICache, it is marked as "in use". + When a URI is no longer in use, the user of this class should call + `mark_unused` to signal that the URI is safe for deletion. + + URIs in the cache can be marked as "in use" by calling `mark_used`. + + Deletion of URIs on disk does not occur until the size limit is exceeded. + When this happens, URIs that are not in use are deleted randomly until the + size limit is satisfied, or there are no more URIs that are not in use. + + It is possible for the total size on disk to exceed the size limit if all + the URIs are in use. + + """ + + def __init__( + self, + delete_fn: Optional[Callable[[str, logging.Logger], int]] = None, + max_total_size_bytes: int = DEFAULT_MAX_URI_CACHE_SIZE_BYTES, + debug_mode: bool = False, + ): + # Maps URIs to the size in bytes of their corresponding disk contents. + self._used_uris: Set[str] = set() + self._unused_uris: Set[str] = set() + + if delete_fn is None: + self._delete_fn = lambda uri, logger: 0 + else: + self._delete_fn = delete_fn + + # Total size of both used and unused URIs in the cache. + self._total_size_bytes = 0 + self.max_total_size_bytes = max_total_size_bytes + + # Used in `self._check_valid()` for testing. + self._debug_mode = debug_mode + + def mark_unused(self, uri: str, logger: logging.Logger = default_logger): + """Mark a URI as unused and okay to be deleted.""" + if uri not in self._used_uris: + logger.info(f"URI {uri} is already unused.") + else: + self._unused_uris.add(uri) + self._used_uris.remove(uri) + logger.info(f"Marked URI {uri} unused.") + self._evict_if_needed(logger) + self._check_valid() + + def mark_used(self, uri: str, logger: logging.Logger = default_logger): + """Mark a URI as in use. URIs in use will not be deleted.""" + if uri in self._used_uris: + return + elif uri in self._unused_uris: + self._used_uris.add(uri) + self._unused_uris.remove(uri) + else: + raise ValueError( + f"Got request to mark URI {uri} used, but this " + "URI is not present in the cache." + ) + logger.info(f"Marked URI {uri} used.") + self._check_valid() + + def add(self, uri: str, size_bytes: int, logger: logging.Logger = default_logger): + """Add a URI to the cache and mark it as in use.""" + if uri in self._unused_uris: + self._unused_uris.remove(uri) + + self._used_uris.add(uri) + self._total_size_bytes += size_bytes + + self._evict_if_needed(logger) + self._check_valid() + logger.info(f"Added URI {uri} with size {size_bytes}") + + def get_total_size_bytes(self) -> int: + return self._total_size_bytes + + def _evict_if_needed(self, logger: logging.Logger = default_logger): + """Evict unused URIs (if they exist) until total size <= max size.""" + while ( + self._unused_uris + and self.get_total_size_bytes() > self.max_total_size_bytes + ): + # TODO(architkulkarni): Evict least recently used URI instead + arbitrary_unused_uri = next(iter(self._unused_uris)) + self._unused_uris.remove(arbitrary_unused_uri) + num_bytes_deleted = self._delete_fn(arbitrary_unused_uri, logger) + self._total_size_bytes -= num_bytes_deleted + logger.info( + f"Deleted URI {arbitrary_unused_uri} with size " f"{num_bytes_deleted}." + ) + + def _check_valid(self): + """(Debug mode only) Check "used" and "unused" sets are disjoint.""" + if self._debug_mode: + assert self._used_uris & self._unused_uris == set() + + def __contains__(self, uri): + return uri in self._used_uris or uri in self._unused_uris + + def __repr__(self): + return str(self.__dict__) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uv.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uv.py new file mode 100644 index 0000000000000000000000000000000000000000..c182723aa0d3290d0b9ecf672af2f03a55cb3d38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/uv.py @@ -0,0 +1,345 @@ +"""Util class to install packages via uv. +""" + +from typing import Dict, List, Optional +import os +import hashlib +from ray._private.runtime_env import virtualenv_utils +from ray._private.runtime_env import dependency_utils +from ray._private.runtime_env.utils import check_output_cmd +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.runtime_env.packaging import Protocol, parse_uri +from asyncio import create_task, get_running_loop +import shutil +import logging +import json +import asyncio +import sys +from ray._private.utils import try_to_create_directory, get_directory_size_bytes + +default_logger = logging.getLogger(__name__) + + +def _get_uv_hash(uv_dict: Dict) -> str: + """Get a deterministic hash value for `uv` related runtime envs.""" + serialized_uv_spec = json.dumps(uv_dict, sort_keys=True) + hash_val = hashlib.sha1(serialized_uv_spec.encode("utf-8")).hexdigest() + return hash_val + + +def get_uri(runtime_env: Dict) -> Optional[str]: + """Return `"uv://"`, or None if no GC required.""" + uv = runtime_env.get("uv") + if uv is not None: + if isinstance(uv, dict): + uri = "uv://" + _get_uv_hash(uv_dict=uv) + elif isinstance(uv, list): + uri = "uv://" + _get_uv_hash(uv_dict=dict(packages=uv)) + else: + raise TypeError( + "uv field received by RuntimeEnvAgent must be " + f"list or dict, not {type(uv).__name__}." + ) + else: + uri = None + return uri + + +class UvProcessor: + def __init__( + self, + target_dir: str, + runtime_env: "RuntimeEnv", # noqa: F821 + logger: Optional[logging.Logger] = default_logger, + ): + try: + import virtualenv # noqa: F401 ensure virtualenv exists. + except ImportError: + raise RuntimeError( + f"Please install virtualenv " + f"`{sys.executable} -m pip install virtualenv`" + f"to enable uv runtime env." + ) + + logger.debug("Setting up uv for runtime_env: %s", runtime_env) + self._target_dir = target_dir + # An empty directory is created to execute cmd. + self._exec_cwd = os.path.join(self._target_dir, "exec_cwd") + self._runtime_env = runtime_env + self._logger = logger + + self._uv_config = self._runtime_env.uv_config() + self._uv_env = os.environ.copy() + self._uv_env.update(self._runtime_env.env_vars()) + + async def _install_uv( + self, path: str, cwd: str, pip_env: dict, logger: logging.Logger + ): + """Before package install, make sure the required version `uv` (if specifieds) + is installed. + """ + virtualenv_path = virtualenv_utils.get_virtualenv_path(path) + python = virtualenv_utils.get_virtualenv_python(path) + + def _get_uv_exec_to_install() -> str: + """Get `uv` executable with version to install.""" + uv_version = self._uv_config.get("uv_version", None) + if uv_version: + return f"uv{uv_version}" + # Use default version. + return "uv" + + uv_install_cmd = [ + python, + "-m", + "pip", + "install", + "--disable-pip-version-check", + "--no-cache-dir", + _get_uv_exec_to_install(), + ] + logger.info("Installing package uv to %s", virtualenv_path) + await check_output_cmd(uv_install_cmd, logger=logger, cwd=cwd, env=pip_env) + + async def _check_uv_existence( + self, path: str, cwd: str, env: dict, logger: logging.Logger + ) -> bool: + """Check and return the existence of `uv` in virtual env.""" + python = virtualenv_utils.get_virtualenv_python(path) + + check_existence_cmd = [ + python, + "-m", + "uv", + "version", + ] + + try: + # If `uv` doesn't exist, exception will be thrown. + await check_output_cmd(check_existence_cmd, logger=logger, cwd=cwd, env=env) + return True + except Exception: + return False + + async def _uv_check(sef, python: str, cwd: str, logger: logging.Logger) -> None: + """Check virtual env dependency compatibility. + If any incompatibility detected, exception will be thrown. + + param: + python: the path for python executable within virtual environment. + """ + cmd = [python, "-m", "uv", "pip", "check"] + await check_output_cmd( + cmd, + logger=logger, + cwd=cwd, + ) + + async def _install_uv_packages( + self, + path: str, + uv_packages: List[str], + cwd: str, + pip_env: Dict, + logger: logging.Logger, + ): + """Install required python packages via `uv`.""" + virtualenv_path = virtualenv_utils.get_virtualenv_path(path) + python = virtualenv_utils.get_virtualenv_python(path) + # TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ... + requirements_file = dependency_utils.get_requirements_file(path, uv_packages) + + # Check existence for `uv` and see if we could skip `uv` installation. + uv_exists = await self._check_uv_existence(python, cwd, pip_env, logger) + + # Install uv, which acts as the default package manager. + if (not uv_exists) or (self._uv_config.get("uv_version", None) is not None): + await self._install_uv(path, cwd, pip_env, logger) + + # Avoid blocking the event loop. + loop = get_running_loop() + await loop.run_in_executor( + None, dependency_utils.gen_requirements_txt, requirements_file, uv_packages + ) + + # Install all dependencies. + # + # Difference with pip: + # 1. `--disable-pip-version-check` has no effect for uv. + # 2. Allow user to specify their own options to install packages via `uv`. + uv_install_cmd = [ + python, + "-m", + "uv", + "pip", + "install", + "-r", + requirements_file, + ] + + uv_opt_list = self._uv_config.get("uv_pip_install_options", ["--no-cache"]) + if uv_opt_list: + uv_install_cmd += uv_opt_list + + logger.info("Installing python requirements to %s", virtualenv_path) + await check_output_cmd(uv_install_cmd, logger=logger, cwd=cwd, env=pip_env) + + # Check python environment for conflicts. + if self._uv_config.get("uv_check", False): + await self._uv_check(python, cwd, logger) + + async def _run(self): + path = self._target_dir + logger = self._logger + uv_packages = self._uv_config["packages"] + # We create an empty directory for exec cmd so that the cmd will + # run more stable. e.g. if cwd has ray, then checking ray will + # look up ray in cwd instead of site packages. + os.makedirs(self._exec_cwd, exist_ok=True) + try: + await virtualenv_utils.create_or_get_virtualenv( + path, self._exec_cwd, logger + ) + python = virtualenv_utils.get_virtualenv_python(path) + async with dependency_utils.check_ray(python, self._exec_cwd, logger): + # Install packages with uv. + await self._install_uv_packages( + path, + uv_packages, + self._exec_cwd, + self._uv_env, + logger, + ) + except Exception: + logger.info("Delete incomplete virtualenv: %s", path) + shutil.rmtree(path, ignore_errors=True) + logger.exception("Failed to install uv packages.") + raise + + def __await__(self): + return self._run().__await__() + + +class UvPlugin(RuntimeEnvPlugin): + name = "uv" + + def __init__(self, resources_dir: str): + self._uv_resource_dir = os.path.join(resources_dir, "uv") + self._creating_task = {} + # Maps a URI to a lock that is used to prevent multiple concurrent + # installs of the same virtualenv, see #24513 + self._create_locks: Dict[str, asyncio.Lock] = {} + # Key: created hashes. Value: size of the uv dir. + self._created_hash_bytes: Dict[str, int] = {} + try_to_create_directory(self._uv_resource_dir) + + def _get_path_from_hash(self, hash_val: str) -> str: + """Generate a path from the hash of a uv spec. + + Example output: + /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources + /uv/ray-9a7972c3a75f55e976e620484f58410c920db091 + """ + return os.path.join(self._uv_resource_dir, hash_val) + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + """Return the uv URI from the RuntimeEnv if it exists, else return [].""" + uv_uri = runtime_env.uv_uri() + if uv_uri: + return [uv_uri] + return [] + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + logger.info("Got request to delete uv URI %s", uri) + protocol, hash_val = parse_uri(uri) + if protocol != Protocol.UV: + raise ValueError( + "UvPlugin can only delete URIs with protocol " + f"uv. Received protocol {protocol}, URI {uri}" + ) + + # Cancel running create task. + task = self._creating_task.pop(hash_val, None) + if task is not None: + task.cancel() + + del self._created_hash_bytes[hash_val] + + uv_env_path = self._get_path_from_hash(hash_val) + local_dir_size = get_directory_size_bytes(uv_env_path) + del self._create_locks[uri] + try: + shutil.rmtree(uv_env_path) + except OSError as e: + logger.warning(f"Error when deleting uv env {uv_env_path}: {str(e)}") + return 0 + + return local_dir_size + + async def create( + self, + uri: str, + runtime_env: "RuntimeEnv", # noqa: F821 + context: "RuntimeEnvContext", # noqa: F821 + logger: Optional[logging.Logger] = default_logger, + ) -> int: + if not runtime_env.has_uv(): + return 0 + + protocol, hash_val = parse_uri(uri) + target_dir = self._get_path_from_hash(hash_val) + + async def _create_for_hash(): + await UvProcessor( + target_dir, + runtime_env, + logger, + ) + + loop = get_running_loop() + return await loop.run_in_executor( + None, get_directory_size_bytes, target_dir + ) + + if uri not in self._create_locks: + # async lock to prevent the same virtualenv being concurrently installed + self._create_locks[uri] = asyncio.Lock() + + async with self._create_locks[uri]: + if hash_val in self._created_hash_bytes: + return self._created_hash_bytes[hash_val] + self._creating_task[hash_val] = task = create_task(_create_for_hash()) + task.add_done_callback(lambda _: self._creating_task.pop(hash_val, None)) + uv_dir_bytes = await task + self._created_hash_bytes[hash_val] = uv_dir_bytes + return uv_dir_bytes + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: "RuntimeEnvContext", # noqa: F821 + logger: logging.Logger = default_logger, + ): + if not runtime_env.has_uv(): + return + # UvPlugin only uses a single URI. + uri = uris[0] + # Update py_executable. + protocol, hash_val = parse_uri(uri) + target_dir = self._get_path_from_hash(hash_val) + virtualenv_python = virtualenv_utils.get_virtualenv_python(target_dir) + + if not os.path.exists(virtualenv_python): + raise ValueError( + f"Local directory {target_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "installing the runtime_env `uv` packages." + ) + context.py_executable = virtualenv_python + context.command_prefix += virtualenv_utils.get_virtualenv_activate_command( + target_dir + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/validation.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..fa6d9bb1635578dca943be47f7417b494a78a802 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/validation.py @@ -0,0 +1,386 @@ +import logging +from pathlib import Path +import sys +from typing import Dict, List, Optional, Union + +from collections import OrderedDict +import yaml + +logger = logging.getLogger(__name__) + + +def validate_uri(uri: str): + if not isinstance(uri, str): + raise TypeError( + "URIs for working_dir and py_modules must be " f"strings, got {type(uri)}." + ) + + try: + from ray._private.runtime_env.packaging import parse_uri, Protocol + + protocol, path = parse_uri(uri) + except ValueError: + raise ValueError( + f"{uri} is not a valid URI. Passing directories or modules to " + "be dynamically uploaded is only supported at the job level " + "(i.e., passed to `ray.init`)." + ) + + if ( + protocol in Protocol.remote_protocols() + and not path.endswith(".zip") + and not path.endswith(".whl") + ): + raise ValueError("Only .zip or .whl files supported for remote URIs.") + + +def _handle_local_deps_requirement_file(requirements_file: str): + """Read the given [requirements_file], and return all required dependencies.""" + requirements_path = Path(requirements_file) + if not requirements_path.is_file(): + raise ValueError(f"{requirements_path} is not a valid file") + return requirements_path.read_text().strip().split("\n") + + +def parse_and_validate_py_modules(py_modules: List[str]) -> List[str]: + """Parses and validates a 'py_modules' option. + + This should be a list of URIs. + """ + if not isinstance(py_modules, list): + raise TypeError( + "`py_modules` must be a list of strings, got " f"{type(py_modules)}." + ) + + for uri in py_modules: + validate_uri(uri) + + return py_modules + + +def parse_and_validate_working_dir(working_dir: str) -> str: + """Parses and validates a 'working_dir' option. + + This should be a URI. + """ + assert working_dir is not None + + if not isinstance(working_dir, str): + raise TypeError("`working_dir` must be a string, got " f"{type(working_dir)}.") + + validate_uri(working_dir) + + return working_dir + + +def parse_and_validate_conda(conda: Union[str, dict]) -> Union[str, dict]: + """Parses and validates a user-provided 'conda' option. + + Conda can be one of three cases: + 1) A dictionary describing the env. This is passed through directly. + 2) A string referring to the name of a preinstalled conda env. + 3) A string pointing to a local conda YAML file. This is detected + by looking for a '.yaml' or '.yml' suffix. In this case, the file + will be read as YAML and passed through as a dictionary. + """ + assert conda is not None + + if sys.platform == "win32": + logger.warning( + "runtime environment support is experimental on Windows. " + "If you run into issues please file a report at " + "https://github.com/ray-project/ray/issues." + ) + + result = None + if isinstance(conda, str): + yaml_file = Path(conda) + if yaml_file.suffix in (".yaml", ".yml"): + if not yaml_file.is_file(): + raise ValueError(f"Can't find conda YAML file {yaml_file}.") + try: + result = yaml.safe_load(yaml_file.read_text()) + except Exception as e: + raise ValueError(f"Failed to read conda file {yaml_file}: {e}.") + else: + # Assume it's a pre-existing conda environment name. + result = conda + elif isinstance(conda, dict): + result = conda + else: + raise TypeError( + "runtime_env['conda'] must be of type str or " f"dict, got {type(conda)}." + ) + + return result + + +def parse_and_validate_uv(uv: Union[str, List[str], Dict]) -> Optional[Dict]: + """Parses and validates a user-provided 'uv' option. + + The value of the input 'uv' field can be one of two cases: + 1) A List[str] describing the requirements. This is passed through. + Example usage: ["tensorflow", "requests"] + 2) a string containing the path to a local pip “requirements.txt” file. + 3) A python dictionary that has one field: + a) packages (required, List[str]): a list of uv packages, it same as 1). + b) uv_check (optional, bool): whether to enable pip check at the end of uv + install, default to False. + c) uv_version (optional, str): user provides a specific uv to use; if + unspecified, default version of uv will be used. + d) uv_pip_install_options (optional, List[str]): user-provided options for + `uv pip install` command, default to ["--no-cache"]. + + The returned parsed value will be a list of packages. If a Ray library + (e.g. "ray[serve]") is specified, it will be deleted and replaced by its + dependencies (e.g. "uvicorn", "requests"). + """ + assert uv is not None + if sys.platform == "win32": + logger.warning( + "runtime environment support is experimental on Windows. " + "If you run into issues please file a report at " + "https://github.com/ray-project/ray/issues." + ) + + result: str = "" + if isinstance(uv, str): + uv_list = _handle_local_deps_requirement_file(uv) + result = dict(packages=uv_list, uv_check=False) + elif isinstance(uv, list) and all(isinstance(dep, str) for dep in uv): + result = dict(packages=uv, uv_check=False) + elif isinstance(uv, dict): + if set(uv.keys()) - { + "packages", + "uv_check", + "uv_version", + "uv_pip_install_options", + }: + raise ValueError( + "runtime_env['uv'] can only have these fields: " + "packages, uv_check, uv_version and uv_pip_install_options, but got: " + f"{list(uv.keys())}" + ) + if "packages" not in uv: + raise ValueError( + f"runtime_env['uv'] must include field 'packages', but got {uv}" + ) + if "uv_check" in uv and not isinstance(uv["uv_check"], bool): + raise TypeError( + "runtime_env['uv']['uv_check'] must be of type bool, " + f"got {type(uv['uv_check'])}" + ) + if "uv_version" in uv and not isinstance(uv["uv_version"], str): + raise TypeError( + "runtime_env['uv']['uv_version'] must be of type str, " + f"got {type(uv['uv_version'])}" + ) + if "uv_pip_install_options" in uv: + if not isinstance(uv["uv_pip_install_options"], list): + raise TypeError( + "runtime_env['uv']['uv_pip_install_options'] must be of type " + f"list[str] got {type(uv['uv_pip_install_options'])}" + ) + # Check each item in installation option. + for idx, cur_opt in enumerate(uv["uv_pip_install_options"]): + if not isinstance(cur_opt, str): + raise TypeError( + "runtime_env['uv']['uv_pip_install_options'] must be of type " + f"list[str] got {type(cur_opt)} for {idx}-th item." + ) + + result = uv.copy() + result["uv_check"] = uv.get("uv_check", False) + result["uv_pip_install_options"] = uv.get( + "uv_pip_install_options", ["--no-cache"] + ) + if not isinstance(uv["packages"], list): + raise ValueError( + "runtime_env['uv']['packages'] must be of type list, " + f"got: {type(uv['packages'])}" + ) + else: + raise TypeError( + "runtime_env['uv'] must be of type " f"List[str], or dict, got {type(uv)}" + ) + + # Deduplicate packages for package lists. + result["packages"] = list(OrderedDict.fromkeys(result["packages"])) + + if len(result["packages"]) == 0: + result = None + logger.debug(f"Rewrote runtime_env `uv` field from {uv} to {result}.") + return result + + +def parse_and_validate_pip(pip: Union[str, List[str], Dict]) -> Optional[Dict]: + """Parses and validates a user-provided 'pip' option. + + The value of the input 'pip' field can be one of two cases: + 1) A List[str] describing the requirements. This is passed through. + 2) A string pointing to a local requirements file. In this case, the + file contents will be read split into a list. + 3) A python dictionary that has three fields: + a) packages (required, List[str]): a list of pip packages, it same as 1). + b) pip_check (optional, bool): whether to enable pip check at the end of pip + install, default to False. + c) pip_version (optional, str): the version of pip, ray will spell + the package name 'pip' in front of the `pip_version` to form the final + requirement string, the syntax of a requirement specifier is defined in + full in PEP 508. + + The returned parsed value will be a list of pip packages. If a Ray library + (e.g. "ray[serve]") is specified, it will be deleted and replaced by its + dependencies (e.g. "uvicorn", "requests"). + """ + assert pip is not None + + result = None + if sys.platform == "win32": + logger.warning( + "runtime environment support is experimental on Windows. " + "If you run into issues please file a report at " + "https://github.com/ray-project/ray/issues." + ) + if isinstance(pip, str): + # We have been given a path to a requirements.txt file. + pip_list = _handle_local_deps_requirement_file(pip) + result = dict(packages=pip_list, pip_check=False) + elif isinstance(pip, list) and all(isinstance(dep, str) for dep in pip): + result = dict(packages=pip, pip_check=False) + elif isinstance(pip, dict): + if set(pip.keys()) - {"packages", "pip_check", "pip_version"}: + raise ValueError( + "runtime_env['pip'] can only have these fields: " + "packages, pip_check and pip_version, but got: " + f"{list(pip.keys())}" + ) + + if "pip_check" in pip and not isinstance(pip["pip_check"], bool): + raise TypeError( + "runtime_env['pip']['pip_check'] must be of type bool, " + f"got {type(pip['pip_check'])}" + ) + if "pip_version" in pip: + if not isinstance(pip["pip_version"], str): + raise TypeError( + "runtime_env['pip']['pip_version'] must be of type str, " + f"got {type(pip['pip_version'])}" + ) + result = pip.copy() + result["pip_check"] = pip.get("pip_check", False) + if "packages" not in pip: + raise ValueError( + f"runtime_env['pip'] must include field 'packages', but got {pip}" + ) + elif isinstance(pip["packages"], str): + result["packages"] = _handle_local_deps_requirement_file(pip["packages"]) + elif not isinstance(pip["packages"], list): + raise ValueError( + "runtime_env['pip']['packages'] must be of type str of list, " + f"got: {type(pip['packages'])}" + ) + else: + raise TypeError( + "runtime_env['pip'] must be of type str or " f"List[str], got {type(pip)}" + ) + + # Eliminate duplicates to prevent `pip install` from erroring. Use + # OrderedDict to preserve the order of the list. This makes the output + # deterministic and easier to debug, because pip install can have + # different behavior depending on the order of the input. + result["packages"] = list(OrderedDict.fromkeys(result["packages"])) + + if len(result["packages"]) == 0: + result = None + + logger.debug(f"Rewrote runtime_env `pip` field from {pip} to {result}.") + + return result + + +def parse_and_validate_container(container: List[str]) -> List[str]: + """Parses and validates a user-provided 'container' option. + + This is passed through without validation (for now). + """ + assert container is not None + return container + + +def parse_and_validate_excludes(excludes: List[str]) -> List[str]: + """Parses and validates a user-provided 'excludes' option. + + This is validated to verify that it is of type List[str]. + + If an empty list is passed, we return `None` for consistency. + """ + assert excludes is not None + + if isinstance(excludes, list) and len(excludes) == 0: + return None + + if isinstance(excludes, list) and all(isinstance(path, str) for path in excludes): + return excludes + else: + raise TypeError( + "runtime_env['excludes'] must be of type " + f"List[str], got {type(excludes)}" + ) + + +def parse_and_validate_env_vars(env_vars: Dict[str, str]) -> Optional[Dict[str, str]]: + """Parses and validates a user-provided 'env_vars' option. + + This is validated to verify that all keys and vals are strings. + + If an empty dictionary is passed, we return `None` for consistency. + + Args: + env_vars: A dictionary of environment variables to set in the + runtime environment. + + Returns: + The validated env_vars dictionary, or None if it was empty. + + Raises: + TypeError: If the env_vars is not a dictionary of strings. The error message + will include the type of the invalid value. + """ + assert env_vars is not None + if len(env_vars) == 0: + return None + + if not isinstance(env_vars, dict): + raise TypeError( + "runtime_env['env_vars'] must be of type " + f"Dict[str, str], got {type(env_vars)}" + ) + + for key, val in env_vars.items(): + if not isinstance(key, str): + raise TypeError( + "runtime_env['env_vars'] must be of type " + f"Dict[str, str], but the key {key} is of type {type(key)}" + ) + if not isinstance(val, str): + raise TypeError( + "runtime_env['env_vars'] must be of type " + f"Dict[str, str], but the value {val} is of type {type(val)}" + ) + + return env_vars + + +# Dictionary mapping runtime_env options with the function to parse and +# validate them. +OPTION_TO_VALIDATION_FN = { + "py_modules": parse_and_validate_py_modules, + "working_dir": parse_and_validate_working_dir, + "excludes": parse_and_validate_excludes, + "conda": parse_and_validate_conda, + "pip": parse_and_validate_pip, + "uv": parse_and_validate_uv, + "env_vars": parse_and_validate_env_vars, + "container": parse_and_validate_container, +} diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/virtualenv_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/virtualenv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09bd9f82e7f96832165b64476bd578118555feaf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/virtualenv_utils.py @@ -0,0 +1,109 @@ +"""Utils to detect runtime environment.""" + +import sys +from ray._private.runtime_env.utils import check_output_cmd +import logging +import os +from typing import List + +_WIN32 = os.name == "nt" + + +def is_in_virtualenv() -> bool: + # virtualenv <= 16.7.9 sets the real_prefix, + # virtualenv > 16.7.9 & venv set the base_prefix. + # So, we check both of them here. + # https://github.com/pypa/virtualenv/issues/1622#issuecomment-586186094 + return hasattr(sys, "real_prefix") or ( + hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix + ) + + +def get_virtualenv_path(target_dir: str) -> str: + """Get virtual environment path.""" + return os.path.join(target_dir, "virtualenv") + + +def get_virtualenv_python(target_dir: str) -> str: + virtualenv_path = get_virtualenv_path(target_dir) + if _WIN32: + return os.path.join(virtualenv_path, "Scripts", "python.exe") + else: + return os.path.join(virtualenv_path, "bin", "python") + + +def get_virtualenv_activate_command(target_dir: str) -> List[str]: + """Get the command to activate virtual environment.""" + virtualenv_path = get_virtualenv_path(target_dir) + if _WIN32: + cmd = [os.path.join(virtualenv_path, "Scripts", "activate.bat")] + else: + cmd = ["source", os.path.join(virtualenv_path, "bin/activate")] + return cmd + ["1>&2", "&&"] + + +async def create_or_get_virtualenv(path: str, cwd: str, logger: logging.Logger): + """Create or get a virtualenv from path.""" + python = sys.executable + virtualenv_path = os.path.join(path, "virtualenv") + virtualenv_app_data_path = os.path.join(path, "virtualenv_app_data") + + if _WIN32: + current_python_dir = sys.prefix + env = os.environ.copy() + else: + current_python_dir = os.path.abspath( + os.path.join(os.path.dirname(python), "..") + ) + env = {} + + if is_in_virtualenv(): + # virtualenv-clone homepage: + # https://github.com/edwardgeorge/virtualenv-clone + # virtualenv-clone Usage: + # virtualenv-clone /path/to/existing/venv /path/to/cloned/ven + # or + # python -m clonevirtualenv /path/to/existing/venv /path/to/cloned/ven + clonevirtualenv = os.path.join(os.path.dirname(__file__), "_clonevirtualenv.py") + create_venv_cmd = [ + python, + clonevirtualenv, + current_python_dir, + virtualenv_path, + ] + logger.info("Cloning virtualenv %s to %s", current_python_dir, virtualenv_path) + else: + # virtualenv options: + # https://virtualenv.pypa.io/en/latest/cli_interface.html + # + # --app-data + # --reset-app-data + # Set an empty seperated app data folder for current virtualenv. + # + # --no-periodic-update + # Disable the periodic (once every 14 days) update of the embedded + # wheels. + # + # --system-site-packages + # Inherit site packages. + # + # --no-download + # Never download the latest pip/setuptools/wheel from PyPI. + create_venv_cmd = [ + python, + "-m", + "virtualenv", + "--app-data", + virtualenv_app_data_path, + "--reset-app-data", + "--no-periodic-update", + "--system-site-packages", + "--no-download", + virtualenv_path, + ] + logger.info( + "Creating virtualenv at %s, current python dir %s", + virtualenv_path, + virtualenv_path, + ) + await check_output_cmd(create_venv_cmd, logger=logger, cwd=cwd, env=env) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/working_dir.py b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/working_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..69021532350137a2fc37bbfe2f8666645603520e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/runtime_env/working_dir.py @@ -0,0 +1,232 @@ +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional +from contextlib import contextmanager + +import ray._private.ray_constants as ray_constants +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.packaging import ( + Protocol, + delete_package, + download_and_unpack_package, + get_local_dir_from_uri, + get_uri_for_directory, + get_uri_for_package, + parse_uri, + upload_package_if_needed, + upload_package_to_gcs, +) +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray._private.utils import get_directory_size_bytes, try_to_create_directory +from ray.exceptions import RuntimeEnvSetupError + +default_logger = logging.getLogger(__name__) + +_WIN32 = os.name == "nt" + + +def upload_working_dir_if_needed( + runtime_env: Dict[str, Any], + scratch_dir: Optional[str] = os.getcwd(), + logger: Optional[logging.Logger] = default_logger, + upload_fn=None, +) -> Dict[str, Any]: + """Uploads the working_dir and replaces it with a URI. + + If the working_dir is already a URI, this is a no-op. + """ + working_dir = runtime_env.get("working_dir") + if working_dir is None: + return runtime_env + + if not isinstance(working_dir, str) and not isinstance(working_dir, Path): + raise TypeError( + "working_dir must be a string or Path (either a local path " + f"or remote URI), got {type(working_dir)}." + ) + + if isinstance(working_dir, Path): + working_dir = str(working_dir) + + # working_dir is already a URI -- just pass it through. + try: + protocol, path = parse_uri(working_dir) + except ValueError: + protocol, path = None, None + + if protocol is not None: + if protocol in Protocol.remote_protocols() and not path.endswith(".zip"): + raise ValueError("Only .zip files supported for remote URIs.") + return runtime_env + + excludes = runtime_env.get("excludes", None) + try: + working_dir_uri = get_uri_for_directory(working_dir, excludes=excludes) + except ValueError: # working_dir is not a directory + package_path = Path(working_dir) + if not package_path.exists() or package_path.suffix != ".zip": + raise ValueError( + f"directory {package_path} must be an existing " + "directory or a zip package" + ) + + pkg_uri = get_uri_for_package(package_path) + try: + upload_package_to_gcs(pkg_uri, package_path.read_bytes()) + except Exception as e: + raise RuntimeEnvSetupError( + f"Failed to upload package {package_path} to the Ray cluster: {e}" + ) from e + runtime_env["working_dir"] = pkg_uri + return runtime_env + if upload_fn is None: + try: + upload_package_if_needed( + working_dir_uri, + scratch_dir, + working_dir, + include_parent_dir=False, + excludes=excludes, + logger=logger, + ) + except Exception as e: + raise RuntimeEnvSetupError( + f"Failed to upload working_dir {working_dir} to the Ray cluster: {e}" + ) from e + else: + upload_fn(working_dir, excludes=excludes) + + runtime_env["working_dir"] = working_dir_uri + return runtime_env + + +def set_pythonpath_in_context(python_path: str, context: RuntimeEnvContext): + """Insert the path as the first entry in PYTHONPATH in the runtime env. + + This is compatible with users providing their own PYTHONPATH in env_vars, + and is also compatible with the existing PYTHONPATH in the cluster. + + The import priority is as follows: + this python_path arg > env_vars PYTHONPATH > existing cluster env PYTHONPATH. + """ + if "PYTHONPATH" in context.env_vars: + python_path += os.pathsep + context.env_vars["PYTHONPATH"] + if "PYTHONPATH" in os.environ: + python_path += os.pathsep + os.environ["PYTHONPATH"] + context.env_vars["PYTHONPATH"] = python_path + + +class WorkingDirPlugin(RuntimeEnvPlugin): + + name = "working_dir" + + # Note working_dir is not following the priority order of other plugins. Instead + # it's specially treated to happen before all other plugins. + priority = 5 + + def __init__( + self, resources_dir: str, gcs_aio_client: "GcsAioClient" # noqa: F821 + ): + self._resources_dir = os.path.join(resources_dir, "working_dir_files") + self._gcs_aio_client = gcs_aio_client + try_to_create_directory(self._resources_dir) + + def delete_uri( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ) -> int: + """Delete URI and return the number of bytes deleted.""" + logger.info("Got request to delete working dir URI %s", uri) + local_dir = get_local_dir_from_uri(uri, self._resources_dir) + local_dir_size = get_directory_size_bytes(local_dir) + + deleted = delete_package(uri, self._resources_dir) + if not deleted: + logger.warning(f"Tried to delete nonexistent URI: {uri}.") + return 0 + + return local_dir_size + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + working_dir_uri = runtime_env.working_dir() + if working_dir_uri != "": + return [working_dir_uri] + return [] + + async def create( + self, + uri: Optional[str], + runtime_env: dict, + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, + ) -> int: + local_dir = await download_and_unpack_package( + uri, + self._resources_dir, + self._gcs_aio_client, + logger=logger, + overwrite=True, + ) + return get_directory_size_bytes(local_dir) + + def modify_context( + self, + uris: List[str], + runtime_env_dict: Dict, + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + if not uris: + return + + # WorkingDirPlugin uses a single URI. + uri = uris[0] + local_dir = get_local_dir_from_uri(uri, self._resources_dir) + if not local_dir.exists(): + raise ValueError( + f"Local directory {local_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "downloading or unpacking the working_dir." + ) + + if not _WIN32: + context.command_prefix += ["cd", str(local_dir), "&&"] + else: + # Include '/d' incase temp folder is on different drive than Ray install. + context.command_prefix += ["cd", "/d", f"{local_dir}", "&&"] + set_pythonpath_in_context(python_path=str(local_dir), context=context) + + @contextmanager + def with_working_dir_env(self, uri): + """ + If uri is not None, add the local working directory to the environment variable + as "RAY_RUNTIME_ENV_CREATE_WORKING_DIR". This is useful for other plugins to + create their environment with reference to the working directory. For example + `pip -r ${RAY_RUNTIME_ENV_CREATE_WORKING_DIR}/requirements.txt` + + The environment variable is removed after the context manager exits. + """ + if uri is None: + yield + else: + local_dir = get_local_dir_from_uri(uri, self._resources_dir) + if not local_dir.exists(): + raise ValueError( + f"Local directory {local_dir} for URI {uri} does " + "not exist on the cluster. Something may have gone wrong while " + "downloading or unpacking the working_dir." + ) + key = ray_constants.RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR + prev = os.environ.get(key) + # Windows backslash paths are weird. When it's passed to the env var, and + # when Pip expands it, the backslashes are interpreted as escape characters + # and messes up the whole path. So we convert it to forward slashes. + # This works at least for all Python applications, including pip. + os.environ[key] = local_dir.as_posix() + try: + yield + finally: + if prev is None: + del os.environ[key] + else: + os.environ[key] = prev diff --git a/.venv/lib/python3.11/site-packages/ray/_private/serialization.py b/.venv/lib/python3.11/site-packages/ray/_private/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..88e6648d012a5be94afd58d9764f67919de1264f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/serialization.py @@ -0,0 +1,556 @@ +import io +import logging +import threading +import traceback +from typing import Any, Optional + + +import google.protobuf.message + +import ray._private.utils +import ray.cloudpickle as pickle +from ray._private import ray_constants +from ray._raylet import ( + MessagePackSerializedObject, + MessagePackSerializer, + DynamicObjectRefGenerator, + Pickle5SerializedObject, + Pickle5Writer, + RawSerializedObject, + split_buffer, + unpack_pickle5_buffers, +) +from ray.core.generated.common_pb2 import ErrorType, RayErrorInfo +from ray.exceptions import ( + ActorDiedError, + ActorPlacementGroupRemoved, + ActorUnavailableError, + ActorUnschedulableError, + LocalRayletDiedError, + NodeDiedError, + ObjectFetchTimedOutError, + ObjectLostError, + ObjectReconstructionFailedError, + ObjectReconstructionFailedLineageEvictedError, + ObjectReconstructionFailedMaxAttemptsExceededError, + OutOfDiskError, + OwnerDiedError, + PlasmaObjectNotAvailable, + RayError, + RaySystemError, + RayTaskError, + ReferenceCountingAssertionError, + ObjectFreedError, + RuntimeEnvSetupError, + TaskCancelledError, + TaskPlacementGroupRemoved, + TaskUnschedulableError, + WorkerCrashedError, + OutOfMemoryError, + ObjectRefStreamEndOfStreamError, +) +import ray.exceptions +from ray.experimental.compiled_dag_ref import CompiledDAGRef +from ray.util import serialization_addons +from ray.util import inspect_serializability + +logger = logging.getLogger(__name__) +ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION = ray_constants.env_bool( + "RAY_allow_out_of_band_object_ref_serialization", True +) + + +class DeserializationError(Exception): + pass + + +def pickle_dumps(obj: Any, error_msg: str): + """Wrap cloudpickle.dumps to provide better error message + when the object is not serializable. + """ + try: + return pickle.dumps(obj) + except (TypeError, ray.exceptions.OufOfBandObjectRefSerializationException) as e: + sio = io.StringIO() + inspect_serializability(obj, print_file=sio) + msg = f"{error_msg}:\n{sio.getvalue()}" + if isinstance(e, TypeError): + raise TypeError(msg) from e + else: + raise ray.exceptions.OufOfBandObjectRefSerializationException(msg) + + +def _object_ref_deserializer(binary, call_site, owner_address, object_status): + # NOTE(suquark): This function should be a global function so + # cloudpickle can access it directly. Otherwise cloudpickle + # has to dump the whole function definition, which is inefficient. + + # NOTE(swang): Must deserialize the object first before asking + # the core worker to resolve the value. This is to make sure + # that the ref count for the ObjectRef is greater than 0 by the + # time the core worker resolves the value of the object. + obj_ref = ray.ObjectRef(binary, owner_address, call_site) + + # TODO(edoakes): we should be able to just capture a reference + # to 'self' here instead, but this function is itself pickled + # somewhere, which causes an error. + if owner_address: + worker = ray._private.worker.global_worker + worker.check_connected() + context = worker.get_serialization_context() + outer_id = context.get_outer_object_ref() + # outer_id is None in the case that this ObjectRef was closed + # over in a function or pickled directly using pickle.dumps(). + if outer_id is None: + outer_id = ray.ObjectRef.nil() + worker.core_worker.deserialize_and_register_object_ref( + obj_ref.binary(), outer_id, owner_address, object_status + ) + return obj_ref + + +def _actor_handle_deserializer(serialized_obj, weak_ref): + # If this actor handle was stored in another object, then tell the + # core worker. + context = ray._private.worker.global_worker.get_serialization_context() + outer_id = context.get_outer_object_ref() + return ray.actor.ActorHandle._deserialization_helper( + serialized_obj, weak_ref, outer_id + ) + + +class SerializationContext: + """Initialize the serialization library. + + This defines a custom serializer for object refs and also tells ray to + serialize several exception classes that we define for error handling. + """ + + def __init__(self, worker): + self.worker = worker + self._thread_local = threading.local() + + def actor_handle_reducer(obj): + ray._private.worker.global_worker.check_connected() + serialized, actor_handle_id, weak_ref = obj._serialization_helper() + # Update ref counting for the actor handle + if not weak_ref: + self.add_contained_object_ref( + actor_handle_id, + # Right now, so many tests are failing when this is set. + # Allow it for now, but we should eventually disallow it here. + allow_out_of_band_serialization=True, + ) + return _actor_handle_deserializer, (serialized, weak_ref) + + self._register_cloudpickle_reducer(ray.actor.ActorHandle, actor_handle_reducer) + + def compiled_dag_ref_reducer(obj): + raise TypeError("Serialization of CompiledDAGRef is not supported.") + + self._register_cloudpickle_reducer(CompiledDAGRef, compiled_dag_ref_reducer) + + def object_ref_reducer(obj): + worker = ray._private.worker.global_worker + worker.check_connected() + self.add_contained_object_ref( + obj, + allow_out_of_band_serialization=( + ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION + ), + call_site=obj.call_site(), + ) + obj, owner_address, object_status = worker.core_worker.serialize_object_ref( + obj + ) + return _object_ref_deserializer, ( + obj.binary(), + obj.call_site(), + owner_address, + object_status, + ) + + self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer) + + def object_ref_generator_reducer(obj): + return DynamicObjectRefGenerator, (obj._refs,) + + self._register_cloudpickle_reducer( + DynamicObjectRefGenerator, object_ref_generator_reducer + ) + + serialization_addons.apply(self) + + def _register_cloudpickle_reducer(self, cls, reducer): + pickle.CloudPickler.dispatch[cls] = reducer + + def _unregister_cloudpickle_reducer(self, cls): + pickle.CloudPickler.dispatch.pop(cls, None) + + def _register_cloudpickle_serializer( + self, cls, custom_serializer, custom_deserializer + ): + def _CloudPicklerReducer(obj): + return custom_deserializer, (custom_serializer(obj),) + + # construct a reducer + pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer + + def is_in_band_serialization(self): + return getattr(self._thread_local, "in_band", False) + + def set_in_band_serialization(self): + self._thread_local.in_band = True + + def set_out_of_band_serialization(self): + self._thread_local.in_band = False + + def get_outer_object_ref(self): + stack = getattr(self._thread_local, "object_ref_stack", []) + return stack[-1] if stack else None + + def get_and_clear_contained_object_refs(self): + if not hasattr(self._thread_local, "object_refs"): + self._thread_local.object_refs = set() + return set() + + object_refs = self._thread_local.object_refs + self._thread_local.object_refs = set() + return object_refs + + def add_contained_object_ref( + self, + object_ref: "ray.ObjectRef", + *, + allow_out_of_band_serialization: bool, + call_site: Optional[str] = None, + ): + if self.is_in_band_serialization(): + # This object ref is being stored in an object. Add the ID to the + # list of IDs contained in the object so that we keep the inner + # object value alive as long as the outer object is in scope. + if not hasattr(self._thread_local, "object_refs"): + self._thread_local.object_refs = set() + self._thread_local.object_refs.add(object_ref) + else: + if not allow_out_of_band_serialization: + raise ray.exceptions.OufOfBandObjectRefSerializationException( + f"It is not allowed to serialize ray.ObjectRef {object_ref.hex()}. " + "If you want to allow serialization, " + "set `RAY_allow_out_of_band_object_ref_serialization=1.` " + "If you set the env var, the object is pinned forever in the " + "lifetime of the worker process and can cause Ray object leaks. " + "See the callsite and trace to find where the serialization " + "occurs.\nCallsite: " + f"{call_site or 'Disabled. Set RAY_record_ref_creation_sites=1'}" + ) + else: + # If this serialization is out-of-band (e.g., from a call to + # cloudpickle directly or captured in a remote function/actor), + # then pin the object for the lifetime of this worker by adding + # a local reference that won't ever be removed. + ray._private.worker.global_worker.core_worker.add_object_ref_reference( + object_ref + ) + + def _deserialize_pickle5_data(self, data): + try: + in_band, buffers = unpack_pickle5_buffers(data) + if len(buffers) > 0: + obj = pickle.loads(in_band, buffers=buffers) + else: + obj = pickle.loads(in_band) + # cloudpickle does not provide error types + except pickle.pickle.PicklingError: + raise DeserializationError() + return obj + + def _deserialize_msgpack_data(self, data, metadata_fields): + msgpack_data, pickle5_data = split_buffer(data) + + if metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_PYTHON: + python_objects = self._deserialize_pickle5_data(pickle5_data) + else: + python_objects = [] + + try: + + def _python_deserializer(index): + return python_objects[index] + + obj = MessagePackSerializer.loads(msgpack_data, _python_deserializer) + except Exception: + raise DeserializationError() + return obj + + def _deserialize_error_info(self, data, metadata_fields): + assert data + pb_bytes = self._deserialize_msgpack_data(data, metadata_fields) + assert pb_bytes + + ray_error_info = RayErrorInfo() + ray_error_info.ParseFromString(pb_bytes) + return ray_error_info + + def _deserialize_actor_died_error(self, data, metadata_fields): + if not data: + return ActorDiedError() + ray_error_info = self._deserialize_error_info(data, metadata_fields) + assert ray_error_info.HasField("actor_died_error") + if ray_error_info.actor_died_error.HasField("creation_task_failure_context"): + return RayError.from_ray_exception( + ray_error_info.actor_died_error.creation_task_failure_context + ) + else: + assert ray_error_info.actor_died_error.HasField("actor_died_error_context") + return ActorDiedError( + cause=ray_error_info.actor_died_error.actor_died_error_context + ) + + def _deserialize_object(self, data, metadata, object_ref): + if metadata: + metadata_fields = metadata.split(b",") + if metadata_fields[0] in [ + ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE, + ray_constants.OBJECT_METADATA_TYPE_PYTHON, + ]: + return self._deserialize_msgpack_data(data, metadata_fields) + # Check if the object should be returned as raw bytes. + if metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_RAW: + if data is None: + return b"" + return data.to_pybytes() + elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: + obj = self._deserialize_msgpack_data(data, metadata_fields) + # The last character is a 1 if weak_ref=True and 0 else. + serialized, weak_ref = obj[:-1], obj[-1:] == b"1" + return _actor_handle_deserializer(serialized, weak_ref) + # Otherwise, return an exception object based on + # the error type. + try: + error_type = int(metadata_fields[0]) + except Exception: + raise Exception( + f"Can't deserialize object: {object_ref}, " f"metadata: {metadata}" + ) + + # RayTaskError is serialized with pickle5 in the data field. + # TODO (kfstorm): exception serialization should be language + # independent. + if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"): + obj = self._deserialize_msgpack_data(data, metadata_fields) + return RayError.from_bytes(obj) + elif error_type == ErrorType.Value("WORKER_DIED"): + return WorkerCrashedError() + elif error_type == ErrorType.Value("ACTOR_DIED"): + return self._deserialize_actor_died_error(data, metadata_fields) + elif error_type == ErrorType.Value("LOCAL_RAYLET_DIED"): + return LocalRayletDiedError() + elif error_type == ErrorType.Value("TASK_CANCELLED"): + # Task cancellations are serialized in two ways, so check both + # deserialization paths. + # TODO(swang): We should only have one serialization path. + try: + # Deserialization from C++ (the CoreWorker task submitter). + # The error info will be stored as a RayErrorInfo. + error_message = "" + if data: + error_info = self._deserialize_error_info(data, metadata_fields) + error_message = error_info.error_message + return TaskCancelledError(error_message=error_message) + except google.protobuf.message.DecodeError: + # Deserialization from Python. The TaskCancelledError is + # serialized and returned directly. + obj = self._deserialize_msgpack_data(data, metadata_fields) + return RayError.from_bytes(obj) + elif error_type == ErrorType.Value("OBJECT_LOST"): + return ObjectLostError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OBJECT_FETCH_TIMED_OUT"): + return ObjectFetchTimedOutError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OUT_OF_DISK_ERROR"): + return OutOfDiskError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OUT_OF_MEMORY"): + error_info = self._deserialize_error_info(data, metadata_fields) + return OutOfMemoryError(error_info.error_message) + elif error_type == ErrorType.Value("NODE_DIED"): + error_info = self._deserialize_error_info(data, metadata_fields) + return NodeDiedError(error_info.error_message) + elif error_type == ErrorType.Value("OBJECT_DELETED"): + return ReferenceCountingAssertionError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OBJECT_FREED"): + return ObjectFreedError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OWNER_DIED"): + return OwnerDiedError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): + return ObjectReconstructionFailedError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value( + "OBJECT_UNRECONSTRUCTABLE_MAX_ATTEMPTS_EXCEEDED" + ): + return ObjectReconstructionFailedMaxAttemptsExceededError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value( + "OBJECT_UNRECONSTRUCTABLE_LINEAGE_EVICTED" + ): + return ObjectReconstructionFailedLineageEvictedError( + object_ref.hex(), object_ref.owner_address(), object_ref.call_site() + ) + elif error_type == ErrorType.Value("RUNTIME_ENV_SETUP_FAILED"): + error_info = self._deserialize_error_info(data, metadata_fields) + # TODO(sang): Assert instead once actor also reports error messages. + error_msg = "" + if error_info.HasField("runtime_env_setup_failed_error"): + error_msg = error_info.runtime_env_setup_failed_error.error_message + return RuntimeEnvSetupError(error_message=error_msg) + elif error_type == ErrorType.Value("TASK_PLACEMENT_GROUP_REMOVED"): + return TaskPlacementGroupRemoved() + elif error_type == ErrorType.Value("ACTOR_PLACEMENT_GROUP_REMOVED"): + return ActorPlacementGroupRemoved() + elif error_type == ErrorType.Value("TASK_UNSCHEDULABLE_ERROR"): + error_info = self._deserialize_error_info(data, metadata_fields) + return TaskUnschedulableError(error_info.error_message) + elif error_type == ErrorType.Value("ACTOR_UNSCHEDULABLE_ERROR"): + error_info = self._deserialize_error_info(data, metadata_fields) + return ActorUnschedulableError(error_info.error_message) + elif error_type == ErrorType.Value("END_OF_STREAMING_GENERATOR"): + return ObjectRefStreamEndOfStreamError() + elif error_type == ErrorType.Value("ACTOR_UNAVAILABLE"): + error_info = self._deserialize_error_info(data, metadata_fields) + if error_info.HasField("actor_unavailable_error"): + actor_id = error_info.actor_unavailable_error.actor_id + else: + actor_id = None + return ActorUnavailableError(error_info.error_message, actor_id) + else: + return RaySystemError("Unrecognized error type " + str(error_type)) + elif data: + raise ValueError("non-null object should always have metadata") + else: + # Object isn't available in plasma. This should never be returned + # to the user. We should only reach this line if this object was + # deserialized as part of a list, and another object in the list + # throws an exception. + return PlasmaObjectNotAvailable + + def deserialize_objects(self, data_metadata_pairs, object_refs): + assert len(data_metadata_pairs) == len(object_refs) + # initialize the thread-local field + if not hasattr(self._thread_local, "object_ref_stack"): + self._thread_local.object_ref_stack = [] + results = [] + for object_ref, (data, metadata) in zip(object_refs, data_metadata_pairs): + try: + # Push the object ref to the stack, so the object under + # the object ref knows where it comes from. + self._thread_local.object_ref_stack.append(object_ref) + obj = self._deserialize_object(data, metadata, object_ref) + except Exception as e: + logger.exception(e) + obj = RaySystemError(e, traceback.format_exc()) + finally: + # Must clear ObjectRef to not hold a reference. + if self._thread_local.object_ref_stack: + self._thread_local.object_ref_stack.pop() + results.append(obj) + return results + + def _serialize_to_pickle5(self, metadata, value): + writer = Pickle5Writer() + # TODO(swang): Check that contained_object_refs is empty. + try: + self.set_in_band_serialization() + inband = pickle.dumps( + value, protocol=5, buffer_callback=writer.buffer_callback + ) + except Exception as e: + self.get_and_clear_contained_object_refs() + raise e + finally: + self.set_out_of_band_serialization() + + return Pickle5SerializedObject( + metadata, inband, writer, self.get_and_clear_contained_object_refs() + ) + + def _serialize_to_msgpack(self, value): + # Only RayTaskError is possible to be serialized here. We don't + # need to deal with other exception types here. + contained_object_refs = [] + + if isinstance(value, RayTaskError): + if issubclass(value.cause.__class__, TaskCancelledError): + # Handle task cancellation errors separately because we never + # want to warn about tasks that were intentionally cancelled by + # the user. + metadata = str(ErrorType.Value("TASK_CANCELLED")).encode("ascii") + value = value.to_bytes() + else: + metadata = str(ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode( + "ascii" + ) + value = value.to_bytes() + elif isinstance(value, ray.actor.ActorHandle): + # TODO(fyresone): ActorHandle should be serialized via the + # custom type feature of cross-language. + serialized, actor_handle_id, weak_ref = value._serialization_helper() + if not weak_ref: + contained_object_refs.append(actor_handle_id) + # Update ref counting for the actor handle + metadata = ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE + # Append a 1 to mean weak ref or 0 for strong ref. + # We do this here instead of in the main serialization helper + # because msgpack expects a bytes object. We cannot serialize + # `weak_ref` in the C++ code because the weak_ref property is only + # available in the Python ActorHandle instance. + value = serialized + (b"1" if weak_ref else b"0") + else: + metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE + + python_objects = [] + + def _python_serializer(o): + index = len(python_objects) + python_objects.append(o) + return index + + msgpack_data = MessagePackSerializer.dumps(value, _python_serializer) + + if python_objects: + metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON + pickle5_serialized_object = self._serialize_to_pickle5( + metadata, python_objects + ) + else: + pickle5_serialized_object = None + + return MessagePackSerializedObject( + metadata, msgpack_data, contained_object_refs, pickle5_serialized_object + ) + + def serialize(self, value): + """Serialize an object. + + Args: + value: The value to serialize. + """ + if isinstance(value, bytes): + # If the object is a byte array, skip serializing it and + # use a special metadata to indicate it's raw binary. So + # that this object can also be read by Java. + return RawSerializedObject(value) + else: + return self._serialize_to_msgpack(value) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/services.py b/.venv/lib/python3.11/site-packages/ray/_private/services.py new file mode 100644 index 0000000000000000000000000000000000000000..ba83e47bd44042212b856c5993bf002126040956 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/services.py @@ -0,0 +1,2282 @@ +import base64 +import collections +import errno +import io +import json +import logging +import mmap +import multiprocessing +import os +import random +import shutil +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import List, Optional, IO, AnyStr + +# Import psutil after ray so the packaged version is used. +import psutil +from filelock import FileLock + +# Ray modules +import ray +import ray._private.ray_constants as ray_constants +from ray._raylet import GcsClient, GcsClientOptions +from ray.core.generated.common_pb2 import Language +from ray._private.ray_constants import RAY_NODE_IP_FILENAME + +resource = None +if sys.platform != "win32": + _timeout = 30 +else: + _timeout = 60 + +EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" + +# True if processes are run in the valgrind profiler. +RUN_RAYLET_PROFILER = False + +# Location of the redis server. +RAY_HOME = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "..") +RAY_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +RAY_PRIVATE_DIR = "_private" +AUTOSCALER_PRIVATE_DIR = os.path.join("autoscaler", "_private") +AUTOSCALER_V2_DIR = os.path.join("autoscaler", "v2") + +# Location of the raylet executables. +RAYLET_EXECUTABLE = os.path.join( + RAY_PATH, "core", "src", "ray", "raylet", "raylet" + EXE_SUFFIX +) +GCS_SERVER_EXECUTABLE = os.path.join( + RAY_PATH, "core", "src", "ray", "gcs", "gcs_server" + EXE_SUFFIX +) + +JEMALLOC_SO = os.path.join(RAY_PATH, "core", "libjemalloc.so") + +JEMALLOC_SO = JEMALLOC_SO if os.path.exists(JEMALLOC_SO) else None + +# Location of the cpp default worker executables. +DEFAULT_WORKER_EXECUTABLE = os.path.join(RAY_PATH, "cpp", "default_worker" + EXE_SUFFIX) + +# Location of the native libraries. +DEFAULT_NATIVE_LIBRARY_PATH = os.path.join(RAY_PATH, "cpp", "lib") + +DASHBOARD_DEPENDENCY_ERROR_MESSAGE = ( + "Not all Ray Dashboard dependencies were " + "found. To use the dashboard please " + "install Ray using `pip install " + "ray[default]`." +) + +RAY_JEMALLOC_LIB_PATH = "RAY_JEMALLOC_LIB_PATH" +RAY_JEMALLOC_CONF = "RAY_JEMALLOC_CONF" +RAY_JEMALLOC_PROFILE = "RAY_JEMALLOC_PROFILE" + +# Comma separated name of components that will run memory profiler. +# Ray uses `memray` to memory profile internal components. +# The name of the component must be one of ray_constants.PROCESS_TYPE*. +RAY_MEMRAY_PROFILE_COMPONENT_ENV = "RAY_INTERNAL_MEM_PROFILE_COMPONENTS" +# Options to specify for `memray run` command. See +# `memray run --help` for more details. +# Example: +# RAY_INTERNAL_MEM_PROFILE_OPTIONS="--live,--live-port,3456,-q," +# -> `memray run --live --live-port 3456 -q` +RAY_MEMRAY_PROFILE_OPTIONS_ENV = "RAY_INTERNAL_MEM_PROFILE_OPTIONS" + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + +ProcessInfo = collections.namedtuple( + "ProcessInfo", + [ + "process", + "stdout_file", + "stderr_file", + "use_valgrind", + "use_gdb", + "use_valgrind_profiler", + "use_perftools_profiler", + "use_tmux", + ], +) + + +def _site_flags() -> List[str]: + """Detect whether flags related to site packages are enabled for the current + interpreter. To run Ray in hermetic build environments, it helps to pass these flags + down to Python workers. + """ + flags = [] + # sys.flags hidden behind helper methods for unit testing. + if _no_site(): + flags.append("-S") + if _no_user_site(): + flags.append("-s") + return flags + + +# sys.flags hidden behind helper methods for unit testing. +def _no_site(): + return sys.flags.no_site + + +# sys.flags hidden behind helper methods for unit testing. +def _no_user_site(): + return sys.flags.no_user_site + + +def _build_python_executable_command_memory_profileable( + component: str, session_dir: str, unbuffered: bool = True +): + """Build the Python executable command. + + It runs a memory profiler if env var is configured. + + Args: + component: Name of the component. It must be one of + ray_constants.PROCESS_TYPE*. + session_dir: The directory name of the Ray session. + unbuffered: If true, Python executable is started with unbuffered option. + e.g., `-u`. + It means the logs are flushed immediately (good when there's a failure), + but writing to a log file can be slower. + """ + command = [ + sys.executable, + ] + if unbuffered: + command.append("-u") + components_to_memory_profile = os.getenv(RAY_MEMRAY_PROFILE_COMPONENT_ENV, "") + if not components_to_memory_profile: + return command + + components_to_memory_profile = set(components_to_memory_profile.split(",")) + try: + import memray # noqa: F401 + except ImportError: + raise ImportError( + "Memray is required to memory profiler on components " + f"{components_to_memory_profile}. Run `pip install memray`." + ) + if component in components_to_memory_profile: + session_dir = Path(session_dir) + session_name = session_dir.name + profile_dir = session_dir / "profile" + profile_dir.mkdir(exist_ok=True) + output_file_path = profile_dir / f"{session_name}_memory_{component}.bin" + options = os.getenv(RAY_MEMRAY_PROFILE_OPTIONS_ENV, None) + options = options.split(",") if options else [] + command.extend(["-m", "memray", "run", "-o", str(output_file_path), *options]) + + return command + + +def _get_gcs_client_options(gcs_server_address): + return GcsClientOptions.create( + gcs_server_address, + None, + allow_cluster_id_nil=True, + fetch_cluster_id_if_nil=False, + ) + + +def serialize_config(config): + return base64.b64encode(json.dumps(config).encode("utf-8")).decode("utf-8") + + +def propagate_jemalloc_env_var( + *, + jemalloc_path: str, + jemalloc_conf: str, + jemalloc_comps: List[str], + process_type: str, +): + """Read the jemalloc memory profiling related + env var and return the dictionary that translates + them to proper jemalloc related env vars. + + For example, if users specify `RAY_JEMALLOC_LIB_PATH`, + it is translated into `LD_PRELOAD` which is needed to + run Jemalloc as a shared library. + + Params: + jemalloc_path: The path to the jemalloc shared library. + jemalloc_conf: `,` separated string of jemalloc config. + jemalloc_comps: The list of Ray components + that we will profile. + process_type: The process type that needs jemalloc + env var for memory profiling. If it doesn't match one of + jemalloc_comps, the function will return an empty dict. + + Returns: + dictionary of {env_var: value} + that are needed to jemalloc profiling. The caller can + call `dict.update(return_value_of_this_func)` to + update the dict of env vars. If the process_type doesn't + match jemalloc_comps, it will return an empty dict. + """ + assert isinstance(jemalloc_comps, list) + assert process_type is not None + process_type = process_type.lower() + if not jemalloc_path: + return {} + + env_vars = {"LD_PRELOAD": jemalloc_path, "RAY_LD_PRELOAD": "1"} + if process_type in jemalloc_comps and jemalloc_conf: + env_vars.update({"MALLOC_CONF": jemalloc_conf}) + return env_vars + + +class ConsolePopen(subprocess.Popen): + if sys.platform == "win32": + + def terminate(self): + if isinstance(self.stdin, io.IOBase): + self.stdin.close() + if self._use_signals: + self.send_signal(signal.CTRL_BREAK_EVENT) + else: + super(ConsolePopen, self).terminate() + + def __init__(self, *args, **kwargs): + # CREATE_NEW_PROCESS_GROUP is used to send Ctrl+C on Windows: + # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.send_signal + new_pgroup = subprocess.CREATE_NEW_PROCESS_GROUP + flags_to_add = 0 + if ray._private.utils.detect_fate_sharing_support(): + # If we don't have kernel-mode fate-sharing, then don't do this + # because our children need to be in out process group for + # the process reaper to properly terminate them. + flags_to_add = new_pgroup + flags_key = "creationflags" + if flags_to_add: + kwargs[flags_key] = (kwargs.get(flags_key) or 0) | flags_to_add + self._use_signals = kwargs[flags_key] & new_pgroup + super(ConsolePopen, self).__init__(*args, **kwargs) + + +def address(ip_address, port): + return ip_address + ":" + str(port) + + +def new_port(lower_bound=10000, upper_bound=65535, denylist=None): + if not denylist: + denylist = set() + port = random.randint(lower_bound, upper_bound) + retry = 0 + while port in denylist: + if retry > 100: + break + port = random.randint(lower_bound, upper_bound) + retry += 1 + if retry > 100: + raise ValueError( + "Failed to find a new port from the range " + f"{lower_bound}-{upper_bound}. Denylist: {denylist}" + ) + return port + + +def _find_address_from_flag(flag: str): + """ + Attempts to find all valid Ray addresses on this node, specified by the + flag. + + Params: + flag: `--redis-address` or `--gcs-address` + Returns: + Set of detected addresses. + """ + # Using Redis address `--redis-address` as an example: + # Currently, this extracts the deprecated --redis-address from the command + # that launched the raylet running on this node, if any. Anyone looking to + # edit this function should be warned that these commands look like, for + # example: + # /usr/local/lib/python3.8/dist-packages/ray/core/src/ray/raylet/raylet + # --redis_address=123.456.78.910 --node_ip_address=123.456.78.910 + # --raylet_socket_name=... --store_socket_name=... --object_manager_port=0 + # --min_worker_port=10000 --max_worker_port=19999 + # --node_manager_port=58578 --redis_port=6379 + # --maximum_startup_concurrency=8 + # --static_resource_list=node:123.456.78.910,1.0,object_store_memory,66 + # --config_list=plasma_store_as_thread,True + # --python_worker_command=/usr/bin/python + # /usr/local/lib/python3.8/dist-packages/ray/workers/default_worker.py + # --redis-address=123.456.78.910:6379 + # --node-ip-address=123.456.78.910 --node-manager-port=58578 + # --object-store-name=... --raylet-name=... + # --temp-dir=/tmp/ray + # --metrics-agent-port=41856 --redis-password=[MASKED] + # --java_worker_command= --cpp_worker_command= + # --redis_password=[MASKED] --temp_dir=/tmp/ray --session_dir=... + # --metrics-agent-port=41856 --metrics_export_port=64229 + # --dashboard_agent_command=/usr/bin/python + # -u /usr/local/lib/python3.8/dist-packages/ray/dashboard/agent.py + # --redis-address=123.456.78.910:6379 --metrics-export-port=64229 + # --dashboard-agent-port=41856 --node-manager-port=58578 + # --object-store-name=... --raylet-name=... --temp-dir=/tmp/ray + # --log-dir=/tmp/ray/session_2020-11-08_14-29-07_199128_278000/logs + # --redis-password=[MASKED] --object_store_memory=5037192806 + # --plasma_directory=/tmp + # Longer arguments are elided with ... but all arguments from this instance + # are included, to provide a sense of what is in these. + # Indeed, we had to pull --redis-address to the front of each call to make + # this readable. + # As you can see, this is very long and complex, which is why we can't + # simply extract all the the arguments using regular expressions and + # present a dict as if we never lost track of these arguments, for + # example. Picking out --redis-address below looks like it might grab the + # wrong thing, but double-checking that we're finding the correct process + # by checking that the contents look like we expect would probably be prone + # to choking in unexpected ways. + # Notice that --redis-address appears twice. This is not a copy-paste + # error; this is the reason why the for loop below attempts to pick out + # every appearance of --redis-address. + + # The --redis-address here is what is now called the --address, but it + # appears in the default_worker.py and agent.py calls as --redis-address. + pids = psutil.pids() + addresses = set() + for pid in pids: + try: + proc = psutil.Process(pid) + # HACK: Workaround for UNIX idiosyncrasy + # Normally, cmdline() is supposed to return the argument list. + # But it in some cases (such as when setproctitle is called), + # an arbitrary string resembling a command-line is stored in + # the first argument. + # Explanation: https://unix.stackexchange.com/a/432681 + # More info: https://github.com/giampaolo/psutil/issues/1179 + cmdline = proc.cmdline() + # NOTE(kfstorm): To support Windows, we can't use + # `os.path.basename(cmdline[0]) == "raylet"` here. + + if len(cmdline) > 0 and "raylet" in os.path.basename(cmdline[0]): + for arglist in cmdline: + # Given we're merely seeking --redis-address, we just split + # every argument on spaces for now. + for arg in arglist.split(" "): + # TODO(ekl): Find a robust solution for locating Redis. + if arg.startswith(flag): + proc_addr = arg.split("=")[1] + # TODO(mwtian): remove this workaround after Ray + # no longer sets --redis-address to None. + if proc_addr != "" and proc_addr != "None": + addresses.add(proc_addr) + except psutil.AccessDenied: + pass + except psutil.NoSuchProcess: + pass + return addresses + + +def find_gcs_addresses(): + """Finds any local GCS processes based on grepping ps.""" + return _find_address_from_flag("--gcs-address") + + +def find_bootstrap_address(temp_dir: Optional[str]): + """Finds the latest Ray cluster address to connect to, if any. This is the + GCS address connected to by the last successful `ray start`.""" + return ray._private.utils.read_ray_address(temp_dir) + + +def get_ray_address_from_environment(addr: str, temp_dir: Optional[str]): + """Attempts to find the address of Ray cluster to use, in this order: + + 1. Use RAY_ADDRESS if defined and nonempty. + 2. If no address is provided or the provided address is "auto", use the + address in /tmp/ray/ray_current_cluster if available. This will error if + the specified address is None and there is no address found. For "auto", + we will fallback to connecting to any detected Ray cluster (legacy). + 3. Otherwise, use the provided address. + + Returns: + A string to pass into `ray.init(address=...)`, e.g. ip:port, `auto`. + """ + env_addr = os.environ.get(ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE) + if env_addr is not None and env_addr != "": + addr = env_addr + + if addr is not None and addr != "auto": + return addr + # We should try to automatically find an active local instance. + gcs_addrs = find_gcs_addresses() + bootstrap_addr = find_bootstrap_address(temp_dir) + + if len(gcs_addrs) > 1 and bootstrap_addr is not None: + logger.warning( + f"Found multiple active Ray instances: {gcs_addrs}. " + f"Connecting to latest cluster at {bootstrap_addr}. " + "You can override this by setting the `--address` flag " + "or `RAY_ADDRESS` environment variable." + ) + elif len(gcs_addrs) > 0 and addr == "auto": + # Preserve legacy "auto" behavior of connecting to any cluster, even if not + # started with ray start. However if addr is None, we will raise an error. + bootstrap_addr = list(gcs_addrs).pop() + + if bootstrap_addr is None: + if addr is None: + # Caller should start a new instance. + return None + else: + raise ConnectionError( + "Could not find any running Ray instance. " + "Please specify the one to connect to by setting `--address` flag " + "or `RAY_ADDRESS` environment variable." + ) + + return bootstrap_addr + + +def wait_for_node( + gcs_address: str, + node_plasma_store_socket_name: str, + timeout: int = _timeout, +): + """Wait until this node has appeared in the client table. + + Args: + gcs_address: The gcs address + node_plasma_store_socket_name: The + plasma_store_socket_name for the given node which we wait for. + timeout: The amount of time in seconds to wait before raising an + exception. + + Raises: + TimeoutError: An exception is raised if the timeout expires before + the node appears in the client table. + """ + gcs_options = GcsClientOptions.create( + gcs_address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False + ) + global_state = ray._private.state.GlobalState() + global_state._initialize_global_state(gcs_options) + start_time = time.time() + while time.time() - start_time < timeout: + clients = global_state.node_table() + object_store_socket_names = [ + client["ObjectStoreSocketName"] for client in clients + ] + if node_plasma_store_socket_name in object_store_socket_names: + return + else: + time.sleep(0.1) + raise TimeoutError( + f"Timed out after {timeout} seconds while waiting for node to startup. " + f"Did not find socket name {node_plasma_store_socket_name} in the list " + "of object store socket names." + ) + + +def get_node_to_connect_for_driver(gcs_address, node_ip_address): + # Get node table from global state accessor. + global_state = ray._private.state.GlobalState() + gcs_options = _get_gcs_client_options(gcs_address) + global_state._initialize_global_state(gcs_options) + return global_state.get_node_to_connect_for_driver(node_ip_address) + + +def get_node(gcs_address, node_id): + """ + Get the node information from the global state accessor. + """ + global_state = ray._private.state.GlobalState() + gcs_options = _get_gcs_client_options(gcs_address) + global_state._initialize_global_state(gcs_options) + return global_state.get_node(node_id) + + +def get_webui_url_from_internal_kv(): + assert ray.experimental.internal_kv._internal_kv_initialized() + webui_url = ray.experimental.internal_kv._internal_kv_get( + "webui:url", namespace=ray_constants.KV_NAMESPACE_DASHBOARD + ) + return ray._private.utils.decode(webui_url) if webui_url is not None else None + + +def get_storage_uri_from_internal_kv(): + assert ray.experimental.internal_kv._internal_kv_initialized() + storage_uri = ray.experimental.internal_kv._internal_kv_get( + "storage", namespace=ray_constants.KV_NAMESPACE_SESSION + ) + return ray._private.utils.decode(storage_uri) if storage_uri is not None else None + + +def remaining_processes_alive(): + """See if the remaining processes are alive or not. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if the remaining processes started by ray.init() are alive and + False otherwise. + + Raises: + Exception: An exception is raised if the processes were not started by + ray.init(). + """ + if ray._private.worker._global_node is None: + raise RuntimeError( + "This process is not in a position to determine " + "whether all processes are alive or not." + ) + return ray._private.worker._global_node.remaining_processes_alive() + + +def canonicalize_bootstrap_address( + addr: str, temp_dir: Optional[str] = None +) -> Optional[str]: + """Canonicalizes Ray cluster bootstrap address to host:port. + Reads address from the environment if needed. + + This function should be used to process user supplied Ray cluster address, + via ray.init() or `--address` flags, before using the address to connect. + + Returns: + Ray cluster address string in format or None if the caller + should start a local Ray instance. + """ + if addr is None or addr == "auto": + addr = get_ray_address_from_environment(addr, temp_dir) + if addr is None or addr == "local": + return None + try: + bootstrap_address = resolve_ip_for_localhost(addr) + except Exception: + logger.exception(f"Failed to convert {addr} to host:port") + raise + return bootstrap_address + + +def canonicalize_bootstrap_address_or_die( + addr: str, temp_dir: Optional[str] = None +) -> str: + """Canonicalizes Ray cluster bootstrap address to host:port. + + This function should be used when the caller expects there to be an active + and local Ray instance. If no address is provided or address="auto", this + will autodetect the latest Ray instance created with `ray start`. + + For convenience, if no address can be autodetected, this function will also + look for any running local GCS processes, based on pgrep output. This is to + allow easier use of Ray CLIs when debugging a local Ray instance (whose GCS + addresses are not recorded). + + Returns: + Ray cluster address string in format. Throws a + ConnectionError if zero or multiple active Ray instances are + autodetected. + """ + bootstrap_addr = canonicalize_bootstrap_address(addr, temp_dir=temp_dir) + if bootstrap_addr is not None: + return bootstrap_addr + + running_gcs_addresses = find_gcs_addresses() + if len(running_gcs_addresses) == 0: + raise ConnectionError( + "Could not find any running Ray instance. " + "Please specify the one to connect to by setting the `--address` " + "flag or `RAY_ADDRESS` environment variable." + ) + if len(running_gcs_addresses) > 1: + raise ConnectionError( + f"Found multiple active Ray instances: {running_gcs_addresses}. " + "Please specify the one to connect to by setting the `--address` " + "flag or `RAY_ADDRESS` environment variable." + ) + return running_gcs_addresses.pop() + + +def extract_ip_port(bootstrap_address: str): + if ":" not in bootstrap_address: + raise ValueError( + f"Malformed address {bootstrap_address}. " f"Expected ':'." + ) + ip, _, port = bootstrap_address.rpartition(":") + try: + port = int(port) + except ValueError: + raise ValueError(f"Malformed address port {port}. Must be an integer.") + if port < 1024 or port > 65535: + raise ValueError( + f"Invalid address port {port}. Must be between 1024 " + "and 65535 (inclusive)." + ) + return ip, port + + +def resolve_ip_for_localhost(address: str): + """Convert to a remotely reachable IP if the address is "localhost" + or "127.0.0.1". Otherwise do nothing. + + Args: + address: This can be either a string containing a hostname (or an IP + address) and a port or it can be just an IP address. + + Returns: + The same address but with the local host replaced by remotely + reachable IP. + """ + if not address: + raise ValueError(f"Malformed address: {address}") + address_parts = address.split(":") + if address_parts[0] == "127.0.0.1" or address_parts[0] == "localhost": + # Make sure localhost isn't resolved to the loopback ip + ip_address = get_node_ip_address() + return ":".join([ip_address] + address_parts[1:]) + else: + return address + + +def node_ip_address_from_perspective(address: str): + """IP address by which the local node can be reached *from* the `address`. + + Args: + address: The IP address and port of any known live service on the + network you care about. + + Returns: + The IP address by which the local node can be reached from the address. + """ + ip_address, port = address.split(":") + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # This command will raise an exception if there is no internet + # connection. + s.connect((ip_address, int(port))) + node_ip_address = s.getsockname()[0] + except OSError as e: + node_ip_address = "127.0.0.1" + # [Errno 101] Network is unreachable + if e.errno == errno.ENETUNREACH: + try: + # try get node ip address from host name + host_name = socket.getfqdn(socket.gethostname()) + node_ip_address = socket.gethostbyname(host_name) + except Exception: + pass + finally: + s.close() + + return node_ip_address + + +# NOTE: This API should not be used when you obtain the +# IP address when ray.init is not called because +# it cannot find the IP address if it is specified by +# ray start --node-ip-address. You should instead use +# get_cached_node_ip_address. +def get_node_ip_address(address="8.8.8.8:53"): + if ray._private.worker._global_node is not None: + return ray._private.worker._global_node.node_ip_address + if not ray_constants.ENABLE_RAY_CLUSTER: + # Use loopback IP as the local IP address to prevent bothersome + # firewall popups on OSX and Windows. + # https://github.com/ray-project/ray/issues/18730. + return "127.0.0.1" + return node_ip_address_from_perspective(address) + + +def get_cached_node_ip_address(session_dir: str) -> str: + """Get a node address cached on this session. + + If a ray instance is started by `ray start --node-ip-address`, + the node ip address is cached to a file RAY_NODE_IP_FILENAME. + Otherwise, the file exists, but it is emptyl. + + This API is process-safe, meaning the file access is protected by + a file lock. + + Args: + session_dir: Path to the Ray session directory. + + Returns: + node_ip_address cached on the current node. None if the node + the file doesn't exist, meaning ray instance hasn't been + started on a current node. If node_ip_address is not written + to a file, it means --node-ip-address is not given, and in this + case, we find the IP address ourselves. + """ + file_path = Path(os.path.join(session_dir, RAY_NODE_IP_FILENAME)) + cached_node_ip_address = {} + + with FileLock(str(file_path.absolute()) + ".lock"): + if not file_path.exists(): + return None + + with file_path.open() as f: + cached_node_ip_address.update(json.load(f)) + + if "node_ip_address" in cached_node_ip_address: + return cached_node_ip_address["node_ip_address"] + else: + return ray.util.get_node_ip_address() + + +def write_node_ip_address(session_dir: str, node_ip_address: Optional[str]) -> None: + """Write a node ip address of the current session to + RAY_NODE_IP_FILENAME. + + If a ray instance is started by `ray start --node-ip-address`, + the node ip address is cached to a file RAY_NODE_IP_FILENAME. + + This API is process-safe, meaning the file access is protected by + a file lock. + + The file contains a single string node_ip_address. If nothing + is written, it means --node-ip-address was not given, and Ray + resolves the IP address on its own. It assumes in a single node, + you can have only 1 IP address (which is the assumption ray + has in general). + + node_ip_address is the ip address of the current node. + + Args: + session_dir: The path to Ray session directory. + node_ip_address: The node IP address of the current node. + If None, it means the node ip address is not given + by --node-ip-address. In this case, we don't write + anything to a file. + """ + file_path = Path(os.path.join(session_dir, RAY_NODE_IP_FILENAME)) + cached_node_ip_address = {} + + with FileLock(str(file_path.absolute()) + ".lock"): + if not file_path.exists(): + with file_path.open(mode="w") as f: + json.dump({}, f) + + with file_path.open() as f: + cached_node_ip_address.update(json.load(f)) + + cached_node_ip = cached_node_ip_address.get("node_ip_address") + + if node_ip_address is not None: + if cached_node_ip: + if cached_node_ip == node_ip_address: + # Nothing to do. + return + else: + logger.warning( + "The node IP address of the current host recorded " + f"in {RAY_NODE_IP_FILENAME} ({cached_node_ip}) " + "is different from the current IP address: " + f"{node_ip_address}. Ray will use {node_ip_address} " + "as the current node's IP address. " + "Creating 2 instances in the same host with different " + "IP address is not supported. " + "Please create an enhnacement request to" + "https://github.com/ray-project/ray/issues." + ) + + cached_node_ip_address["node_ip_address"] = node_ip_address + with file_path.open(mode="w") as f: + json.dump(cached_node_ip_address, f) + + +def create_redis_client(redis_address, password=None, username=None): + """Create a Redis client. + + Args: + The IP address, port, username, and password of the Redis server. + + Returns: + A Redis client. + """ + import redis + + if not hasattr(create_redis_client, "instances"): + create_redis_client.instances = {} + + num_retries = ray_constants.START_REDIS_WAIT_RETRIES + delay = 0.001 + for i in range(num_retries): + cli = create_redis_client.instances.get(redis_address) + if cli is None: + redis_ip_address, redis_port = extract_ip_port( + canonicalize_bootstrap_address_or_die(redis_address) + ) + cli = redis.StrictRedis( + host=redis_ip_address, + port=int(redis_port), + username=username, + password=password, + ) + create_redis_client.instances[redis_address] = cli + try: + cli.ping() + return cli + except Exception as e: + create_redis_client.instances.pop(redis_address) + if i >= num_retries - 1: + raise RuntimeError( + f"Unable to connect to Redis at {redis_address}: {e}" + ) + # Wait a little bit. + time.sleep(delay) + # Make sure the retry interval doesn't increase too large. + delay = min(1, delay * 2) + + +def start_ray_process( + command: List[str], + process_type: str, + fate_share: bool, + env_updates: Optional[dict] = None, + cwd: Optional[str] = None, + use_valgrind: bool = False, + use_gdb: bool = False, + use_valgrind_profiler: bool = False, + use_perftools_profiler: bool = False, + use_tmux: bool = False, + stdout_file: Optional[IO[AnyStr]] = None, + stderr_file: Optional[IO[AnyStr]] = None, + pipe_stdin: bool = False, +): + """Start one of the Ray processes. + + TODO(rkn): We need to figure out how these commands interact. For example, + it may only make sense to start a process in gdb if we also start it in + tmux. Similarly, certain combinations probably don't make sense, like + simultaneously running the process in valgrind and the profiler. + + Args: + command: The command to use to start the Ray process. + process_type: The type of the process that is being started + (e.g., "raylet"). + fate_share: If true, the child will be killed if its parent (us) dies. + True must only be passed after detection of this functionality. + env_updates: A dictionary of additional environment variables to + run the command with (in addition to the caller's environment + variables). + cwd: The directory to run the process in. + use_valgrind: True if we should start the process in valgrind. + use_gdb: True if we should start the process in gdb. + use_valgrind_profiler: True if we should start the process in + the valgrind profiler. + use_perftools_profiler: True if we should profile the process + using perftools. + use_tmux: True if we should start the process in tmux. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + pipe_stdin: If true, subprocess.PIPE will be passed to the process as + stdin. + + Returns: + Information about the process that was started including a handle to + the process that was started. + """ + # Detect which flags are set through environment variables. + valgrind_env_var = f"RAY_{process_type.upper()}_VALGRIND" + if os.environ.get(valgrind_env_var) == "1": + logger.info("Detected environment variable '%s'.", valgrind_env_var) + use_valgrind = True + valgrind_profiler_env_var = f"RAY_{process_type.upper()}_VALGRIND_PROFILER" + if os.environ.get(valgrind_profiler_env_var) == "1": + logger.info("Detected environment variable '%s'.", valgrind_profiler_env_var) + use_valgrind_profiler = True + perftools_profiler_env_var = f"RAY_{process_type.upper()}_PERFTOOLS_PROFILER" + if os.environ.get(perftools_profiler_env_var) == "1": + logger.info("Detected environment variable '%s'.", perftools_profiler_env_var) + use_perftools_profiler = True + tmux_env_var = f"RAY_{process_type.upper()}_TMUX" + if os.environ.get(tmux_env_var) == "1": + logger.info("Detected environment variable '%s'.", tmux_env_var) + use_tmux = True + gdb_env_var = f"RAY_{process_type.upper()}_GDB" + if os.environ.get(gdb_env_var) == "1": + logger.info("Detected environment variable '%s'.", gdb_env_var) + use_gdb = True + # Jemalloc memory profiling. + if os.environ.get("LD_PRELOAD") is None: + jemalloc_lib_path = os.environ.get(RAY_JEMALLOC_LIB_PATH, JEMALLOC_SO) + jemalloc_conf = os.environ.get(RAY_JEMALLOC_CONF) + jemalloc_comps = os.environ.get(RAY_JEMALLOC_PROFILE) + jemalloc_comps = [] if not jemalloc_comps else jemalloc_comps.split(",") + jemalloc_env_vars = propagate_jemalloc_env_var( + jemalloc_path=jemalloc_lib_path, + jemalloc_conf=jemalloc_conf, + jemalloc_comps=jemalloc_comps, + process_type=process_type, + ) + else: + jemalloc_env_vars = {} + + use_jemalloc_mem_profiler = "MALLOC_CONF" in jemalloc_env_vars + + if ( + sum( + [ + use_gdb, + use_valgrind, + use_valgrind_profiler, + use_perftools_profiler, + use_jemalloc_mem_profiler, + ] + ) + > 1 + ): + raise ValueError( + "At most one of the 'use_gdb', 'use_valgrind', " + "'use_valgrind_profiler', 'use_perftools_profiler', " + "and 'use_jemalloc_mem_profiler' flags can " + "be used at a time." + ) + if env_updates is None: + env_updates = {} + if not isinstance(env_updates, dict): + raise ValueError("The 'env_updates' argument must be a dictionary.") + + modified_env = os.environ.copy() + modified_env.update(env_updates) + + if use_gdb: + if not use_tmux: + raise ValueError( + "If 'use_gdb' is true, then 'use_tmux' must be true as well." + ) + + # TODO(suquark): Any better temp file creation here? + gdb_init_path = os.path.join( + ray._private.utils.get_ray_temp_dir(), + f"gdb_init_{process_type}_{time.time()}", + ) + ray_process_path = command[0] + ray_process_args = command[1:] + run_args = " ".join(["'{}'".format(arg) for arg in ray_process_args]) + with open(gdb_init_path, "w") as gdb_init_file: + gdb_init_file.write(f"run {run_args}") + command = ["gdb", ray_process_path, "-x", gdb_init_path] + + if use_valgrind: + command = [ + "valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--leak-check-heuristics=stdstring", + "--error-exitcode=1", + ] + command + + if use_valgrind_profiler: + command = ["valgrind", "--tool=callgrind"] + command + + if use_perftools_profiler: + modified_env["LD_PRELOAD"] = os.environ["PERFTOOLS_PATH"] + modified_env["CPUPROFILE"] = os.environ["PERFTOOLS_LOGFILE"] + + modified_env.update(jemalloc_env_vars) + + if use_tmux: + # The command has to be created exactly as below to ensure that it + # works on all versions of tmux. (Tested with tmux 1.8-5, travis' + # version, and tmux 2.1) + command = ["tmux", "new-session", "-d", f"{' '.join(command)}"] + + if fate_share: + assert ray._private.utils.detect_fate_sharing_support(), ( + "kernel-level fate-sharing must only be specified if " + "detect_fate_sharing_support() has returned True" + ) + + def preexec_fn(): + import signal + + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) + if fate_share and sys.platform.startswith("linux"): + ray._private.utils.set_kill_on_parent_death_linux() + + win32_fate_sharing = fate_share and sys.platform == "win32" + # With Windows fate-sharing, we need special care: + # The process must be added to the job before it is allowed to execute. + # Otherwise, there's a race condition: the process might spawn children + # before the process itself is assigned to the job. + # After that point, its children will not be added to the job anymore. + CREATE_SUSPENDED = 0x00000004 # from Windows headers + if sys.platform == "win32": + # CreateProcess, which underlies Popen, is limited to + # 32,767 characters, including the Unicode terminating null + # character + total_chrs = sum([len(x) for x in command]) + if total_chrs > 31766: + raise ValueError( + f"command is limited to a total of 31767 characters, " + f"got {total_chrs}" + ) + + process = ConsolePopen( + command, + env=modified_env, + cwd=cwd, + stdout=stdout_file, + stderr=stderr_file, + stdin=subprocess.PIPE if pipe_stdin else None, + preexec_fn=preexec_fn if sys.platform != "win32" else None, + creationflags=CREATE_SUSPENDED if win32_fate_sharing else 0, + ) + + if win32_fate_sharing: + try: + ray._private.utils.set_kill_child_on_death_win32(process) + psutil.Process(process.pid).resume() + except (psutil.Error, OSError): + process.kill() + raise + + def _get_stream_name(stream): + if stream is not None: + try: + return stream.name + except AttributeError: + return str(stream) + return None + + return ProcessInfo( + process=process, + stdout_file=_get_stream_name(stdout_file), + stderr_file=_get_stream_name(stderr_file), + use_valgrind=use_valgrind, + use_gdb=use_gdb, + use_valgrind_profiler=use_valgrind_profiler, + use_perftools_profiler=use_perftools_profiler, + use_tmux=use_tmux, + ) + + +def start_reaper(fate_share=None): + """Start the reaper process. + + This is a lightweight process that simply + waits for its parent process to die and then terminates its own + process group. This allows us to ensure that ray processes are always + terminated properly so long as that process itself isn't SIGKILLed. + + Returns: + ProcessInfo for the process that was started. + """ + # Make ourselves a process group leader so that the reaper can clean + # up other ray processes without killing the process group of the + # process that started us. + try: + if sys.platform != "win32": + os.setpgrp() + except OSError as e: + errcode = e.errno + if errcode == errno.EPERM and os.getpgrp() == os.getpid(): + # Nothing to do; we're already a session leader. + pass + else: + logger.warning( + f"setpgrp failed, processes may not be cleaned up properly: {e}." + ) + # Don't start the reaper in this case as it could result in killing + # other user processes. + return None + + reaper_filepath = os.path.join(RAY_PATH, RAY_PRIVATE_DIR, "ray_process_reaper.py") + command = [sys.executable, "-u", reaper_filepath] + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_REAPER, + pipe_stdin=True, + fate_share=fate_share, + ) + return process_info + + +def start_log_monitor( + session_dir: str, + logs_dir: str, + gcs_address: str, + fate_share: Optional[bool] = None, + max_bytes: int = 0, + backup_count: int = 0, + redirect_logging: bool = True, + stdout_file: Optional[IO[AnyStr]] = subprocess.DEVNULL, + stderr_file: Optional[IO[AnyStr]] = subprocess.DEVNULL, +): + """Start a log monitor process. + + Args: + session_dir: The session directory. + logs_dir: The directory of logging files. + gcs_address: GCS address for pubsub. + fate_share: Whether to share fate between log_monitor + and this process. + max_bytes: Log rotation parameter. Corresponding to + RotatingFileHandler's maxBytes. + backup_count: Log rotation parameter. Corresponding to + RotatingFileHandler's backupCount. + redirect_logging: Whether we should redirect logging to + the provided log directory. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + + Returns: + ProcessInfo for the process that was started. + """ + log_monitor_filepath = os.path.join(RAY_PATH, RAY_PRIVATE_DIR, "log_monitor.py") + + command = [ + sys.executable, + "-u", + log_monitor_filepath, + f"--session-dir={session_dir}", + f"--logs-dir={logs_dir}", + f"--gcs-address={gcs_address}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + ] + + if not redirect_logging: + # If not redirecting logging to files, unset log filename. + # This will cause log records to go to stderr. + command.append("--logging-filename=") + # Use stderr log format with the component name as a message prefix. + logging_format = ray_constants.LOGGER_FORMAT_STDERR.format( + component=ray_constants.PROCESS_TYPE_LOG_MONITOR + ) + command.append(f"--logging-format={logging_format}") + # Inherit stdout/stderr streams. + stdout_file = None + stderr_file = None + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_LOG_MONITOR, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + return process_info + + +def start_api_server( + include_dashboard: Optional[bool], + raise_on_failure: bool, + host: str, + gcs_address: str, + cluster_id_hex: str, + node_ip_address: str, + temp_dir: str, + logdir: str, + session_dir: str, + port: Optional[int] = None, + dashboard_grpc_port: Optional[int] = None, + fate_share: Optional[bool] = None, + max_bytes: int = 0, + backup_count: int = 0, + redirect_logging: bool = True, + stdout_file: Optional[IO[AnyStr]] = subprocess.DEVNULL, + stderr_file: Optional[IO[AnyStr]] = subprocess.DEVNULL, +): + """Start a API server process. + + Args: + include_dashboard: If true, this will load all dashboard-related modules + when starting the API server, or fail. If None, it will load all + dashboard-related modules conditioned on dependencies being present. + Otherwise, it will only start the modules that are not relevant to + the dashboard. + raise_on_failure: If true, this will raise an exception + if we fail to start the API server. Otherwise it will print + a warning if we fail to start the API server. + host: The host to bind the dashboard web server to. + gcs_address: The gcs address the dashboard should connect to + cluster_id_hex: Cluster ID in hex. + node_ip_address: The IP address where this is running. + temp_dir: The temporary directory used for log files and + information for this Ray session. + session_dir: The session directory under temp_dir. + It is used as a identifier of individual cluster. + logdir: The log directory used to generate dashboard log. + port: The port to bind the dashboard web server to. + Defaults to 8265. + dashboard_grpc_port: The port which the dashboard listens for + gRPC on. Defaults to a random, available port. + max_bytes: Log rotation parameter. Corresponding to + RotatingFileHandler's maxBytes. + backup_count: Log rotation parameter. Corresponding to + RotatingFileHandler's backupCount. + redirect_logging: Whether we should redirect logging to + the provided log directory. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + + Returns: + A tuple of : + - Dashboard URL if dashboard enabled and started. + - ProcessInfo for the process that was started. + """ + try: + # Make sure port is available. + if port is None: + port_retries = 50 + port = ray_constants.DEFAULT_DASHBOARD_PORT + else: + port_retries = 0 + port_test_socket = socket.socket() + port_test_socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + 1, + ) + try: + port_test_socket.bind((host, port)) + port_test_socket.close() + except socket.error as e: + # 10013 on windows is a bit more broad than just + # "address in use": it can also indicate "permission denied". + # TODO: improve the error message? + if e.errno in {48, 98, 10013}: # address already in use. + raise ValueError( + f"Failed to bind to {host}:{port} because it's " + "already occupied. You can use `ray start " + "--dashboard-port ...` or `ray.init(dashboard_port=..." + ")` to select a different port." + ) + else: + raise e + # Make sure the process can start. + minimal: bool = not ray._private.utils.check_dashboard_dependencies_installed() + + # Explicitly check here that when the user explicitly specifies + # dashboard inclusion, the install is not minimal. + if include_dashboard and minimal: + logger.error( + "--include-dashboard is not supported when minimal ray is used. " + "Download ray[default] to use the dashboard." + ) + raise Exception("Cannot include dashboard with missing packages.") + + include_dash: bool = True if include_dashboard is None else include_dashboard + + # Start the dashboard process. + dashboard_dir = "dashboard" + dashboard_filepath = os.path.join(RAY_PATH, dashboard_dir, "dashboard.py") + + command = [ + *_build_python_executable_command_memory_profileable( + ray_constants.PROCESS_TYPE_DASHBOARD, + session_dir, + unbuffered=False, + ), + dashboard_filepath, + f"--host={host}", + f"--port={port}", + f"--port-retries={port_retries}", + f"--temp-dir={temp_dir}", + f"--log-dir={logdir}", + f"--session-dir={session_dir}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + f"--gcs-address={gcs_address}", + f"--cluster-id-hex={cluster_id_hex}", + f"--node-ip-address={node_ip_address}", + ] + + if not redirect_logging: + # If not redirecting logging to files, unset log filename. + # This will cause log records to go to stderr. + command.append("--logging-filename=") + # Use stderr log format with the component name as a message prefix. + logging_format = ray_constants.LOGGER_FORMAT_STDERR.format( + component=ray_constants.PROCESS_TYPE_DASHBOARD + ) + command.append(f"--logging-format={logging_format}") + # Inherit stdout/stderr streams so that + # logs are redirected to stderr. + stdout_file = None + stderr_file = None + if minimal: + command.append("--minimal") + + if not include_dash: + # If dashboard is not included, load modules + # that are irrelevant to the dashboard. + # TODO(sang): Modules like job or state APIs should be + # loaded although dashboard is disabled. Fix it. + command.append("--modules-to-load=UsageStatsHead") + command.append("--disable-frontend") + + if dashboard_grpc_port is not None: + command.append(f"--grpc-port={dashboard_grpc_port}") + + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_DASHBOARD, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + + # Retrieve the dashboard url + gcs_client = GcsClient(address=gcs_address, cluster_id=cluster_id_hex) + ray.experimental.internal_kv._initialize_internal_kv(gcs_client) + dashboard_url = None + dashboard_returncode = None + for _ in range(200): + dashboard_url = ray.experimental.internal_kv._internal_kv_get( + ray_constants.DASHBOARD_ADDRESS, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + if dashboard_url is not None: + dashboard_url = dashboard_url.decode("utf-8") + break + dashboard_returncode = process_info.process.poll() + if dashboard_returncode is not None: + break + # This is often on the critical path of ray.init() and ray start, + # so we need to poll often. + time.sleep(0.1) + + # Dashboard couldn't be started. + if dashboard_url is None: + returncode_str = ( + f", return code {dashboard_returncode}" + if dashboard_returncode is not None + else "" + ) + logger.error(f"Failed to start the dashboard {returncode_str}") + + def read_log(filename, lines_to_read): + """Read a log file and return the last 20 lines.""" + dashboard_log = os.path.join(logdir, filename) + # Read last n lines of dashboard log. The log file may be large. + lines_to_read = 20 + lines = [] + with open(dashboard_log, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + end = mm.size() + for _ in range(lines_to_read): + sep = mm.rfind(b"\n", 0, end - 1) + if sep == -1: + break + lines.append(mm[sep + 1 : end].decode("utf-8")) + end = sep + lines.append( + f"The last {lines_to_read} lines of {dashboard_log} " + "(it contains the error message from the dashboard): " + ) + return lines + + if logdir: + lines_to_read = 20 + logger.error( + "Error should be written to 'dashboard.log' or " + "'dashboard.err'. We are printing the last " + f"{lines_to_read} lines for you. See " + "'https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#logging-directory-structure' " # noqa + "to find where the log file is." + ) + try: + lines = read_log("dashboard.log", lines_to_read=lines_to_read) + except Exception as e: + logger.error( + f"Couldn't read dashboard.log file. Error: {e}. " + "It means the dashboard is broken even before it " + "initializes the logger (mostly dependency issues). " + "Reading the dashboard.err file which contains stdout/stderr." + ) + # If we cannot read the .log file, we fallback to .err file. + # This is the case where dashboard couldn't be started at all + # and couldn't even initialize the logger to write logs to .log + # file. + try: + lines = read_log("dashboard.err", lines_to_read=lines_to_read) + except Exception as e: + raise Exception( + f"Failed to read dashboard.err file: {e}. " + "It is unexpected. Please report an issue to " + "Ray github. " + "https://github.com/ray-project/ray/issues" + ) + last_log_str = "\n" + "\n".join(reversed(lines[-lines_to_read:])) + raise Exception(last_log_str) + else: + # Is it reachable? + raise Exception("Failed to start a dashboard.") + + if minimal or not include_dash: + # If it is the minimal installation, the web url (dashboard url) + # shouldn't be configured because it doesn't start a server. + dashboard_url = "" + return dashboard_url, process_info + except Exception as e: + if raise_on_failure: + raise e from e + else: + logger.error(e) + return None, None + + +def get_address(redis_address): + parts = redis_address.split("://", 1) + enable_redis_ssl = False + if len(parts) == 1: + redis_ip_address, redis_port = parts[0].rsplit(":", 1) + else: + # rediss for SSL + if len(parts) != 2 or parts[0] not in ("redis", "rediss"): + raise ValueError( + f"Invalid redis address {redis_address}." + "Expected format is ip:port or redis://ip:port, " + "or rediss://ip:port for SSL." + ) + redis_ip_address, redis_port = parts[1].rsplit(":", 1) + if parts[0] == "rediss": + enable_redis_ssl = True + return redis_ip_address, redis_port, enable_redis_ssl + + +def start_gcs_server( + redis_address: str, + log_dir: str, + ray_log_filepath: Optional[str], + stderr_file: Optional[IO[AnyStr]], + session_name: str, + redis_username: Optional[str] = None, + redis_password: Optional[str] = None, + config: Optional[dict] = None, + fate_share: Optional[bool] = None, + gcs_server_port: Optional[int] = None, + metrics_agent_port: Optional[int] = None, + node_ip_address: Optional[str] = None, +): + """Start a gcs server. + + Args: + redis_address: The address that the Redis server is listening on. + log_dir: The path of the dir where gcs log files are created. + ray_log_filepath: The file path to dump gcs server log, which is + written via `RAY_LOG`. If None, logs will be sent to stdout. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + session_name: The session name (cluster id) of this cluster. + redis_username: The username of the Redis server. + redis_password: The password of the Redis server. + config: Optional configuration that will + override defaults in RayConfig. + gcs_server_port: Port number of the gcs server. + metrics_agent_port: The port where metrics agent is bound to. + node_ip_address: IP Address of a node where gcs server starts. + + Returns: + ProcessInfo for the process that was started. + """ + assert gcs_server_port > 0 + + command = [ + GCS_SERVER_EXECUTABLE, + f"--log_dir={log_dir}", + f"--config_list={serialize_config(config)}", + f"--gcs_server_port={gcs_server_port}", + f"--metrics-agent-port={metrics_agent_port}", + f"--node-ip-address={node_ip_address}", + f"--session-name={session_name}", + f"--ray-commit={ray.__commit__}", + ] + + if ray_log_filepath: + command += [f"--ray_log_filepath={ray_log_filepath}"] + + if redis_address: + redis_ip_address, redis_port, enable_redis_ssl = get_address(redis_address) + + command += [ + f"--redis_address={redis_ip_address}", + f"--redis_port={redis_port}", + f"--redis_enable_ssl={'true' if enable_redis_ssl else 'false'}", + ] + if redis_username: + command += [f"--redis_username={redis_username}"] + if redis_password: + command += [f"--redis_password={redis_password}"] + + stdout_file = None + if ray_log_filepath: + stdout_file = open(os.devnull, "w") + + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_GCS_SERVER, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + return process_info + + +def start_raylet( + redis_address: str, + gcs_address: str, + node_id: str, + node_ip_address: str, + node_manager_port: int, + raylet_name: str, + plasma_store_name: str, + cluster_id: str, + worker_path: str, + setup_worker_path: str, + storage: str, + temp_dir: str, + session_dir: str, + resource_dir: str, + log_dir: str, + resource_spec, + plasma_directory: str, + object_store_memory: int, + session_name: str, + is_head_node: bool, + min_worker_port: Optional[int] = None, + max_worker_port: Optional[int] = None, + worker_port_list: Optional[List[int]] = None, + object_manager_port: Optional[int] = None, + redis_username: Optional[str] = None, + redis_password: Optional[str] = None, + metrics_agent_port: Optional[int] = None, + metrics_export_port: Optional[int] = None, + dashboard_agent_listen_port: Optional[int] = None, + runtime_env_agent_port: Optional[int] = None, + use_valgrind: bool = False, + use_profiler: bool = False, + ray_log_filepath: Optional[str] = None, + stderr_file: Optional[IO[AnyStr]] = None, + huge_pages: bool = False, + fate_share: Optional[bool] = None, + socket_to_use: Optional[int] = None, + max_bytes: int = 0, + backup_count: int = 0, + ray_debugger_external: bool = False, + env_updates: Optional[dict] = None, + node_name: Optional[str] = None, + webui: Optional[str] = None, + labels: Optional[dict] = None, + enable_physical_mode: bool = False, +): + """Start a raylet, which is a combined local scheduler and object manager. + + Args: + redis_address: The address of the primary Redis server. + gcs_address: The address of GCS server. + node_id: The hex ID of this node. + node_ip_address: The IP address of this node. + node_manager_port: The port to use for the node manager. If it's + 0, a random port will be used. + raylet_name: The name of the raylet socket to create. + plasma_store_name: The name of the plasma store socket to connect + to. + worker_path: The path of the Python file that new worker + processes will execute. + setup_worker_path: The path of the Python file that will set up + the environment for the worker process. + storage: The persistent storage URI. + temp_dir: The path of the temporary directory Ray will use. + session_dir: The path of this session. + resource_dir: The path of resource of this session . + log_dir: The path of the dir where log files are created. + resource_spec: Resources for this raylet. + session_name: The session name (cluster id) of this cluster. + object_manager_port: The port to use for the object manager. If this is + None, then the object manager will choose its own port. + min_worker_port: The lowest port number that workers will bind + on. If not set, random ports will be chosen. + max_worker_port: The highest port number that workers will bind + on. If set, min_worker_port must also be set. + redis_username: The username to use when connecting to Redis. + redis_password: The password to use when connecting to Redis. + metrics_agent_port: The port where metrics agent is bound to. + metrics_export_port: The port at which metrics are exposed to. + dashboard_agent_listen_port: The port at which the dashboard agent + listens to for HTTP. + runtime_env_agent_port: The port at which the runtime env agent + listens to for HTTP. + use_valgrind: True if the raylet should be started inside + of valgrind. If this is True, use_profiler must be False. + use_profiler: True if the raylet should be started inside + a profiler. If this is True, use_valgrind must be False. + ray_log_filepath: The file path to dump raylet log, which is + written via `RAY_LOG`. If None, logs will be sent to stdout. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + tracing_startup_hook: Tracing startup hook. + max_bytes: Log rotation parameter. Corresponding to + RotatingFileHandler's maxBytes. + backup_count: Log rotation parameter. Corresponding to + RotatingFileHandler's backupCount. + ray_debugger_external: True if the Ray debugger should be made + available externally to this node. + env_updates: Environment variable overrides. + labels: The key-value labels of the node. + enable_physical_mode: Whether physical mode is enabled, which applies + constraint to tasks' resource consumption. As of now only memory + resource is supported. + Returns: + ProcessInfo for the process that was started. + """ + assert node_manager_port is not None and type(node_manager_port) is int + + if use_valgrind and use_profiler: + raise ValueError("Cannot use valgrind and profiler at the same time.") + + assert resource_spec.resolved() + static_resources = resource_spec.to_resource_dict() + + # Limit the number of workers that can be started in parallel by the + # raylet. However, make sure it is at least 1. + num_cpus_static = static_resources.get("CPU", 0) + maximum_startup_concurrency = max( + 1, min(multiprocessing.cpu_count(), num_cpus_static) + ) + + # Format the resource argument in a form like 'CPU,1.0,GPU,0,Custom,3'. + resource_argument = ",".join( + ["{},{}".format(*kv) for kv in static_resources.items()] + ) + + has_java_command = False + if shutil.which("java") is not None: + has_java_command = True + + ray_java_installed = False + try: + jars_dir = get_ray_jars_dir() + if os.path.exists(jars_dir): + ray_java_installed = True + except Exception: + pass + + include_java = has_java_command and ray_java_installed + if include_java is True: + java_worker_command = build_java_worker_command( + gcs_address, + plasma_store_name, + raylet_name, + redis_username, + redis_password, + session_dir, + node_ip_address, + setup_worker_path, + ) + else: + java_worker_command = [] + + if os.path.exists(DEFAULT_WORKER_EXECUTABLE): + cpp_worker_command = build_cpp_worker_command( + gcs_address, + plasma_store_name, + raylet_name, + redis_username, + redis_password, + session_dir, + log_dir, + node_ip_address, + setup_worker_path, + ) + else: + cpp_worker_command = [] + + # Create the command that the Raylet will use to start workers. + # TODO(architkulkarni): Pipe in setup worker args separately instead of + # inserting them into start_worker_command and later erasing them if + # needed. + start_worker_command = ( + [ + sys.executable, + setup_worker_path, + ] + + _site_flags() # Inherit "-S" and "-s" flags from current Python interpreter. + + [ + worker_path, + f"--node-ip-address={node_ip_address}", + "--node-manager-port=RAY_NODE_MANAGER_PORT_PLACEHOLDER", + f"--object-store-name={plasma_store_name}", + f"--raylet-name={raylet_name}", + f"--redis-address={redis_address}", + f"--metrics-agent-port={metrics_agent_port}", + f"--runtime-env-agent-port={runtime_env_agent_port}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + f"--runtime-env-agent-port={runtime_env_agent_port}", + f"--gcs-address={gcs_address}", + f"--session-name={session_name}", + f"--temp-dir={temp_dir}", + f"--webui={webui}", + f"--cluster-id={cluster_id}", + ] + ) + + if storage is not None: + start_worker_command.append(f"--storage={storage}") + + start_worker_command.append("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER") + + if redis_username: + start_worker_command += [f"--redis-username={redis_username}"] + + if redis_password: + start_worker_command += [f"--redis-password={redis_password}"] + + # If the object manager port is None, then use 0 to cause the object + # manager to choose its own port. + if object_manager_port is None: + object_manager_port = 0 + + if min_worker_port is None: + min_worker_port = 0 + + if max_worker_port is None: + max_worker_port = 0 + + labels_json_str = "" + if labels: + labels_json_str = json.dumps(labels) + + dashboard_agent_command = [ + *_build_python_executable_command_memory_profileable( + ray_constants.PROCESS_TYPE_DASHBOARD_AGENT, session_dir + ), + os.path.join(RAY_PATH, "dashboard", "agent.py"), + f"--node-ip-address={node_ip_address}", + f"--metrics-export-port={metrics_export_port}", + f"--dashboard-agent-port={metrics_agent_port}", + f"--listen-port={dashboard_agent_listen_port}", + "--node-manager-port=RAY_NODE_MANAGER_PORT_PLACEHOLDER", + f"--object-store-name={plasma_store_name}", + f"--raylet-name={raylet_name}", + f"--temp-dir={temp_dir}", + f"--session-dir={session_dir}", + f"--log-dir={log_dir}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + f"--session-name={session_name}", + f"--gcs-address={gcs_address}", + f"--cluster-id-hex={cluster_id}", + ] + if ray_log_filepath is None and stderr_file is None: + # If not redirecting logging to files, unset log filename. + # This will cause log records to go to stderr. + dashboard_agent_command.append("--logging-filename=") + # Use stderr log format with the component name as a message prefix. + logging_format = ray_constants.LOGGER_FORMAT_STDERR.format( + component=ray_constants.PROCESS_TYPE_DASHBOARD_AGENT + ) + dashboard_agent_command.append(f"--logging-format={logging_format}") + + if not ray._private.utils.check_dashboard_dependencies_installed(): + # If dependencies are not installed, it is the minimally packaged + # ray. We should restrict the features within dashboard agent + # that requires additional dependencies to be downloaded. + dashboard_agent_command.append("--minimal") + + runtime_env_agent_command = [ + *_build_python_executable_command_memory_profileable( + ray_constants.PROCESS_TYPE_RUNTIME_ENV_AGENT, session_dir + ), + os.path.join(RAY_PATH, "_private", "runtime_env", "agent", "main.py"), + f"--node-ip-address={node_ip_address}", + f"--runtime-env-agent-port={runtime_env_agent_port}", + f"--gcs-address={gcs_address}", + f"--cluster-id-hex={cluster_id}", + f"--runtime-env-dir={resource_dir}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + f"--log-dir={log_dir}", + f"--temp-dir={temp_dir}", + ] + + command = [ + RAYLET_EXECUTABLE, + f"--raylet_socket_name={raylet_name}", + f"--store_socket_name={plasma_store_name}", + f"--object_manager_port={object_manager_port}", + f"--min_worker_port={min_worker_port}", + f"--max_worker_port={max_worker_port}", + f"--node_manager_port={node_manager_port}", + f"--node_id={node_id}", + f"--node_ip_address={node_ip_address}", + f"--maximum_startup_concurrency={maximum_startup_concurrency}", + f"--static_resource_list={resource_argument}", + f"--python_worker_command={subprocess.list2cmdline(start_worker_command)}", # noqa + f"--java_worker_command={subprocess.list2cmdline(java_worker_command)}", # noqa + f"--cpp_worker_command={subprocess.list2cmdline(cpp_worker_command)}", # noqa + f"--native_library_path={DEFAULT_NATIVE_LIBRARY_PATH}", + f"--temp_dir={temp_dir}", + f"--session_dir={session_dir}", + f"--log_dir={log_dir}", + f"--resource_dir={resource_dir}", + f"--metrics-agent-port={metrics_agent_port}", + f"--metrics_export_port={metrics_export_port}", + f"--runtime_env_agent_port={runtime_env_agent_port}", + f"--object_store_memory={object_store_memory}", + f"--plasma_directory={plasma_directory}", + f"--ray-debugger-external={1 if ray_debugger_external else 0}", + f"--gcs-address={gcs_address}", + f"--session-name={session_name}", + f"--labels={labels_json_str}", + f"--cluster-id={cluster_id}", + ] + + if ray_log_filepath: + command.append(f"--ray_log_filepath={ray_log_filepath}") + + if is_head_node: + command.append("--head") + + if worker_port_list is not None: + command.append(f"--worker_port_list={worker_port_list}") + command.append( + "--num_prestart_python_workers={}".format(int(resource_spec.num_cpus)) + ) + command.append( + "--dashboard_agent_command={}".format( + subprocess.list2cmdline(dashboard_agent_command) + ) + ) + command.append( + "--runtime_env_agent_command={}".format( + subprocess.list2cmdline(runtime_env_agent_command) + ) + ) + if huge_pages: + command.append("--huge_pages") + if socket_to_use: + socket_to_use.close() + if node_name is not None: + command.append( + f"--node-name={node_name}", + ) + + stdout_file = None + if ray_log_filepath: + stdout_file = open(os.devnull, "w") + else: + stdout_file = None + + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_RAYLET, + use_valgrind=use_valgrind, + use_gdb=False, + use_valgrind_profiler=use_profiler, + use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ), + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + env_updates=env_updates, + ) + return process_info + + +def get_ray_jars_dir(): + """Return a directory where all ray-related jars and + their dependencies locate.""" + current_dir = RAY_PATH + jars_dir = os.path.abspath(os.path.join(current_dir, "jars")) + if not os.path.exists(jars_dir): + raise RuntimeError( + "Ray jars is not packaged into ray. " + "Please build ray with java enabled " + "(set env var RAY_INSTALL_JAVA=1)" + ) + return os.path.abspath(os.path.join(current_dir, "jars")) + + +def build_java_worker_command( + bootstrap_address: str, + plasma_store_name: str, + raylet_name: str, + redis_username: str, + redis_password: str, + session_dir: str, + node_ip_address: str, + setup_worker_path: str, +): + """This method assembles the command used to start a Java worker. + + Args: + bootstrap_address: Bootstrap address of ray cluster. + plasma_store_name: The name of the plasma store socket to connect + to. + raylet_name: The name of the raylet socket to create. + redis_username: The username to connect to Redis. + redis_password: The password to connect to Redis. + session_dir: The path of this session. + node_ip_address: The IP address for this node. + setup_worker_path: The path of the Python file that will set up + the environment for the worker process. + Returns: + The command string for starting Java worker. + """ + pairs = [] + if bootstrap_address is not None: + pairs.append(("ray.address", bootstrap_address)) + pairs.append(("ray.raylet.node-manager-port", "RAY_NODE_MANAGER_PORT_PLACEHOLDER")) + + if plasma_store_name is not None: + pairs.append(("ray.object-store.socket-name", plasma_store_name)) + + if raylet_name is not None: + pairs.append(("ray.raylet.socket-name", raylet_name)) + + if redis_username is not None: + pairs.append(("ray.redis.username", redis_username)) + + if redis_password is not None: + pairs.append(("ray.redis.password", redis_password)) + + if node_ip_address is not None: + pairs.append(("ray.node-ip", node_ip_address)) + + pairs.append(("ray.home", RAY_HOME)) + pairs.append(("ray.logging.dir", os.path.join(session_dir, "logs"))) + pairs.append(("ray.session-dir", session_dir)) + command = ( + [sys.executable] + + [setup_worker_path] + + ["-D{}={}".format(*pair) for pair in pairs] + ) + + command += ["RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"] + command += ["io.ray.runtime.runner.worker.DefaultWorker"] + + return command + + +def build_cpp_worker_command( + bootstrap_address: str, + plasma_store_name: str, + raylet_name: str, + redis_username: str, + redis_password: str, + session_dir: str, + log_dir: str, + node_ip_address: str, + setup_worker_path: str, +): + """This method assembles the command used to start a CPP worker. + + Args: + bootstrap_address: The bootstrap address of the cluster. + plasma_store_name: The name of the plasma store socket to connect + to. + raylet_name: The name of the raylet socket to create. + redis_username: The username to connect to Redis. + redis_password: The password to connect to Redis. + session_dir: The path of this session. + log_dir: The path of logs. + node_ip_address: The ip address for this node. + setup_worker_path: The path of the Python file that will set up + the environment for the worker process. + Returns: + The command string for starting CPP worker. + """ + + command = [ + sys.executable, + setup_worker_path, + DEFAULT_WORKER_EXECUTABLE, + f"--ray_plasma_store_socket_name={plasma_store_name}", + f"--ray_raylet_socket_name={raylet_name}", + "--ray_node_manager_port=RAY_NODE_MANAGER_PORT_PLACEHOLDER", + f"--ray_address={bootstrap_address}", + f"--ray_redis_username={redis_username}", + f"--ray_redis_password={redis_password}", + f"--ray_session_dir={session_dir}", + f"--ray_logs_dir={log_dir}", + f"--ray_node_ip_address={node_ip_address}", + "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER", + ] + + return command + + +def determine_plasma_store_config( + object_store_memory: int, + plasma_directory: Optional[str] = None, + huge_pages: bool = False, +): + """Figure out how to configure the plasma object store. + + This will determine which directory to use for the plasma store. On Linux, + we will try to use /dev/shm unless the shared memory file system is too + small, in which case we will fall back to /tmp. If any of the object store + memory or plasma directory parameters are specified by the user, then those + values will be preserved. + + Args: + object_store_memory: The object store memory to use. + plasma_directory: The user-specified plasma directory parameter. + huge_pages: The user-specified huge pages parameter. + + Returns: + The plasma directory to use. If it is specified by the user, then that + value will be preserved. + """ + if not isinstance(object_store_memory, int): + object_store_memory = int(object_store_memory) + + if huge_pages and not (sys.platform == "linux" or sys.platform == "linux2"): + raise ValueError("The huge_pages argument is only supported on Linux.") + + system_memory = ray._private.utils.get_system_memory() + + # Determine which directory to use. By default, use /tmp on MacOS and + # /dev/shm on Linux, unless the shared-memory file system is too small, + # in which case we default to /tmp on Linux. + if plasma_directory is None: + if sys.platform == "linux" or sys.platform == "linux2": + shm_avail = ray._private.utils.get_shared_memory_bytes() + # Compare the requested memory size to the memory available in + # /dev/shm. + if shm_avail >= object_store_memory: + plasma_directory = "/dev/shm" + elif ( + not os.environ.get("RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE") + and object_store_memory > ray_constants.REQUIRE_SHM_SIZE_THRESHOLD + ): + raise ValueError( + "The configured object store size ({} GB) exceeds " + "/dev/shm size ({} GB). This will harm performance. " + "Consider deleting files in /dev/shm or increasing its " + "size with " + "--shm-size in Docker. To ignore this warning, " + "set RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1.".format( + object_store_memory / 1e9, shm_avail / 1e9 + ) + ) + else: + plasma_directory = ray._private.utils.get_user_temp_dir() + logger.warning( + "WARNING: The object store is using {} instead of " + "/dev/shm because /dev/shm has only {} bytes available. " + "This will harm performance! You may be able to free up " + "space by deleting files in /dev/shm. If you are inside a " + "Docker container, you can increase /dev/shm size by " + "passing '--shm-size={:.2f}gb' to 'docker run' (or add it " + "to the run_options list in a Ray cluster config). Make " + "sure to set this to more than 30% of available RAM.".format( + ray._private.utils.get_user_temp_dir(), + shm_avail, + object_store_memory * (1.1) / (2**30), + ) + ) + else: + plasma_directory = ray._private.utils.get_user_temp_dir() + + # Do some sanity checks. + if object_store_memory > system_memory: + raise ValueError( + "The requested object store memory size is greater " + "than the total available memory." + ) + else: + plasma_directory = os.path.abspath(plasma_directory) + logger.info("object_store_memory is not verified when plasma_directory is set.") + + if not os.path.isdir(plasma_directory): + raise ValueError( + f"The file {plasma_directory} does not exist or is not a directory." + ) + + if huge_pages and plasma_directory is None: + raise ValueError( + "If huge_pages is True, then the " + "plasma_directory argument must be provided." + ) + + if object_store_memory < ray_constants.OBJECT_STORE_MINIMUM_MEMORY_BYTES: + raise ValueError( + "Attempting to cap object store memory usage at {} " + "bytes, but the minimum allowed is {} bytes.".format( + object_store_memory, ray_constants.OBJECT_STORE_MINIMUM_MEMORY_BYTES + ) + ) + + if ( + sys.platform == "darwin" + and object_store_memory > ray_constants.MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT + and os.environ.get("RAY_ENABLE_MAC_LARGE_OBJECT_STORE") != "1" + ): + raise ValueError( + "The configured object store size ({:.4}GiB) exceeds " + "the optimal size on Mac ({:.4}GiB). " + "This will harm performance! There is a known issue where " + "Ray's performance degrades with object store size greater" + " than {:.4}GB on a Mac." + "To reduce the object store capacity, specify" + "`object_store_memory` when calling ray.init() or ray start." + "To ignore this warning, " + "set RAY_ENABLE_MAC_LARGE_OBJECT_STORE=1.".format( + object_store_memory / 2**30, + ray_constants.MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT / 2**30, + ray_constants.MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT / 2**30, + ) + ) + + # Print the object store memory using two decimal places. + logger.debug( + "Determine to start the Plasma object store with {} GB memory " + "using {}.".format(round(object_store_memory / 10**9, 2), plasma_directory) + ) + + return plasma_directory, object_store_memory + + +def start_monitor( + gcs_address: str, + logs_dir: str, + stdout_file: Optional[str] = None, + stderr_file: Optional[str] = None, + autoscaling_config: Optional[str] = None, + fate_share: Optional[bool] = None, + max_bytes: int = 0, + backup_count: int = 0, + monitor_ip: Optional[str] = None, + autoscaler_v2: bool = False, +): + """Run a process to monitor the other processes. + + Args: + gcs_address: The address of GCS server. + logs_dir: The path to the log directory. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + autoscaling_config: path to autoscaling config file. + max_bytes: Log rotation parameter. Corresponding to + RotatingFileHandler's maxBytes. + backup_count: Log rotation parameter. Corresponding to + RotatingFileHandler's backupCount. + monitor_ip: IP address of the machine that the monitor will be + run on. Can be excluded, but required for autoscaler metrics. + Returns: + ProcessInfo for the process that was started. + """ + if autoscaler_v2: + entrypoint = os.path.join(RAY_PATH, AUTOSCALER_V2_DIR, "monitor.py") + else: + entrypoint = os.path.join(RAY_PATH, AUTOSCALER_PRIVATE_DIR, "monitor.py") + + command = [ + sys.executable, + "-u", + entrypoint, + f"--logs-dir={logs_dir}", + f"--logging-rotate-bytes={max_bytes}", + f"--logging-rotate-backup-count={backup_count}", + ] + assert gcs_address is not None + command.append(f"--gcs-address={gcs_address}") + + if stdout_file is None and stderr_file is None: + # If not redirecting logging to files, unset log filename. + # This will cause log records to go to stderr. + command.append("--logging-filename=") + # Use stderr log format with the component name as a message prefix. + logging_format = ray_constants.LOGGER_FORMAT_STDERR.format( + component=ray_constants.PROCESS_TYPE_MONITOR + ) + command.append(f"--logging-format={logging_format}") + if autoscaling_config: + command.append("--autoscaling-config=" + str(autoscaling_config)) + if monitor_ip: + command.append("--monitor-ip=" + monitor_ip) + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_MONITOR, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + return process_info + + +def start_ray_client_server( + address: str, + ray_client_server_ip: str, + ray_client_server_port: int, + stdout_file: Optional[int] = None, + stderr_file: Optional[int] = None, + redis_username: Optional[int] = None, + redis_password: Optional[int] = None, + fate_share: Optional[bool] = None, + runtime_env_agent_address: Optional[str] = None, + server_type: str = "proxy", + serialized_runtime_env_context: Optional[str] = None, +): + """Run the server process of the Ray client. + + Args: + address: The address of the cluster. + ray_client_server_ip: Host IP the Ray client server listens on. + ray_client_server_port: Port the Ray client server listens on. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + redis_username: The username of the Redis server. + redis_password: The password of the Redis server. + runtime_env_agent_address: Address to the Runtime Env Agent listens on via HTTP. + Only needed when server_type == "proxy". + server_type: Whether to start the proxy version of Ray Client. + serialized_runtime_env_context (str|None): If specified, the serialized + runtime_env_context to start the client server in. + + Returns: + ProcessInfo for the process that was started. + """ + root_ray_dir = Path(__file__).resolve().parents[1] + setup_worker_path = os.path.join( + root_ray_dir, "_private", "workers", ray_constants.SETUP_WORKER_FILENAME + ) + + ray_client_server_host = ( + "127.0.0.1" if ray_client_server_ip == "127.0.0.1" else "0.0.0.0" + ) + command = [ + sys.executable, + setup_worker_path, + "-m", + "ray.util.client.server", + f"--address={address}", + f"--host={ray_client_server_host}", + f"--port={ray_client_server_port}", + f"--mode={server_type}", + f"--language={Language.Name(Language.PYTHON)}", + ] + if redis_username: + command.append(f"--redis-username={redis_username}") + if redis_password: + command.append(f"--redis-password={redis_password}") + if serialized_runtime_env_context: + command.append( + f"--serialized-runtime-env-context={serialized_runtime_env_context}" # noqa: E501 + ) + if server_type == "proxy": + assert len(runtime_env_agent_address) > 0 + if runtime_env_agent_address: + command.append(f"--runtime-env-agent-address={runtime_env_agent_address}") + + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + return process_info diff --git a/.venv/lib/python3.11/site-packages/ray/_private/signature.py b/.venv/lib/python3.11/site-packages/ray/_private/signature.py new file mode 100644 index 0000000000000000000000000000000000000000..00f4e90c29ff4202cc35baccb9c24d3026cb6311 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/signature.py @@ -0,0 +1,185 @@ +import inspect +import logging +from inspect import Parameter +from typing import List + +from ray._private.inspect_util import is_cython + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + +# This dummy type is also defined in ArgumentsBuilder.java. Please keep it +# synced. +DUMMY_TYPE = b"__RAY_DUMMY__" + + +def get_signature(func): + """Get signature parameters. + + Support Cython functions by grabbing relevant attributes from the Cython + function and attaching to a no-op function. This is somewhat brittle, since + inspect may change, but given that inspect is written to a PEP, we hope + it is relatively stable. Future versions of Python may allow overloading + the inspect 'isfunction' and 'ismethod' functions / create ABC for Python + functions. Until then, it appears that Cython won't do anything about + compatability with the inspect module. + + Args: + func: The function whose signature should be checked. + + Returns: + A function signature object, which includes the names of the keyword + arguments as well as their default values. + + Raises: + TypeError: A type error if the signature is not supported + """ + # The first condition for Cython functions, the latter for Cython instance + # methods + if is_cython(func): + attrs = ["__code__", "__annotations__", "__defaults__", "__kwdefaults__"] + + if all(hasattr(func, attr) for attr in attrs): + original_func = func + + def func(): + return + + for attr in attrs: + setattr(func, attr, getattr(original_func, attr)) + else: + raise TypeError(f"{func!r} is not a Python function we can process") + + return inspect.signature(func) + + +def extract_signature(func, ignore_first=False): + """Extract the function signature from the function. + + Args: + func: The function whose signature should be extracted. + ignore_first: True if the first argument should be ignored. This should + be used when func is a method of a class. + + Returns: + List of Parameter objects representing the function signature. + """ + signature_parameters = list(get_signature(func).parameters.values()) + + if ignore_first: + if len(signature_parameters) == 0: + raise ValueError( + "Methods must take a 'self' argument, but the " + f"method '{func.__name__}' does not have one." + ) + signature_parameters = signature_parameters[1:] + + return signature_parameters + + +def validate_args(signature_parameters: List[Parameter], args, kwargs): + """Validates the arguments against the signature. + + Args: + signature_parameters: The list of Parameter objects + representing the function signature, obtained from + `extract_signature`. + args: The positional arguments passed into the function. + kwargs: The keyword arguments passed into the function. + + Raises: + TypeError: Raised if arguments do not fit in the function signature. + """ + reconstructed_signature = inspect.Signature(parameters=signature_parameters) + try: + reconstructed_signature.bind(*args, **kwargs) + except TypeError as exc: # capture a friendlier stacktrace + raise TypeError(str(exc)) from None + + +def flatten_args(signature_parameters: List[Parameter], args, kwargs): + """Validates the arguments against the signature and flattens them. + + The flat list representation is a serializable format for arguments. + Since the flatbuffer representation of function arguments is a list, we + combine both keyword arguments and positional arguments. We represent + this with two entries per argument value - [DUMMY_TYPE, x] for positional + arguments and [KEY, VALUE] for keyword arguments. See the below example. + See `recover_args` for logic restoring the flat list back to args/kwargs. + + Args: + signature_parameters: The list of Parameter objects + representing the function signature, obtained from + `extract_signature`. + args: The positional arguments passed into the function. + kwargs: The keyword arguments passed into the function. + + Returns: + List of args and kwargs. Non-keyword arguments are prefixed + by internal enum DUMMY_TYPE. + + Raises: + TypeError: Raised if arguments do not fit in the function signature. + """ + validate_args(signature_parameters, args, kwargs) + list_args = [] + for arg in args: + list_args += [DUMMY_TYPE, arg] + + for keyword, arg in kwargs.items(): + list_args += [keyword, arg] + return list_args + + +def recover_args(flattened_args): + """Recreates `args` and `kwargs` from the flattened arg list. + + Args: + flattened_args: List of args and kwargs. This should be the output of + `flatten_args`. + + Returns: + args: The non-keyword arguments passed into the function. + kwargs: The keyword arguments passed into the function. + """ + assert ( + len(flattened_args) % 2 == 0 + ), "Flattened arguments need to be even-numbered. See `flatten_args`." + args = [] + kwargs = {} + for name_index in range(0, len(flattened_args), 2): + name, arg = flattened_args[name_index], flattened_args[name_index + 1] + if name == DUMMY_TYPE: + args.append(arg) + else: + kwargs[name] = arg + + return args, kwargs + + +def _convert_from_parameter_kind(kind): + if kind == Parameter.POSITIONAL_ONLY: + return 0 + if kind == Parameter.POSITIONAL_OR_KEYWORD: + return 1 + if kind == Parameter.VAR_POSITIONAL: + return 2 + if kind == Parameter.KEYWORD_ONLY: + return 3 + if kind == Parameter.VAR_KEYWORD: + return 4 + + +def _convert_to_parameter_kind(value): + if value == 0: + return Parameter.POSITIONAL_ONLY + if value == 1: + return Parameter.POSITIONAL_OR_KEYWORD + if value == 2: + return Parameter.VAR_POSITIONAL + if value == 3: + return Parameter.KEYWORD_ONLY + if value == 4: + return Parameter.VAR_KEYWORD diff --git a/.venv/lib/python3.11/site-packages/ray/_private/state.py b/.venv/lib/python3.11/site-packages/ray/_private/state.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdc709a4f1ba643fce111f70b46b8b00b18eb1f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/state.py @@ -0,0 +1,1095 @@ +import json +import logging +from collections import defaultdict +from typing import Dict + +from ray._private.protobuf_compat import message_to_dict + +import ray +from ray._private.client_mode_hook import client_mode_hook +from ray._private.resource_spec import NODE_ID_PREFIX, HEAD_NODE_RESOURCE_NAME +from ray._private.utils import ( + binary_to_hex, + decode, + hex_to_binary, + validate_actor_state_name, +) +from ray._raylet import GlobalStateAccessor +from ray.core.generated import common_pb2 +from ray.core.generated import gcs_pb2 +from ray.core.generated import autoscaler_pb2 +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + + +class GlobalState: + """A class used to interface with the Ray control state. + + Attributes: + global_state_accessor: The client used to query gcs table from gcs + server. + """ + + def __init__(self): + """Create a GlobalState object.""" + # Args used for lazy init of this object. + self.gcs_options = None + self.global_state_accessor = None + + def _check_connected(self): + """Ensure that the object has been initialized before it is used. + + This lazily initializes clients needed for state accessors. + + Raises: + RuntimeError: An exception is raised if ray.init() has not been + called yet. + """ + if self.gcs_options is not None and self.global_state_accessor is None: + self._really_init_global_state() + + # _really_init_global_state should have set self.global_state_accessor + if self.global_state_accessor is None: + raise ray.exceptions.RaySystemError( + "Ray has not been started yet. You can start Ray with 'ray.init()'." + ) + + def disconnect(self): + """Disconnect global state from GCS.""" + self.gcs_options = None + if self.global_state_accessor is not None: + self.global_state_accessor.disconnect() + self.global_state_accessor = None + + def _initialize_global_state(self, gcs_options): + """Set args for lazily initialization of the GlobalState object. + + It's possible that certain keys in gcs kv may not have been fully + populated yet. In this case, we will retry this method until they have + been populated or we exceed a timeout. + + Args: + gcs_options: The client options for gcs + """ + + # Save args for lazy init of global state. This avoids opening extra + # gcs connections from each worker until needed. + self.gcs_options = gcs_options + + def _really_init_global_state(self): + self.global_state_accessor = GlobalStateAccessor(self.gcs_options) + self.global_state_accessor.connect() + + def actor_table( + self, actor_id: str, job_id: ray.JobID = None, actor_state_name: str = None + ): + """Fetch and parse the actor table information for a single actor ID. + + Args: + actor_id: A hex string of the actor ID to fetch information about. + If this is None, then the actor table is fetched. + If this is not None, `job_id` and `actor_state_name` + will not take effect. + job_id: To filter actors by job_id, which is of type `ray.JobID`. + You can use the `ray.get_runtime_context().job_id` function + to get the current job ID + actor_state_name: To filter actors based on actor state, + which can be one of the following: "DEPENDENCIES_UNREADY", + "PENDING_CREATION", "ALIVE", "RESTARTING", or "DEAD". + Returns: + Information from the actor table. + """ + self._check_connected() + + if actor_id is not None: + actor_id = ray.ActorID(hex_to_binary(actor_id)) + actor_info = self.global_state_accessor.get_actor_info(actor_id) + if actor_info is None: + return {} + else: + actor_table_data = gcs_pb2.ActorTableData.FromString(actor_info) + return self._gen_actor_info(actor_table_data) + else: + validate_actor_state_name(actor_state_name) + actor_table = self.global_state_accessor.get_actor_table( + job_id, actor_state_name + ) + results = {} + for i in range(len(actor_table)): + actor_table_data = gcs_pb2.ActorTableData.FromString(actor_table[i]) + results[ + binary_to_hex(actor_table_data.actor_id) + ] = self._gen_actor_info(actor_table_data) + + return results + + def _gen_actor_info(self, actor_table_data): + """Parse actor table data. + + Returns: + Information from actor table. + """ + actor_info = { + "ActorID": binary_to_hex(actor_table_data.actor_id), + "ActorClassName": actor_table_data.class_name, + "IsDetached": actor_table_data.is_detached, + "Name": actor_table_data.name, + "JobID": binary_to_hex(actor_table_data.job_id), + "Address": { + "IPAddress": actor_table_data.address.ip_address, + "Port": actor_table_data.address.port, + "NodeID": binary_to_hex(actor_table_data.address.raylet_id), + }, + "OwnerAddress": { + "IPAddress": actor_table_data.owner_address.ip_address, + "Port": actor_table_data.owner_address.port, + "NodeID": binary_to_hex(actor_table_data.owner_address.raylet_id), + }, + "State": gcs_pb2.ActorTableData.ActorState.DESCRIPTOR.values_by_number[ + actor_table_data.state + ].name, + "NumRestarts": actor_table_data.num_restarts, + "Timestamp": actor_table_data.timestamp, + "StartTime": actor_table_data.start_time, + "EndTime": actor_table_data.end_time, + "DeathCause": actor_table_data.death_cause, + "Pid": actor_table_data.pid, + } + return actor_info + + def node_table(self): + """Fetch and parse the Gcs node info table. + + Returns: + Information about the node in the cluster. + """ + self._check_connected() + + return self.global_state_accessor.get_node_table() + + def job_table(self): + """Fetch and parse the gcs job table. + + Returns: + Information about the Ray jobs in the cluster, + namely a list of dicts with keys: + - "JobID" (identifier for the job), + - "DriverIPAddress" (IP address of the driver for this job), + - "DriverPid" (process ID of the driver for this job), + - "StartTime" (UNIX timestamp of the start time of this job), + - "StopTime" (UNIX timestamp of the stop time of this job, if any) + """ + self._check_connected() + + job_table = self.global_state_accessor.get_job_table( + skip_submission_job_info_field=True, skip_is_running_tasks_field=True + ) + + results = [] + for i in range(len(job_table)): + entry = gcs_pb2.JobTableData.FromString(job_table[i]) + job_info = {} + job_info["JobID"] = entry.job_id.hex() + job_info["DriverIPAddress"] = entry.driver_address.ip_address + job_info["DriverPid"] = entry.driver_pid + job_info["Timestamp"] = entry.timestamp + job_info["StartTime"] = entry.start_time + job_info["EndTime"] = entry.end_time + job_info["IsDead"] = entry.is_dead + job_info["Entrypoint"] = entry.entrypoint + results.append(job_info) + + return results + + def next_job_id(self): + """Get next job id from GCS. + + Returns: + Next job id in the cluster. + """ + self._check_connected() + + return ray.JobID.from_int(self.global_state_accessor.get_next_job_id()) + + def profile_events(self): + """Retrieve and return task profiling events from GCS. + + Return: + Profiling events by component id (e.g. worker id). + { + : [ + { + event_type: , + component_id: , + node_ip_address: , + component_type: , + start_time: , + end_time: , + extra_data: , + } + ] + } + """ + self._check_connected() + + result = defaultdict(list) + task_events = self.global_state_accessor.get_task_events() + for i in range(len(task_events)): + event = gcs_pb2.TaskEvents.FromString(task_events[i]) + profile = event.profile_events + if not profile: + continue + + component_type = profile.component_type + component_id = binary_to_hex(profile.component_id) + node_ip_address = profile.node_ip_address + + for event in profile.events: + try: + extra_data = json.loads(event.extra_data) + except ValueError: + extra_data = {} + profile_event = { + "event_type": event.event_name, + "component_id": component_id, + "node_ip_address": node_ip_address, + "component_type": component_type, + "start_time": event.start_time, + "end_time": event.end_time, + "extra_data": extra_data, + } + + result[component_id].append(profile_event) + + return dict(result) + + def get_placement_group_by_name(self, placement_group_name, ray_namespace): + self._check_connected() + + placement_group_info = self.global_state_accessor.get_placement_group_by_name( + placement_group_name, ray_namespace + ) + if placement_group_info is None: + return None + else: + placement_group_table_data = gcs_pb2.PlacementGroupTableData.FromString( + placement_group_info + ) + return self._gen_placement_group_info(placement_group_table_data) + + def placement_group_table(self, placement_group_id=None): + self._check_connected() + + if placement_group_id is not None: + placement_group_id = ray.PlacementGroupID( + hex_to_binary(placement_group_id.hex()) + ) + placement_group_info = self.global_state_accessor.get_placement_group_info( + placement_group_id + ) + if placement_group_info is None: + return {} + else: + placement_group_info = gcs_pb2.PlacementGroupTableData.FromString( + placement_group_info + ) + return self._gen_placement_group_info(placement_group_info) + else: + placement_group_table = ( + self.global_state_accessor.get_placement_group_table() + ) + results = {} + for placement_group_info in placement_group_table: + placement_group_table_data = gcs_pb2.PlacementGroupTableData.FromString( + placement_group_info + ) + placement_group_id = binary_to_hex( + placement_group_table_data.placement_group_id + ) + results[placement_group_id] = self._gen_placement_group_info( + placement_group_table_data + ) + + return results + + def _gen_placement_group_info(self, placement_group_info): + # This should be imported here, otherwise, it will error doc build. + from ray.core.generated.common_pb2 import PlacementStrategy + + def get_state(state): + if state == gcs_pb2.PlacementGroupTableData.PENDING: + return "PENDING" + elif state == gcs_pb2.PlacementGroupTableData.PREPARED: + return "PREPARED" + elif state == gcs_pb2.PlacementGroupTableData.CREATED: + return "CREATED" + elif state == gcs_pb2.PlacementGroupTableData.RESCHEDULING: + return "RESCHEDULING" + else: + return "REMOVED" + + def get_strategy(strategy): + if strategy == PlacementStrategy.PACK: + return "PACK" + elif strategy == PlacementStrategy.STRICT_PACK: + return "STRICT_PACK" + elif strategy == PlacementStrategy.STRICT_SPREAD: + return "STRICT_SPREAD" + elif strategy == PlacementStrategy.SPREAD: + return "SPREAD" + else: + raise ValueError(f"Invalid strategy returned: {PlacementStrategy}") + + stats = placement_group_info.stats + assert placement_group_info is not None + return { + "placement_group_id": binary_to_hex( + placement_group_info.placement_group_id + ), + "name": placement_group_info.name, + "bundles": { + # The value here is needs to be dictionarified + # otherwise, the payload becomes unserializable. + bundle.bundle_id.bundle_index: message_to_dict(bundle)["unitResources"] + for bundle in placement_group_info.bundles + }, + "bundles_to_node_id": { + bundle.bundle_id.bundle_index: binary_to_hex(bundle.node_id) + for bundle in placement_group_info.bundles + }, + "strategy": get_strategy(placement_group_info.strategy), + "state": get_state(placement_group_info.state), + "stats": { + "end_to_end_creation_latency_ms": ( + stats.end_to_end_creation_latency_us / 1000.0 + ), + "scheduling_latency_ms": (stats.scheduling_latency_us / 1000.0), + "scheduling_attempt": stats.scheduling_attempt, + "highest_retry_delay_ms": stats.highest_retry_delay_ms, + "scheduling_state": gcs_pb2.PlacementGroupStats.SchedulingState.DESCRIPTOR.values_by_number[ # noqa: E501 + stats.scheduling_state + ].name, + }, + } + + def _nanoseconds_to_microseconds(self, time_in_nanoseconds): + """A helper function for converting nanoseconds to microseconds.""" + time_in_microseconds = time_in_nanoseconds / 1000 + return time_in_microseconds + + # Colors are specified at + # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 + _default_color_mapping = defaultdict( + lambda: "generic_work", + { + "worker_idle": "cq_build_abandoned", + "task": "rail_response", + "task:deserialize_arguments": "rail_load", + "task:execute": "rail_animation", + "task:store_outputs": "rail_idle", + "wait_for_function": "detailed_memory_dump", + "ray.get": "good", + "ray.put": "terrible", + "ray.wait": "vsync_highlight_color", + "submit_task": "background_memory_dump", + "fetch_and_run_function": "detailed_memory_dump", + "register_remote_function": "detailed_memory_dump", + }, + ) + + # These colors are for use in Chrome tracing. + _chrome_tracing_colors = [ + "thread_state_uninterruptible", + "thread_state_iowait", + "thread_state_running", + "thread_state_runnable", + "thread_state_sleeping", + "thread_state_unknown", + "background_memory_dump", + "light_memory_dump", + "detailed_memory_dump", + "vsync_highlight_color", + "generic_work", + "good", + "bad", + "terrible", + # "black", + # "grey", + # "white", + "yellow", + "olive", + "rail_response", + "rail_animation", + "rail_idle", + "rail_load", + "startup", + "heap_dump_stack_frame", + "heap_dump_object_type", + "heap_dump_child_node_arrow", + "cq_build_running", + "cq_build_passed", + "cq_build_failed", + "cq_build_abandoned", + "cq_build_attempt_runnig", + "cq_build_attempt_passed", + "cq_build_attempt_failed", + ] + + def chrome_tracing_dump(self, filename=None): + """Return a list of profiling events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file + by passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + Make sure to enable "Flow events" in the "View Options" menu. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling + events. Each profile event is a dictionary. + """ + # TODO(rkn): Support including the task specification data in the + # timeline. + # TODO(rkn): This should support viewing just a window of time or a + # limited number of events. + + self._check_connected() + + # Add a small delay to account for propagation delay of events to the GCS. + # This should be harmless enough but prevents calls to timeline() from + # missing recent timeline data. + import time + + time.sleep(1) + + profile_events = self.profile_events() + all_events = [] + + for component_id_hex, component_events in profile_events.items(): + # Only consider workers and drivers. + component_type = component_events[0]["component_type"] + if component_type not in ["worker", "driver"]: + continue + + for event in component_events: + new_event = { + # The category of the event. + "cat": event["event_type"], + # The string displayed on the event. + "name": event["event_type"], + # The identifier for the group of rows that the event + # appears in. + "pid": event["node_ip_address"], + # The identifier for the row that the event appears in. + "tid": event["component_type"] + ":" + event["component_id"], + # The start time in microseconds. + "ts": self._nanoseconds_to_microseconds(event["start_time"]), + # The duration in microseconds. + "dur": self._nanoseconds_to_microseconds( + event["end_time"] - event["start_time"] + ), + # What is this? + "ph": "X", + # This is the name of the color to display the box in. + "cname": self._default_color_mapping[event["event_type"]], + # The extra user-defined data. + "args": event["extra_data"], + } + + # Modify the json with the additional user-defined extra data. + # This can be used to add fields or override existing fields. + if "cname" in event["extra_data"]: + new_event["cname"] = event["extra_data"]["cname"] + if "name" in event["extra_data"]: + new_event["name"] = event["extra_data"]["name"] + + all_events.append(new_event) + + if not all_events: + logger.warning( + "No profiling events found. Ray profiling must be enabled " + "by setting RAY_PROFILING=1, and make sure " + "RAY_task_events_report_interval_ms=0." + ) + + if filename is not None: + with open(filename, "w") as outfile: + json.dump(all_events, outfile) + else: + return all_events + + def chrome_tracing_object_transfer_dump(self, filename=None): + """Return a list of transfer events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file + by passing in "filename" or using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + Make sure to enable "Flow events" in the "View Options" menu. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling + events. Each profile event is a dictionary. + """ + self._check_connected() + + node_id_to_address = {} + for node_info in self.node_table(): + node_id_to_address[node_info["NodeID"]] = "{}:{}".format( + node_info["NodeManagerAddress"], node_info["ObjectManagerPort"] + ) + + all_events = [] + + for key, items in self.profile_events().items(): + # Only consider object manager events. + if items[0]["component_type"] != "object_manager": + continue + + for event in items: + if event["event_type"] == "transfer_send": + object_ref, remote_node_id, _, _ = event["extra_data"] + + elif event["event_type"] == "transfer_receive": + object_ref, remote_node_id, _ = event["extra_data"] + + elif event["event_type"] == "receive_pull_request": + object_ref, remote_node_id = event["extra_data"] + + else: + assert False, "This should be unreachable." + + # Choose a color by reading the first couple of hex digits of + # the object ref as an integer and turning that into a color. + object_ref_int = int(object_ref[:2], 16) + color = self._chrome_tracing_colors[ + object_ref_int % len(self._chrome_tracing_colors) + ] + + new_event = { + # The category of the event. + "cat": event["event_type"], + # The string displayed on the event. + "name": event["event_type"], + # The identifier for the group of rows that the event + # appears in. + "pid": node_id_to_address[key], + # The identifier for the row that the event appears in. + "tid": node_id_to_address[remote_node_id], + # The start time in microseconds. + "ts": self._nanoseconds_to_microseconds(event["start_time"]), + # The duration in microseconds. + "dur": self._nanoseconds_to_microseconds( + event["end_time"] - event["start_time"] + ), + # What is this? + "ph": "X", + # This is the name of the color to display the box in. + "cname": color, + # The extra user-defined data. + "args": event["extra_data"], + } + all_events.append(new_event) + + # Add another box with a color indicating whether it was a send + # or a receive event. + if event["event_type"] == "transfer_send": + additional_event = new_event.copy() + additional_event["cname"] = "black" + all_events.append(additional_event) + elif event["event_type"] == "transfer_receive": + additional_event = new_event.copy() + additional_event["cname"] = "grey" + all_events.append(additional_event) + else: + pass + + if filename is not None: + with open(filename, "w") as outfile: + json.dump(all_events, outfile) + else: + return all_events + + def workers(self): + """Get a dictionary mapping worker ID to worker information.""" + self._check_connected() + + # Get all data in worker table + worker_table = self.global_state_accessor.get_worker_table() + workers_data = {} + for i in range(len(worker_table)): + worker_table_data = gcs_pb2.WorkerTableData.FromString(worker_table[i]) + if ( + worker_table_data.is_alive + and worker_table_data.worker_type == common_pb2.WORKER + ): + worker_id = binary_to_hex(worker_table_data.worker_address.worker_id) + worker_info = worker_table_data.worker_info + + workers_data[worker_id] = { + "node_ip_address": decode(worker_info[b"node_ip_address"]), + "plasma_store_socket": decode(worker_info[b"plasma_store_socket"]), + } + if b"stderr_file" in worker_info: + workers_data[worker_id]["stderr_file"] = decode( + worker_info[b"stderr_file"] + ) + if b"stdout_file" in worker_info: + workers_data[worker_id]["stdout_file"] = decode( + worker_info[b"stdout_file"] + ) + return workers_data + + def add_worker(self, worker_id, worker_type, worker_info): + """Add a worker to the cluster. + + Args: + worker_id: ID of this worker. Type is bytes. + worker_type: Type of this worker. Value is common_pb2.DRIVER or + common_pb2.WORKER. + worker_info: Info of this worker. Type is dict{str: str}. + + Returns: + Is operation success + """ + worker_data = gcs_pb2.WorkerTableData() + worker_data.is_alive = True + worker_data.worker_address.worker_id = worker_id + worker_data.worker_type = worker_type + for k, v in worker_info.items(): + worker_data.worker_info[k] = bytes(v, encoding="utf-8") + return self.global_state_accessor.add_worker_info( + worker_data.SerializeToString() + ) + + def update_worker_debugger_port(self, worker_id, debugger_port): + """Update the debugger port of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + debugger_port: Port of the debugger. Type is int. + + Returns: + Is operation success + """ + self._check_connected() + + assert worker_id is not None, "worker_id is not valid" + assert ( + debugger_port is not None and debugger_port > 0 + ), "debugger_port is not valid" + + return self.global_state_accessor.update_worker_debugger_port( + worker_id, debugger_port + ) + + def get_worker_debugger_port(self, worker_id): + """Get the debugger port of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + + Returns: + Debugger port of the worker. + """ + self._check_connected() + + assert worker_id is not None, "worker_id is not valid" + + return self.global_state_accessor.get_worker_debugger_port(worker_id) + + def update_worker_num_paused_threads(self, worker_id, num_paused_threads_delta): + """Updates the number of paused threads of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + num_paused_threads_delta: The delta of the number of paused threads. + + Returns: + Is operation success + """ + self._check_connected() + + assert worker_id is not None, "worker_id is not valid" + assert num_paused_threads_delta is not None, "worker_id is not valid" + + return self.global_state_accessor.update_worker_num_paused_threads( + worker_id, num_paused_threads_delta + ) + + def cluster_resources(self): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or + removed from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + self._check_connected() + + # Calculate total resources. + total_resources = defaultdict(int) + for node_total_resources in self.total_resources_per_node().values(): + for resource_id, value in node_total_resources.items(): + total_resources[resource_id] += value + + return dict(total_resources) + + def _live_node_ids(self): + """Returns a set of node IDs corresponding to nodes still alive.""" + return set(self.total_resources_per_node().keys()) + + def available_resources_per_node(self): + """Returns a dictionary mapping node id to available resources.""" + self._check_connected() + available_resources_by_id = {} + + all_available_resources = ( + self.global_state_accessor.get_all_available_resources() + ) + for available_resource in all_available_resources: + message = gcs_pb2.AvailableResources.FromString(available_resource) + # Calculate available resources for this node. + dynamic_resources = {} + for resource_id, capacity in message.resources_available.items(): + dynamic_resources[resource_id] = capacity + # Update available resources for this node. + node_id = ray._private.utils.binary_to_hex(message.node_id) + available_resources_by_id[node_id] = dynamic_resources + + return available_resources_by_id + + # returns a dict that maps node_id(hex string) to a dict of {resource_id: capacity} + def total_resources_per_node(self) -> Dict[str, Dict[str, int]]: + self._check_connected() + total_resources_by_node = {} + + all_total_resources = self.global_state_accessor.get_all_total_resources() + for node_total_resources in all_total_resources: + message = gcs_pb2.TotalResources.FromString(node_total_resources) + # Calculate total resources for this node. + node_resources = {} + for resource_id, capacity in message.resources_total.items(): + node_resources[resource_id] = capacity + # Update total resources for this node. + node_id = ray._private.utils.binary_to_hex(message.node_id) + total_resources_by_node[node_id] = node_resources + + return total_resources_by_node + + def available_resources(self): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return + idle (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. Note that if a resource (e.g., "CPU") + is currently not available (i.e., quantity is 0), it will not + be included in this dictionary. + """ + self._check_connected() + + available_resources_by_id = self.available_resources_per_node() + + # Calculate total available resources. + total_available_resources = defaultdict(int) + for available_resources in available_resources_by_id.values(): + for resource_id, num_available in available_resources.items(): + total_available_resources[resource_id] += num_available + + return dict(total_available_resources) + + def get_system_config(self): + """Get the system config of the cluster.""" + self._check_connected() + return json.loads(self.global_state_accessor.get_system_config()) + + def get_node_to_connect_for_driver(self, node_ip_address): + """Get the node to connect for a Ray driver.""" + self._check_connected() + return self.global_state_accessor.get_node_to_connect_for_driver( + node_ip_address + ) + + def get_node(self, node_id: str): + """Get the node information for a node id.""" + self._check_connected() + return self.global_state_accessor.get_node(node_id) + + def get_draining_nodes(self) -> Dict[str, int]: + """Get all the hex ids of nodes that are being drained + and the corresponding draining deadline timestamps in ms. + + There is no deadline if the timestamp is 0. + """ + self._check_connected() + return self.global_state_accessor.get_draining_nodes() + + def get_cluster_config(self) -> autoscaler_pb2.ClusterConfig: + """Get the cluster config of the current cluster.""" + self._check_connected() + serialized_cluster_config = self.global_state_accessor.get_internal_kv( + ray._raylet.GCS_AUTOSCALER_STATE_NAMESPACE.encode(), + ray._raylet.GCS_AUTOSCALER_CLUSTER_CONFIG_KEY.encode(), + ) + if serialized_cluster_config: + return autoscaler_pb2.ClusterConfig.FromString(serialized_cluster_config) + return None + + +state = GlobalState() +"""A global object used to access the cluster's global state.""" + + +def jobs(): + """Get a list of the jobs in the cluster (for debugging only). + + Returns: + Information from the job table, namely a list of dicts with keys: + - "JobID" (identifier for the job), + - "DriverIPAddress" (IP address of the driver for this job), + - "DriverPid" (process ID of the driver for this job), + - "StartTime" (UNIX timestamp of the start time of this job), + - "StopTime" (UNIX timestamp of the stop time of this job, if any) + """ + return state.job_table() + + +def next_job_id(): + """Get next job id from GCS. + + Returns: + Next job id in integer representation in the cluster. + """ + return state.next_job_id() + + +@DeveloperAPI +@client_mode_hook +def nodes(): + """Get a list of the nodes in the cluster (for debugging only). + + Returns: + Information about the Ray clients in the cluster. + """ + return state.node_table() + + +def workers(): + """Get a list of the workers in the cluster. + + Returns: + Information about the Ray workers in the cluster. + """ + return state.workers() + + +def current_node_id(): + """Return the node id of the current node. + + For example, "node:172.10.5.34". This can be used as a custom resource, + e.g., {node_id: 1} to reserve the whole node, or {node_id: 0.001} to + just force placement on the node. + + Returns: + Id of the current node. + """ + return NODE_ID_PREFIX + ray.util.get_node_ip_address() + + +def node_ids(): + """Get a list of the node ids in the cluster. + + For example, ["node:172.10.5.34", "node:172.42.3.77"]. These can be used + as custom resources, e.g., {node_id: 1} to reserve the whole node, or + {node_id: 0.001} to just force placement on the node. + + Returns: + List of the node resource ids. + """ + node_ids = [] + for node_total_resources in state.total_resources_per_node().values(): + for resource_id in node_total_resources.keys(): + if ( + resource_id.startswith(NODE_ID_PREFIX) + and resource_id != HEAD_NODE_RESOURCE_NAME + ): + node_ids.append(resource_id) + return node_ids + + +def actors( + actor_id: str = None, job_id: ray.JobID = None, actor_state_name: str = None +): + """Fetch actor info for one or more actor IDs (for debugging only). + + Args: + actor_id: A hex string of the actor ID to fetch information about. If + this is None, then all actor information is fetched. + If this is not None, `job_id` and `actor_state_name` + will not take effect. + job_id: To filter actors by job_id, which is of type `ray.JobID`. + You can use the `ray.get_runtime_context().job_id` function + to get the current job ID + actor_state_name: To filter actors based on actor state, + which can be one of the following: "DEPENDENCIES_UNREADY", + "PENDING_CREATION", "ALIVE", "RESTARTING", or "DEAD". + Returns: + Information about the actors. + """ + return state.actor_table( + actor_id=actor_id, job_id=job_id, actor_state_name=actor_state_name + ) + + +@DeveloperAPI +@client_mode_hook +def timeline(filename=None): + """Return a list of profiling events that can viewed as a timeline. + + Ray profiling must be enabled by setting the RAY_PROFILING=1 environment + variable prior to starting Ray, and set RAY_task_events_report_interval_ms=0 + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_dump(filename=filename) + + +def object_transfer_timeline(filename=None): + """Return a list of transfer events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. Make + sure to enable "Flow events" in the "View Options" menu. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_object_transfer_dump(filename=filename) + + +@DeveloperAPI +@client_mode_hook +def cluster_resources(): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or removed + from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return state.cluster_resources() + + +@DeveloperAPI +@client_mode_hook +def available_resources(): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return idle + (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. Note that if a resource (e.g., "CPU") + is currently not available (i.e., quantity is 0), it will not + be included in this dictionary. + """ + return state.available_resources() + + +@DeveloperAPI +def available_resources_per_node(): + """Get the current available resources of each live node. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping node hex id to available resources dictionary. + """ + + return state.available_resources_per_node() + + +@DeveloperAPI +def total_resources_per_node(): + """Get the current total resources of each live node. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping node hex id to total resources dictionary. + """ + + return state.total_resources_per_node() + + +def update_worker_debugger_port(worker_id, debugger_port): + """Update the debugger port of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + debugger_port: Port of the debugger. Type is int. + + Returns: + Is operation success + """ + return state.update_worker_debugger_port(worker_id, debugger_port) + + +def update_worker_num_paused_threads(worker_id, num_paused_threads_delta): + """Update the number of paused threads of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + num_paused_threads_delta: The delta of the number of paused threads. + + Returns: + Is operation success + """ + return state.update_worker_num_paused_threads(worker_id, num_paused_threads_delta) + + +def get_worker_debugger_port(worker_id): + """Get the debugger port of a worker. + + Args: + worker_id: ID of this worker. Type is bytes. + + Returns: + Debugger port of the worker. + """ + return state.get_worker_debugger_port(worker_id) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/state_api_test_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/state_api_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14f494c0e3ad889f41e483824c86a2885cbda4d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/state_api_test_utils.py @@ -0,0 +1,499 @@ +import asyncio +import sys +from copy import deepcopy +from collections import defaultdict +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +import logging +import numpy as np +import pprint +import time +import traceback +from typing import Callable, Dict, List, Optional, Tuple, Union +from ray.util.state import list_tasks +import ray +from ray.actor import ActorHandle +from ray.util.state import list_workers + +from ray._private.gcs_utils import GcsAioClient, GcsChannel +from ray.util.state.state_manager import StateDataSourceClient +from ray.dashboard.state_aggregator import ( + StateAPIManager, +) +from ray.util.state.common import ( + DEFAULT_LIMIT, + DEFAULT_RPC_TIMEOUT, + ListApiOptions, + PredicateType, + SupportedFilterType, +) + + +@dataclass +class StateAPIMetric: + latency_sec: float + result_size: int + + +@dataclass +class StateAPICallSpec: + api: Callable + verify_cb: Callable + kwargs: Dict = field(default_factory=dict) + + +@dataclass +class StateAPIStats: + pending_calls: int = 0 + total_calls: int = 0 + calls: Dict = field(default_factory=lambda: defaultdict(list)) + + +GLOBAL_STATE_STATS = StateAPIStats() + +STATE_LIST_LIMIT = int(1e6) # 1m +STATE_LIST_TIMEOUT = 600 # 10min + + +def invoke_state_api( + verify_cb: Callable, + state_api_fn: Callable, + state_stats: StateAPIStats = GLOBAL_STATE_STATS, + key_suffix: Optional[str] = None, + print_result: Optional[bool] = False, + err_msg: Optional[str] = None, + **kwargs, +): + """Invoke a State API + + Args: + - verify_cb: Callback that takes in the response from `state_api_fn` and + returns a boolean, indicating the correctness of the results. + - state_api_fn: Function of the state API + - state_stats: Stats + - kwargs: Keyword arguments to be forwarded to the `state_api_fn` + """ + if "timeout" not in kwargs: + kwargs["timeout"] = STATE_LIST_TIMEOUT + + # Suppress missing output warning + kwargs["raise_on_missing_output"] = False + + res = None + try: + state_stats.total_calls += 1 + state_stats.pending_calls += 1 + + t_start = time.perf_counter() + res = state_api_fn(**kwargs) + t_end = time.perf_counter() + + if print_result: + pprint.pprint(res) + + metric = StateAPIMetric(t_end - t_start, len(res)) + if key_suffix: + key = f"{state_api_fn.__name__}_{key_suffix}" + else: + key = state_api_fn.__name__ + state_stats.calls[key].append(metric) + assert verify_cb( + res + ), f"Calling State API failed. len(res)=({len(res)}): {err_msg}" + except Exception as e: + traceback.print_exc() + assert ( + False + ), f"Calling {state_api_fn.__name__}({kwargs}) failed with {repr(e)}." + finally: + state_stats.pending_calls -= 1 + + return res + + +def aggregate_perf_results(state_stats: StateAPIStats = GLOBAL_STATE_STATS): + """Aggregate stats of state API calls + + Return: + This returns a dict of below fields: + - max_{api_key_name}_latency_sec: + Max latency of call to {api_key_name} + - {api_key_name}_result_size_with_max_latency: + The size of the result (or the number of bytes for get_log API) + for the max latency invocation + - avg/p99/p95/p50_{api_key_name}_latency_sec: + The percentile latency stats + - avg_state_api_latency_sec: + The average latency of all the state apis tracked + """ + # Prevent iteration when modifying error + state_stats = deepcopy(state_stats) + perf_result = {} + for api_key_name, metrics in state_stats.calls.items(): + # Per api aggregation + # Max latency + latency_key = f"max_{api_key_name}_latency_sec" + size_key = f"{api_key_name}_result_size_with_max_latency" + metric = max(metrics, key=lambda metric: metric.latency_sec) + + perf_result[latency_key] = metric.latency_sec + perf_result[size_key] = metric.result_size + + latency_list = np.array([metric.latency_sec for metric in metrics]) + # avg latency + key = f"avg_{api_key_name}_latency_sec" + perf_result[key] = np.average(latency_list) + + # p99 latency + key = f"p99_{api_key_name}_latency_sec" + perf_result[key] = np.percentile(latency_list, 99) + + # p95 latency + key = f"p95_{api_key_name}_latency_sec" + perf_result[key] = np.percentile(latency_list, 95) + + # p50 latency + key = f"p50_{api_key_name}_latency_sec" + perf_result[key] = np.percentile(latency_list, 50) + + all_state_api_latency = sum( + metric.latency_sec + for metric_samples in state_stats.calls.values() + for metric in metric_samples + ) + + perf_result["avg_state_api_latency_sec"] = ( + (all_state_api_latency / state_stats.total_calls) + if state_stats.total_calls != 0 + else -1 + ) + + return perf_result + + +@ray.remote(num_cpus=0) +class StateAPIGeneratorActor: + def __init__( + self, + apis: List[StateAPICallSpec], + call_interval_s: float = 5.0, + print_interval_s: float = 20.0, + wait_after_stop: bool = True, + print_result: bool = False, + ) -> None: + """An actor that periodically issues state API + + Args: + - apis: List of StateAPICallSpec + - call_interval_s: State apis in the `apis` will be issued + every `call_interval_s` seconds. + - print_interval_s: How frequent state api stats will be dumped. + - wait_after_stop: When true, call to `ray.get(actor.stop.remote())` + will wait for all pending state APIs to return. + Setting it to `False` might miss some long-running state apis calls. + - print_result: True if result of each API call is printed. Default False. + """ + # Configs + self._apis = apis + self._call_interval_s = call_interval_s + self._print_interval_s = print_interval_s + self._wait_after_cancel = wait_after_stop + self._logger = logging.getLogger(self.__class__.__name__) + self._print_result = print_result + + # States + self._tasks = None + self._fut_queue = None + self._executor = None + self._loop = None + self._stopping = False + self._stopped = False + self._stats = StateAPIStats() + + async def start(self): + # Run the periodic api generator + self._fut_queue = asyncio.Queue() + self._executor = concurrent.futures.ThreadPoolExecutor() + + self._tasks = [ + asyncio.ensure_future(awt) + for awt in [ + self._run_generator(), + self._run_result_waiter(), + self._run_stats_reporter(), + ] + ] + await asyncio.gather(*self._tasks) + + def call(self, fn, verify_cb, **kwargs): + def run_fn(): + try: + self._logger.debug(f"calling {fn.__name__}({kwargs})") + return invoke_state_api( + verify_cb, + fn, + state_stats=self._stats, + print_result=self._print_result, + **kwargs, + ) + except Exception as e: + self._logger.warning(f"{fn.__name__}({kwargs}) failed with: {repr(e)}") + return None + + fut = asyncio.get_running_loop().run_in_executor(self._executor, run_fn) + return fut + + async def _run_stats_reporter(self): + while not self._stopped: + # Keep the reporter running until all pending apis finish and the bool + # `self._stopped` is then True + self._logger.info(pprint.pprint(aggregate_perf_results(self._stats))) + try: + await asyncio.sleep(self._print_interval_s) + except asyncio.CancelledError: + self._logger.info( + "_run_stats_reporter cancelled, " + f"waiting for all api {self._stats.pending_calls}calls to return..." + ) + + async def _run_generator(self): + try: + while not self._stopping: + # Run the state API in another thread + for api_spec in self._apis: + fut = self.call(api_spec.api, api_spec.verify_cb, **api_spec.kwargs) + self._fut_queue.put_nowait(fut) + + await asyncio.sleep(self._call_interval_s) + except asyncio.CancelledError: + # Stop running + self._logger.info("_run_generator cancelled, now stopping...") + return + + async def _run_result_waiter(self): + try: + while not self._stopping: + fut = await self._fut_queue.get() + await fut + except asyncio.CancelledError: + self._logger.info( + f"_run_result_waiter cancelled, cancelling {self._fut_queue.qsize()} " + "pending futures..." + ) + while not self._fut_queue.empty(): + fut = self._fut_queue.get_nowait() + if self._wait_after_cancel: + await fut + else: + # Ignore the queue futures if we are not + # waiting on them after stop() called + fut.cancel() + return + + def get_stats(self): + # deep copy to prevent race between reporting and modifying stats + return aggregate_perf_results(self._stats) + + def ready(self): + pass + + def stop(self): + self._stopping = True + self._logger.debug(f"calling stop, canceling {len(self._tasks)} tasks") + for task in self._tasks: + task.cancel() + + # This will block the stop() function until all futures are cancelled + # if _wait_after_cancel=True. When _wait_after_cancel=False, it will still + # wait for any in-progress futures. + # See: https://docs.python.org/3.8/library/concurrent.futures.html + self._executor.shutdown(wait=self._wait_after_cancel) + self._stopped = True + + +def periodic_invoke_state_apis_with_actor(*args, **kwargs) -> ActorHandle: + current_node_ip = ray._private.worker.global_worker.node_ip_address + # Schedule the actor on the current node. + actor = StateAPIGeneratorActor.options( + resources={f"node:{current_node_ip}": 0.001} + ).remote(*args, **kwargs) + print("Waiting for state api actor to be ready...") + ray.get(actor.ready.remote()) + print("State api actor is ready now.") + actor.start.remote() + return actor + + +def get_state_api_manager(gcs_address: str) -> StateAPIManager: + gcs_aio_client = GcsAioClient(address=gcs_address) + gcs_channel = GcsChannel(gcs_address=gcs_address, aio=True) + gcs_channel.connect() + state_api_data_source_client = StateDataSourceClient( + gcs_channel.channel(), gcs_aio_client + ) + return StateAPIManager( + state_api_data_source_client, + thread_pool_executor=ThreadPoolExecutor( + thread_name_prefix="state_api_test_utils" + ), + ) + + +def summarize_worker_startup_time(): + workers = list_workers( + detail=True, + filters=[("worker_type", "=", "WORKER")], + limit=10000, + raise_on_missing_output=False, + ) + time_to_launch = [] + time_to_initialize = [] + for worker in workers: + launch_time = worker.get("worker_launch_time_ms") + launched_time = worker.get("worker_launched_time_ms") + start_time = worker.get("start_time_ms") + + if launched_time > 0: + time_to_launch.append(launched_time - launch_time) + if start_time: + time_to_initialize.append(start_time - launched_time) + time_to_launch.sort() + time_to_initialize.sort() + + def print_latencies(latencies): + print(f"Avg: {round(sum(latencies) / len(latencies), 2)} ms") + print(f"P25: {round(latencies[int(len(latencies) * 0.25)], 2)} ms") + print(f"P50: {round(latencies[int(len(latencies) * 0.5)], 2)} ms") + print(f"P95: {round(latencies[int(len(latencies) * 0.95)], 2)} ms") + print(f"P99: {round(latencies[int(len(latencies) * 0.99)], 2)} ms") + + print("Time to launch workers") + print_latencies(time_to_launch) + print("=======================") + print("Time to initialize workers") + print_latencies(time_to_initialize) + + +def verify_failed_task( + name: str, error_type: str, error_message: Union[str, List[str]] +) -> bool: + """ + Check if a task with 'name' has failed with the exact error type 'error_type' + and 'error_message' in the error message. + """ + tasks = list_tasks(filters=[("name", "=", name)], detail=True) + assert len(tasks) == 1, tasks + t = tasks[0] + assert t["state"] == "FAILED", t + assert t["error_type"] == error_type, t + if isinstance(error_message, str): + error_message = [error_message] + for msg in error_message: + assert msg in t.get("error_message", None), t + return True + + +@ray.remote +class PidActor: + def __init__(self): + self.name_to_pid = {} + + def get_pids(self): + return self.name_to_pid + + def report_pid(self, name, pid, state=None): + self.name_to_pid[name] = (pid, state) + + +def verify_tasks_running_or_terminated( + task_pids: Dict[str, Tuple[int, Optional[str]]], expect_num_tasks: int +): + """ + Check if the tasks in task_pids are in RUNNING state if pid exists + and running the task. + If the pid is missing or the task is not running the task, check if the task + is marked FAILED or FINISHED. + + Args: + task_pids: A dict of task name to (pid, expected terminal state). + + """ + import psutil + + assert len(task_pids) == expect_num_tasks, task_pids + for task_name, pid_and_state in task_pids.items(): + tasks = list_tasks(detail=True, filters=[("name", "=", task_name)]) + assert len(tasks) == 1, ( + f"One unique task with {task_name} should be found. " + "Use `options(name=)` when creating the task." + ) + task = tasks[0] + pid, expected_state = pid_and_state + + # If it's windows/macos, we don't have a way to check if the process + # is actually running the task since the process name is just python, + # rather than the actual task name. + if sys.platform in ["win32", "darwin"]: + if expected_state is not None: + assert task["state"] == expected_state, task + continue + if psutil.pid_exists(pid) and task_name in psutil.Process(pid).name(): + assert ( + "ray::IDLE" not in task["name"] + ), "One should not name it 'IDLE' since it's reserved in Ray" + assert task["state"] == "RUNNING", task + if expected_state is not None: + assert task["state"] == expected_state, task + else: + # Tasks no longer running. + if expected_state is None: + assert task["state"] in [ + "FAILED", + "FINISHED", + ], f"{task_name}: {task['task_id']} = {task['state']}" + else: + assert ( + task["state"] == expected_state + ), f"expect {expected_state} but {task['state']} for {task}" + + return True + + +def verify_schema(state, result_dict: dict, detail: bool = False): + """ + Verify the schema of the result_dict is the same as the state. + """ + state_fields_columns = set() + if detail: + state_fields_columns = state.columns() + else: + state_fields_columns = state.base_columns() + + for k in state_fields_columns: + assert k in result_dict + + for k in result_dict: + assert k in state_fields_columns + + # Make the field values can be converted without error as well + state(**result_dict) + + +def create_api_options( + timeout: int = DEFAULT_RPC_TIMEOUT, + limit: int = DEFAULT_LIMIT, + filters: List[Tuple[str, PredicateType, SupportedFilterType]] = None, + detail: bool = False, + exclude_driver: bool = True, +): + if not filters: + filters = [] + return ListApiOptions( + limit=limit, + timeout=timeout, + filters=filters, + server_timeout_multiplier=1.0, + detail=detail, + exclude_driver=exclude_driver, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/storage.py b/.venv/lib/python3.11/site-packages/ray/_private/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..56f6987b88e946024f7e1bd5490709e6587e53ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/storage.py @@ -0,0 +1,491 @@ +import os +import re +import urllib +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional + +from ray._private.client_mode_hook import client_mode_hook +from ray._private.utils import _add_creatable_buckets_param_if_s3_uri, load_class +from ray._private.auto_init_hook import wrap_auto_init + +if TYPE_CHECKING: + import pyarrow.fs + + +# The full storage argument specified, e.g., in ``ray.init(storage="s3://foo/bar")`` +# This is set immediately on Ray worker init. +_storage_uri = None + +# The storage prefix, e.g., "foo/bar" under which files should be written. +# This is set lazily the first time storage is accessed on a worker. +_storage_prefix = None + +# The pyarrow.fs.FileSystem instantiated for the storage. +# This is set lazily the first time storage is accessed on a worker. +_filesystem = None + + +@wrap_auto_init +@client_mode_hook +def get_filesystem() -> ("pyarrow.fs.FileSystem", str): + """Initialize and get the configured storage filesystem, if possible. + + This method can be called from any Ray worker to get a reference to the configured + storage filesystem. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + fs, path = storage.get_filesystem() + print(fs) + print(path) + + .. testoutput:: + + + /tmp/storage/cluster_1/storage + + Returns: + Tuple of pyarrow filesystem instance and the path under which files should + be created for this cluster. + + Raises: + RuntimeError if storage has not been configured or init failed. + """ + return _get_filesystem_internal() + + +# TODO(suquark): There is no implementation of 'get_client' in client hook. +@wrap_auto_init +@client_mode_hook +def get_client(prefix: str) -> "KVClient": + """Returns a KV-client (convenience wrapper around underlying filesystem). + + Args: + prefix: Path prefix (e.g., "foo", "foo/bar") that defines the sub-directory + data will be stored under. All writes will be scoped to this sub-dir. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + client = storage.get_client("foo") + client.put("foo", b"bar") + + Returns: + KVClient. + """ + if not prefix: + raise ValueError("A directory prefix must be specified.") + fs, base_prefix = get_filesystem() + combined_prefix = os.path.join(base_prefix, prefix) + return KVClient(fs, combined_prefix) + + +def _is_os_error_file_not_found(err: OSError) -> bool: + """Instead of "FileNotFoundError", pyarrow S3 filesystem raises + OSError starts with "Path does not exist" for some of its APIs. + + # TODO(suquark): Delete this function after pyarrow handles missing files + in a consistent way. + """ + return ( + len(err.args) > 0 + and isinstance(err.args[0], str) + and err.args[0].startswith("Path does not exist") + ) + + +class KVClient: + """Simple KV API built on the underlying filesystem. + + This is a convenience wrapper around get_filesystem() and working with files. + Slashes in the path are interpreted as directory delimiters. + """ + + def __init__(self, fs: "pyarrow.fs.FileSystem", prefix: str): + """Use storage.get_client() to construct KVClient.""" + self.fs = fs + self.root = Path(prefix) + + def put(self, path: str, value: bytes) -> None: + """Save a blob in persistent storage at the given path, if possible. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + client = storage.get_client("my_app") + client.put("path/foo.txt", b"bar") + + Args: + path: Relative directory of the blobs. + value: String value to save. + """ + full_path = self._resolve_path(path) + parent_dir = os.path.dirname(full_path) + try: + with self.fs.open_output_stream(full_path) as f: + f.write(value) + except FileNotFoundError: + # Directory likely doesn't exist; retry after creating it. + self.fs.create_dir(parent_dir) + with self.fs.open_output_stream(full_path) as f: + f.write(value) + + def get(self, path: str) -> bytes: + """Load a blob from persistent storage at the given path, if possible. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + + client = storage.get_client("my_app") + client.put("path/foo.txt", b"bar") + assert client.get("path/foo.txt") == b"bar" + assert client.get("invalid") is None + + Args: + path: Relative directory of the blobs. + + Returns: + String content of the blob, or None if not found. + """ + full_path = self._resolve_path(path) + try: + with self.fs.open_input_stream(full_path) as f: + return f.read() + except FileNotFoundError: + return None + except OSError as e: + if _is_os_error_file_not_found(e): + return None + raise e + + def delete(self, path: str) -> bool: + """Load the blob from persistent storage at the given path, if possible. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + + client = storage.get_client("my_app") + client.put("path/foo.txt", b"bar") + assert client.delete("path/foo.txt") + + Args: + path: Relative directory of the blob. + + Returns: + Whether the blob was deleted. + """ + full_path = self._resolve_path(path) + try: + self.fs.delete_file(full_path) + return True + except FileNotFoundError: + return False + except OSError as e: + if _is_os_error_file_not_found(e): + return False + raise e + + def delete_dir(self, path: str) -> bool: + """Delete a directory and its contents, recursively. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + + client = storage.get_client("my_app") + client.put("path/foo.txt", b"bar") + assert client.delete_dir("path") + + Args: + path: Relative directory of the blob. + + Returns: + Whether the dir was deleted. + """ + full_path = self._resolve_path(path) + try: + self.fs.delete_dir(full_path) + return True + except FileNotFoundError: + return False + except OSError as e: + if _is_os_error_file_not_found(e): + return False + raise e + + def get_info(self, path: str) -> Optional["pyarrow.fs.FileInfo"]: + """Get info about the persistent blob at the given path, if possible. + + Examples: + .. testcode:: + + import ray + from ray._private import storage + + ray.shutdown() + + ray.init(storage="/tmp/storage/cluster_1/storage") + + client = storage.get_client("my_app") + client.put("path/foo.txt", b"bar") + + print(client.get_info("path/foo.txt")) + + print(client.get_info("path/does_not_exist.txt")) + + .. testoutput:: + + + None + + Args: + path: Relative directory of the blob. + + Returns: + Info about the blob, or None if it doesn't exist. + """ + import pyarrow.fs + + full_path = self._resolve_path(path) + info = self.fs.get_file_info([full_path])[0] + if info.type == pyarrow.fs.FileType.NotFound: + return None + return info + + def list( + self, + path: str, + ) -> List["pyarrow.fs.FileInfo"]: + """List blobs and sub-dirs in the given path, if possible. + + Examples: + + >>> import ray + >>> from ray._private import storage + >>> ray.shutdown() + + Normal usage. + + >>> ray.init(storage="/tmp/storage/cluster_1/storage") + RayContext(...) + >>> client = storage.get_client("my_app") + >>> client.put("path/foo.txt", b"bar") + >>> client.list("path") + [] + + Non-existent path. + + >>> client.list("does_not_exist") + Traceback (most recent call last): + ... + FileNotFoundError: ... No such file or directory + + Not a directory. + + >>> client.list("path/foo.txt") + Traceback (most recent call last): + ... + NotADirectoryError: ... Not a directory + + Args: + path: Relative directory to list from. + + Returns: + List of file-info objects for the directory contents. + + Raises: + FileNotFoundError if the given path is not found. + NotADirectoryError if the given path isn't a valid directory. + """ + from pyarrow.fs import FileSelector, FileType, LocalFileSystem + + full_path = self._resolve_path(path) + selector = FileSelector(full_path, recursive=False) + try: + files = self.fs.get_file_info(selector) + except FileNotFoundError as e: + raise e + except OSError as e: + if _is_os_error_file_not_found(e): + raise FileNotFoundError(*e.args) + raise e + if self.fs is not LocalFileSystem and not files: + # TODO(suquark): pyarrow does not raise "NotADirectoryError" + # for non-local filesystems like S3. Check and raise it here. + info = self.fs.get_file_info([full_path])[0] + if info.type == FileType.File: + raise NotADirectoryError( + f"Cannot list directory '{full_path}'. " + f"Detail: [errno 20] Not a directory" + ) + return files + + def _resolve_path(self, path: str) -> str: + from pyarrow.fs import LocalFileSystem + + if isinstance(self.fs, LocalFileSystem): + joined = self.root.joinpath(path).resolve() + # Raises an error if the path is above the root (e.g., "../data" attack). + joined.relative_to(self.root.resolve()) + return str(joined) + + # In this case, we are not a local file system. However, pathlib would + # still add prefix to the path as if it is a local path when resolving + # the path, even when the path does not exist at all. If the path exists + # locally and is a symlink, then pathlib resolves it to the unwanted + # physical path. This could leak to an attack. Third, if the path was + # under Windows, "/" becomes "\", which is invalid for non-local stores. + # So we decide to resolve it mannually. + def _normalize_path(p: str) -> str: + # "////bucket//go/./foo///..//.././/bar/./" becomes "bucket/bar" + segments = [] + for s in p.replace("\\", "/").split("/"): + if s == "..": + if not segments: + raise ValueError("Path goes beyond root.") + segments.pop() + elif s not in (".", ""): + segments.append(s) + return "/".join(segments) + + root = _normalize_path(str(self.root)) + joined = _normalize_path(str(self.root.joinpath(path))) + if not joined.startswith(root): + raise ValueError(f"{joined!r} does not start with {root!r}") + return joined + + +def _init_storage(storage_uri: str, is_head: bool): + """Init global storage. + + On the head (ray start) process, this also creates a _valid file under the given + storage path to validate the storage is writable. This file is also checked on each + worker process to validate the storage is readable. This catches common errors + like using a non-NFS filesystem path on a multi-node cluster. + + On worker nodes, the actual filesystem is lazily initialized on first use. + """ + global _storage_uri + + if storage_uri: + _storage_uri = storage_uri + if is_head: + _init_filesystem(create_valid_file=True) + + +def _get_storage_uri() -> Optional[str]: + """Get storage API, if configured.""" + global _storage_uri + return _storage_uri + + +def _get_filesystem_internal() -> ("pyarrow.fs.FileSystem", str): + """Internal version of get_filesystem() that doesn't hit Ray client hooks. + + This forces full (non-lazy) init of the filesystem. + """ + global _filesystem, _storage_prefix + if _filesystem is None: + _init_filesystem() + return _filesystem, _storage_prefix + + +def _init_filesystem(create_valid_file: bool = False, check_valid_file: bool = True): + """Fully initialize the filesystem at the given storage URI.""" + global _filesystem, _storage_prefix, _storage_uri + assert _filesystem is None, "Init can only be called once." + + if not _storage_uri: + raise RuntimeError( + "No storage URI has been configured for the cluster. " + "Specify a storage URI via `ray.init(storage=)` or " + "`ray start --head --storage=`" + ) + + import pyarrow.fs + + # TODO(suquark): This is a temporary patch for windows - the backslash + # could not be understood by pyarrow. We replace it with slash here. + parsed_uri = urllib.parse.urlparse(_storage_uri.replace("\\", "/")) + if parsed_uri.scheme == "custom": + fs_creator = _load_class(parsed_uri.netloc) + _filesystem, _storage_prefix = fs_creator(parsed_uri.path) + else: + # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add a + # query arg enabling bucket creation if an S3 URI is provided. + _storage_uri = _add_creatable_buckets_param_if_s3_uri(_storage_uri) + _filesystem, _storage_prefix = pyarrow.fs.FileSystem.from_uri(_storage_uri) + + if os.name == "nt": + # Special care for windows. "//C/windows/system32" is a valid network + # name many applications support, but unfortunately not by pyarrow. + # This formats "//C/windows/system32" to "C:/windows/system32". + if re.match("^//[A-Za-z]/.*", _storage_prefix): + _storage_prefix = _storage_prefix[2] + ":" + _storage_prefix[4:] + + # enforce use of "/" + valid_file = _storage_prefix + "/_valid" + if create_valid_file: + _filesystem.create_dir(_storage_prefix) + with _filesystem.open_output_stream(valid_file): + pass + if check_valid_file: + valid = _filesystem.get_file_info([valid_file])[0] + if valid.type == pyarrow.fs.FileType.NotFound: + raise RuntimeError( + "Unable to initialize storage: {} file created during init not found. " + "Check that configured cluster storage path is readable from all " + "worker nodes of the cluster.".format(valid_file) + ) + + return _filesystem, _storage_prefix + + +def _reset() -> None: + """Resets all initialized state to None.""" + global _storage_uri, _filesystem, _storage_prefix + _storage_uri = _filesystem = _storage_prefix = None + + +# TODO(ekl): remove this indirection. +def _load_class(path): + return load_class(path) diff --git a/.venv/lib/python3.11/site-packages/ray/_private/test_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50812015c5b6cfa9649a65518380a5f8cce8a3c5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/test_utils.py @@ -0,0 +1,2312 @@ +import asyncio +from datetime import datetime +import inspect +import fnmatch +import functools +import io +import json +import logging +import math +import os +import pathlib +import random +import socket +import subprocess +import sys +import tempfile +import time +import timeit +import traceback +from collections import defaultdict +from contextlib import contextmanager, redirect_stderr, redirect_stdout +from typing import Any, Callable, Dict, List, Optional, Tuple +import uuid +from dataclasses import dataclass + +import requests +from ray._raylet import Config + +import psutil # We must import psutil after ray because we bundle it with ray. +from ray._private import ( + ray_constants, +) +from ray._private.worker import RayContext +import yaml + +import ray +import ray._private.gcs_utils as gcs_utils +import ray._private.memory_monitor as memory_monitor +import ray._private.services +import ray._private.utils +from ray._private.internal_api import memory_summary +from ray._private.tls_utils import generate_self_signed_tls_certs +from ray._raylet import GcsClientOptions, GlobalStateAccessor +from ray.core.generated import ( + gcs_pb2, + node_manager_pb2, + gcs_service_pb2, +) +from ray.util.queue import Empty, Queue, _QueueActor +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + +logger = logging.getLogger(__name__) + +EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" +RAY_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +REDIS_EXECUTABLE = os.path.join( + RAY_PATH, "core/src/ray/thirdparty/redis/src/redis-server" + EXE_SUFFIX +) + +try: + from prometheus_client.parser import text_string_to_metric_families, Sample +except (ImportError, ModuleNotFoundError): + + Sample = None + + def text_string_to_metric_families(*args, **kwargs): + raise ModuleNotFoundError("`prometheus_client` not found") + + +class RayTestTimeoutException(Exception): + """Exception used to identify timeouts from test utilities.""" + + pass + + +def make_global_state_accessor(ray_context): + gcs_options = GcsClientOptions.create( + ray_context.address_info["gcs_address"], + None, + allow_cluster_id_nil=True, + fetch_cluster_id_if_nil=False, + ) + global_state_accessor = GlobalStateAccessor(gcs_options) + global_state_accessor.connect() + return global_state_accessor + + +def enable_external_redis(): + import os + + return os.environ.get("TEST_EXTERNAL_REDIS") == "1" + + +def redis_replicas(): + import os + + return int(os.environ.get("TEST_EXTERNAL_REDIS_REPLICAS", "1")) + + +def redis_sentinel_replicas(): + import os + + return int(os.environ.get("TEST_EXTERNAL_REDIS_SENTINEL_REPLICAS", "2")) + + +def get_redis_cli(port, enable_tls): + try: + # If there is no redis libs installed, skip the check. + # This could happen In minimal test, where we don't have + # redis. + import redis + except Exception: + return True + + params = {} + if enable_tls: + from ray._raylet import Config + + params = {"ssl": True, "ssl_cert_reqs": "required"} + if Config.REDIS_CA_CERT(): + params["ssl_ca_certs"] = Config.REDIS_CA_CERT() + if Config.REDIS_CLIENT_CERT(): + params["ssl_certfile"] = Config.REDIS_CLIENT_CERT() + if Config.REDIS_CLIENT_KEY(): + params["ssl_keyfile"] = Config.REDIS_CLIENT_KEY() + + return redis.Redis("localhost", str(port), **params) + + +def start_redis_sentinel_instance( + session_dir_path: str, + port: int, + redis_master_port: int, + password: Optional[str] = None, + enable_tls: bool = False, + db_dir=None, + free_port=0, +): + config_file = os.path.join( + session_dir_path, "redis-sentinel-" + uuid.uuid4().hex + ".conf" + ) + config_lines = [] + # Port for this Sentinel instance + if enable_tls: + config_lines.append(f"port {free_port}") + else: + config_lines.append(f"port {port}") + + # Monitor the Redis master + config_lines.append(f"sentinel monitor redis-test 127.0.0.1 {redis_master_port} 1") + config_lines.append( + "sentinel down-after-milliseconds redis-test 1000" + ) # failover after 1 second + config_lines.append("sentinel failover-timeout redis-test 5000") # + config_lines.append("sentinel parallel-syncs redis-test 1") + + if password: + config_lines.append(f"sentinel auth-pass redis-test {password}") + + if enable_tls: + config_lines.append(f"tls-port {port}") + if Config.REDIS_CA_CERT(): + config_lines.append(f"tls-ca-cert-file {Config.REDIS_CA_CERT()}") + # Check and add TLS client certificate file + if Config.REDIS_CLIENT_CERT(): + config_lines.append(f"tls-cert-file {Config.REDIS_CLIENT_CERT()}") + # Check and add TLS client key file + if Config.REDIS_CLIENT_KEY(): + config_lines.append(f"tls-key-file {Config.REDIS_CLIENT_KEY()}") + config_lines.append("tls-auth-clients no") + config_lines.append("sentinel tls-auth-clients redis-test no") + if db_dir: + config_lines.append(f"dir {db_dir}") + + with open(config_file, "w") as f: + f.write("\n".join(config_lines)) + + command = [REDIS_EXECUTABLE, config_file, "--sentinel"] + process_info = ray._private.services.start_ray_process( + command, + ray_constants.PROCESS_TYPE_REDIS_SERVER, + fate_share=False, + ) + return process_info + + +def start_redis_instance( + session_dir_path: str, + port: int, + redis_max_clients: Optional[int] = None, + num_retries: int = 20, + stdout_file: Optional[str] = None, + stderr_file: Optional[str] = None, + password: Optional[str] = None, + redis_max_memory: Optional[int] = None, + fate_share: Optional[bool] = None, + port_denylist: Optional[List[int]] = None, + listen_to_localhost_only: bool = False, + enable_tls: bool = False, + replica_of=None, + leader_id=None, + db_dir=None, + free_port=0, +): + """Start a single Redis server. + + Notes: + We will initially try to start the Redis instance at the given port, + and then try at most `num_retries - 1` times to start the Redis + instance at successive random ports. + + Args: + session_dir_path: Path to the session directory of + this Ray cluster. + port: Try to start a Redis server at this port. + redis_max_clients: If this is provided, Ray will attempt to configure + Redis with this maxclients number. + num_retries: The number of times to attempt to start Redis at + successive ports. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + password: Prevents external clients without the password + from connecting to Redis if provided. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. + port_denylist: A set of denylist ports that shouldn't + be used when allocating a new port. + listen_to_localhost_only: Redis server only listens to + localhost (127.0.0.1) if it's true, + otherwise it listens to all network interfaces. + enable_tls: Enable the TLS/SSL in Redis or not + + Returns: + A tuple of the port used by Redis and ProcessInfo for the process that + was started. If a port is passed in, then the returned port value + is the same. + + Raises: + Exception: An exception is raised if Redis could not be started. + """ + + assert os.path.isfile(REDIS_EXECUTABLE) + + # Construct the command to start the Redis server. + command = [REDIS_EXECUTABLE] + if password: + if " " in password: + raise ValueError("Spaces not permitted in redis password.") + command += ["--requirepass", password] + if redis_replicas() > 1: + command += ["--cluster-enabled", "yes", "--cluster-config-file", f"node-{port}"] + if enable_tls: + command += [ + "--tls-port", + str(port), + "--loglevel", + "warning", + "--port", + str(free_port), + ] + else: + command += ["--port", str(port), "--loglevel", "warning"] + + if listen_to_localhost_only: + command += ["--bind", "127.0.0.1"] + pidfile = os.path.join(session_dir_path, "redis-" + uuid.uuid4().hex + ".pid") + command += ["--pidfile", pidfile] + if enable_tls: + if Config.REDIS_CA_CERT(): + command += ["--tls-ca-cert-file", Config.REDIS_CA_CERT()] + if Config.REDIS_CLIENT_CERT(): + command += ["--tls-cert-file", Config.REDIS_CLIENT_CERT()] + if Config.REDIS_CLIENT_KEY(): + command += ["--tls-key-file", Config.REDIS_CLIENT_KEY()] + if replica_of is not None: + command += ["--tls-replication", "yes"] + command += ["--tls-auth-clients", "no", "--tls-cluster", "yes"] + if sys.platform != "win32": + command += ["--save", "", "--appendonly", "no"] + if db_dir is not None: + command += ["--dir", str(db_dir)] + + process_info = ray._private.services.start_ray_process( + command, + ray_constants.PROCESS_TYPE_REDIS_SERVER, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share, + ) + node_id = None + if redis_replicas() > 1: + # Setup redis cluster + import redis + + while True: + try: + redis_cli = get_redis_cli(port, enable_tls) + if replica_of is None: + slots = [str(i) for i in range(16384)] + redis_cli.cluster("addslots", *slots) + else: + logger.info(redis_cli.cluster("meet", "127.0.0.1", str(replica_of))) + logger.info(redis_cli.cluster("replicate", leader_id)) + node_id = redis_cli.cluster("myid") + break + except ( + redis.exceptions.ConnectionError, + redis.exceptions.ResponseError, + ) as e: + from time import sleep + + logger.info( + f"Waiting for redis to be up. Check failed with error: {e}. " + "Will retry in 0.1s" + ) + + if process_info.process.poll() is not None: + raise Exception( + f"Redis process exited unexpectedly: {process_info}. " + f"Exit code: {process_info.process.returncode}" + ) + + sleep(0.1) + + logger.info( + f"Redis started with node_id {node_id} and pid {process_info.process.pid}" + ) + + return node_id, process_info + + +def _pid_alive(pid): + """Check if the process with this PID is alive or not. + + Args: + pid: The pid to check. + + Returns: + This returns false if the process is dead. Otherwise, it returns true. + """ + alive = True + try: + proc = psutil.Process(pid) + if proc.status() == psutil.STATUS_ZOMBIE: + alive = False + except psutil.NoSuchProcess: + alive = False + return alive + + +def check_call_module(main, argv, capture_stdout=False, capture_stderr=False): + # We use this function instead of calling the "ray" command to work around + # some deadlocks that occur when piping ray's output on Windows + stream = io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding) + old_argv = sys.argv[:] + try: + sys.argv = argv[:] + try: + with redirect_stderr(stream if capture_stderr else sys.stderr): + with redirect_stdout(stream if capture_stdout else sys.stdout): + main() + finally: + stream.flush() + except SystemExit as ex: + if ex.code: + output = stream.buffer.getvalue() + raise subprocess.CalledProcessError(ex.code, argv, output) + except Exception as ex: + output = stream.buffer.getvalue() + raise subprocess.CalledProcessError(1, argv, output, ex.args[0]) + finally: + sys.argv = old_argv + if capture_stdout: + sys.stdout.buffer.write(stream.buffer.getvalue()) + elif capture_stderr: + sys.stderr.buffer.write(stream.buffer.getvalue()) + return stream.buffer.getvalue() + + +def check_call_subprocess(argv, capture_stdout=False, capture_stderr=False): + # We use this function instead of calling the "ray" command to work around + # some deadlocks that occur when piping ray's output on Windows + from ray.scripts.scripts import main as ray_main + + if sys.platform == "win32": + result = check_call_module( + ray_main, argv, capture_stdout=capture_stdout, capture_stderr=capture_stderr + ) + else: + stdout_redir = None + stderr_redir = None + if capture_stdout: + stdout_redir = subprocess.PIPE + if capture_stderr and capture_stdout: + stderr_redir = subprocess.STDOUT + elif capture_stderr: + stderr_redir = subprocess.PIPE + proc = subprocess.Popen(argv, stdout=stdout_redir, stderr=stderr_redir) + (stdout, stderr) = proc.communicate() + if proc.returncode: + raise subprocess.CalledProcessError(proc.returncode, argv, stdout, stderr) + result = b"".join([s for s in [stdout, stderr] if s is not None]) + return result + + +def check_call_ray(args, capture_stdout=False, capture_stderr=False): + check_call_subprocess(["ray"] + args, capture_stdout, capture_stderr) + + +def wait_for_pid_to_exit(pid, timeout=20): + start_time = time.time() + while time.time() - start_time < timeout: + if not _pid_alive(pid): + return + time.sleep(0.1) + raise RayTestTimeoutException(f"Timed out while waiting for process {pid} to exit.") + + +def wait_for_children_names_of_pid(pid, children_names, timeout=20): + p = psutil.Process(pid) + start_time = time.time() + children_names = set(children_names) + not_found_children = [] + children = [] + while time.time() - start_time < timeout: + children = p.children(recursive=False) + not_found_children = set(children_names) - {c.name() for c in children} + if len(not_found_children) == 0: + return + time.sleep(0.1) + raise RayTestTimeoutException( + "Timed out while waiting for process {} children to start " + "({} not found from children {}).".format(pid, not_found_children, children) + ) + + +def wait_for_children_of_pid(pid, num_children=1, timeout=20): + p = psutil.Process(pid) + start_time = time.time() + alive = [] + while time.time() - start_time < timeout: + alive = p.children(recursive=False) + num_alive = len(alive) + if num_alive >= num_children: + return + time.sleep(0.1) + raise RayTestTimeoutException( + f"Timed out while waiting for process {pid} children to start " + f"({num_alive}/{num_children} started: {alive})." + ) + + +def wait_for_children_of_pid_to_exit(pid, timeout=20): + children = psutil.Process(pid).children() + if len(children) == 0: + return + + _, alive = psutil.wait_procs(children, timeout=timeout) + if len(alive) > 0: + raise RayTestTimeoutException( + "Timed out while waiting for process children to exit." + " Children still alive: {}.".format([p.name() for p in alive]) + ) + + +def kill_process_by_name(name, SIGKILL=False): + for p in psutil.process_iter(attrs=["name"]): + if p.info["name"] == name + ray._private.services.EXE_SUFFIX: + if SIGKILL: + p.kill() + else: + p.terminate() + + +def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "utf-8"): + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + env: The environment variables for the driver. + + Returns: + The script's output. + """ + proc = subprocess.Popen( + [sys.executable, "-"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + with proc: + output = proc.communicate(driver_script.encode(encoding=encode))[0] + if proc.returncode: + print(ray._private.utils.decode(output, encode_type=encode)) + logger.error(proc.stderr) + raise subprocess.CalledProcessError( + proc.returncode, proc.args, output, proc.stderr + ) + out = ray._private.utils.decode(output, encode_type=encode) + return out + + +def run_string_as_driver_stdout_stderr( + driver_script: str, env: Dict = None, encode: str = "utf-8" +) -> Tuple[str, str]: + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + env: The environment variables for the driver. + + Returns: + The script's stdout and stderr. + """ + proc = subprocess.Popen( + [sys.executable, "-"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + with proc: + outputs_bytes = proc.communicate(driver_script.encode(encoding=encode)) + out_str, err_str = [ + ray._private.utils.decode(output, encode_type=encode) + for output in outputs_bytes + ] + if proc.returncode: + print(out_str) + print(err_str) + raise subprocess.CalledProcessError( + proc.returncode, proc.args, out_str, err_str + ) + return out_str, err_str + + +def run_string_as_driver_nonblocking(driver_script, env: Dict = None): + """Start a driver as a separate process and return immediately. + + Args: + driver_script: A string to run as a Python script. + + Returns: + A handle to the driver process. + """ + script = "; ".join( + [ + "import sys", + "script = sys.stdin.read()", + "sys.stdin.close()", + "del sys", + 'exec("del script\\n" + script)', + ] + ) + proc = subprocess.Popen( + [sys.executable, "-c", script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + proc.stdin.write(driver_script.encode("ascii")) + proc.stdin.close() + return proc + + +def convert_actor_state(state): + if not state: + return None + return gcs_pb2.ActorTableData.ActorState.DESCRIPTOR.values_by_number[state].name + + +def wait_for_num_actors(num_actors, state=None, timeout=10): + state = convert_actor_state(state) + start_time = time.time() + while time.time() - start_time < timeout: + if ( + len( + [ + _ + for _ in ray._private.state.actors().values() + if state is None or _["State"] == state + ] + ) + >= num_actors + ): + return + time.sleep(0.1) + raise RayTestTimeoutException("Timed out while waiting for global state.") + + +def wait_for_num_nodes(num_nodes: int, timeout_s: int): + curr_nodes = 0 + start = time.time() + next_feedback = start + max_time = start + timeout_s + while not curr_nodes >= num_nodes: + now = time.time() + + if now >= max_time: + raise RuntimeError( + f"Maximum wait time reached, but only " + f"{curr_nodes}/{num_nodes} nodes came up. Aborting." + ) + + if now >= next_feedback: + passed = now - start + print( + f"Waiting for more nodes to come up: " + f"{curr_nodes}/{num_nodes} " + f"({passed:.0f} seconds passed)" + ) + next_feedback = now + 10 + + time.sleep(5) + curr_nodes = len(ray.nodes()) + + passed = time.time() - start + print( + f"Cluster is up: {curr_nodes}/{num_nodes} nodes online after " + f"{passed:.0f} seconds" + ) + + +def kill_actor_and_wait_for_failure(actor, timeout=10, retry_interval_ms=100): + actor_id = actor._actor_id.hex() + current_num_restarts = ray._private.state.actors(actor_id)["NumRestarts"] + ray.kill(actor) + start = time.time() + while time.time() - start <= timeout: + actor_status = ray._private.state.actors(actor_id) + if ( + actor_status["State"] == convert_actor_state(gcs_utils.ActorTableData.DEAD) + or actor_status["NumRestarts"] > current_num_restarts + ): + return + time.sleep(retry_interval_ms / 1000.0) + raise RuntimeError("It took too much time to kill an actor: {}".format(actor_id)) + + +def wait_for_condition( + condition_predictor, + timeout=10, + retry_interval_ms=100, + raise_exceptions=False, + **kwargs: Any, +): + """Wait until a condition is met or time out with an exception. + + Args: + condition_predictor: A function that predicts the condition. + timeout: Maximum timeout in seconds. + retry_interval_ms: Retry interval in milliseconds. + raise_exceptions: If true, exceptions that occur while executing + condition_predictor won't be caught and instead will be raised. + + Raises: + RuntimeError: If the condition is not met before the timeout expires. + """ + start = time.time() + last_ex = None + while time.time() - start <= timeout: + try: + if condition_predictor(**kwargs): + return + except Exception: + if raise_exceptions: + raise + last_ex = ray._private.utils.format_error_message(traceback.format_exc()) + time.sleep(retry_interval_ms / 1000.0) + message = "The condition wasn't met before the timeout expired." + if last_ex is not None: + message += f" Last exception: {last_ex}" + raise RuntimeError(message) + + +async def async_wait_for_condition( + condition_predictor, timeout=10, retry_interval_ms=100, **kwargs: Any +): + """Wait until a condition is met or time out with an exception. + + Args: + condition_predictor: A function that predicts the condition. + timeout: Maximum timeout in seconds. + retry_interval_ms: Retry interval in milliseconds. + + Raises: + RuntimeError: If the condition is not met before the timeout expires. + """ + start = time.time() + last_ex = None + while time.time() - start <= timeout: + try: + if inspect.iscoroutinefunction(condition_predictor): + if await condition_predictor(**kwargs): + return + else: + if condition_predictor(**kwargs): + return + except Exception as ex: + last_ex = ex + await asyncio.sleep(retry_interval_ms / 1000.0) + message = "The condition wasn't met before the timeout expired." + if last_ex is not None: + message += f" Last exception: {last_ex}" + raise RuntimeError(message) + + +async def async_wait_for_condition_async_predicate( + async_condition_predictor, timeout=10, retry_interval_ms=100, **kwargs: Any +): + """Wait until a condition is met or time out with an exception. + + Args: + condition_predictor: A function that predicts the condition. + timeout: Maximum timeout in seconds. + retry_interval_ms: Retry interval in milliseconds. + + Raises: + RuntimeError: If the condition is not met before the timeout expires. + """ + start = time.time() + last_ex = None + while time.time() - start <= timeout: + try: + if await async_condition_predictor(**kwargs): + return + except Exception as ex: + last_ex = ex + await asyncio.sleep(retry_interval_ms / 1000.0) + message = "The condition wasn't met before the timeout expired." + if last_ex is not None: + message += f" Last exception: {last_ex}" + raise RuntimeError(message) + + +@dataclass +class MetricSamplePattern: + name: Optional[str] = None + value: Optional[str] = None + partial_label_match: Optional[Dict[str, str]] = None + + def matches(self, sample: Sample): + if self.name is not None: + if self.name != sample.name: + return False + + if self.value is not None: + if self.value != sample.value: + return False + + if self.partial_label_match is not None: + for label, value in self.partial_label_match.items(): + if sample.labels.get(label) != value: + return False + + return True + + +def get_metric_check_condition( + metrics_to_check: List[MetricSamplePattern], export_addr: Optional[str] = None +) -> Callable[[], bool]: + """A condition to check if a prometheus metrics reach a certain value. + This is a blocking check that can be passed into a `wait_for_condition` + style function. + + Args: + metrics_to_check: A list of MetricSamplePattern. The fields that + aren't `None` will be matched. + + Returns: + A function that returns True if all the metrics are emitted. + + """ + node_info = ray.nodes()[0] + metrics_export_port = node_info["MetricsExportPort"] + addr = node_info["NodeManagerAddress"] + prom_addr = export_addr or f"{addr}:{metrics_export_port}" + + def f(): + for metric_pattern in metrics_to_check: + _, _, metric_samples = fetch_prometheus([prom_addr]) + for metric_sample in metric_samples: + if metric_pattern.matches(metric_sample): + break + else: + print( + f"Didn't find {metric_pattern}", + "all samples", + metric_samples, + ) + return False + return True + + return f + + +def wait_for_stdout(strings_to_match: List[str], timeout_s: int): + """Returns a decorator which waits until the stdout emitted + by a function contains the provided list of strings. + Raises an exception if the stdout doesn't have the expected output in time. + + Note: The decorated function should not block! + (It should return soon after being called.) + + Args: + strings_to_match: Wait until stdout contains all of these string. + timeout_s: Max time to wait, in seconds, before raising a RuntimeError. + """ + + def decorator(func): + @functools.wraps(func) + def decorated_func(*args, **kwargs): + success = False + try: + # Redirect stdout to an in-memory stream. + out_stream = io.StringIO() + sys.stdout = out_stream + # Execute the func. (Make sure the function doesn't block!) + out = func(*args, **kwargs) + # Check out_stream once a second until the timeout. + # Raise a RuntimeError if we timeout. + wait_for_condition( + # Does redirected stdout contain all of the expected strings? + lambda: all( + string in out_stream.getvalue() for string in strings_to_match + ), + timeout=timeout_s, + retry_interval_ms=1000, + ) + # out_stream has the expected strings + success = True + return out + # Exception raised on failure. + finally: + sys.stdout = sys.__stdout__ + if success: + print("Confirmed expected function stdout. Stdout follows:") + else: + print("Did not confirm expected function stdout. Stdout follows:") + print(out_stream.getvalue()) + out_stream.close() + + return decorated_func + + return decorator + + +def wait_until_succeeded_without_exception( + func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False +): + """A helper function that waits until a given function + completes without exceptions. + + Args: + func: A function to run. + exceptions: Exceptions that are supposed to occur. + args: arguments to pass for a given func + timeout_ms: Maximum timeout in milliseconds. + retry_interval_ms: Retry interval in milliseconds. + raise_last_ex: Raise the last exception when timeout. + + Return: + Whether exception occurs within a timeout. + """ + if isinstance(type(exceptions), tuple): + raise Exception("exceptions arguments should be given as a tuple") + + time_elapsed = 0 + start = time.time() + last_ex = None + while time_elapsed <= timeout_ms: + try: + func(*args) + return True + except exceptions as ex: + last_ex = ex + time_elapsed = (time.time() - start) * 1000 + time.sleep(retry_interval_ms / 1000.0) + if raise_last_ex: + ex_stack = ( + traceback.format_exception(type(last_ex), last_ex, last_ex.__traceback__) + if last_ex + else [] + ) + ex_stack = "".join(ex_stack) + raise Exception(f"Timed out while testing, {ex_stack}") + return False + + +def recursive_fnmatch(dirpath, pattern): + """Looks at a file directory subtree for a filename pattern. + + Similar to glob.glob(..., recursive=True) but also supports 2.7 + """ + matches = [] + for root, dirnames, filenames in os.walk(dirpath): + for filename in fnmatch.filter(filenames, pattern): + matches.append(os.path.join(root, filename)) + return matches + + +def generate_system_config_map(**kwargs): + ray_kwargs = { + "_system_config": kwargs, + } + return ray_kwargs + + +@ray.remote(num_cpus=0) +class SignalActor: + def __init__(self): + self.ready_event = asyncio.Event() + self.num_waiters = 0 + + def send(self, clear=False): + self.ready_event.set() + if clear: + self.ready_event.clear() + + async def wait(self, should_wait=True): + if should_wait: + self.num_waiters += 1 + await self.ready_event.wait() + self.num_waiters -= 1 + + async def cur_num_waiters(self): + return self.num_waiters + + +@ray.remote(num_cpus=0) +class Semaphore: + def __init__(self, value=1): + self._sema = asyncio.Semaphore(value=value) + + async def acquire(self): + await self._sema.acquire() + + async def release(self): + self._sema.release() + + async def locked(self): + return self._sema.locked() + + +def dicts_equal(dict1, dict2, abs_tol=1e-4): + """Compares to dicts whose values may be floating point numbers.""" + + if dict1.keys() != dict2.keys(): + return False + + for k, v in dict1.items(): + if ( + isinstance(v, float) + and isinstance(dict2[k], float) + and math.isclose(v, dict2[k], abs_tol=abs_tol) + ): + continue + if v != dict2[k]: + return False + return True + + +def same_elements(elems_a, elems_b): + """Checks if two iterables (such as lists) contain the same elements. Elements + do not have to be hashable (this allows us to compare sets of dicts for + example). This comparison is not necessarily efficient. + """ + a = list(elems_a) + b = list(elems_b) + + for x in a: + if x not in b: + return False + + for x in b: + if x not in a: + return False + + return True + + +@ray.remote +def _put(obj): + return obj + + +def put_object(obj, use_ray_put): + if use_ray_put: + return ray.put(obj) + else: + return _put.remote(obj) + + +def wait_until_server_available(address, timeout_ms=5000, retry_interval_ms=100): + ip_port = address.split(":") + ip = ip_port[0] + port = int(ip_port[1]) + time_elapsed = 0 + start = time.time() + while time_elapsed <= timeout_ms: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(1) + try: + s.connect((ip, port)) + except Exception: + time_elapsed = (time.time() - start) * 1000 + time.sleep(retry_interval_ms / 1000.0) + s.close() + continue + s.close() + return True + return False + + +def get_other_nodes(cluster, exclude_head=False): + """Get all nodes except the one that we're connected to.""" + return [ + node + for node in cluster.list_all_nodes() + if node._raylet_socket_name + != ray._private.worker._global_node._raylet_socket_name + and (exclude_head is False or node.head is False) + ] + + +def get_non_head_nodes(cluster): + """Get all non-head nodes.""" + return list(filter(lambda x: x.head is False, cluster.list_all_nodes())) + + +def init_error_pubsub(): + """Initialize error info pub/sub""" + s = ray._raylet.GcsErrorSubscriber( + address=ray._private.worker.global_worker.gcs_client.address + ) + s.subscribe() + return s + + +def get_error_message(subscriber, num=1e6, error_type=None, timeout=20): + """Gets errors from GCS subscriber. + + Returns maximum `num` error strings within `timeout`. + Only returns errors of `error_type` if specified. + """ + deadline = time.time() + timeout + msgs = [] + while time.time() < deadline and len(msgs) < num: + _, error_data = subscriber.poll(timeout=deadline - time.time()) + if not error_data: + # Timed out before any data is received. + break + if error_type is None or error_type == error_data["type"]: + msgs.append(error_data) + else: + time.sleep(0.01) + + return msgs + + +def init_log_pubsub(): + """Initialize log pub/sub""" + s = ray._raylet.GcsLogSubscriber( + address=ray._private.worker.global_worker.gcs_client.address + ) + s.subscribe() + return s + + +def get_log_data( + subscriber, + num: int = 1e6, + timeout: float = 20, + job_id: Optional[str] = None, + matcher=None, +) -> List[dict]: + deadline = time.time() + timeout + msgs = [] + while time.time() < deadline and len(msgs) < num: + logs_data = subscriber.poll(timeout=deadline - time.time()) + if not logs_data: + # Timed out before any data is received. + break + if job_id and job_id != logs_data["job"]: + continue + if matcher and all(not matcher(line) for line in logs_data["lines"]): + continue + msgs.append(logs_data) + return msgs + + +def get_log_message( + subscriber, + num: int = 1e6, + timeout: float = 20, + job_id: Optional[str] = None, + matcher=None, +) -> List[List[str]]: + """Gets log lines through GCS subscriber. + + Returns maximum `num` of log messages, within `timeout`. + + If `job_id` or `match` is specified, only returns log lines from `job_id` + or when `matcher` is true. + """ + msgs = get_log_data(subscriber, num, timeout, job_id, matcher) + return [msg["lines"] for msg in msgs] + + +def get_log_sources( + subscriber, + num: int = 1e6, + timeout: float = 20, + job_id: Optional[str] = None, + matcher=None, +): + """Get the source of all log messages""" + msgs = get_log_data(subscriber, num, timeout, job_id, matcher) + return {msg["pid"] for msg in msgs} + + +def get_log_batch( + subscriber, + num: int, + timeout: float = 20, + job_id: Optional[str] = None, + matcher=None, +) -> List[str]: + """Gets log batches through GCS subscriber. + + Returns maximum `num` batches of logs. Each batch is a dict that includes + metadata such as `pid`, `job_id`, and `lines` of log messages. + + If `job_id` or `match` is specified, only returns log batches from `job_id` + or when `matcher` is true. + """ + deadline = time.time() + timeout + batches = [] + while time.time() < deadline and len(batches) < num: + logs_data = subscriber.poll(timeout=deadline - time.time()) + if not logs_data: + # Timed out before any data is received. + break + if job_id and job_id != logs_data["job"]: + continue + if matcher and not matcher(logs_data): + continue + batches.append(logs_data) + + return batches + + +def format_web_url(url): + """Format web url.""" + url = url.replace("localhost", "http://127.0.0.1") + if not url.startswith("http://"): + return "http://" + url + return url + + +def client_test_enabled() -> bool: + return ray._private.client_mode_hook.is_client_mode_enabled + + +def object_memory_usage() -> bool: + """Returns the number of bytes used in the object store.""" + total = ray.cluster_resources().get("object_store_memory", 0) + avail = ray.available_resources().get("object_store_memory", 0) + return total - avail + + +def fetch_raw_prometheus(prom_addresses): + # Local import so minimal dependency tests can run without requests + import requests + + for address in prom_addresses: + try: + response = requests.get(f"http://{address}/metrics") + yield address, response.text + except requests.exceptions.ConnectionError: + continue + + +def fetch_prometheus(prom_addresses): + components_dict = {} + metric_descriptors = {} + metric_samples = [] + + for address in prom_addresses: + if address not in components_dict: + components_dict[address] = set() + + for address, response in fetch_raw_prometheus(prom_addresses): + for metric in text_string_to_metric_families(response): + for sample in metric.samples: + metric_descriptors[sample.name] = metric + metric_samples.append(sample) + if "Component" in sample.labels: + components_dict[address].add(sample.labels["Component"]) + return components_dict, metric_descriptors, metric_samples + + +def fetch_prometheus_metrics(prom_addresses: List[str]) -> Dict[str, List[Any]]: + """Return prometheus metrics from the given addresses. + + Args: + prom_addresses: List of metrics_agent addresses to collect metrics from. + + Returns: + Dict mapping from metric name to list of samples for the metric. + """ + _, _, samples = fetch_prometheus(prom_addresses) + samples_by_name = defaultdict(list) + for sample in samples: + samples_by_name[sample.name].append(sample) + return samples_by_name + + +def raw_metrics(info: RayContext) -> Dict[str, List[Any]]: + """Return prometheus metrics from a RayContext + + Args: + info: Ray context returned from ray.init() + + Returns: + Dict from metric name to a list of samples for the metrics + """ + metrics_page = "localhost:{}".format(info.address_info["metrics_export_port"]) + print("Fetch metrics from", metrics_page) + return fetch_prometheus_metrics([metrics_page]) + + +def get_test_config_path(config_file_name): + """Resolve the test config path from the config file dir""" + here = os.path.realpath(__file__) + path = pathlib.Path(here) + grandparent = path.parent.parent + return os.path.join(grandparent, "tests/test_cli_patterns", config_file_name) + + +def load_test_config(config_file_name): + """Loads a config yaml from tests/test_cli_patterns.""" + config_path = get_test_config_path(config_file_name) + config = yaml.safe_load(open(config_path).read()) + return config + + +def set_setup_func(): + import ray._private.runtime_env as runtime_env + + runtime_env.VAR = "hello world" + + +class BatchQueue(Queue): + def __init__(self, maxsize: int = 0, actor_options: Optional[Dict] = None) -> None: + actor_options = actor_options or {} + self.maxsize = maxsize + self.actor = ( + ray.remote(_BatchQueueActor).options(**actor_options).remote(self.maxsize) + ) + + def get_batch( + self, + batch_size: int = None, + total_timeout: Optional[float] = None, + first_timeout: Optional[float] = None, + ) -> List[Any]: + """Gets batch of items from the queue and returns them in a + list in order. + + Raises: + Empty: if the queue does not contain the desired number of items + """ + return ray.get( + self.actor.get_batch.remote(batch_size, total_timeout, first_timeout) + ) + + +class _BatchQueueActor(_QueueActor): + async def get_batch(self, batch_size=None, total_timeout=None, first_timeout=None): + start = timeit.default_timer() + try: + first = await asyncio.wait_for(self.queue.get(), first_timeout) + batch = [first] + if total_timeout: + end = timeit.default_timer() + total_timeout = max(total_timeout - (end - start), 0) + except asyncio.TimeoutError: + raise Empty + if batch_size is None: + if total_timeout is None: + total_timeout = 0 + while True: + try: + start = timeit.default_timer() + batch.append( + await asyncio.wait_for(self.queue.get(), total_timeout) + ) + if total_timeout: + end = timeit.default_timer() + total_timeout = max(total_timeout - (end - start), 0) + except asyncio.TimeoutError: + break + else: + for _ in range(batch_size - 1): + try: + start = timeit.default_timer() + batch.append( + await asyncio.wait_for(self.queue.get(), total_timeout) + ) + if total_timeout: + end = timeit.default_timer() + total_timeout = max(total_timeout - (end - start), 0) + except asyncio.TimeoutError: + break + return batch + + +def is_placement_group_removed(pg): + table = ray.util.placement_group_table(pg) + if "state" not in table: + return False + return table["state"] == "REMOVED" + + +def placement_group_assert_no_leak(pgs_created): + for pg in pgs_created: + ray.util.remove_placement_group(pg) + + def wait_for_pg_removed(): + for pg_entry in ray.util.placement_group_table().values(): + if pg_entry["state"] != "REMOVED": + return False + return True + + wait_for_condition(wait_for_pg_removed) + + cluster_resources = ray.cluster_resources() + cluster_resources.pop("memory") + cluster_resources.pop("object_store_memory") + + def wait_for_resource_recovered(): + for resource, val in ray.available_resources().items(): + if resource in cluster_resources and cluster_resources[resource] != val: + return False + if "_group_" in resource: + return False + return True + + wait_for_condition(wait_for_resource_recovered) + + +def monitor_memory_usage( + print_interval_s: int = 30, + record_interval_s: int = 5, + warning_threshold: float = 0.9, +): + """Run the memory monitor actor that prints the memory usage. + + The monitor will run on the same node as this function is called. + + Params: + interval_s: The interval memory usage information is printed + warning_threshold: The threshold where the + memory usage warning is printed. + + Returns: + The memory monitor actor. + """ + assert ray.is_initialized(), "The API is only available when Ray is initialized." + + @ray.remote(num_cpus=0) + class MemoryMonitorActor: + def __init__( + self, + print_interval_s: float = 20, + record_interval_s: float = 5, + warning_threshold: float = 0.9, + n: int = 10, + ): + """The actor that monitor the memory usage of the cluster. + + Params: + print_interval_s: The interval where + memory usage is printed. + record_interval_s: The interval where + memory usage is recorded. + warning_threshold: The threshold where + memory warning is printed + n: When memory usage is printed, + top n entries are printed. + """ + # -- Interval the monitor prints the memory usage information. -- + self.print_interval_s = print_interval_s + # -- Interval the monitor records the memory usage information. -- + self.record_interval_s = record_interval_s + # -- Whether or not the monitor is running. -- + self.is_running = False + # -- The used_gb/total_gb threshold where warning message omits. -- + self.warning_threshold = warning_threshold + # -- The monitor that calculates the memory usage of the node. -- + self.monitor = memory_monitor.MemoryMonitor() + # -- The top n memory usage of processes are printed. -- + self.n = n + # -- The peak memory usage in GB during lifetime of monitor. -- + self.peak_memory_usage = 0 + # -- The top n memory usage of processes + # during peak memory usage. -- + self.peak_top_n_memory_usage = "" + # -- The last time memory usage was printed -- + self._last_print_time = 0 + # -- logger. -- + logging.basicConfig(level=logging.INFO) + + def ready(self): + pass + + async def run(self): + """Run the monitor.""" + self.is_running = True + while self.is_running: + now = time.time() + used_gb, total_gb = self.monitor.get_memory_usage() + top_n_memory_usage = memory_monitor.get_top_n_memory_usage(n=self.n) + if used_gb > self.peak_memory_usage: + self.peak_memory_usage = used_gb + self.peak_top_n_memory_usage = top_n_memory_usage + + if used_gb > total_gb * self.warning_threshold: + logging.warning( + "The memory usage is high: " f"{used_gb / total_gb * 100}%" + ) + if now - self._last_print_time > self.print_interval_s: + logging.info(f"Memory usage: {used_gb} / {total_gb}") + logging.info(f"Top {self.n} process memory usage:") + logging.info(top_n_memory_usage) + self._last_print_time = now + await asyncio.sleep(self.record_interval_s) + + async def stop_run(self): + """Stop running the monitor. + + Returns: + True if the monitor is stopped. False otherwise. + """ + was_running = self.is_running + self.is_running = False + return was_running + + async def get_peak_memory_info(self): + """Return the tuple of the peak memory usage and the + top n process information during the peak memory usage. + """ + return self.peak_memory_usage, self.peak_top_n_memory_usage + + current_node_ip = ray._private.worker.global_worker.node_ip_address + # Schedule the actor on the current node. + memory_monitor_actor = MemoryMonitorActor.options( + resources={f"node:{current_node_ip}": 0.001} + ).remote( + print_interval_s=print_interval_s, + record_interval_s=record_interval_s, + warning_threshold=warning_threshold, + ) + print("Waiting for memory monitor actor to be ready...") + ray.get(memory_monitor_actor.ready.remote()) + print("Memory monitor actor is ready now.") + memory_monitor_actor.run.remote() + return memory_monitor_actor + + +def setup_tls(): + """Sets up required environment variables for tls""" + import pytest + + if sys.platform == "darwin": + pytest.skip("Cryptography doesn't install in Mac build pipeline") + cert, key = generate_self_signed_tls_certs() + temp_dir = tempfile.mkdtemp("ray-test-certs") + cert_filepath = os.path.join(temp_dir, "server.crt") + key_filepath = os.path.join(temp_dir, "server.key") + with open(cert_filepath, "w") as fh: + fh.write(cert) + with open(key_filepath, "w") as fh: + fh.write(key) + + os.environ["RAY_USE_TLS"] = "1" + os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath + os.environ["RAY_TLS_SERVER_KEY"] = key_filepath + os.environ["RAY_TLS_CA_CERT"] = cert_filepath + + return key_filepath, cert_filepath, temp_dir + + +def teardown_tls(key_filepath, cert_filepath, temp_dir): + os.remove(key_filepath) + os.remove(cert_filepath) + os.removedirs(temp_dir) + del os.environ["RAY_USE_TLS"] + del os.environ["RAY_TLS_SERVER_CERT"] + del os.environ["RAY_TLS_SERVER_KEY"] + del os.environ["RAY_TLS_CA_CERT"] + + +class ResourceKillerActor: + """Abstract base class used to implement resource killers for chaos testing. + + Subclasses should implement _find_resource_to_kill, which should find a resource + to kill. This method should return the args to _kill_resource, which is another + abstract method that should kill the resource and add it to the `killed` set. + """ + + def __init__( + self, + head_node_id, + kill_interval_s: float = 60, + max_to_kill: int = 2, + batch_size_to_kill: int = 1, + kill_filter_fn: Optional[Callable] = None, + ): + self.kill_interval_s = kill_interval_s + self.is_running = False + self.head_node_id = head_node_id + self.killed = set() + self.done = ray._private.utils.get_or_create_event_loop().create_future() + self.max_to_kill = max_to_kill + self.batch_size_to_kill = batch_size_to_kill + self.kill_filter_fn = kill_filter_fn + self.kill_immediately_after_found = False + # -- logger. -- + logging.basicConfig(level=logging.INFO) + + def ready(self): + pass + + async def run(self): + self.is_running = True + while self.is_running: + to_kills = await self._find_resources_to_kill() + + if not self.is_running: + break + + if self.kill_immediately_after_found: + sleep_interval = 0 + else: + sleep_interval = random.random() * self.kill_interval_s + time.sleep(sleep_interval) + + for to_kill in to_kills: + self._kill_resource(*to_kill) + if len(self.killed) >= self.max_to_kill: + break + await asyncio.sleep(self.kill_interval_s - sleep_interval) + + self.done.set_result(True) + + async def _find_resources_to_kill(self): + raise NotImplementedError + + def _kill_resource(self, *args): + raise NotImplementedError + + async def stop_run(self): + was_running = self.is_running + self.is_running = False + return was_running + + async def get_total_killed(self): + """Get the total number of killed resources""" + await self.done + return self.killed + + +class NodeKillerBase(ResourceKillerActor): + async def _find_resources_to_kill(self): + nodes_to_kill = [] + while not nodes_to_kill and self.is_running: + worker_nodes = [ + node + for node in ray.nodes() + if node["Alive"] + and (node["NodeID"] != self.head_node_id) + and (node["NodeID"] not in self.killed) + ] + if self.kill_filter_fn: + candidates = list(filter(self.kill_filter_fn(), worker_nodes)) + else: + candidates = worker_nodes + + # Ensure at least one worker node remains alive. + if len(worker_nodes) < self.batch_size_to_kill + 1: + # Give the cluster some time to start. + await asyncio.sleep(1) + continue + + # Collect nodes to kill, limited by batch size. + for candidate in candidates[: self.batch_size_to_kill]: + nodes_to_kill.append( + ( + candidate["NodeID"], + candidate["NodeManagerAddress"], + candidate["NodeManagerPort"], + ) + ) + + return nodes_to_kill + + +@ray.remote(num_cpus=0) +class RayletKiller(NodeKillerBase): + def _kill_resource(self, node_id, node_to_kill_ip, node_to_kill_port): + if node_to_kill_port is not None: + try: + self._kill_raylet(node_to_kill_ip, node_to_kill_port, graceful=False) + except Exception: + pass + logging.info( + f"Killed node {node_id} at address: " + f"{node_to_kill_ip}, port: {node_to_kill_port}" + ) + self.killed.add(node_id) + + def _kill_raylet(self, ip, port, graceful=False): + import grpc + from grpc._channel import _InactiveRpcError + from ray.core.generated import node_manager_pb2_grpc + + raylet_address = f"{ip}:{port}" + channel = grpc.insecure_channel(raylet_address) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + try: + stub.ShutdownRaylet( + node_manager_pb2.ShutdownRayletRequest(graceful=graceful) + ) + except _InactiveRpcError: + assert not graceful + + +@ray.remote(num_cpus=0) +class EC2InstanceTerminator(NodeKillerBase): + def _kill_resource(self, node_id, node_to_kill_ip, _): + if node_to_kill_ip is not None: + try: + self._terminate_ec2_instance(node_to_kill_ip) + except Exception: + pass + logging.info(f"Terminated instance, {node_id=}, address={node_to_kill_ip}") + self.killed.add(node_id) + + def _terminate_ec2_instance(self, ip): + # This command uses IMDSv2 to get the host instance id and region. + # After that it terminates itself using aws cli. + multi_line_command = ( + 'TOKEN=$(curl -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600");' # noqa: E501 + 'instanceId=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/instance-id/);' # noqa: E501 + 'region=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/placement/region);' # noqa: E501 + "aws ec2 terminate-instances --region $region --instance-ids $instanceId" # noqa: E501 + ) + # This is a feature on Anyscale platform that enables + # easy ssh access to worker nodes. + ssh_command = f"ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p 2222 ray@{ip} '{multi_line_command}'" # noqa: E501 + + result = subprocess.run( + ssh_command, shell=True, capture_output=True, text=True, check=True + ) + print(f"STDOUT:\n{result.stdout}\n") + print(f"STDERR:\n{result.stderr}\n") + + +@ray.remote(num_cpus=0) +class WorkerKillerActor(ResourceKillerActor): + def __init__( + self, + head_node_id, + kill_interval_s: float = 60, + max_to_kill: int = 2, + batch_size_to_kill: int = 1, + kill_filter_fn: Optional[Callable] = None, + ): + super().__init__( + head_node_id, + kill_interval_s, + max_to_kill, + batch_size_to_kill, + kill_filter_fn, + ) + + # Kill worker immediately so that the task does + # not finish successfully on its own. + self.kill_immediately_after_found = True + + from ray.util.state.common import ListApiOptions + from ray.util.state.api import StateApiClient + + self.client = StateApiClient() + self.task_options = ListApiOptions( + filters=[ + ("state", "=", "RUNNING"), + ("name", "!=", "WorkerKillActor.run"), + ] + ) + + async def _find_resources_to_kill(self): + from ray.util.state.common import StateResource + + process_to_kill_task_id = None + process_to_kill_pid = None + process_to_kill_node_id = None + while process_to_kill_pid is None and self.is_running: + tasks = self.client.list( + StateResource.TASKS, + options=self.task_options, + raise_on_missing_output=False, + ) + if self.kill_filter_fn is not None: + tasks = list(filter(self.kill_filter_fn(), tasks)) + + for task in tasks: + if task.worker_id is not None and task.node_id is not None: + process_to_kill_task_id = task.task_id + process_to_kill_pid = task.worker_pid + process_to_kill_node_id = task.node_id + break + + # Give the cluster some time to start. + await asyncio.sleep(0.1) + + return [(process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id)] + + def _kill_resource( + self, process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id + ): + if process_to_kill_pid is not None: + + @ray.remote + def kill_process(pid): + import psutil + + proc = psutil.Process(pid) + proc.kill() + + scheduling_strategy = ( + ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=process_to_kill_node_id, + soft=False, + ) + ) + kill_process.options(scheduling_strategy=scheduling_strategy).remote( + process_to_kill_pid + ) + logging.info( + f"Killing pid {process_to_kill_pid} on node {process_to_kill_node_id}" + ) + # Store both task_id and pid because retried tasks have same task_id. + self.killed.add((process_to_kill_task_id, process_to_kill_pid)) + + +def get_and_run_resource_killer( + resource_killer_cls, + kill_interval_s, + namespace=None, + lifetime=None, + no_start=False, + max_to_kill=2, + batch_size_to_kill=1, + kill_delay_s=0, + kill_filter_fn=None, +): + assert ray.is_initialized(), "The API is only available when Ray is initialized." + + head_node_id = ray.get_runtime_context().get_node_id() + # Schedule the actor on the current node. + resource_killer = resource_killer_cls.options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=head_node_id, soft=False + ), + namespace=namespace, + name="ResourceKiller", + lifetime=lifetime, + ).remote( + head_node_id, + kill_interval_s=kill_interval_s, + max_to_kill=max_to_kill, + batch_size_to_kill=batch_size_to_kill, + kill_filter_fn=kill_filter_fn, + ) + print("Waiting for ResourceKiller to be ready...") + ray.get(resource_killer.ready.remote()) + print("ResourceKiller is ready now.") + if not no_start: + time.sleep(kill_delay_s) + resource_killer.run.remote() + return resource_killer + + +def get_actor_node_id(actor_handle: "ray.actor.ActorHandle") -> str: + return ray.get( + actor_handle.__ray_call__.remote( + lambda self: ray.get_runtime_context().get_node_id() + ) + ) + + +@contextmanager +def chdir(d: str): + old_dir = os.getcwd() + os.chdir(d) + try: + yield + finally: + os.chdir(old_dir) + + +def test_get_directory_size_bytes(): + with tempfile.TemporaryDirectory() as tmp_dir, chdir(tmp_dir): + assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 0 + with open("test_file", "wb") as f: + f.write(os.urandom(100)) + assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 100 + with open("test_file_2", "wb") as f: + f.write(os.urandom(50)) + assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 150 + os.mkdir("subdir") + with open("subdir/subdir_file", "wb") as f: + f.write(os.urandom(2)) + assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 152 + + +def check_local_files_gced(cluster, whitelist=None): + for node in cluster.list_all_nodes(): + for subdir in ["conda", "pip", "working_dir_files", "py_modules_files"]: + all_files = os.listdir( + os.path.join(node.get_runtime_env_dir_path(), subdir) + ) + # Check that there are no files remaining except for .lock files + # and generated requirements.txt files. + # Note: On Windows the top folder is not deleted as it is in use. + # TODO(architkulkarni): these files should get cleaned up too! + items = list(filter(lambda f: not f.endswith((".lock", ".txt")), all_files)) + if whitelist and set(items).issubset(whitelist): + continue + if len(items) > 0: + return False + return True + + +def generate_runtime_env_dict(field, spec_format, tmp_path, pip_list=None): + if pip_list is None: + pip_list = ["pip-install-test==0.5"] + if field == "conda": + conda_dict = {"dependencies": ["pip", {"pip": pip_list}]} + if spec_format == "file": + conda_file = tmp_path / f"environment-{hash(str(pip_list))}.yml" + conda_file.write_text(yaml.dump(conda_dict)) + conda = str(conda_file) + elif spec_format == "python_object": + conda = conda_dict + runtime_env = {"conda": conda} + elif field == "pip": + if spec_format == "file": + pip_file = tmp_path / f"requirements-{hash(str(pip_list))}.txt" + pip_file.write_text("\n".join(pip_list)) + pip = str(pip_file) + elif spec_format == "python_object": + pip = pip_list + runtime_env = {"pip": pip} + return runtime_env + + +def check_spilled_mb(address, spilled=None, restored=None, fallback=None): + def ok(): + s = memory_summary(address=address["address"], stats_only=True) + print(s) + if restored: + if "Restored {} MiB".format(restored) not in s: + return False + else: + if "Restored" in s: + return False + if spilled: + if not isinstance(spilled, list): + spilled_lst = [spilled] + else: + spilled_lst = spilled + found = False + for n in spilled_lst: + if "Spilled {} MiB".format(n) in s: + found = True + if not found: + return False + else: + if "Spilled" in s: + return False + if fallback: + if "Plasma filesystem mmap usage: {} MiB".format(fallback) not in s: + return False + else: + if "Plasma filesystem mmap usage:" in s: + return False + return True + + wait_for_condition(ok, timeout=3, retry_interval_ms=1000) + + +def no_resource_leaks_excluding_node_resources(): + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() + for r in ray.cluster_resources(): + if "node" in r: + del cluster_resources[r] + del available_resources[r] + + return cluster_resources == available_resources + + +@contextmanager +def simulate_storage( + storage_type: str, + root: Optional[str] = None, + port: int = 5002, + region: str = "us-west-2", +): + """Context that simulates a given storage type and yields the URI. + + Args: + storage_type: The storage type to simiulate ("fs" or "s3") + root: Root directory of the URI to return (e.g., s3 bucket name) + port: The port of the localhost endpoint where s3 is being served (s3 only) + region: The s3 region (s3 only) + """ + if storage_type == "fs": + if root is None: + with tempfile.TemporaryDirectory() as d: + yield "file://" + d + else: + yield "file://" + root + elif storage_type == "s3": + from moto.server import ThreadedMotoServer + + old_env = os.environ + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + + root = root or uuid.uuid4().hex + s3_server = f"http://localhost:{port}" + server = ThreadedMotoServer(port=port) + server.start() + url = f"s3://{root}?region={region}&endpoint_override={s3_server}" + yield url + server.stop() + + os.environ = old_env + + else: + raise NotImplementedError(f"Unknown storage type: {storage_type}") + + +def job_hook(**kwargs): + """Function called by reflection by test_cli_integration.""" + cmd = " ".join(kwargs["entrypoint"]) + print(f"hook intercepted: {cmd}") + sys.exit(0) + + +def find_free_port(): + sock = socket.socket() + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def wandb_setup_api_key_hook(): + """ + Example external hook to set up W&B API key in + WandbIntegrationTest.testWandbLoggerConfig + """ + return "abcd" + + +# Get node stats from node manager. +def get_node_stats(raylet, num_retry=5, timeout=2): + import grpc + from ray.core.generated import node_manager_pb2_grpc + + raylet_address = f'{raylet["NodeManagerAddress"]}:{raylet["NodeManagerPort"]}' + channel = ray._private.utils.init_grpc_channel(raylet_address) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + for _ in range(num_retry): + try: + reply = stub.GetNodeStats( + node_manager_pb2.GetNodeStatsRequest(), timeout=timeout + ) + break + except grpc.RpcError: + continue + assert reply is not None + return reply + + +# Gets resource usage assuming gcs is local. +def get_resource_usage(gcs_address, timeout=10): + from ray.core.generated import gcs_service_pb2_grpc + + if not gcs_address: + gcs_address = ray.worker._global_node.gcs_address + + gcs_channel = ray._private.utils.init_grpc_channel( + gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False + ) + + gcs_node_resources_stub = gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub( + gcs_channel + ) + + request = gcs_service_pb2.GetAllResourceUsageRequest() + response = gcs_node_resources_stub.GetAllResourceUsage(request, timeout=timeout) + resources_batch_data = response.resource_usage_data + + return resources_batch_data + + +# Gets the load metrics report assuming gcs is local. +def get_load_metrics_report(webui_url): + webui_url = format_web_url(webui_url) + response = requests.get(f"{webui_url}/api/cluster_status") + response.raise_for_status() + return response.json()["data"]["clusterStatus"]["loadMetricsReport"] + + +# Send a RPC to the raylet to have it self-destruct its process. +def kill_raylet(raylet, graceful=False): + import grpc + from grpc._channel import _InactiveRpcError + from ray.core.generated import node_manager_pb2_grpc + + raylet_address = f'{raylet["NodeManagerAddress"]}:{raylet["NodeManagerPort"]}' + channel = grpc.insecure_channel(raylet_address) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + try: + stub.ShutdownRaylet(node_manager_pb2.ShutdownRayletRequest(graceful=graceful)) + except _InactiveRpcError: + assert not graceful + + +# Global counter to test different return values +# for external_ray_cluster_activity_hook1. +ray_cluster_activity_hook_counter = 0 +ray_cluster_activity_hook_5_counter = 0 + + +def external_ray_cluster_activity_hook1(): + """ + Example external hook for test_component_activities_hook. + + Returns valid response and increments counter in `reason` + field on each call. + """ + global ray_cluster_activity_hook_counter + ray_cluster_activity_hook_counter += 1 + + from pydantic import BaseModel, Extra + + class TestRayActivityResponse(BaseModel, extra=Extra.allow): + """ + Redefinition of dashboard.modules.snapshot.snapshot_head.RayActivityResponse + used in test_component_activities_hook to mimic typical + usage of redefining or extending response type. + """ + + is_active: str + reason: Optional[str] = None + timestamp: float + + return { + "test_component1": TestRayActivityResponse( + is_active="ACTIVE", + reason=f"Counter: {ray_cluster_activity_hook_counter}", + timestamp=datetime.now().timestamp(), + ) + } + + +def external_ray_cluster_activity_hook2(): + """ + Example external hook for test_component_activities_hook. + + Returns invalid output because the value of `test_component2` + should be of type RayActivityResponse. + """ + return {"test_component2": "bad_output"} + + +def external_ray_cluster_activity_hook3(): + """ + Example external hook for test_component_activities_hook. + + Returns invalid output because return type is not + Dict[str, RayActivityResponse] + """ + return "bad_output" + + +def external_ray_cluster_activity_hook4(): + """ + Example external hook for test_component_activities_hook. + + Errors during execution. + """ + raise Exception("Error in external cluster activity hook") + + +def external_ray_cluster_activity_hook5(): + """ + Example external hook for test_component_activities_hook. + + Returns valid response and increments counter in `reason` + field on each call. + """ + global ray_cluster_activity_hook_5_counter + ray_cluster_activity_hook_5_counter += 1 + return { + "test_component5": { + "is_active": "ACTIVE", + "reason": f"Counter: {ray_cluster_activity_hook_5_counter}", + "timestamp": datetime.now().timestamp(), + } + } + + +def get_gcs_memory_used(): + import psutil + + m = { + process.name(): process.memory_info().rss + for process in psutil.process_iter() + if ( + process.status() not in (psutil.STATUS_ZOMBIE, psutil.STATUS_DEAD) + and process.name() in ("gcs_server", "redis-server") + ) + } + assert "gcs_server" in m + return sum(m.values()) + + +def wandb_populate_run_location_hook(): + """ + Example external hook to populate W&B project and group env vars in + WandbIntegrationTest.testWandbLoggerConfig + """ + from ray.air.integrations.wandb import WANDB_GROUP_ENV_VAR, WANDB_PROJECT_ENV_VAR + + os.environ[WANDB_PROJECT_ENV_VAR] = "test_project" + os.environ[WANDB_GROUP_ENV_VAR] = "test_group" + + +def safe_write_to_results_json( + result: dict, + default_file_name: str = "/tmp/release_test_output.json", + env_var: Optional[str] = "TEST_OUTPUT_JSON", +): + """ + Safe (atomic) write to file to guard against malforming the json + if the job gets interrupted in the middle of writing. + """ + test_output_json = os.environ.get(env_var, default_file_name) + test_output_json_tmp = test_output_json + ".tmp" + with open(test_output_json_tmp, "wt") as f: + json.dump(result, f) + os.replace(test_output_json_tmp, test_output_json) + logger.info(f"Wrote results to {test_output_json}") + logger.info(json.dumps(result)) + + +def get_current_unused_port(): + """ + Returns a port number that is not currently in use. + + This is useful for testing when we need to bind to a port but don't + care which one. + + Returns: + A port number that is not currently in use. (Note that this port + might become used by the time you try to bind to it.) + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + # Bind the socket to a local address with a random port number + sock.bind(("localhost", 0)) + + port = sock.getsockname()[1] + sock.close() + return port + + +def search_words(string: str, words: str): + """Check whether each word is in the given string. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return [word in string for word in words.split(" ")] + + +def has_all_words(string: str, words: str): + """Check that string has all of the given words. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return all(search_words(string, words)) + + +def has_no_words(string, words): + """Check that string has none of the given words. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return not any(search_words(string, words)) + + +def find_available_port(start, end, port_num=1): + ports = [] + for _ in range(port_num): + random_port = 0 + with socket.socket() as s: + s.bind(("", 0)) + random_port = s.getsockname()[1] + if random_port >= start and random_port <= end and random_port not in ports: + ports.append(random_port) + continue + + for port in range(start, end + 1): + if port in ports: + continue + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + ports.append(port) + break + except OSError: + pass + + if len(ports) != port_num: + raise RuntimeError( + f"Can't find {port_num} available port from {start} to {end}." + ) + return ports + + +# TODO(rickyx): We could remove this once we unify the autoscaler v1 and v2 +# code path for ray status +def reset_autoscaler_v2_enabled_cache(): + import ray.autoscaler.v2.utils as u + + u.cached_is_autoscaler_v2 = None + + +def skip_flaky_core_test_premerge(reason: str): + """ + Decorator to skip a test if it is flaky (e.g. in premerge) + + Default we will skip the flaky test if not specified otherwise in + CI with CI_SKIP_FLAKY_TEST="0" + """ + import pytest + + def wrapper(func): + return pytest.mark.skipif( + os.environ.get("CI_SKIP_FLAKY_TEST", "1") == "1", reason=reason + )(func) + + return wrapper + + +def close_common_connections(pid): + """ + Closes ipv4 connections between the current process and another process specified by + its PID. + """ + current_process = psutil.Process() + current_connections = current_process.connections(kind="inet") + try: + other_process = psutil.Process(pid) + other_connections = other_process.connections(kind="inet") + except psutil.NoSuchProcess: + print(f"No process with PID {pid} found.") + return + # Finding common connections based on matching addresses and ports. + common_connections = [] + for conn1 in current_connections: + for conn2 in other_connections: + if conn1.laddr == conn2.raddr and conn1.raddr == conn2.laddr: + common_connections.append((conn1.fd, conn1.laddr, conn1.raddr)) + # Closing the FDs. + for fd, laddr, raddr in common_connections: + if fd != -1: # FD is -1 if it's not accessible or if it's a pseudo FD. + os.close(fd) + print(f"Closed FD: {fd}, laddr: {laddr}, raddr: {raddr}") diff --git a/.venv/lib/python3.11/site-packages/ray/_private/tls_utils.py b/.venv/lib/python3.11/site-packages/ray/_private/tls_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22b6f050ee604b977da184917e7b685145f69315 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/tls_utils.py @@ -0,0 +1,99 @@ +import datetime +import os +import socket + + +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda" + ) + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName( + [ + x509.DNSName( + socket.gethostbyname(socket.gethostname()) + ), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ] + ) + now = datetime.datetime.utcnow() + cert = ( + x509.CertificateBuilder() + .subject_name(ray_interal) + .issuer_name(ray_interal) + .add_extension(altnames, critical=False) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)) + .sign(key, hashes.SHA256(), default_backend()) + ) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents + + +def add_port_to_grpc_server(server, address): + import grpc + + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None, + ) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + + +def load_certs_from_env(): + tls_env_vars = ["RAY_TLS_SERVER_CERT", "RAY_TLS_SERVER_KEY", "RAY_TLS_CA_CERT"] + if any(v not in os.environ for v in tls_env_vars): + raise RuntimeError( + "If the environment variable RAY_USE_TLS is set to true " + "then RAY_TLS_SERVER_CERT, RAY_TLS_SERVER_KEY and " + "RAY_TLS_CA_CERT must also be set." + ) + + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + + return server_cert_chain, private_key, ca_cert diff --git a/.venv/lib/python3.11/site-packages/ray/_private/utils.py b/.venv/lib/python3.11/site-packages/ray/_private/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd23131bebdf5165330e17a3915d2335a8d0a8d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/utils.py @@ -0,0 +1,2100 @@ +import asyncio +import binascii +from collections import defaultdict +import contextlib +import errno +import functools +import importlib +import inspect +import json +import logging +import multiprocessing +import os +import platform +import re +import signal +import subprocess +import sys +import tempfile +import threading +import time +from urllib.parse import urlencode, unquote, urlparse, parse_qsl, urlunparse +import warnings +from inspect import signature +from pathlib import Path +from subprocess import list2cmdline +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Sequence, + Tuple, + Union, + Coroutine, + List, + Mapping, +) + +# Import psutil after ray so the packaged version is used. +import psutil +from google.protobuf import json_format + +import ray +import ray._private.ray_constants as ray_constants +from ray.core.generated.runtime_env_common_pb2 import ( + RuntimeEnvInfo as ProtoRuntimeEnvInfo, +) + +if TYPE_CHECKING: + from ray.runtime_env import RuntimeEnv + +pwd = None +if sys.platform != "win32": + import pwd + +logger = logging.getLogger(__name__) + +# Linux can bind child processes' lifetimes to that of their parents via prctl. +# prctl support is detected dynamically once, and assumed thereafter. +linux_prctl = None + +# Windows can bind processes' lifetimes to that of kernel-level "job objects". +# We keep a global job object to tie its lifetime to that of our own process. +win32_job = None +win32_AssignProcessToJobObject = None + +ENV_DISABLE_DOCKER_CPU_WARNING = "RAY_DISABLE_DOCKER_CPU_WARNING" in os.environ +_PYARROW_VERSION = None + +# This global variable is used for testing only +_CALLED_FREQ = defaultdict(lambda: 0) +_CALLED_FREQ_LOCK = threading.Lock() + +PLACEMENT_GROUP_INDEXED_BUNDLED_RESOURCE_PATTERN = re.compile( + r"(.+)_group_(\d+)_([0-9a-zA-Z]+)" +) +PLACEMENT_GROUP_WILDCARD_RESOURCE_PATTERN = re.compile(r"(.+)_group_([0-9a-zA-Z]+)") + + +def get_user_temp_dir(): + if "RAY_TMPDIR" in os.environ: + return os.environ["RAY_TMPDIR"] + elif sys.platform.startswith("linux") and "TMPDIR" in os.environ: + return os.environ["TMPDIR"] + elif sys.platform.startswith("darwin") or sys.platform.startswith("linux"): + # Ideally we wouldn't need this fallback, but keep it for now for + # for compatibility + tempdir = os.path.join(os.sep, "tmp") + else: + tempdir = tempfile.gettempdir() + return tempdir + + +def get_ray_temp_dir(): + return os.path.join(get_user_temp_dir(), "ray") + + +def get_ray_address_file(temp_dir: Optional[str]): + if temp_dir is None: + temp_dir = get_ray_temp_dir() + return os.path.join(temp_dir, "ray_current_cluster") + + +def write_ray_address(ray_address: str, temp_dir: Optional[str] = None): + address_file = get_ray_address_file(temp_dir) + if os.path.exists(address_file): + with open(address_file, "r") as f: + prev_address = f.read() + if prev_address == ray_address: + return + + logger.info( + f"Overwriting previous Ray address ({prev_address}). " + "Running ray.init() on this node will now connect to the new " + f"instance at {ray_address}. To override this behavior, pass " + f"address={prev_address} to ray.init()." + ) + + with open(address_file, "w+") as f: + f.write(ray_address) + + +def reset_ray_address(temp_dir: Optional[str] = None): + address_file = get_ray_address_file(temp_dir) + if os.path.exists(address_file): + try: + os.remove(address_file) + except OSError: + pass + + +def read_ray_address(temp_dir: Optional[str] = None) -> str: + address_file = get_ray_address_file(temp_dir) + if not os.path.exists(address_file): + return None + with open(address_file, "r") as f: + return f.read().strip() + + +def format_error_message(exception_message: str, task_exception: bool = False): + """Improve the formatting of an exception thrown by a remote function. + + This method takes a traceback from an exception and makes it nicer by + removing a few uninformative lines and adding some space to indent the + remaining lines nicely. + + Args: + exception_message: A message generated by traceback.format_exc(). + + Returns: + A string of the formatted exception message. + """ + lines = exception_message.split("\n") + if task_exception: + # For errors that occur inside of tasks, remove lines 1 and 2 which are + # always the same, they just contain information about the worker code. + lines = lines[0:1] + lines[3:] + pass + return "\n".join(lines) + + +def push_error_to_driver( + worker, error_type: str, message: str, job_id: Optional[str] = None +): + """Push an error message to the driver to be printed in the background. + + Args: + worker: The worker to use. + error_type: The type of the error. + message: The message that will be printed in the background + on the driver. + job_id: The ID of the driver to push the error message to. If this + is None, then the message will be pushed to all drivers. + """ + if job_id is None: + job_id = ray.JobID.nil() + assert isinstance(job_id, ray.JobID) + worker.core_worker.push_error(job_id, error_type, message, time.time()) + + +def publish_error_to_driver( + error_type: str, + message: str, + gcs_publisher, + job_id=None, + num_retries=None, +): + """Push an error message to the driver to be printed in the background. + + Normally the push_error_to_driver function should be used. However, in some + instances, the raylet client is not available, e.g., because the + error happens in Python before the driver or worker has connected to the + backend processes. + + Args: + error_type: The type of the error. + message: The message that will be printed in the background + on the driver. + gcs_publisher: The GCS publisher to use. + job_id: The ID of the driver to push the error message to. If this + is None, then the message will be pushed to all drivers. + """ + if job_id is None: + job_id = ray.JobID.nil() + assert isinstance(job_id, ray.JobID) + try: + gcs_publisher.publish_error( + job_id.hex().encode(), error_type, message, job_id, num_retries + ) + except Exception: + logger.exception(f"Failed to publish error: {message} [type {error_type}]") + + +def decode(byte_str: str, allow_none: bool = False, encode_type: str = "utf-8"): + """Make this unicode in Python 3, otherwise leave it as bytes. + + Args: + byte_str: The byte string to decode. + allow_none: If true, then we will allow byte_str to be None in which + case we will return an empty string. TODO(rkn): Remove this flag. + This is only here to simplify upgrading to flatbuffers 1.10.0. + + Returns: + A byte string in Python 2 and a unicode string in Python 3. + """ + if byte_str is None and allow_none: + return "" + + if not isinstance(byte_str, bytes): + raise ValueError(f"The argument {byte_str} must be a bytes object.") + return byte_str.decode(encode_type) + + +def ensure_str(s, encoding="utf-8", errors="strict"): + """Coerce *s* to `str`. + + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, str): + return s + else: + assert isinstance(s, bytes), f"Expected str or bytes, got {type(s)}" + return s.decode(encoding, errors) + + +def binary_to_object_ref(binary_object_ref): + return ray.ObjectRef(binary_object_ref) + + +def binary_to_task_id(binary_task_id): + return ray.TaskID(binary_task_id) + + +def binary_to_hex(identifier): + hex_identifier = binascii.hexlify(identifier) + hex_identifier = hex_identifier.decode() + return hex_identifier + + +def hex_to_binary(hex_identifier): + return binascii.unhexlify(hex_identifier) + + +# TODO(qwang): Remove these hepler functions +# once we separate `WorkerID` from `UniqueID`. +def compute_job_id_from_driver(driver_id): + assert isinstance(driver_id, ray.WorkerID) + return ray.JobID(driver_id.binary()[0 : ray.JobID.size()]) + + +def compute_driver_id_from_job(job_id): + assert isinstance(job_id, ray.JobID) + rest_length = ray_constants.ID_SIZE - job_id.size() + driver_id_str = job_id.binary() + (rest_length * b"\xff") + return ray.WorkerID(driver_id_str) + + +def get_visible_accelerator_ids() -> Mapping[str, Optional[List[str]]]: + """Get the mapping from accelerator resource name + to the visible ids.""" + + from ray._private.accelerators import ( + get_all_accelerator_resource_names, + get_accelerator_manager_for_resource, + ) + + return { + accelerator_resource_name: get_accelerator_manager_for_resource( + accelerator_resource_name + ).get_current_process_visible_accelerator_ids() + for accelerator_resource_name in get_all_accelerator_resource_names() + } + + +def set_omp_num_threads_if_unset() -> bool: + """Set the OMP_NUM_THREADS to default to num cpus assigned to the worker + + This function sets the environment variable OMP_NUM_THREADS for the worker, + if the env is not previously set and it's running in worker (WORKER_MODE). + + Returns True if OMP_NUM_THREADS is set in this function. + + """ + num_threads_from_env = os.environ.get("OMP_NUM_THREADS") + if num_threads_from_env is not None: + # No ops if it's set + return False + + # If unset, try setting the correct CPU count assigned. + runtime_ctx = ray.get_runtime_context() + if runtime_ctx.worker.mode != ray._private.worker.WORKER_MODE: + # Non worker mode, no ops. + return False + + num_assigned_cpus = runtime_ctx.get_assigned_resources().get("CPU") + + if num_assigned_cpus is None: + # This is an actor task w/o any num_cpus specified, set it to 1 + logger.debug( + "[ray] Forcing OMP_NUM_THREADS=1 to avoid performance " + "degradation with many workers (issue #6998). You can override this " + "by explicitly setting OMP_NUM_THREADS, or changing num_cpus." + ) + num_assigned_cpus = 1 + + import math + + # For num_cpu < 1: Set to 1. + # For num_cpus >= 1: Set to the floor of the actual assigned cpus. + omp_num_threads = max(math.floor(num_assigned_cpus), 1) + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + return True + + +def set_visible_accelerator_ids() -> None: + """Set (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, ROCR_VISIBLE_DEVICES, + NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS , HABANA_VISIBLE_MODULES ,...) + environment variables based on the accelerator runtime. + """ + for resource_name, accelerator_ids in ( + ray.get_runtime_context().get_accelerator_ids().items() + ): + ray._private.accelerators.get_accelerator_manager_for_resource( + resource_name + ).set_current_process_visible_accelerator_ids(accelerator_ids) + + +def resources_from_ray_options(options_dict: Dict[str, Any]) -> Dict[str, Any]: + """Determine a task's resource requirements. + + Args: + options_dict: The dictionary that contains resources requirements. + + Returns: + A dictionary of the resource requirements for the task. + """ + resources = (options_dict.get("resources") or {}).copy() + + if "CPU" in resources or "GPU" in resources: + raise ValueError( + "The resources dictionary must not contain the key 'CPU' or 'GPU'" + ) + elif "memory" in resources or "object_store_memory" in resources: + raise ValueError( + "The resources dictionary must not " + "contain the key 'memory' or 'object_store_memory'" + ) + elif ray_constants.PLACEMENT_GROUP_BUNDLE_RESOURCE_NAME in resources: + raise ValueError( + "The resource should not include `bundle` which " + f"is reserved for Ray. resources: {resources}" + ) + + num_cpus = options_dict.get("num_cpus") + num_gpus = options_dict.get("num_gpus") + memory = options_dict.get("memory") + object_store_memory = options_dict.get("object_store_memory") + accelerator_type = options_dict.get("accelerator_type") + + if num_cpus is not None: + resources["CPU"] = num_cpus + if num_gpus is not None: + resources["GPU"] = num_gpus + if memory is not None: + resources["memory"] = int(memory) + if object_store_memory is not None: + resources["object_store_memory"] = object_store_memory + if accelerator_type is not None: + resources[ + f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}{accelerator_type}" + ] = 0.001 + + return resources + + +class Unbuffered(object): + """There's no "built-in" solution to programatically disabling buffering of + text files. Ray expects stdout/err to be text files, so creating an + unbuffered binary file is unacceptable. + + See + https://mail.python.org/pipermail/tutor/2003-November/026645.html. + https://docs.python.org/3/library/functions.html#open + + """ + + def __init__(self, stream): + self.stream = stream + + def write(self, data): + self.stream.write(data) + self.stream.flush() + + def writelines(self, datas): + self.stream.writelines(datas) + self.stream.flush() + + def __getattr__(self, attr): + return getattr(self.stream, attr) + + +def open_log(path, unbuffered=False, **kwargs): + """ + Opens the log file at `path`, with the provided kwargs being given to + `open`. + """ + # Disable buffering, see test_advanced_3.py::test_logging_to_driver + kwargs.setdefault("buffering", 1) + kwargs.setdefault("mode", "a") + kwargs.setdefault("encoding", "utf-8") + stream = open(path, **kwargs) + if unbuffered: + return Unbuffered(stream) + else: + return stream + + +def get_system_memory( + # For cgroups v1: + memory_limit_filename="/sys/fs/cgroup/memory/memory.limit_in_bytes", + # For cgroups v2: + memory_limit_filename_v2="/sys/fs/cgroup/memory.max", +): + """Return the total amount of system memory in bytes. + + Returns: + The total amount of system memory in bytes. + """ + # Try to accurately figure out the memory limit if we are in a docker + # container. Note that this file is not specific to Docker and its value is + # often much larger than the actual amount of memory. + docker_limit = None + if os.path.exists(memory_limit_filename): + with open(memory_limit_filename, "r") as f: + docker_limit = int(f.read().strip()) + elif os.path.exists(memory_limit_filename_v2): + with open(memory_limit_filename_v2, "r") as f: + # Don't forget to strip() the newline: + max_file = f.read().strip() + if max_file.isnumeric(): + docker_limit = int(max_file) + else: + # max_file is "max", i.e. is unset. + docker_limit = None + + # Use psutil if it is available. + psutil_memory_in_bytes = psutil.virtual_memory().total + + if docker_limit is not None: + # We take the min because the cgroup limit is very large if we aren't + # in Docker. + return min(docker_limit, psutil_memory_in_bytes) + + return psutil_memory_in_bytes + + +def _get_docker_cpus( + cpu_quota_file_name="/sys/fs/cgroup/cpu/cpu.cfs_quota_us", + cpu_period_file_name="/sys/fs/cgroup/cpu/cpu.cfs_period_us", + cpuset_file_name="/sys/fs/cgroup/cpuset/cpuset.cpus", + cpu_max_file_name="/sys/fs/cgroup/cpu.max", +) -> Optional[float]: + # TODO (Alex): Don't implement this logic oursleves. + # Docker has 2 underyling ways of implementing CPU limits: + # https://docs.docker.com/config/containers/resource_constraints/#configure-the-default-cfs-scheduler + # 1. --cpuset-cpus 2. --cpus or --cpu-quota/--cpu-period (--cpu-shares is a + # soft limit so we don't worry about it). For Ray's purposes, if we use + # docker, the number of vCPUs on a machine is whichever is set (ties broken + # by smaller value). + + cpu_quota = None + # See: https://bugs.openjdk.java.net/browse/JDK-8146115 + if os.path.exists(cpu_quota_file_name) and os.path.exists(cpu_period_file_name): + try: + with open(cpu_quota_file_name, "r") as quota_file, open( + cpu_period_file_name, "r" + ) as period_file: + cpu_quota = float(quota_file.read()) / float(period_file.read()) + except Exception: + logger.exception("Unexpected error calculating docker cpu quota.") + # Look at cpu.max for cgroups v2 + elif os.path.exists(cpu_max_file_name): + try: + max_file = open(cpu_max_file_name).read() + quota_str, period_str = max_file.split() + if quota_str.isnumeric() and period_str.isnumeric(): + cpu_quota = float(quota_str) / float(period_str) + else: + # quota_str is "max" meaning the cpu quota is unset + cpu_quota = None + except Exception: + logger.exception("Unexpected error calculating docker cpu quota.") + if (cpu_quota is not None) and (cpu_quota < 0): + cpu_quota = None + elif cpu_quota == 0: + # Round up in case the cpu limit is less than 1. + cpu_quota = 1 + + cpuset_num = None + if os.path.exists(cpuset_file_name): + try: + with open(cpuset_file_name) as cpuset_file: + ranges_as_string = cpuset_file.read() + ranges = ranges_as_string.split(",") + cpu_ids = [] + for num_or_range in ranges: + if "-" in num_or_range: + start, end = num_or_range.split("-") + cpu_ids.extend(list(range(int(start), int(end) + 1))) + else: + cpu_ids.append(int(num_or_range)) + cpuset_num = len(cpu_ids) + except Exception: + logger.exception("Unexpected error calculating docker cpuset ids.") + # Possible to-do: Parse cgroups v2's cpuset.cpus.effective for the number + # of accessible CPUs. + + if cpu_quota and cpuset_num: + return min(cpu_quota, cpuset_num) + return cpu_quota or cpuset_num + + +def get_num_cpus( + override_docker_cpu_warning: bool = ENV_DISABLE_DOCKER_CPU_WARNING, +) -> int: + """ + Get the number of CPUs available on this node. + Depending on the situation, use multiprocessing.cpu_count() or cgroups. + + Args: + override_docker_cpu_warning: An extra flag to explicitly turn off the Docker + warning. Setting this flag True has the same effect as setting the env + RAY_DISABLE_DOCKER_CPU_WARNING. By default, whether or not to log + the warning is determined by the env variable + RAY_DISABLE_DOCKER_CPU_WARNING. + """ + cpu_count = multiprocessing.cpu_count() + if os.environ.get("RAY_USE_MULTIPROCESSING_CPU_COUNT"): + logger.info( + "Detected RAY_USE_MULTIPROCESSING_CPU_COUNT=1: Using " + "multiprocessing.cpu_count() to detect the number of CPUs. " + "This may be inconsistent when used inside docker. " + "To correctly detect CPUs, unset the env var: " + "`RAY_USE_MULTIPROCESSING_CPU_COUNT`." + ) + return cpu_count + try: + # Not easy to get cpu count in docker, see: + # https://bugs.python.org/issue36054 + docker_count = _get_docker_cpus() + if docker_count is not None and docker_count != cpu_count: + # Don't log this warning if we're on K8s or if the warning is + # explicitly disabled. + if ( + "KUBERNETES_SERVICE_HOST" not in os.environ + and not ENV_DISABLE_DOCKER_CPU_WARNING + and not override_docker_cpu_warning + ): + logger.warning( + "Detecting docker specified CPUs. In " + "previous versions of Ray, CPU detection in containers " + "was incorrect. Please ensure that Ray has enough CPUs " + "allocated. As a temporary workaround to revert to the " + "prior behavior, set " + "`RAY_USE_MULTIPROCESSING_CPU_COUNT=1` as an env var " + "before starting Ray. Set the env var: " + "`RAY_DISABLE_DOCKER_CPU_WARNING=1` to mute this warning." + ) + # TODO (Alex): We should probably add support for fractional cpus. + if int(docker_count) != float(docker_count): + logger.warning( + f"Ray currently does not support initializing Ray " + f"with fractional cpus. Your num_cpus will be " + f"truncated from {docker_count} to " + f"{int(docker_count)}." + ) + docker_count = int(docker_count) + cpu_count = docker_count + + except Exception: + # `nproc` and cgroup are linux-only. If docker only works on linux + # (will run in a linux VM on other platforms), so this is fine. + pass + + return cpu_count + + +# TODO(clarng): merge code with c++ +def get_cgroup_used_memory( + memory_stat_filename: str, + memory_usage_filename: str, + inactive_file_key: str, + active_file_key: str, +): + """ + The calculation logic is the same with `GetCGroupMemoryUsedBytes` + in `memory_monitor.cc` file. + """ + inactive_file_bytes = -1 + active_file_bytes = -1 + with open(memory_stat_filename, "r") as f: + lines = f.readlines() + for line in lines: + if f"{inactive_file_key} " in line: + inactive_file_bytes = int(line.split()[1]) + elif f"{active_file_key} " in line: + active_file_bytes = int(line.split()[1]) + + with open(memory_usage_filename, "r") as f: + lines = f.readlines() + cgroup_usage_in_bytes = int(lines[0].strip()) + + if ( + inactive_file_bytes == -1 + or cgroup_usage_in_bytes == -1 + or active_file_bytes == -1 + ): + return None + + return cgroup_usage_in_bytes - inactive_file_bytes - active_file_bytes + + +def get_used_memory(): + """Return the currently used system memory in bytes + + Returns: + The total amount of used memory + """ + # Try to accurately figure out the memory usage if we are in a docker + # container. + docker_usage = None + # For cgroups v1: + memory_usage_filename_v1 = "/sys/fs/cgroup/memory/memory.usage_in_bytes" + memory_stat_filename_v1 = "/sys/fs/cgroup/memory/memory.stat" + # For cgroups v2: + memory_usage_filename_v2 = "/sys/fs/cgroup/memory.current" + memory_stat_filename_v2 = "/sys/fs/cgroup/memory.stat" + if os.path.exists(memory_usage_filename_v1) and os.path.exists( + memory_stat_filename_v1 + ): + docker_usage = get_cgroup_used_memory( + memory_stat_filename_v1, + memory_usage_filename_v1, + "total_inactive_file", + "total_active_file", + ) + elif os.path.exists(memory_usage_filename_v2) and os.path.exists( + memory_stat_filename_v2 + ): + docker_usage = get_cgroup_used_memory( + memory_stat_filename_v2, + memory_usage_filename_v2, + "inactive_file", + "active_file", + ) + + if docker_usage is not None: + return docker_usage + return psutil.virtual_memory().used + + +def estimate_available_memory(): + """Return the currently available amount of system memory in bytes. + + Returns: + The total amount of available memory in bytes. Based on the used + and total memory. + + """ + return get_system_memory() - get_used_memory() + + +def get_shared_memory_bytes(): + """Get the size of the shared memory file system. + + Returns: + The size of the shared memory file system in bytes. + """ + # Make sure this is only called on Linux. + assert sys.platform == "linux" or sys.platform == "linux2" + + shm_fd = os.open("/dev/shm", os.O_RDONLY) + try: + shm_fs_stats = os.fstatvfs(shm_fd) + # The value shm_fs_stats.f_bsize is the block size and the + # value shm_fs_stats.f_bavail is the number of available + # blocks. + shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail + finally: + os.close(shm_fd) + + return shm_avail + + +def check_oversized_function( + pickled: bytes, name: str, obj_type: str, worker: "ray.Worker" +) -> None: + """Send a warning message if the pickled function is too large. + + Args: + pickled: the pickled function. + name: name of the pickled object. + obj_type: type of the pickled object, can be 'function', + 'remote function', or 'actor'. + worker: the worker used to send warning message. message will be logged + locally if None. + """ + length = len(pickled) + if length <= ray_constants.FUNCTION_SIZE_WARN_THRESHOLD: + return + elif length < ray_constants.FUNCTION_SIZE_ERROR_THRESHOLD: + warning_message = ( + "The {} {} is very large ({} MiB). " + "Check that its definition is not implicitly capturing a large " + "array or other object in scope. Tip: use ray.put() to put large " + "objects in the Ray object store." + ).format(obj_type, name, length // (1024 * 1024)) + if worker: + push_error_to_driver( + worker, + ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, + "Warning: " + warning_message, + job_id=worker.current_job_id, + ) + else: + error = ( + "The {} {} is too large ({} MiB > FUNCTION_SIZE_ERROR_THRESHOLD={}" + " MiB). Check that its definition is not implicitly capturing a " + "large array or other object in scope. Tip: use ray.put() to " + "put large objects in the Ray object store." + ).format( + obj_type, + name, + length // (1024 * 1024), + ray_constants.FUNCTION_SIZE_ERROR_THRESHOLD // (1024 * 1024), + ) + raise ValueError(error) + + +def is_main_thread(): + return threading.current_thread().getName() == "MainThread" + + +def detect_fate_sharing_support_win32(): + global win32_job, win32_AssignProcessToJobObject + if win32_job is None and sys.platform == "win32": + import ctypes + + try: + from ctypes.wintypes import BOOL, DWORD, HANDLE, LPCWSTR, LPVOID + + kernel32 = ctypes.WinDLL("kernel32") + kernel32.CreateJobObjectW.argtypes = (LPVOID, LPCWSTR) + kernel32.CreateJobObjectW.restype = HANDLE + sijo_argtypes = (HANDLE, ctypes.c_int, LPVOID, DWORD) + kernel32.SetInformationJobObject.argtypes = sijo_argtypes + kernel32.SetInformationJobObject.restype = BOOL + kernel32.AssignProcessToJobObject.argtypes = (HANDLE, HANDLE) + kernel32.AssignProcessToJobObject.restype = BOOL + kernel32.IsDebuggerPresent.argtypes = () + kernel32.IsDebuggerPresent.restype = BOOL + except (AttributeError, TypeError, ImportError): + kernel32 = None + job = kernel32.CreateJobObjectW(None, None) if kernel32 else None + job = subprocess.Handle(job) if job else job + if job: + from ctypes.wintypes import DWORD, LARGE_INTEGER, ULARGE_INTEGER + + class JOBOBJECT_BASIC_LIMIT_INFORMATION(ctypes.Structure): + _fields_ = [ + ("PerProcessUserTimeLimit", LARGE_INTEGER), + ("PerJobUserTimeLimit", LARGE_INTEGER), + ("LimitFlags", DWORD), + ("MinimumWorkingSetSize", ctypes.c_size_t), + ("MaximumWorkingSetSize", ctypes.c_size_t), + ("ActiveProcessLimit", DWORD), + ("Affinity", ctypes.c_size_t), + ("PriorityClass", DWORD), + ("SchedulingClass", DWORD), + ] + + class IO_COUNTERS(ctypes.Structure): + _fields_ = [ + ("ReadOperationCount", ULARGE_INTEGER), + ("WriteOperationCount", ULARGE_INTEGER), + ("OtherOperationCount", ULARGE_INTEGER), + ("ReadTransferCount", ULARGE_INTEGER), + ("WriteTransferCount", ULARGE_INTEGER), + ("OtherTransferCount", ULARGE_INTEGER), + ] + + class JOBOBJECT_EXTENDED_LIMIT_INFORMATION(ctypes.Structure): + _fields_ = [ + ("BasicLimitInformation", JOBOBJECT_BASIC_LIMIT_INFORMATION), + ("IoInfo", IO_COUNTERS), + ("ProcessMemoryLimit", ctypes.c_size_t), + ("JobMemoryLimit", ctypes.c_size_t), + ("PeakProcessMemoryUsed", ctypes.c_size_t), + ("PeakJobMemoryUsed", ctypes.c_size_t), + ] + + debug = kernel32.IsDebuggerPresent() + + # Defined in ; also available here: + # https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-setinformationjobobject + JobObjectExtendedLimitInformation = 9 + JOB_OBJECT_LIMIT_BREAKAWAY_OK = 0x00000800 + JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION = 0x00000400 + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000 + buf = JOBOBJECT_EXTENDED_LIMIT_INFORMATION() + buf.BasicLimitInformation.LimitFlags = ( + (0 if debug else JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE) + | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION + | JOB_OBJECT_LIMIT_BREAKAWAY_OK + ) + infoclass = JobObjectExtendedLimitInformation + if not kernel32.SetInformationJobObject( + job, infoclass, ctypes.byref(buf), ctypes.sizeof(buf) + ): + job = None + win32_AssignProcessToJobObject = ( + kernel32.AssignProcessToJobObject if kernel32 is not None else False + ) + win32_job = job if job else False + return bool(win32_job) + + +def detect_fate_sharing_support_linux(): + global linux_prctl + if linux_prctl is None and sys.platform.startswith("linux"): + try: + from ctypes import CDLL, c_int, c_ulong + + prctl = CDLL(None).prctl + prctl.restype = c_int + prctl.argtypes = [c_int, c_ulong, c_ulong, c_ulong, c_ulong] + except (AttributeError, TypeError): + prctl = None + linux_prctl = prctl if prctl else False + return bool(linux_prctl) + + +def detect_fate_sharing_support(): + result = None + if sys.platform == "win32": + result = detect_fate_sharing_support_win32() + elif sys.platform.startswith("linux"): + result = detect_fate_sharing_support_linux() + return result + + +def set_kill_on_parent_death_linux(): + """Ensures this process dies if its parent dies (fate-sharing). + + Linux-only. Must be called in preexec_fn (i.e. by the child). + """ + if detect_fate_sharing_support_linux(): + import signal + + PR_SET_PDEATHSIG = 1 + if linux_prctl(PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0) != 0: + import ctypes + + raise OSError(ctypes.get_errno(), "prctl(PR_SET_PDEATHSIG) failed") + else: + assert False, "PR_SET_PDEATHSIG used despite being unavailable" + + +def set_kill_child_on_death_win32(child_proc): + """Ensures the child process dies if this process dies (fate-sharing). + + Windows-only. Must be called by the parent, after spawning the child. + + Args: + child_proc: The subprocess.Popen or subprocess.Handle object. + """ + + if isinstance(child_proc, subprocess.Popen): + child_proc = child_proc._handle + assert isinstance(child_proc, subprocess.Handle) + + if detect_fate_sharing_support_win32(): + if not win32_AssignProcessToJobObject(win32_job, int(child_proc)): + import ctypes + + raise OSError(ctypes.get_last_error(), "AssignProcessToJobObject() failed") + else: + assert False, "AssignProcessToJobObject used despite being unavailable" + + +def set_sigterm_handler(sigterm_handler): + """Registers a handler for SIGTERM in a platform-compatible manner.""" + if sys.platform == "win32": + # Note that these signal handlers only work for console applications. + # TODO(mehrdadn): implement graceful process termination mechanism + # SIGINT is Ctrl+C, SIGBREAK is Ctrl+Break. + signal.signal(signal.SIGBREAK, sigterm_handler) + else: + signal.signal(signal.SIGTERM, sigterm_handler) + + +def try_make_directory_shared(directory_path): + try: + os.chmod(directory_path, 0o0777) + except OSError as e: + # Silently suppress the PermissionError that is thrown by the chmod. + # This is done because the user attempting to change the permissions + # on a directory may not own it. The chmod is attempted whether the + # directory is new or not to avoid race conditions. + # ray-project/ray/#3591 + if e.errno in [errno.EACCES, errno.EPERM]: + pass + else: + raise + + +def try_to_create_directory(directory_path): + """Attempt to create a directory that is globally readable/writable. + + Args: + directory_path: The path of the directory to create. + """ + directory_path = os.path.expanduser(directory_path) + os.makedirs(directory_path, exist_ok=True) + # Change the log directory permissions so others can use it. This is + # important when multiple people are using the same machine. + try_make_directory_shared(directory_path) + + +def try_to_symlink(symlink_path, target_path): + """Attempt to create a symlink. + + If the symlink path exists and isn't a symlink, the symlink will not be + created. If a symlink exists in the path, it will be attempted to be + removed and replaced. + + Args: + symlink_path: The path at which to create the symlink. + target_path: The path the symlink should point to. + """ + symlink_path = os.path.expanduser(symlink_path) + target_path = os.path.expanduser(target_path) + + if os.path.exists(symlink_path): + if os.path.islink(symlink_path): + # Try to remove existing symlink. + try: + os.remove(symlink_path) + except OSError: + return + else: + # There's an existing non-symlink file, don't overwrite it. + return + + try: + os.symlink(target_path, symlink_path) + except OSError: + return + + +def get_user(): + if pwd is None: + return "" + try: + return pwd.getpwuid(os.getuid()).pw_name + except Exception: + return "" + + +def get_function_args(callable): + all_parameters = frozenset(signature(callable).parameters) + return list(all_parameters) + + +def get_conda_bin_executable(executable_name): + """ + Return path to the specified executable, assumed to be discoverable within + the 'bin' subdirectory of a conda installation. Adapted from + https://github.com/mlflow/mlflow. + """ + + # Use CONDA_EXE as per https://github.com/conda/conda/issues/7126 + if "CONDA_EXE" in os.environ: + conda_bin_dir = os.path.dirname(os.environ["CONDA_EXE"]) + return os.path.join(conda_bin_dir, executable_name) + return executable_name + + +def get_conda_env_dir(env_name): + """Find and validate the conda directory for a given conda environment. + + For example, given the environment name `tf1`, this function checks + the existence of the corresponding conda directory, e.g. + `/Users/scaly/anaconda3/envs/tf1`, and returns it. + """ + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix is None: + # The caller is neither in a conda env or in (base) env. This is rare + # because by default, new terminals start in (base), but we can still + # support this case. + conda_exe = os.environ.get("CONDA_EXE") + if conda_exe is None: + raise ValueError( + "Cannot find environment variables set by conda. " + "Please verify conda is installed." + ) + # Example: CONDA_EXE=$HOME/anaconda3/bin/python + # Strip out /bin/python by going up two parent directories. + conda_prefix = str(Path(conda_exe).parent.parent) + + # There are two cases: + # 1. We are in a conda (base) env: CONDA_DEFAULT_ENV=base and + # CONDA_PREFIX=$HOME/anaconda3 + # 2. We are in a user-created conda env: CONDA_DEFAULT_ENV=$env_name and + # CONDA_PREFIX=$HOME/anaconda3/envs/$current_env_name + if os.environ.get("CONDA_DEFAULT_ENV") == "base": + # Caller's curent environment is (base). + # Not recommended by conda, but we can still support it. + if env_name == "base": + # Desired environment is (base), located at e.g. $HOME/anaconda3 + env_dir = conda_prefix + else: + # Desired environment is user-created, e.g. + # $HOME/anaconda3/envs/$env_name + env_dir = os.path.join(conda_prefix, "envs", env_name) + else: + # Now `conda_prefix` should be something like + # $HOME/anaconda3/envs/$current_env_name + # We want to replace the last component with the desired env name. + conda_envs_dir = os.path.split(conda_prefix)[0] + env_dir = os.path.join(conda_envs_dir, env_name) + if not os.path.isdir(env_dir): + raise ValueError( + "conda env " + + env_name + + " not found in conda envs directory. Run `conda env list` to " + + "verify the name is correct." + ) + return env_dir + + +def get_call_location(back: int = 1): + """ + Get the location (filename and line number) of a function caller, `back` + frames up the stack. + + Args: + back: The number of frames to go up the stack, not including this + function. + """ + stack = inspect.stack() + try: + frame = stack[back + 1] + return f"{frame.filename}:{frame.lineno}" + except IndexError: + return "UNKNOWN" + + +def get_ray_doc_version(): + """Get the docs.ray.io version corresponding to the ray.__version__.""" + # The ray.__version__ can be official Ray release (such as 1.12.0), or + # dev (3.0.0dev0) or release candidate (2.0.0rc0). For the later we map + # to the master doc version at docs.ray.io. + if re.match(r"^\d+\.\d+\.\d+$", ray.__version__) is None: + return "master" + # For the former (official Ray release), we have corresponding doc version + # released as well. + return f"releases-{ray.__version__}" + + +# Used to only print a deprecation warning once for a given function if we +# don't wish to spam the caller. +_PRINTED_WARNING = set() + + +# The following is inspired by +# https://github.com/tensorflow/tensorflow/blob/dec8e0b11f4f87693b67e125e67dfbc68d26c205/tensorflow/python/util/deprecation.py#L274-L329 +def deprecated( + instructions: Optional[str] = None, + removal_release: Optional[str] = None, + removal_date: Optional[str] = None, + warn_once: bool = True, + stacklevel=2, +): + """ + Creates a decorator for marking functions as deprecated. The decorator + will log a deprecation warning on the first (or all, see `warn_once` arg) + invocations, and will otherwise leave the wrapped function unchanged. + + Args: + instructions: Instructions for the caller to update their code. + removal_release: The release in which this deprecated function + will be removed. Only one of removal_release and removal_date + should be specified. If neither is specfieid, we'll warning that + the function will be removed "in a future release". + removal_date: The date on which this deprecated function will be + removed. Only one of removal_release and removal_date should be + specified. If neither is specfieid, we'll warning that + the function will be removed "in a future release". + warn_once: If true, the deprecation warning will only be logged + on the first invocation. Otherwise, the deprecation warning will + be logged on every invocation. Defaults to True. + stacklevel: adjust the warnings stacklevel to trace the source call + + Returns: + A decorator to be used for wrapping deprecated functions. + """ + if removal_release is not None and removal_date is not None: + raise ValueError( + "Only one of removal_release and removal_date should be specified." + ) + + def deprecated_wrapper(func): + @functools.wraps(func) + def new_func(*args, **kwargs): + global _PRINTED_WARNING + if func not in _PRINTED_WARNING: + if warn_once: + _PRINTED_WARNING.add(func) + msg = ( + "From {}: {} (from {}) is deprecated and will ".format( + get_call_location(), func.__name__, func.__module__ + ) + + "be removed " + + ( + f"in version {removal_release}." + if removal_release is not None + else f"after {removal_date}" + if removal_date is not None + else "in a future version" + ) + + (f" {instructions}" if instructions is not None else "") + ) + warnings.warn(msg, stacklevel=stacklevel) + return func(*args, **kwargs) + + return new_func + + return deprecated_wrapper + + +def import_attr(full_path: str, *, reload_module: bool = False): + """Given a full import path to a module attr, return the imported attr. + + If `reload_module` is set, the module will be reloaded using `importlib.reload`. + + For example, the following are equivalent: + MyClass = import_attr("module.submodule:MyClass") + MyClass = import_attr("module.submodule.MyClass") + from module.submodule import MyClass + + Returns: + Imported attr + """ + if full_path is None: + raise TypeError("import path cannot be None") + + if ":" in full_path: + if full_path.count(":") > 1: + raise ValueError( + f'Got invalid import path "{full_path}". An ' + "import path may have at most one colon." + ) + module_name, attr_name = full_path.split(":") + else: + last_period_idx = full_path.rfind(".") + module_name = full_path[:last_period_idx] + attr_name = full_path[last_period_idx + 1 :] + + module = importlib.import_module(module_name) + if reload_module: + importlib.reload(module) + return getattr(module, attr_name) + + +def get_wheel_filename( + sys_platform: str = sys.platform, + ray_version: str = ray.__version__, + py_version: Tuple[int, int] = (sys.version_info.major, sys.version_info.minor), + architecture: Optional[str] = None, +) -> str: + """Returns the filename used for the nightly Ray wheel. + + Args: + sys_platform: The platform as returned by sys.platform. Examples: + "darwin", "linux", "win32" + ray_version: The Ray version as returned by ray.__version__ or + `ray --version`. Examples: "3.0.0.dev0" + py_version: The Python version as returned by sys.version_info. A + tuple of (major, minor). Examples: (3, 8) + architecture: Architecture, e.g. ``x86_64`` or ``aarch64``. If None, will + be determined by calling ``platform.processor()``. + + Returns: + The wheel file name. Examples: + ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl + """ + assert py_version in ray_constants.RUNTIME_ENV_CONDA_PY_VERSIONS, py_version + + py_version_str = "".join(map(str, py_version)) + + architecture = architecture or platform.processor() + + if py_version_str in ["311", "310", "39", "38"] and architecture == "arm64": + darwin_os_string = "macosx_11_0_arm64" + else: + darwin_os_string = "macosx_10_15_x86_64" + + if architecture == "aarch64": + linux_os_string = "manylinux2014_aarch64" + else: + linux_os_string = "manylinux2014_x86_64" + + os_strings = { + "darwin": darwin_os_string, + "linux": linux_os_string, + "win32": "win_amd64", + } + + assert sys_platform in os_strings, sys_platform + + wheel_filename = ( + f"ray-{ray_version}-cp{py_version_str}-" + f"cp{py_version_str}{'m' if py_version_str in ['37'] else ''}" + f"-{os_strings[sys_platform]}.whl" + ) + + return wheel_filename + + +def get_master_wheel_url( + ray_commit: str = ray.__commit__, + sys_platform: str = sys.platform, + ray_version: str = ray.__version__, + py_version: Tuple[int, int] = sys.version_info[:2], +) -> str: + """Return the URL for the wheel from a specific commit.""" + filename = get_wheel_filename( + sys_platform=sys_platform, ray_version=ray_version, py_version=py_version + ) + return ( + f"https://s3-us-west-2.amazonaws.com/ray-wheels/master/" + f"{ray_commit}/{filename}" + ) + + +def get_release_wheel_url( + ray_commit: str = ray.__commit__, + sys_platform: str = sys.platform, + ray_version: str = ray.__version__, + py_version: Tuple[int, int] = sys.version_info[:2], +) -> str: + """Return the URL for the wheel for a specific release.""" + filename = get_wheel_filename( + sys_platform=sys_platform, ray_version=ray_version, py_version=py_version + ) + return ( + f"https://ray-wheels.s3-us-west-2.amazonaws.com/releases/" + f"{ray_version}/{ray_commit}/{filename}" + ) + # e.g. https://ray-wheels.s3-us-west-2.amazonaws.com/releases/1.4.0rc1/e7c7 + # f6371a69eb727fa469e4cd6f4fbefd143b4c/ray-1.4.0rc1-cp36-cp36m-manylinux201 + # 4_x86_64.whl + + +def validate_namespace(namespace: str): + if not isinstance(namespace, str): + raise TypeError("namespace must be None or a string.") + elif namespace == "": + raise ValueError( + '"" is not a valid namespace. ' "Pass None to not specify a namespace." + ) + + +def init_grpc_channel( + address: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + asynchronous: bool = False, +): + import grpc + + try: + from grpc import aio as aiogrpc + except ImportError: + from grpc.experimental import aio as aiogrpc + + from ray._private.tls_utils import load_certs_from_env + + grpc_module = aiogrpc if asynchronous else grpc + + options = options or [] + options_dict = dict(options) + options_dict["grpc.keepalive_time_ms"] = options_dict.get( + "grpc.keepalive_time_ms", ray._config.grpc_client_keepalive_time_ms() + ) + options_dict["grpc.keepalive_timeout_ms"] = options_dict.get( + "grpc.keepalive_timeout_ms", ray._config.grpc_client_keepalive_timeout_ms() + ) + options = options_dict.items() + + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + channel = grpc_module.secure_channel(address, credentials, options=options) + else: + channel = grpc_module.insecure_channel(address, options=options) + + return channel + + +def check_dashboard_dependencies_installed() -> bool: + """Returns True if Ray Dashboard dependencies are installed. + + Checks to see if we should start the dashboard agent or not based on the + Ray installation version the user has installed (ray vs. ray[default]). + Unfortunately there doesn't seem to be a cleaner way to detect this other + than just blindly importing the relevant packages. + + """ + try: + import ray.dashboard.optional_deps # noqa: F401 + + return True + except ImportError: + return False + + +def check_ray_client_dependencies_installed() -> bool: + """Returns True if Ray Client dependencies are installed. + + See documents for check_dashboard_dependencies_installed. + """ + try: + import grpc # noqa: F401 + + return True + except ImportError: + return False + + +connect_error = ( + "Unable to connect to GCS (ray head) at {}. " + "Check that (1) Ray with matching version started " + "successfully at the specified address, (2) this " + "node can reach the specified address, and (3) there is " + "no firewall setting preventing access." +) + + +def internal_kv_list_with_retry(gcs_client, prefix, namespace, num_retries=20): + result = None + if isinstance(prefix, str): + prefix = prefix.encode() + if isinstance(namespace, str): + namespace = namespace.encode() + for _ in range(num_retries): + try: + result = gcs_client.internal_kv_keys(prefix, namespace) + except Exception as e: + if isinstance(e, ray.exceptions.RpcError) and e.rpc_code in ( + ray._raylet.GRPC_STATUS_CODE_UNAVAILABLE, + ray._raylet.GRPC_STATUS_CODE_UNKNOWN, + ): + logger.warning(connect_error.format(gcs_client.address)) + else: + logger.exception("Internal KV List failed") + result = None + + if result is not None: + break + else: + logger.debug(f"Fetched {prefix}=None from KV. Retrying.") + time.sleep(2) + if result is None: + raise ConnectionError( + f"Could not list '{prefix}' from GCS. Did GCS start successfully?" + ) + return result + + +def internal_kv_get_with_retry(gcs_client, key, namespace, num_retries=20): + result = None + if isinstance(key, str): + key = key.encode() + for _ in range(num_retries): + try: + result = gcs_client.internal_kv_get(key, namespace) + except Exception as e: + if isinstance(e, ray.exceptions.RpcError) and e.rpc_code in ( + ray._raylet.GRPC_STATUS_CODE_UNAVAILABLE, + ray._raylet.GRPC_STATUS_CODE_UNKNOWN, + ): + logger.warning(connect_error.format(gcs_client.address)) + else: + logger.exception("Internal KV Get failed") + result = None + + if result is not None: + break + else: + logger.debug(f"Fetched {key}=None from KV. Retrying.") + time.sleep(2) + if not result: + raise ConnectionError( + f"Could not read '{key.decode()}' from GCS. Did GCS start successfully?" + ) + return result + + +def parse_resources_json( + resources: str, cli_logger, cf, command_arg="--resources" +) -> Dict[str, float]: + try: + resources = json.loads(resources) + if not isinstance(resources, dict): + raise ValueError("The format after deserialization is not a dict") + except Exception as e: + cli_logger.error( + "`{}` is not a valid JSON string, detail error:{}", + cf.bold(f"{command_arg}={resources}"), + str(e), + ) + cli_logger.abort( + "Valid values look like this: `{}`", + cf.bold( + f'{command_arg}=\'{{"CustomResource3": 1, "CustomResource2": 2}}\'' + ), + ) + return resources + + +def parse_metadata_json( + metadata: str, cli_logger, cf, command_arg="--metadata-json" +) -> Dict[str, str]: + try: + metadata = json.loads(metadata) + if not isinstance(metadata, dict): + raise ValueError("The format after deserialization is not a dict") + except Exception as e: + cli_logger.error( + "`{}` is not a valid JSON string, detail error:{}", + cf.bold(f"{command_arg}={metadata}"), + str(e), + ) + cli_logger.abort( + "Valid values look like this: `{}`", + cf.bold(f'{command_arg}=\'{{"key1": "value1", "key2": "value2"}}\''), + ) + return metadata + + +def internal_kv_put_with_retry(gcs_client, key, value, namespace, num_retries=20): + if isinstance(key, str): + key = key.encode() + if isinstance(value, str): + value = value.encode() + if isinstance(namespace, str): + namespace = namespace.encode() + error = None + for _ in range(num_retries): + try: + return gcs_client.internal_kv_put( + key, value, overwrite=True, namespace=namespace + ) + except ray.exceptions.RpcError as e: + if e.rpc_code in ( + ray._raylet.GRPC_STATUS_CODE_UNAVAILABLE, + ray._raylet.GRPC_STATUS_CODE_UNKNOWN, + ): + logger.warning(connect_error.format(gcs_client.address)) + else: + logger.exception("Internal KV Put failed") + time.sleep(2) + error = e + # Reraise the last error. + raise error + + +def compute_version_info(): + """Compute the versions of Python, and Ray. + + Returns: + A tuple containing the version information. + """ + ray_version = ray.__version__ + python_version = ".".join(map(str, sys.version_info[:3])) + return ray_version, python_version + + +def get_directory_size_bytes(path: Union[str, Path] = ".") -> int: + """Get the total size of a directory in bytes, including subdirectories.""" + total_size_bytes = 0 + for dirpath, dirnames, filenames in os.walk(path): + for f in filenames: + fp = os.path.join(dirpath, f) + # skip if it is a symbolic link or a .pyc file + if not os.path.islink(fp) and not f.endswith(".pyc"): + total_size_bytes += os.path.getsize(fp) + + return total_size_bytes + + +def check_version_info( + cluster_metadata, + this_process_address, + raise_on_mismatch=True, + python_version_match_level="patch", +): + """Check if the Python and Ray versions stored in GCS matches this process. + Args: + cluster_metadata: Ray cluster metadata from GCS. + this_process_address: Informational only. The address of this process. + e.g. "node address:port" or "Ray Client". + raise_on_mismatch: Raise an exception on True, log a warning otherwise. + python_version_match_level: "minor" or "patch". To which python version level we + try to match. Note if "minor" and the patch is different, we will still log + a warning. + + Behavior: + - We raise or log a warning, based on raise_on_mismatch, if: + - Ray versions do not match; OR + - Python (major, minor) versions do not match, + if python_version_match_level == 'minor'; OR + - Python (major, minor, patch) versions do not match, + if python_version_match_level == 'patch'. + - We also log a warning if: + - Python (major, minor) versions match, AND + - Python patch versions do not match, AND + - python_version_match_level == 'minor' AND + - raise_on_mismatch == False. + Raises: + Exception: An exception is raised if there is a version mismatch. + """ + cluster_version_info = ( + cluster_metadata["ray_version"], + cluster_metadata["python_version"], + ) + my_version_info = compute_version_info() + + # Calculate: ray_matches, python_matches, python_full_matches + ray_matches = cluster_version_info[0] == my_version_info[0] + python_full_matches = cluster_version_info[1] == my_version_info[1] + if python_version_match_level == "patch": + python_matches = cluster_version_info[1] == my_version_info[1] + elif python_version_match_level == "minor": + my_python_versions = my_version_info[1].split(".") + cluster_python_versions = cluster_version_info[1].split(".") + python_matches = my_python_versions[:2] == cluster_python_versions[:2] + else: + raise ValueError( + f"Invalid python_version_match_level: {python_version_match_level}, " + "want: 'minor' or 'patch'" + ) + + mismatch_msg = ( + "The cluster was started with:\n" + f" Ray: {cluster_version_info[0]}\n" + f" Python: {cluster_version_info[1]}\n" + f"This process on {this_process_address} was started with:\n" + f" Ray: {my_version_info[0]}\n" + f" Python: {my_version_info[1]}\n" + ) + + if ray_matches and python_matches: + if not python_full_matches: + logger.warning(f"Python patch version mismatch: {mismatch_msg}") + else: + error_message = f"Version mismatch: {mismatch_msg}" + if raise_on_mismatch: + raise RuntimeError(error_message) + else: + logger.warning(error_message) + + +def get_runtime_env_info( + runtime_env: "RuntimeEnv", + *, + is_job_runtime_env: bool = False, + serialize: bool = False, +): + """Create runtime env info from runtime env. + + In the user interface, the argument `runtime_env` contains some fields + which not contained in `ProtoRuntimeEnv` but in `ProtoRuntimeEnvInfo`, + such as `eager_install`. This function will extract those fields from + `RuntimeEnv` and create a new `ProtoRuntimeEnvInfo`, and serialize it + into json format. + """ + from ray.runtime_env import RuntimeEnvConfig + + proto_runtime_env_info = ProtoRuntimeEnvInfo() + + if runtime_env.working_dir_uri(): + proto_runtime_env_info.uris.working_dir_uri = runtime_env.working_dir_uri() + if len(runtime_env.py_modules_uris()) > 0: + proto_runtime_env_info.uris.py_modules_uris[:] = runtime_env.py_modules_uris() + + # TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the + # runtime_env of all internal code from dict to RuntimeEnv. + + runtime_env_config = runtime_env.get("config") + if runtime_env_config is None: + runtime_env_config = RuntimeEnvConfig.default_config() + else: + runtime_env_config = RuntimeEnvConfig.parse_and_validate_runtime_env_config( + runtime_env_config + ) + + proto_runtime_env_info.runtime_env_config.CopyFrom( + runtime_env_config.build_proto_runtime_env_config() + ) + + # Normally, `RuntimeEnv` should guarantee the accuracy of field eager_install, + # but so far, the internal code has not completely prohibited direct + # modification of fields in RuntimeEnv, so we should check it for insurance. + eager_install = ( + runtime_env_config.get("eager_install") + if runtime_env_config is not None + else None + ) + if is_job_runtime_env or eager_install is not None: + if eager_install is None: + eager_install = True + elif not isinstance(eager_install, bool): + raise TypeError( + f"eager_install must be a boolean. got {type(eager_install)}" + ) + proto_runtime_env_info.runtime_env_config.eager_install = eager_install + + proto_runtime_env_info.serialized_runtime_env = runtime_env.serialize() + + if not serialize: + return proto_runtime_env_info + + return json_format.MessageToJson(proto_runtime_env_info) + + +def parse_runtime_env(runtime_env: Optional[Union[Dict, "RuntimeEnv"]]): + from ray.runtime_env import RuntimeEnv + + # Parse local pip/conda config files here. If we instead did it in + # .remote(), it would get run in the Ray Client server, which runs on + # a remote node where the files aren't available. + if runtime_env: + if isinstance(runtime_env, dict): + return RuntimeEnv(**(runtime_env or {})) + raise TypeError( + "runtime_env must be dict or RuntimeEnv, ", + f"but got: {type(runtime_env)}", + ) + else: + # Keep the new_runtime_env as None. In .remote(), we need to know + # if runtime_env is None to know whether or not to fall back to the + # runtime_env specified in the @ray.remote decorator. + return None + + +def split_address(address: str) -> Tuple[str, str]: + """Splits address into a module string (scheme) and an inner_address. + + We use a custom splitting function instead of urllib because + PEP allows "underscores" in a module names, while URL schemes do not + allow them. + + Args: + address: The address to split. + + Returns: + A tuple of (scheme, inner_address). + + Raises: + ValueError: If the address does not contain '://'. + + Examples: + >>> split_address("ray://my_cluster") + ('ray', 'my_cluster') + """ + if "://" not in address: + raise ValueError("Address must contain '://'") + + module_string, inner_address = address.split("://", maxsplit=1) + return (module_string, inner_address) + + +def get_or_create_event_loop() -> asyncio.AbstractEventLoop: + """Get a running async event loop if one exists, otherwise create one. + + This function serves as a proxy for the deprecating get_event_loop(). + It tries to get the running loop first, and if no running loop + could be retrieved: + - For python version <3.10: it falls back to the get_event_loop + call. + - For python version >= 3.10: it uses the same python implementation + of _get_event_loop() at asyncio/events.py. + + Ideally, one should use high level APIs like asyncio.run() with python + version >= 3.7, if not possible, one should create and manage the event + loops explicitly. + """ + vers_info = sys.version_info + if vers_info.major >= 3 and vers_info.minor >= 10: + # This follows the implementation of the deprecating `get_event_loop` + # in python3.10's asyncio. See python3.10/asyncio/events.py + # _get_event_loop() + try: + loop = asyncio.get_running_loop() + assert loop is not None + return loop + except RuntimeError as e: + # No running loop, relying on the error message as for now to + # differentiate runtime errors. + assert "no running event loop" in str(e) + return asyncio.get_event_loop_policy().get_event_loop() + + return asyncio.get_event_loop() + + +def get_entrypoint_name(): + """Get the entrypoint of the current script.""" + prefix = "" + try: + curr = psutil.Process() + # Prepend `interactive_shell` for interactive shell scripts. + # https://stackoverflow.com/questions/2356399/tell-if-python-is-in-interactive-mode # noqa + if hasattr(sys, "ps1"): + prefix = "(interactive_shell) " + + return prefix + list2cmdline(curr.cmdline()) + except Exception: + return "unknown" + + +def _add_url_query_params(url: str, params: Dict[str, str]) -> str: + """Add params to the provided url as query parameters. + + If url already contains query parameters, they will be merged with params, with the + existing query parameters overriding any in params with the same parameter name. + + Args: + url: The URL to add query parameters to. + params: The query parameters to add. + + Returns: + URL with params added as query parameters. + """ + # Unquote URL first so we don't lose existing args. + url = unquote(url) + # Parse URL. + parsed_url = urlparse(url) + # Merge URL query string arguments dict with new params. + base_params = params + params = dict(parse_qsl(parsed_url.query)) + base_params.update(params) + # bool and dict values should be converted to json-friendly values. + base_params.update( + { + k: json.dumps(v) + for k, v in base_params.items() + if isinstance(v, (bool, dict)) + } + ) + + # Convert URL arguments to proper query string. + encoded_params = urlencode(base_params, doseq=True) + # Replace query string in parsed URL with updated query string. + parsed_url = parsed_url._replace(query=encoded_params) + # Convert back to URL. + return urlunparse(parsed_url) + + +def _add_creatable_buckets_param_if_s3_uri(uri: str) -> str: + """If the provided URI is an S3 URL, add allow_bucket_creation=true as a query + parameter. For pyarrow >= 9.0.0, this is required in order to allow + ``S3FileSystem.create_dir()`` to create S3 buckets. + + If the provided URI is not an S3 URL or if pyarrow < 9.0.0 is installed, we return + the URI unchanged. + + Args: + uri: The URI that we'll add the query parameter to, if it's an S3 URL. + + Returns: + A URI with the added allow_bucket_creation=true query parameter, if the provided + URI is an S3 URL; uri will be returned unchanged otherwise. + """ + from packaging.version import parse as parse_version + + pyarrow_version = _get_pyarrow_version() + if pyarrow_version is not None: + pyarrow_version = parse_version(pyarrow_version) + if pyarrow_version is not None and pyarrow_version < parse_version("9.0.0"): + # This bucket creation query parameter is not required for pyarrow < 9.0.0. + return uri + parsed_uri = urlparse(uri) + if parsed_uri.scheme == "s3": + uri = _add_url_query_params(uri, {"allow_bucket_creation": True}) + return uri + + +def _get_pyarrow_version() -> Optional[str]: + """Get the version of the installed pyarrow package, returned as a tuple of ints. + Returns None if the package is not found. + """ + global _PYARROW_VERSION + if _PYARROW_VERSION is None: + try: + import pyarrow + except ModuleNotFoundError: + # pyarrow not installed, short-circuit. + pass + else: + if hasattr(pyarrow, "__version__"): + _PYARROW_VERSION = pyarrow.__version__ + return _PYARROW_VERSION + + +class DeferSigint(contextlib.AbstractContextManager): + """Context manager that defers SIGINT signals until the context is left.""" + + # This is used by Ray's task cancellation to defer cancellation interrupts during + # problematic areas, e.g. task argument deserialization. + def __init__(self): + # Whether a SIGINT signal was received during the context. + self.signal_received = False + # The overridden SIGINT handler + self.overridden_sigint_handler = None + # The original signal method. + self.orig_signal = None + + @classmethod + def create_if_main_thread(cls) -> contextlib.AbstractContextManager: + """Creates a DeferSigint context manager if running on the main thread, + returns a no-op context manager otherwise. + """ + if threading.current_thread() == threading.main_thread(): + return cls() + else: + return contextlib.nullcontext() + + def _set_signal_received(self, signum, frame): + """SIGINT handler that defers the signal.""" + self.signal_received = True + + def _signal_monkey_patch(self, signum, handler): + """Monkey patch for signal.signal that defers the setting of new signal + handler after the DeferSigint context exits.""" + # Only handle it in the main thread because if setting a handler in a non-main + # thread, signal.signal will raise an error because Python doesn't allow it. + if ( + threading.current_thread() == threading.main_thread() + and signum == signal.SIGINT + ): + orig_sigint_handler = self.overridden_sigint_handler + self.overridden_sigint_handler = handler + return orig_sigint_handler + return self.orig_signal(signum, handler) + + def __enter__(self): + # Save original SIGINT handler for later restoration. + self.overridden_sigint_handler = signal.getsignal(signal.SIGINT) + # Set SIGINT signal handler that defers the signal. + signal.signal(signal.SIGINT, self._set_signal_received) + # Monkey patch signal.signal to raise an error if a SIGINT handler is registered + # within the context. + self.orig_signal = signal.signal + signal.signal = self._signal_monkey_patch + return self + + def __exit__(self, exc_type, exc, exc_tb): + assert self.overridden_sigint_handler is not None + assert self.orig_signal is not None + # Restore original signal.signal function. + signal.signal = self.orig_signal + # Restore overridden SIGINT handler. + signal.signal(signal.SIGINT, self.overridden_sigint_handler) + if exc_type is None and self.signal_received: + # No exception raised in context, call the original SIGINT handler. + # By default, this means raising KeyboardInterrupt. + self.overridden_sigint_handler(signal.SIGINT, None) + else: + # If exception was raised in context, returning False will cause it to be + # reraised. + return False + + +background_tasks = set() + + +def run_background_task(coroutine: Coroutine) -> asyncio.Task: + """Schedule a task reliably to the event loop. + + This API is used when you don't want to cache the reference of `asyncio.Task`. + For example, + + ``` + get_event_loop().create_task(coroutine(*args)) + ``` + + The above code doesn't guarantee to schedule the coroutine to the event loops + + When using create_task in a "fire and forget" way, we should keep the references + alive for the reliable execution. This API is used to fire and forget + asynchronous execution. + + https://docs.python.org/3/library/asyncio-task.html#creating-tasks + """ + task = get_or_create_event_loop().create_task(coroutine) + # Add task to the set. This creates a strong reference. + background_tasks.add(task) + + # To prevent keeping references to finished tasks forever, + # make each task remove its own reference from the set after + # completion: + task.add_done_callback(background_tasks.discard) + return task + + +def try_import_each_module(module_names_to_import: List[str]) -> None: + """ + Make a best-effort attempt to import each named Python module. + This is used by the Python default_worker.py to preload modules. + """ + for module_to_preload in module_names_to_import: + try: + importlib.import_module(module_to_preload) + except ImportError: + logger.exception(f'Failed to preload the module "{module_to_preload}"') + + +def remove_ray_internal_flags_from_env(env: dict): + """ + Remove Ray internal flags from `env`. + Defined in ray/common/ray_internal_flag_def.h + """ + for flag in ray_constants.RAY_INTERNAL_FLAGS: + env.pop(flag, None) + + +def update_envs(env_vars: Dict[str, str]): + """ + When updating the environment variable, if there is ${X}, + it will be replaced with the current environment variable. + """ + if not env_vars: + return + + for key, value in env_vars.items(): + expanded = os.path.expandvars(value) + # Replace non-existant env vars with an empty string. + result = re.sub(r"\$\{[A-Z0-9_]+\}", "", expanded) + os.environ[key] = result + + +def parse_node_labels_json( + labels_json: str, cli_logger, cf, command_arg="--labels" +) -> Dict[str, str]: + try: + labels = json.loads(labels_json) + if not isinstance(labels, dict): + raise ValueError( + "The format after deserialization is not a key-value pair map" + ) + for key, value in labels.items(): + if not isinstance(key, str): + raise ValueError("The key is not string type.") + if not isinstance(value, str): + raise ValueError(f'The value of the "{key}" is not string type') + except Exception as e: + cli_logger.abort( + "`{}` is not a valid JSON string, detail error:{}" + "Valid values look like this: `{}`", + cf.bold(f"{command_arg}={labels_json}"), + str(e), + cf.bold(f'{command_arg}=\'{{"gpu_type": "A100", "region": "us"}}\''), + ) + return labels + + +def validate_node_labels(labels: Dict[str, str]): + if labels is None: + return + for key in labels.keys(): + if key.startswith(ray_constants.RAY_DEFAULT_LABEL_KEYS_PREFIX): + raise ValueError( + f"Custom label keys `{key}` cannot start with the prefix " + f"`{ray_constants.RAY_DEFAULT_LABEL_KEYS_PREFIX}`. " + f"This is reserved for Ray defined labels." + ) + + +def parse_pg_formatted_resources_to_original( + pg_formatted_resources: Dict[str, float] +) -> Dict[str, float]: + original_resources = {} + + for key, value in pg_formatted_resources.items(): + result = PLACEMENT_GROUP_INDEXED_BUNDLED_RESOURCE_PATTERN.match(key) + if result and len(result.groups()) == 3: + # Filter out resources that have bundle_group_[pg_id] since + # it is an implementation detail. + # This resource is automatically added to the resource + # request for all tasks that require placement groups. + if result.group(1) == ray_constants.PLACEMENT_GROUP_BUNDLE_RESOURCE_NAME: + continue + + original_resources[result.group(1)] = value + continue + + result = PLACEMENT_GROUP_WILDCARD_RESOURCE_PATTERN.match(key) + if result and len(result.groups()) == 2: + if result.group(1) == "bundle": + continue + + original_resources[result.group(1)] = value + continue + original_resources[key] = value + + return original_resources + + +def load_class(path): + """Load a class at runtime given a full path. + + Example of the path: mypkg.mysubpkg.myclass + """ + class_data = path.split(".") + if len(class_data) < 2: + raise ValueError("You need to pass a valid path like mymodule.provider_class") + module_path = ".".join(class_data[:-1]) + class_str = class_data[-1] + module = importlib.import_module(module_path) + return getattr(module, class_str) + + +def validate_actor_state_name(actor_state_name): + if actor_state_name is None: + return + actor_state_names = [ + "DEPENDENCIES_UNREADY", + "PENDING_CREATION", + "ALIVE", + "RESTARTING", + "DEAD", + ] + if actor_state_name not in actor_state_names: + raise ValueError( + f'"{actor_state_name}" is not a valid actor state name, ' + 'it must be one of the following: "DEPENDENCIES_UNREADY", ' + '"PENDING_CREATION", "ALIVE", "RESTARTING", or "DEAD"' + ) + + +def get_current_node_cpu_model_name() -> Optional[str]: + if not sys.platform.startswith("linux"): + return None + + try: + """ + /proc/cpuinfo content example: + + processor : 0 + vendor_id : GenuineIntel + cpu family : 6 + model : 85 + model name : Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz + stepping : 7 + """ + with open("/proc/cpuinfo", "r") as f: + for line in f: + if line.startswith("model name"): + return line.split(":")[1].strip() + return None + except Exception: + logger.debug("Failed to get CPU model name", exc_info=True) + return None diff --git a/.venv/lib/python3.11/site-packages/ray/_private/worker.py b/.venv/lib/python3.11/site-packages/ray/_private/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..225e904ae4754411afb12f05ebfcf95fdfadf127 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/worker.py @@ -0,0 +1,3576 @@ +import atexit +import faulthandler +import functools +import inspect +import io +import json +import logging +import os +import sys +import threading +import time +import traceback +import urllib +import warnings +from abc import ABCMeta, abstractmethod +from collections.abc import Mapping +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + IO, + Any, + AnyStr, + Callable, + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, +) +from urllib.parse import urlparse + +import colorama +import setproctitle + +from typing import Literal, Protocol + +import ray +import ray._private.worker +import ray._private.node +import ray._private.parameter +import ray._private.profiling as profiling +import ray._private.ray_constants as ray_constants +import ray._private.serialization as serialization +import ray._private.services as services +import ray._private.state +import ray._private.storage as storage + +from ray._private.ray_logging.logging_config import LoggingConfig + +# Ray modules +import ray.actor +import ray.cloudpickle as pickle # noqa +import ray.job_config +import ray.remote_function +from ray import ActorID, JobID, Language, ObjectRef +from ray._raylet import raise_sys_exit_with_custom_error_message +from ray._raylet import ObjectRefGenerator, TaskID +from ray.runtime_env.runtime_env import _merge_runtime_env +from ray._private import ray_option_utils +from ray._private.client_mode_hook import client_mode_hook +from ray._private.function_manager import FunctionActorManager + +from ray._private.inspect_util import is_cython +from ray._private.ray_logging import ( + global_worker_stdstream_dispatcher, + stdout_deduplicator, + stderr_deduplicator, + setup_logger, +) +from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR +from ray._private.runtime_env.py_modules import upload_py_modules_if_needed +from ray._private.runtime_env.working_dir import upload_working_dir_if_needed +from ray._private.runtime_env.setup_hook import ( + upload_worker_process_setup_hook_if_needed, +) +from ray._private.storage import _load_class +from ray._private.utils import get_ray_doc_version +from ray.exceptions import ObjectStoreFullError, RayError, RaySystemError, RayTaskError +from ray.experimental.internal_kv import ( + _initialize_internal_kv, + _internal_kv_get, + _internal_kv_initialized, + _internal_kv_reset, +) +from ray.experimental import tqdm_ray +from ray.experimental.compiled_dag_ref import CompiledDAGRef +from ray.experimental.tqdm_ray import RAY_TQDM_MAGIC +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.util.debug import log_once +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from ray.util.tracing.tracing_helper import _import_from_string +from ray.widgets import Template +from ray.widgets.util import repr_with_fallback + +SCRIPT_MODE = 0 +WORKER_MODE = 1 +LOCAL_MODE = 2 +SPILL_WORKER_MODE = 3 +RESTORE_WORKER_MODE = 4 + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + + +T = TypeVar("T") +T0 = TypeVar("T0") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") +R = TypeVar("R") + +DAGNode = TypeVar("DAGNode") + + +# Only used for type annotations as a placeholder +Undefined: Any = object() + + +# TypeVar for self-referential generics in `RemoteFunction[N]`. +RF = TypeVar("RF", bound="HasOptions") + + +class HasOptions(Protocol): + def options(self: RF, **task_options) -> RF: + ... + + +class RemoteFunctionNoArgs(HasOptions, Generic[R]): + def __init__(self, function: Callable[[], R]) -> None: + pass + + def remote( + self, + ) -> "ObjectRef[R]": + ... + + def bind( + self, + ) -> "DAGNode[R]": + ... + + +class RemoteFunction0(HasOptions, Generic[R, T0]): + def __init__(self, function: Callable[[T0], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction1(HasOptions, Generic[R, T0, T1]): + def __init__(self, function: Callable[[T0, T1], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction2(HasOptions, Generic[R, T0, T1, T2]): + def __init__(self, function: Callable[[T0, T1, T2], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction3(HasOptions, Generic[R, T0, T1, T2, T3]): + def __init__(self, function: Callable[[T0, T1, T2, T3], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction4(HasOptions, Generic[R, T0, T1, T2, T3, T4]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction5(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction6(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction7(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction8(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]): + def __init__( + self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] + ) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + __arg8: "Union[T8, ObjectRef[T8]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + __arg8: "Union[T8, DAGNode[T8]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction9(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]): + def __init__( + self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] + ) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + __arg8: "Union[T8, ObjectRef[T8]]", + __arg9: "Union[T9, ObjectRef[T9]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + __arg8: "Union[T8, DAGNode[T8]]", + __arg9: "Union[T9, DAGNode[T9]]", + ) -> "DAGNode[R]": + ... + + +# Visible for testing. +def _unhandled_error_handler(e: Exception): + logger.error( + f"Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): {e}" + ) + + +class Worker: + """A class used to define the control flow of a worker process. + + Note: + The methods in this class are considered unexposed to the user. The + functions outside of this class are considered exposed. + + Attributes: + node (ray._private.node.Node): The node this worker is attached to. + mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and + WORKER_MODE. + """ + + def __init__(self): + """Initialize a Worker object.""" + self.node = None + self.mode = None + self.actors = {} + # When the worker is constructed. Record the original value of the + # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, ROCR_VISIBLE_DEVICES, + # NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) environment variables. + self.original_visible_accelerator_ids = ( + ray._private.utils.get_visible_accelerator_ids() + ) + # A dictionary that maps from driver id to SerializationContext + # TODO: clean up the SerializationContext once the job finished. + self.serialization_context_map = {} + self.function_actor_manager = FunctionActorManager(self) + # This event is checked regularly by all of the threads so that they + # know when to exit. + self.threads_stopped = threading.Event() + # If this is set, the next .remote call should drop into the + # debugger, at the specified breakpoint ID. + self.debugger_breakpoint = b"" + # If this is set, ray.get calls invoked on the object ID returned + # by the worker should drop into the debugger at the specified + # breakpoint ID. + self.debugger_get_breakpoint = b"" + # If True, make the debugger external to the node this worker is + # running on. + self.ray_debugger_external = False + self._load_code_from_local = False + # Opened file descriptor to stdout/stderr for this python worker. + self._enable_record_actor_task_log = ( + ray_constants.RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING + ) + self._out_file = None + self._err_file = None + # Create the lock here because the serializer will use it before + # initializing Ray. + self.lock = threading.RLock() + # By default, don't show logs from other drivers. This is set to true by Serve + # in order to stream logs from the controller and replica actors across + # different drivers that connect to the same Serve instance. + # See https://github.com/ray-project/ray/pull/35070. + self._filter_logs_by_job = True + # the debugger port for this worker + self._debugger_port = None + # Cache the job id from initialize_job_config() to optimize lookups. + # This is on the critical path of ray.get()/put() calls. + self._cached_job_id = None + # Indicates whether the worker is connected to the Ray cluster. + # It should be set to True in `connect` and False in `disconnect`. + self._is_connected: bool = False + + @property + def connected(self): + """bool: True if Ray has been started and False otherwise.""" + return self._is_connected + + def set_is_connected(self, is_connected: bool): + self._is_connected = is_connected + + @property + def node_ip_address(self): + self.check_connected() + return self.node.node_ip_address + + @property + def load_code_from_local(self): + self.check_connected() + return self._load_code_from_local + + @property + def current_job_id(self): + if self._cached_job_id is not None: + return self._cached_job_id + elif hasattr(self, "core_worker"): + return self.core_worker.get_current_job_id() + return JobID.nil() + + @property + def actor_id(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_actor_id() + return ActorID.nil() + + @property + def actor_name(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_actor_name().decode("utf-8") + return None + + @property + def current_task_id(self): + return self.core_worker.get_current_task_id() + + @property + def current_task_name(self): + return self.core_worker.get_current_task_name() + + @property + def current_task_function_name(self): + return self.core_worker.get_current_task_function_name() + + @property + def current_node_id(self): + return self.core_worker.get_current_node_id() + + @property + def task_depth(self): + return self.core_worker.get_task_depth() + + @property + def namespace(self): + return self.core_worker.get_job_config().ray_namespace + + @property + def placement_group_id(self): + return self.core_worker.get_placement_group_id() + + @property + def worker_id(self): + return self.core_worker.get_worker_id().binary() + + @property + def should_capture_child_tasks_in_placement_group(self): + return self.core_worker.should_capture_child_tasks_in_placement_group() + + @property + def current_cluster_and_job(self): + """Get the current session index and job id as pair.""" + assert isinstance(self.node.cluster_id, ray.ClusterID) + assert isinstance(self.current_job_id, ray.JobID) + return self.node.cluster_id, self.current_job_id + + @property + def runtime_env(self): + """Get the runtime env in json format""" + return self.core_worker.get_current_runtime_env() + + @property + def debugger_port(self): + """Get the debugger port for this worker""" + worker_id = self.core_worker.get_worker_id() + return ray._private.state.get_worker_debugger_port(worker_id) + + @property + def job_logging_config(self): + """Get the job's logging config for this worker""" + if not hasattr(self, "core_worker"): + return None + job_config = self.core_worker.get_job_config() + if not job_config.serialized_py_logging_config: + return None + logging_config = pickle.loads(job_config.serialized_py_logging_config) + return logging_config + + def set_debugger_port(self, port): + worker_id = self.core_worker.get_worker_id() + ray._private.state.update_worker_debugger_port(worker_id, port) + + def set_cached_job_id(self, job_id): + """Set the cached job id to speed `current_job_id()`.""" + self._cached_job_id = job_id + + @contextmanager + def task_paused_by_debugger(self): + """Use while the task is paused by debugger""" + try: + self.core_worker.update_task_is_debugger_paused( + ray.get_runtime_context()._get_current_task_id(), True + ) + yield + finally: + self.core_worker.update_task_is_debugger_paused( + ray.get_runtime_context()._get_current_task_id(), False + ) + + @contextmanager + def worker_paused_by_debugger(self): + """ + Updates the worker num paused threads when the worker is paused by debugger + """ + try: + worker_id = self.core_worker.get_worker_id() + ray._private.state.update_worker_num_paused_threads(worker_id, 1) + yield + finally: + ray._private.state.update_worker_num_paused_threads(worker_id, -1) + + def set_err_file(self, err_file=Optional[IO[AnyStr]]) -> None: + """Set the worker's err file where stderr is redirected to""" + self._err_file = err_file + + def set_out_file(self, out_file=Optional[IO[AnyStr]]) -> None: + """Set the worker's out file where stdout is redirected to""" + self._out_file = out_file + + def record_task_log_start(self, task_id: TaskID, attempt_number: int): + """Record the task log info when task starts executing for + non concurrent actor tasks.""" + if not self._enable_record_actor_task_log and not self.actor_id.is_nil(): + # We are not recording actor task log if not enabled explicitly. + # Recording actor task log is expensive and should be enabled only + # when needed. + # https://github.com/ray-project/ray/issues/35598 + return + + if not hasattr(self, "core_worker"): + return + + self.core_worker.record_task_log_start( + task_id, + attempt_number, + self.get_out_file_path(), + self.get_err_file_path(), + self.get_current_out_offset(), + self.get_current_err_offset(), + ) + + def record_task_log_end(self, task_id: TaskID, attempt_number: int): + """Record the task log info when task finishes executing for + non concurrent actor tasks.""" + if not self._enable_record_actor_task_log and not self.actor_id.is_nil(): + # We are not recording actor task log if not enabled explicitly. + # Recording actor task log is expensive and should be enabled only + # when needed. + # https://github.com/ray-project/ray/issues/35598 + return + + if not hasattr(self, "core_worker"): + return + + self.core_worker.record_task_log_end( + task_id, + attempt_number, + self.get_current_out_offset(), + self.get_current_err_offset(), + ) + + def get_err_file_path(self) -> str: + """Get the err log file path""" + return self._err_file.name if self._err_file is not None else "" + + def get_out_file_path(self) -> str: + """Get the out log file path""" + return self._out_file.name if self._out_file is not None else "" + + def get_current_out_offset(self) -> int: + """Get the current offset of the out file if seekable, else 0""" + if self._out_file is not None and self._out_file.seekable(): + return self._out_file.tell() + return 0 + + def get_current_err_offset(self) -> int: + """Get the current offset of the err file if seekable, else 0""" + if self._err_file is not None and self._err_file.seekable(): + return self._err_file.tell() + return 0 + + def get_serialization_context(self): + """Get the SerializationContext of the job that this worker is processing. + + Returns: + The serialization context of the given job. + """ + # This function needs to be protected by a lock, because it will be + # called by`register_class_for_serialization`, as well as the import + # thread, from different threads. Also, this function will recursively + # call itself, so we use RLock here. + job_id = self.current_job_id + context_map = self.serialization_context_map + with self.lock: + if job_id not in context_map: + # The job ID is nil before initializing Ray. + if JobID.nil() in context_map: + # Transfer the serializer context used before initializing Ray. + context_map[job_id] = context_map.pop(JobID.nil()) + else: + context_map[job_id] = serialization.SerializationContext(self) + return context_map[job_id] + + def check_connected(self): + """Check if the worker is connected. + + Raises: + Exception: An exception is raised if the worker is not connected. + """ + if not self.connected: + raise RaySystemError( + "Ray has not been started yet. You can start Ray with 'ray.init()'." + ) + + def set_mode(self, mode): + """Set the mode of the worker. + + The mode SCRIPT_MODE should be used if this Worker is a driver that is + being run as a Python script or interactively in a shell. It will print + information about task failures. + + The mode WORKER_MODE should be used if this Worker is not a driver. It + will not print information about tasks. + + The mode LOCAL_MODE should be used if this Worker is a driver and if + you want to run the driver in a manner equivalent to serial Python for + debugging purposes. It will not send remote function calls to the + scheduler and will instead execute them in a blocking fashion. + + Args: + mode: One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. + """ + self.mode = mode + + def set_load_code_from_local(self, load_code_from_local): + self._load_code_from_local = load_code_from_local + + def put_object( + self, + value: Any, + object_ref: Optional["ray.ObjectRef"] = None, + owner_address: Optional[str] = None, + _is_experimental_channel: bool = False, + ): + """Put value in the local object store with object reference `object_ref`. + + This assumes that the value for `object_ref` has not yet been placed in + the local object store. If the plasma store is full, the worker will + automatically retry up to DEFAULT_PUT_OBJECT_RETRIES times. Each + retry will delay for an exponentially doubling amount of time, + starting with DEFAULT_PUT_OBJECT_DELAY. After this, exception + will be raised. + + Args: + value: The value to put in the object store. + object_ref: The object ref of the value to be + put. If None, one will be generated. + owner_address: The serialized address of object's owner. + _is_experimental_channel: An experimental flag for mutable + objects. If True, then the returned object will not have a + valid value. The object must be written to using the + ray.experimental.channel API before readers can read. + + Returns: + ObjectRef: The object ref the object was put under. + + Raises: + ray.exceptions.ObjectStoreFullError: This is raised if the attempt + to store the object fails because the object store is full even + after multiple retries. + """ + # Make sure that the value is not an object ref. + if isinstance(value, ObjectRef): + raise TypeError( + "Calling 'put' on an ray.ObjectRef is not allowed. " + "If you really want to do this, you can wrap the " + "ray.ObjectRef in a list and call 'put' on it." + ) + + if self.mode == LOCAL_MODE: + assert ( + object_ref is None + ), "Local Mode does not support inserting with an ObjectRef" + + try: + serialized_value = self.get_serialization_context().serialize(value) + except TypeError as e: + sio = io.StringIO() + ray.util.inspect_serializability(value, print_file=sio) + msg = ( + "Could not serialize the put value " + f"{repr(value)}:\n" + f"{sio.getvalue()}" + ) + raise TypeError(msg) from e + + # If the object is mutable, then the raylet should never read the + # object. Instead, clients will keep the object pinned. + pin_object = not _is_experimental_channel + + # This *must* be the first place that we construct this python + # ObjectRef because an entry with 0 local references is created when + # the object is Put() in the core worker, expecting that this python + # reference will be created. If another reference is created and + # removed before this one, it will corrupt the state in the + # reference counter. + return ray.ObjectRef( + self.core_worker.put_serialized_object_and_increment_local_ref( + serialized_value, + object_ref=object_ref, + pin_object=pin_object, + owner_address=owner_address, + _is_experimental_channel=_is_experimental_channel, + ), + # The initial local reference is already acquired internally. + skip_adding_local_ref=True, + ) + + def raise_errors(self, data_metadata_pairs, object_refs): + out = self.deserialize_objects(data_metadata_pairs, object_refs) + if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: + return + for e in out: + _unhandled_error_handler(e) + + def deserialize_objects(self, data_metadata_pairs, object_refs): + # Function actor manager or the import thread may call pickle.loads + # at the same time which can lead to failed imports + # TODO: We may be better off locking on all imports or injecting a lock + # into pickle.loads (https://github.com/ray-project/ray/issues/16304) + with self.function_actor_manager.lock: + context = self.get_serialization_context() + return context.deserialize_objects(data_metadata_pairs, object_refs) + + def get_objects( + self, + object_refs: list, + timeout: Optional[float] = None, + return_exceptions: bool = False, + skip_deserialization: bool = False, + ): + """Get the values in the object store associated with the IDs. + + Return the values from the local object store for object_refs. This + will block until all the values for object_refs have been written to + the local object store. + + Args: + object_refs: A list of the object refs + whose values should be retrieved. + timeout: The maximum amount of time in + seconds to wait before returning. + return_exceptions: If any of the objects deserialize to an + Exception object, whether to return them as values in the + returned list. If False, then the first found exception will be + raised. + skip_deserialization: If true, only the buffer will be released and + the object associated with the buffer will not be deserailized. + Returns: + list: List of deserialized objects or None if skip_deserialization is True. + bytes: UUID of the debugger breakpoint we should drop + into or b"" if there is no breakpoint. + """ + # Make sure that the values are object refs. + for object_ref in object_refs: + if not isinstance(object_ref, ObjectRef): + raise TypeError( + f"Attempting to call `get` on the value {object_ref}, " + "which is not an ray.ObjectRef." + ) + + timeout_ms = ( + int(timeout * 1000) if timeout is not None and timeout != -1 else -1 + ) + data_metadata_pairs: List[ + Tuple[ray._raylet.Buffer, bytes] + ] = self.core_worker.get_objects( + object_refs, + timeout_ms, + ) + + debugger_breakpoint = b"" + for data, metadata in data_metadata_pairs: + if metadata: + metadata_fields = metadata.split(b",") + if len(metadata_fields) >= 2 and metadata_fields[1].startswith( + ray_constants.OBJECT_METADATA_DEBUG_PREFIX + ): + debugger_breakpoint = metadata_fields[1][ + len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) : + ] + if skip_deserialization: + return None, debugger_breakpoint + + values = self.deserialize_objects(data_metadata_pairs, object_refs) + if not return_exceptions: + # Raise exceptions instead of returning them to the user. + for i, value in enumerate(values): + if isinstance(value, RayError): + if isinstance(value, ray.exceptions.ObjectLostError): + global_worker.core_worker.dump_object_store_memory_usage() + if isinstance(value, RayTaskError): + raise value.as_instanceof_cause() + else: + raise value + + return values, debugger_breakpoint + + def main_loop(self): + """The main loop a worker runs to receive and execute tasks.""" + + def sigterm_handler(signum, frame): + raise_sys_exit_with_custom_error_message( + "The process receives a SIGTERM.", exit_code=1 + ) + # Note: shutdown() function is called from atexit handler. + + ray._private.utils.set_sigterm_handler(sigterm_handler) + self.core_worker.run_task_loop() + sys.exit(0) + + def print_logs(self): + """Prints log messages from workers on all nodes in the same job.""" + subscriber = self.gcs_log_subscriber + subscriber.subscribe() + exception_type = ray.exceptions.RpcError + localhost = services.get_node_ip_address() + try: + # Number of messages received from the last polling. When the batch + # size exceeds 100 and keeps increasing, the worker and the user + # probably will not be able to consume the log messages as rapidly + # as they are coming in. + # This is meaningful only for GCS subscriber. + last_polling_batch_size = 0 + job_id_hex = self.current_job_id.hex() + while True: + # Exit if we received a signal that we should stop. + if self.threads_stopped.is_set(): + return + + data = subscriber.poll() + # GCS subscriber only returns None on unavailability. + if data is None: + last_polling_batch_size = 0 + continue + + if ( + self._filter_logs_by_job + and data["job"] + and data["job"] != job_id_hex + ): + last_polling_batch_size = 0 + continue + + data["localhost"] = localhost + global_worker_stdstream_dispatcher.emit(data) + + lagging = 100 <= last_polling_batch_size < subscriber.last_batch_size + if lagging: + logger.warning( + "The driver may not be able to keep up with the " + "stdout/stderr of the workers. To avoid forwarding " + "logs to the driver, use " + "'ray.init(log_to_driver=False)'." + ) + + last_polling_batch_size = subscriber.last_batch_size + + except (OSError, exception_type) as e: + logger.error(f"print_logs: {e}") + finally: + # Close the pubsub client to avoid leaking file descriptors. + subscriber.close() + + def get_accelerator_ids_for_accelerator_resource( + self, resource_name: str, resource_regex: str + ) -> Union[List[str], List[int]]: + """Get the accelerator IDs that are assigned to the given accelerator resource. + + Args: + resource_name: The name of the resource. + resource_regex: The regex of the resource. + + Returns: + (List[str]) The IDs that are assigned to the given resource pre-configured. + (List[int]) The IDs that are assigned to the given resource. + """ + resource_ids = self.core_worker.resource_ids() + assigned_ids = set() + # Handle both normal and placement group accelerator resources. + # Note: We should only get the accelerator ids from the placement + # group resource that does not contain the bundle index! + import re + + for resource, assignment in resource_ids.items(): + if resource == resource_name or re.match(resource_regex, resource): + for resource_id, _ in assignment: + assigned_ids.add(resource_id) + + # If the user had already set the environment variables + # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, NEURON_RT_VISIBLE_CORES, + # TPU_VISIBLE_CHIPS, ..) then respect that in the sense that only IDs + # that appear in (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, + # ROCR_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) + # should be returned. + if self.original_visible_accelerator_ids.get(resource_name, None) is not None: + original_ids = self.original_visible_accelerator_ids[resource_name] + assigned_ids = {str(original_ids[i]) for i in assigned_ids} + # Give all accelerator ids in local_mode. + if self.mode == LOCAL_MODE: + if resource_name == ray_constants.GPU: + max_accelerators = self.node.get_resource_spec().num_gpus + else: + max_accelerators = self.node.get_resource_spec().resources.get( + resource_name, None + ) + if max_accelerators: + assigned_ids = original_ids[:max_accelerators] + return list(assigned_ids) + + +@PublicAPI +@client_mode_hook +def get_gpu_ids() -> Union[List[int], List[str]]: + """Get the IDs of the GPUs that are available to the worker. + + This method should only be called inside of a task or actor, and not a driver. + + If the CUDA_VISIBLE_DEVICES environment variable was set when the worker + started up, then the IDs returned by this method will be a subset of the + IDs in CUDA_VISIBLE_DEVICES. If not, the IDs will fall in the range + [0, NUM_GPUS - 1], where NUM_GPUS is the number of GPUs that the node has. + + Returns: + A list of GPU IDs. + """ + worker = global_worker + worker.check_connected() + return worker.get_accelerator_ids_for_accelerator_resource( + ray_constants.GPU, f"^{ray_constants.GPU}_group_[0-9A-Za-z]+$" + ) + + +@Deprecated( + message="Use ray.get_runtime_context().get_assigned_resources() instead.", + warning=True, +) +def get_resource_ids(): + """Get the IDs of the resources that are available to the worker. + + Returns: + A dictionary mapping the name of a resource to a list of pairs, where + each pair consists of the ID of a resource and the fraction of that + resource reserved for this worker. + """ + worker = global_worker + worker.check_connected() + + if _mode() == LOCAL_MODE: + raise RuntimeError( + "ray._private.worker.get_resource_ids() does not work in local_mode." + ) + + return global_worker.core_worker.resource_ids() + + +@Deprecated(message="Use ray.init().address_info['webui_url'] instead.") +def get_dashboard_url(): + """Get the URL to access the Ray dashboard. + + Note that the URL does not specify which node the dashboard is on. + + Returns: + The URL of the dashboard as a string. + """ + if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ: + return _remove_protocol_from_url( + os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL) + ) + else: + worker = global_worker + worker.check_connected() + return _global_node.webui_url + + +def _remove_protocol_from_url(url: Optional[str]) -> str: + """ + Helper function to remove protocol from URL if it exists. + """ + if not url: + return url + parsed_url = urllib.parse.urlparse(url) + if parsed_url.scheme: + # Construct URL without protocol + scheme = f"{parsed_url.scheme}://" + return parsed_url.geturl().replace(scheme, "", 1) + return url + + +class BaseContext(metaclass=ABCMeta): + """ + Base class for RayContext and ClientContext + """ + + dashboard_url: Optional[str] + python_version: str + ray_version: str + + @abstractmethod + def disconnect(self): + """ + If this context is for directly attaching to a cluster, disconnect + will call ray.shutdown(). Otherwise, if the context is for a ray + client connection, the client will be disconnected. + """ + pass + + @abstractmethod + def __enter__(self): + pass + + @abstractmethod + def __exit__(self): + pass + + def _context_table_template(self): + if self.dashboard_url: + dashboard_row = Template("context_dashrow.html.j2").render( + dashboard_url="http://" + self.dashboard_url + ) + else: + dashboard_row = None + + return Template("context_table.html.j2").render( + python_version=self.python_version, + ray_version=self.ray_version, + dashboard_row=dashboard_row, + ) + + def _repr_html_(self): + return Template("context.html.j2").render( + context_logo=Template("context_logo.html.j2").render(), + context_table=self._context_table_template(), + ) + + @repr_with_fallback(["ipywidgets", "8"]) + def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]: + """Get the mimebundle for the widget representation of the context. + + Args: + **kwargs: Passed to the _repr_mimebundle_() function for the widget + + Returns: + Dictionary ("mimebundle") of the widget representation of the context. + """ + import ipywidgets + + disconnect_button = ipywidgets.Button( + description="Disconnect", + disabled=False, + button_style="", + tooltip="Disconnect from the Ray cluster", + layout=ipywidgets.Layout(margin="auto 0px 0px 0px"), + ) + + def disconnect_callback(button): + button.disabled = True + button.description = "Disconnecting..." + self.disconnect() + button.description = "Disconnected" + + disconnect_button.on_click(disconnect_callback) + left_content = ipywidgets.VBox( + [ + ipywidgets.HTML(Template("context_logo.html.j2").render()), + disconnect_button, + ], + layout=ipywidgets.Layout(), + ) + right_content = ipywidgets.HTML(self._context_table_template()) + widget = ipywidgets.HBox( + [left_content, right_content], layout=ipywidgets.Layout(width="100%") + ) + return widget._repr_mimebundle_(**kwargs) + + def _repr_mimebundle_(self, **kwargs): + bundle = self._get_widget_bundle(**kwargs) + + # Overwrite the widget html repr and default repr with those of the BaseContext + bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)}) + return bundle + + +@dataclass +class RayContext(BaseContext, Mapping): + """ + Context manager for attached drivers. + """ + + dashboard_url: Optional[str] + python_version: str + ray_version: str + ray_commit: str + + def __init__(self, address_info: Dict[str, Optional[str]]): + super().__init__() + self.dashboard_url = get_dashboard_url() + self.python_version = "{}.{}.{}".format(*sys.version_info[:3]) + self.ray_version = ray.__version__ + self.ray_commit = ray.__commit__ + self.address_info = address_info + + def __getitem__(self, key): + if log_once("ray_context_getitem"): + warnings.warn( + f'Accessing values through ctx["{key}"] is deprecated. ' + f'Use ctx.address_info["{key}"] instead.', + DeprecationWarning, + stacklevel=2, + ) + return self.address_info[key] + + def __len__(self): + if log_once("ray_context_len"): + warnings.warn("len(ctx) is deprecated. Use len(ctx.address_info) instead.") + return len(self.address_info) + + def __iter__(self): + if log_once("ray_context_len"): + warnings.warn( + "iter(ctx) is deprecated. Use iter(ctx.address_info) instead." + ) + return iter(self.address_info) + + def __enter__(self) -> "RayContext": + return self + + def __exit__(self, *exc): + ray.shutdown() + + def disconnect(self): + # Include disconnect() to stay consistent with ClientContext + ray.shutdown() + + +global_worker = Worker() +"""Worker: The global Worker object for this worker process. + +We use a global Worker object to ensure that there is a single worker object +per worker process. +""" + +_global_node = None +"""ray._private.node.Node: The global node object that is created by ray.init().""" + + +@PublicAPI +@client_mode_hook +def init( + address: Optional[str] = None, + *, + num_cpus: Optional[int] = None, + num_gpus: Optional[int] = None, + resources: Optional[Dict[str, float]] = None, + labels: Optional[Dict[str, str]] = None, + object_store_memory: Optional[int] = None, + local_mode: bool = False, + ignore_reinit_error: bool = False, + include_dashboard: Optional[bool] = None, + dashboard_host: str = ray_constants.DEFAULT_DASHBOARD_IP, + dashboard_port: Optional[int] = None, + job_config: "ray.job_config.JobConfig" = None, + configure_logging: bool = True, + logging_level: int = ray_constants.LOGGER_LEVEL, + logging_format: Optional[str] = None, + logging_config: Optional[LoggingConfig] = None, + log_to_driver: Optional[bool] = None, + namespace: Optional[str] = None, + runtime_env: Optional[Union[Dict[str, Any], "RuntimeEnv"]] = None, # noqa: F821 + storage: Optional[str] = None, + **kwargs, +) -> BaseContext: + """ + Connect to an existing Ray cluster or start one and connect to it. + + This method handles two cases; either a Ray cluster already exists and we + just attach this driver to it or we start all of the processes associated + with a Ray cluster and attach to the newly started cluster. + Note: This method overwrite sigterm handler of the driver process. + + In most cases, it is enough to just call this method with no arguments. + This will autodetect an existing Ray cluster or start a new Ray instance if + no existing cluster is found: + + .. testcode:: + + ray.init() + + To explicitly connect to an existing local cluster, use this as follows. A + ConnectionError will be thrown if no existing local cluster is found. + + .. testcode:: + :skipif: True + + ray.init(address="auto") + + To connect to an existing remote cluster, use this as follows (substituting + in the appropriate address). Note the addition of "ray://" at the beginning + of the address. This requires `ray[client]`. + + .. testcode:: + :skipif: True + + ray.init(address="ray://123.45.67.89:10001") + + More details for starting and connecting to a remote cluster can be found + here: https://docs.ray.io/en/master/cluster/getting-started.html + + You can also define an environment variable called `RAY_ADDRESS` in + the same format as the `address` parameter to connect to an existing + cluster with ray.init() or ray.init(address="auto"). + + Args: + address: The address of the Ray cluster to connect to. The provided + address is resolved as follows: + 1. If a concrete address (e.g., localhost:) is provided, try to + connect to it. Concrete addresses can be prefixed with "ray://" to + connect to a remote cluster. For example, passing in the address + "ray://123.45.67.89:50005" will connect to the cluster at the given + address. + 2. If no address is provided, try to find an existing Ray instance + to connect to. This is done by first checking the environment + variable `RAY_ADDRESS`. If this is not defined, check the address + of the latest cluster started (found in + /tmp/ray/ray_current_cluster) if available. If this is also empty, + then start a new local Ray instance. + 3. If the provided address is "auto", then follow the same process + as above. However, if there is no existing cluster found, this will + throw a ConnectionError instead of starting a new local Ray + instance. + 4. If the provided address is "local", start a new local Ray + instance, even if there is already an existing local Ray instance. + num_cpus: Number of CPUs the user wishes to assign to each + raylet. By default, this is set based on virtual cores. + num_gpus: Number of GPUs the user wishes to assign to each + raylet. By default, this is set based on detected GPUs. + resources: A dictionary mapping the names of custom resources to the + quantities for them available. + labels: [Experimental] The key-value labels of the node. + object_store_memory: The amount of memory (in bytes) to start the + object store with. + By default, this is 30% + (ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION) + of available system memory capped by + the shm size and 200G (ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES) + but can be set higher. + local_mode: Deprecated: consider using the Ray Debugger instead. + ignore_reinit_error: If true, Ray suppresses errors from calling + ray.init() a second time. Ray won't be restarted. + include_dashboard: Boolean flag indicating whether or not to start the + Ray dashboard, which displays the status of the Ray + cluster. If this argument is None, then the UI will be started if + the relevant dependencies are present. + dashboard_host: The host to bind the dashboard server to. Can either be + localhost (127.0.0.1) or 0.0.0.0 (available from all interfaces). + By default, this is set to localhost to prevent access from + external machines. + dashboard_port(int, None): The port to bind the dashboard server to. + Defaults to 8265 and Ray will automatically find a free port if + 8265 is not available. + job_config (ray.job_config.JobConfig): The job configuration. + configure_logging: True (default) if configuration of logging is + allowed here. Otherwise, the user may want to configure it + separately. + logging_level: Logging level for the "ray" logger of the driver process, + defaults to logging.INFO. Ignored unless "configure_logging" is true. + logging_format: Logging format for the "ray" logger of the driver process, + defaults to a string containing a timestamp, filename, line number, and + message. See the source file ray_constants.py for details. Ignored unless + "configure_logging" is true. + logging_config: [Experimental] Logging configuration will be applied to the + root loggers for both the driver process and all worker processes belonging + to the current job. See :class:`~ray.LoggingConfig` for details. + log_to_driver: If true, the output from all of the worker + processes on all nodes will be directed to the driver. + namespace: A namespace is a logical grouping of jobs and named actors. + runtime_env: The runtime environment to use + for this job (see :ref:`runtime-environments` for details). + storage: [Experimental] Specify a URI for persistent cluster-wide storage. + This storage path must be accessible by all nodes of the cluster, otherwise + an error will be raised. This option can also be specified as the + RAY_STORAGE env var. + _enable_object_reconstruction: If True, when an object stored in + the distributed plasma store is lost due to node failure, Ray will + attempt to reconstruct the object by re-executing the task that + created the object. Arguments to the task will be recursively + reconstructed. If False, then ray.ObjectLostError will be + thrown. + _redis_max_memory: Redis max memory. + _plasma_directory: Override the plasma mmap file directory. + _node_ip_address: The IP address of the node that we are on. + _driver_object_store_memory: Deprecated. + _memory: Amount of reservable memory resource in bytes rounded + down to the nearest integer. + _redis_username: Prevents external clients without the username + from connecting to Redis if provided. + _redis_password: Prevents external clients without the password + from connecting to Redis if provided. + _temp_dir: If provided, specifies the root temporary + directory for the Ray process. Must be an absolute path. Defaults to an + OS-specific conventional location, e.g., "/tmp/ray". + _metrics_export_port: Port number Ray exposes system metrics + through a Prometheus endpoint. It is currently under active + development, and the API is subject to change. + _system_config: Configuration for overriding + RayConfig defaults. For testing purposes ONLY. + _tracing_startup_hook: If provided, turns on and sets up tracing + for Ray. Must be the name of a function that takes no arguments and + sets up a Tracer Provider, Remote Span Processors, and + (optional) additional instruments. See more at + docs.ray.io/tracing.html. It is currently under active development, + and the API is subject to change. + _node_name: User-provided node name or identifier. Defaults to + the node IP address. + + Returns: + If the provided address includes a protocol, for example by prepending + "ray://" to the address to get "ray://1.2.3.4:10001", then a + ClientContext is returned with information such as settings, server + versions for ray and python, and the dashboard_url. Otherwise, + a RayContext is returned with ray and python versions, and address + information about the started processes. + + Raises: + Exception: An exception is raised if an inappropriate combination of + arguments is passed in. + """ + if log_to_driver is None: + log_to_driver = ray_constants.RAY_LOG_TO_DRIVER + + # Configure the "ray" logger for the driver process. + if configure_logging: + setup_logger(logging_level, logging_format or ray_constants.LOGGER_FORMAT) + else: + logging.getLogger("ray").handlers.clear() + + # Configure the logging settings for the driver process. + if logging_config or ray_constants.RAY_LOGGING_CONFIG_ENCODING: + logging_config = logging_config or LoggingConfig( + encoding=ray_constants.RAY_LOGGING_CONFIG_ENCODING + ) + logging_config._apply() + + # Parse the hidden options: + _enable_object_reconstruction: bool = kwargs.pop( + "_enable_object_reconstruction", False + ) + _redis_max_memory: Optional[int] = kwargs.pop("_redis_max_memory", None) + _plasma_directory: Optional[str] = kwargs.pop("_plasma_directory", None) + _node_ip_address: str = kwargs.pop("_node_ip_address", None) + _driver_object_store_memory: Optional[int] = kwargs.pop( + "_driver_object_store_memory", None + ) + _memory: Optional[int] = kwargs.pop("_memory", None) + _redis_username: str = kwargs.pop( + "_redis_username", ray_constants.REDIS_DEFAULT_USERNAME + ) + _redis_password: str = kwargs.pop( + "_redis_password", ray_constants.REDIS_DEFAULT_PASSWORD + ) + _temp_dir: Optional[str] = kwargs.pop("_temp_dir", None) + _metrics_export_port: Optional[int] = kwargs.pop("_metrics_export_port", None) + _system_config: Optional[Dict[str, str]] = kwargs.pop("_system_config", None) + _tracing_startup_hook: Optional[Callable] = kwargs.pop( + "_tracing_startup_hook", None + ) + _node_name: str = kwargs.pop("_node_name", None) + # Fix for https://github.com/ray-project/ray/issues/26729 + _skip_env_hook: bool = kwargs.pop("_skip_env_hook", False) + + # terminate any signal before connecting driver + def sigterm_handler(signum, frame): + sys.exit(signum) + + if threading.current_thread() is threading.main_thread(): + ray._private.utils.set_sigterm_handler(sigterm_handler) + else: + logger.warning( + "SIGTERM handler is not set because current thread " + "is not the main thread." + ) + + # If available, use RAY_ADDRESS to override if the address was left + # unspecified, or set to "auto" in the call to init + address_env_var = os.environ.get(ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE) + if address_env_var and (address is None or address == "auto"): + address = address_env_var + logger.info( + f"Using address {address_env_var} set in the environment " + f"variable {ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE}" + ) + + if address is not None and "://" in address: + # Address specified a protocol, use ray client + builder = ray.client(address, _deprecation_warn_enabled=False) + + # Forward any keyword arguments that were changed from their default + # values to the builder + init_sig = inspect.signature(init) + passed_kwargs = {} + for argument_name, param_obj in init_sig.parameters.items(): + if argument_name in {"kwargs", "address"}: + # kwargs and address are handled separately + continue + default_value = param_obj.default + passed_value = locals()[argument_name] + if passed_value != default_value: + # passed value is different than default, pass to the client + # builder + passed_kwargs[argument_name] = passed_value + passed_kwargs.update(kwargs) + builder._init_args(**passed_kwargs) + ctx = builder.connect() + from ray._private.usage import usage_lib + + if passed_kwargs.get("allow_multiple") is True: + with ctx: + usage_lib.put_pre_init_usage_stats() + else: + usage_lib.put_pre_init_usage_stats() + + usage_lib.record_library_usage("client") + return ctx + + if kwargs.get("allow_multiple"): + raise RuntimeError( + "`allow_multiple` argument is passed to `ray.init` when the " + "ray client is not used (" + f"https://docs.ray.io/en/{get_ray_doc_version()}/cluster" + "/running-applications/job-submission" + "/ray-client.html#connect-to-multiple-ray-clusters-experimental). " + "Do not pass the `allow_multiple` to `ray.init` to fix the issue." + ) + + if kwargs: + # User passed in extra keyword arguments but isn't connecting through + # ray client. Raise an error, since most likely a typo in keyword + unknown = ", ".join(kwargs) + raise RuntimeError(f"Unknown keyword argument(s): {unknown}") + + # Try to increase the file descriptor limit, which is too low by + # default for Ray: https://github.com/ray-project/ray/issues/11239 + try: + import resource + + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft < hard: + # https://github.com/ray-project/ray/issues/12059 + soft = max(soft, min(hard, 65536)) + logger.debug( + f"Automatically increasing RLIMIT_NOFILE to max value of {hard}" + ) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) + except ValueError: + logger.debug("Failed to raise limit.") + soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft < 4096: + logger.warning( + "File descriptor limit {} is too low for production " + "servers and may result in connection errors. " + "At least 8192 is recommended. --- " + "Fix with 'ulimit -n 8192'".format(soft) + ) + except ImportError: + logger.debug("Could not import resource module (on Windows)") + pass + + if job_config is None: + job_config = ray.job_config.JobConfig() + + if RAY_JOB_CONFIG_JSON_ENV_VAR in os.environ: + injected_job_config_json = json.loads( + os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR) + ) + injected_job_config: ray.job_config.JobConfig = ( + ray.job_config.JobConfig.from_json(injected_job_config_json) + ) + driver_runtime_env = runtime_env + runtime_env = _merge_runtime_env( + injected_job_config.runtime_env, + driver_runtime_env, + override=os.getenv("RAY_OVERRIDE_JOB_RUNTIME_ENV") == "1", + ) + if runtime_env is None: + # None means there was a conflict. + raise ValueError( + "Failed to merge the Job's runtime env " + f"{injected_job_config.runtime_env} with " + f"a ray.init's runtime env {driver_runtime_env} because " + "of a conflict. Specifying the same runtime_env fields " + "or the same environment variable keys is not allowed. " + "Use RAY_OVERRIDE_JOB_RUNTIME_ENV=1 to instruct Ray to " + "combine Job and Driver's runtime environment in the event of " + "a conflict." + ) + + if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook: + runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])( + runtime_env + ) + job_config.set_runtime_env(runtime_env) + # Similarly, we prefer metadata provided via job submission API + for key, value in injected_job_config.metadata.items(): + job_config.set_metadata(key, value) + + # RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has + # higher priority in case user also provided runtime_env for ray.init() + else: + if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook: + runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])( + runtime_env + ) + + if runtime_env: + # Set runtime_env in job_config if passed in as part of ray.init() + job_config.set_runtime_env(runtime_env) + + # Pass the logging_config to job_config to configure loggers of all worker + # processes belonging to the job. + if logging_config is not None: + job_config.set_py_logging_config(logging_config) + + redis_address, gcs_address = None, None + bootstrap_address = services.canonicalize_bootstrap_address(address, _temp_dir) + if bootstrap_address is not None: + gcs_address = bootstrap_address + logger.info("Connecting to existing Ray cluster at address: %s...", gcs_address) + + if local_mode: + driver_mode = LOCAL_MODE + warnings.warn( + "DeprecationWarning: local mode is an experimental feature that is no " + "longer maintained and will be removed in the future." + "For debugging consider using Ray debugger. ", + DeprecationWarning, + stacklevel=2, + ) + else: + driver_mode = SCRIPT_MODE + + global _global_node + + if global_worker.connected: + if ignore_reinit_error: + logger.info("Calling ray.init() again after it has already been called.") + node_id = global_worker.core_worker.get_current_node_id() + return RayContext(dict(_global_node.address_info, node_id=node_id.hex())) + else: + raise RuntimeError( + "Maybe you called ray.init twice by accident? " + "This error can be suppressed by passing in " + "'ignore_reinit_error=True' or by calling " + "'ray.shutdown()' prior to 'ray.init()'." + ) + + _system_config = _system_config or {} + if not isinstance(_system_config, dict): + raise TypeError("The _system_config must be a dict.") + + if bootstrap_address is None: + # In this case, we need to start a new cluster. + + # Don't collect usage stats in ray.init() unless it's a nightly wheel. + from ray._private.usage import usage_lib + + if usage_lib.is_nightly_wheel(): + usage_lib.show_usage_stats_prompt(cli=False) + else: + usage_lib.set_usage_stats_enabled_via_env_var(False) + + # Use a random port by not specifying Redis port / GCS server port. + ray_params = ray._private.parameter.RayParams( + node_ip_address=_node_ip_address, + object_ref_seed=None, + driver_mode=driver_mode, + redirect_output=None, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=resources, + labels=labels, + num_redis_shards=None, + redis_max_clients=None, + redis_username=_redis_username, + redis_password=_redis_password, + plasma_directory=_plasma_directory, + huge_pages=None, + include_dashboard=include_dashboard, + dashboard_host=dashboard_host, + dashboard_port=dashboard_port, + memory=_memory, + object_store_memory=object_store_memory, + redis_max_memory=_redis_max_memory, + plasma_store_socket_name=None, + temp_dir=_temp_dir, + storage=storage, + _system_config=_system_config, + enable_object_reconstruction=_enable_object_reconstruction, + metrics_export_port=_metrics_export_port, + tracing_startup_hook=_tracing_startup_hook, + node_name=_node_name, + ) + # Start the Ray processes. We set shutdown_at_exit=False because we + # shutdown the node in the ray.shutdown call that happens in the atexit + # handler. We still spawn a reaper process in case the atexit handler + # isn't called. + _global_node = ray._private.node.Node( + ray_params=ray_params, + head=True, + shutdown_at_exit=False, + spawn_reaper=True, + ray_init_cluster=True, + ) + else: + # In this case, we are connecting to an existing cluster. + if num_cpus is not None or num_gpus is not None: + raise ValueError( + "When connecting to an existing cluster, num_cpus " + "and num_gpus must not be provided." + ) + if resources is not None: + raise ValueError( + "When connecting to an existing cluster, " + "resources must not be provided." + ) + if labels is not None: + raise ValueError( + "When connecting to an existing cluster, " + "labels must not be provided." + ) + if object_store_memory is not None: + raise ValueError( + "When connecting to an existing cluster, " + "object_store_memory must not be provided." + ) + if storage is not None: + raise ValueError( + "When connecting to an existing cluster, " + "storage must not be provided." + ) + if _system_config is not None and len(_system_config) != 0: + raise ValueError( + "When connecting to an existing cluster, " + "_system_config must not be provided." + ) + if _enable_object_reconstruction: + raise ValueError( + "When connecting to an existing cluster, " + "_enable_object_reconstruction must not be provided." + ) + if _node_name is not None: + raise ValueError( + "_node_name cannot be configured when connecting to " + "an existing cluster." + ) + + # In this case, we only need to connect the node. + ray_params = ray._private.parameter.RayParams( + node_ip_address=_node_ip_address, + gcs_address=gcs_address, + redis_address=redis_address, + redis_username=_redis_username, + redis_password=_redis_password, + object_ref_seed=None, + temp_dir=_temp_dir, + _system_config=_system_config, + enable_object_reconstruction=_enable_object_reconstruction, + metrics_export_port=_metrics_export_port, + ) + try: + _global_node = ray._private.node.Node( + ray_params, + head=False, + shutdown_at_exit=False, + spawn_reaper=False, + connect_only=True, + ) + except (ConnectionError, RuntimeError): + if gcs_address == ray._private.utils.read_ray_address(_temp_dir): + logger.info( + "Failed to connect to the default Ray cluster address at " + f"{gcs_address}. This is most likely due to a previous Ray " + "instance that has since crashed. To reset the default " + "address to connect to, run `ray stop` or restart Ray with " + "`ray start`." + ) + raise ConnectionError + + # Log a message to find the Ray address that we connected to and the + # dashboard URL. + if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ: + dashboard_url = os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL) + else: + dashboard_url = _global_node.webui_url + # Add http protocol to dashboard URL if it doesn't + # already contain a protocol. + if dashboard_url and not urlparse(dashboard_url).scheme: + dashboard_url = "http://" + dashboard_url + + # We logged the address before attempting the connection, so we don't need + # to log it again. + info_str = "Connected to Ray cluster." + if gcs_address is None: + info_str = "Started a local Ray instance." + if dashboard_url: + logger.info( + info_str + " View the dashboard at %s%s%s %s%s", + colorama.Style.BRIGHT, + colorama.Fore.GREEN, + dashboard_url, + colorama.Fore.RESET, + colorama.Style.NORMAL, + ) + else: + logger.info(info_str) + + connect( + _global_node, + _global_node.session_name, + mode=driver_mode, + log_to_driver=log_to_driver, + worker=global_worker, + driver_object_store_memory=_driver_object_store_memory, + job_id=None, + namespace=namespace, + job_config=job_config, + entrypoint=ray._private.utils.get_entrypoint_name(), + ) + if job_config and job_config.code_search_path: + global_worker.set_load_code_from_local(True) + else: + # Because `ray.shutdown()` doesn't reset this flag, for multiple + # sessions in one process, the 2nd `ray.init()` will reuse the + # flag of last session. For example: + # ray.init(load_code_from_local=True) + # ray.shutdown() + # ray.init() + # # Here the flag `load_code_from_local` is still True if we + # # doesn't have this `else` branch. + # ray.shutdown() + global_worker.set_load_code_from_local(False) + + for hook in _post_init_hooks: + hook() + + node_id = global_worker.core_worker.get_current_node_id() + global_node_address_info = _global_node.address_info.copy() + global_node_address_info["webui_url"] = _remove_protocol_from_url(dashboard_url) + return RayContext(dict(global_node_address_info, node_id=node_id.hex())) + + +# Functions to run as callback after a successful ray init. +_post_init_hooks = [] + + +@PublicAPI +@client_mode_hook +def shutdown(_exiting_interpreter: bool = False): + """Disconnect the worker, and terminate processes started by ray.init(). + + This will automatically run at the end when a Python process that uses Ray + exits. It is ok to run this twice in a row. The primary use case for this + function is to cleanup state between tests. + + Note that this will clear any remote function definitions, actor + definitions, and existing actors, so if you wish to use any previously + defined remote functions or actors after calling ray.shutdown(), then you + need to redefine them. If they were defined in an imported module, then you + will need to reload the module. + + Args: + _exiting_interpreter: True if this is called by the atexit hook + and false otherwise. If we are exiting the interpreter, we will + wait a little while to print any extra error messages. + """ + # Make sure to clean up compiled dag node if exists. + from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags + + _shutdown_all_compiled_dags() + + if _exiting_interpreter and global_worker.mode == SCRIPT_MODE: + # This is a duration to sleep before shutting down everything in order + # to make sure that log messages finish printing. + time.sleep(0.5) + disconnect(_exiting_interpreter) + + # disconnect internal kv + if hasattr(global_worker, "gcs_client"): + del global_worker.gcs_client + _internal_kv_reset() + + # We need to destruct the core worker here because after this function, + # we will tear down any processes spawned by ray.init() and the background + # IO thread in the core worker doesn't currently handle that gracefully. + if hasattr(global_worker, "core_worker"): + if global_worker.mode == SCRIPT_MODE or global_worker.mode == LOCAL_MODE: + global_worker.core_worker.shutdown_driver() + del global_worker.core_worker + # We need to reset function actor manager to clear the context + global_worker.function_actor_manager = FunctionActorManager(global_worker) + # Disconnect global state from GCS. + ray._private.state.state.disconnect() + + # Shut down the Ray processes. + global _global_node + if _global_node is not None: + if _global_node.is_head(): + _global_node.destroy_external_storage() + _global_node.kill_all_processes(check_alive=False, allow_graceful=True) + _global_node = None + storage._reset() + + # TODO(rkn): Instead of manually resetting some of the worker fields, we + # should simply set "global_worker" to equal "None" or something like that. + global_worker.set_mode(None) + global_worker.set_cached_job_id(None) + + +atexit.register(shutdown, True) + +# Define a custom excepthook so that if the driver exits with an exception, we +# can push that exception to Redis. +normal_excepthook = sys.excepthook + + +def custom_excepthook(type, value, tb): + import ray.core.generated.common_pb2 as common_pb2 + + # If this is a driver, push the exception to GCS worker table. + if global_worker.mode == SCRIPT_MODE and hasattr(global_worker, "worker_id"): + error_message = "".join(traceback.format_tb(tb)) + worker_id = global_worker.worker_id + worker_type = common_pb2.DRIVER + worker_info = {"exception": error_message} + + ray._private.state.state._check_connected() + ray._private.state.state.add_worker(worker_id, worker_type, worker_info) + # Call the normal excepthook. + normal_excepthook(type, value, tb) + + +sys.excepthook = custom_excepthook + + +def print_to_stdstream(data, ignore_prefix: bool): + should_dedup = data.get("pid") not in ["autoscaler"] + + if data["is_err"]: + if should_dedup: + batches = stderr_deduplicator.deduplicate(data) + else: + batches = [data] + sink = sys.stderr + else: + if should_dedup: + batches = stdout_deduplicator.deduplicate(data) + else: + batches = [data] + sink = sys.stdout + + for batch in batches: + print_worker_logs(batch, sink, ignore_prefix) + + +# Start time of this process, used for relative time logs. +t0 = time.time() +autoscaler_log_fyi_printed = False + + +def filter_autoscaler_events(lines: List[str]) -> Iterator[str]: + """Given raw log lines from the monitor, return only autoscaler events. + + For Autoscaler V1: + Autoscaler events are denoted by the ":event_summary:" magic token. + For Autoscaler V2: + Autoscaler events are published from log_monitor.py which read + them from the `event_AUTOSCALER.log`. + """ + + if not ray_constants.AUTOSCALER_EVENTS: + return + + AUTOSCALER_LOG_FYI = ( + "Tip: use `ray status` to view detailed " + "cluster status. To disable these " + "messages, set RAY_SCHEDULER_EVENTS=0." + ) + + def autoscaler_log_fyi_needed() -> bool: + global autoscaler_log_fyi_printed + if not autoscaler_log_fyi_printed: + autoscaler_log_fyi_printed = True + return True + return False + + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + if is_autoscaler_v2(): + from ray._private.event.event_logger import parse_event, filter_event_by_level + + for event_line in lines: + if autoscaler_log_fyi_needed(): + yield AUTOSCALER_LOG_FYI + + event = parse_event(event_line) + if not event or not event.message: + continue + + if filter_event_by_level( + event, ray_constants.RAY_LOG_TO_DRIVER_EVENT_LEVEL + ): + continue + + yield event.message + else: + # Print out autoscaler events only, ignoring other messages. + for line in lines: + if ray_constants.LOG_PREFIX_EVENT_SUMMARY in line: + if autoscaler_log_fyi_needed(): + yield AUTOSCALER_LOG_FYI + # The event text immediately follows the ":event_summary:" + # magic token. + yield line.split(ray_constants.LOG_PREFIX_EVENT_SUMMARY)[1] + + +def time_string() -> str: + """Return the relative time from the start of this job. + + For example, 15m30s. + """ + delta = time.time() - t0 + hours = 0 + minutes = 0 + while delta > 3600: + hours += 1 + delta -= 3600 + while delta > 60: + minutes += 1 + delta -= 60 + output = "" + if hours: + output += f"{hours}h" + if minutes: + output += f"{minutes}m" + output += f"{int(delta)}s" + return output + + +# When we enter a breakpoint, worker logs are automatically disabled via this. +_worker_logs_enabled = True + + +def print_worker_logs( + data: Dict[str, str], print_file: Any, ignore_prefix: bool = False +): + if not _worker_logs_enabled: + return + + def prefix_for(data: Dict[str, str]) -> str: + """The PID prefix for this log line.""" + if data.get("pid") in ["autoscaler", "raylet"]: + return "" + else: + res = "pid=" + if data.get("actor_name"): + res = f"{data['actor_name']} {res}" + elif data.get("task_name"): + res = f"{data['task_name']} {res}" + return res + + def message_for(data: Dict[str, str], line: str) -> str: + """The printed message of this log line.""" + if ray_constants.LOG_PREFIX_INFO_MESSAGE in line: + return line.split(ray_constants.LOG_PREFIX_INFO_MESSAGE)[1] + return line + + def color_for(data: Dict[str, str], line: str) -> str: + """The color for this log line.""" + if ( + data.get("pid") == "raylet" + and ray_constants.LOG_PREFIX_INFO_MESSAGE not in line + ): + return colorama.Fore.YELLOW + elif data.get("pid") == "autoscaler": + if "Error:" in line or "Warning:" in line: + return colorama.Fore.YELLOW + else: + return colorama.Fore.CYAN + elif os.getenv("RAY_COLOR_PREFIX") == "1": + colors = [ + # colorama.Fore.BLUE, # Too dark + colorama.Fore.MAGENTA, + colorama.Fore.CYAN, + colorama.Fore.GREEN, + # colorama.Fore.WHITE, # Too light + # colorama.Fore.RED, + colorama.Fore.LIGHTBLACK_EX, + colorama.Fore.LIGHTBLUE_EX, + # colorama.Fore.LIGHTCYAN_EX, # Too light + # colorama.Fore.LIGHTGREEN_EX, # Too light + colorama.Fore.LIGHTMAGENTA_EX, + # colorama.Fore.LIGHTWHITE_EX, # Too light + # colorama.Fore.LIGHTYELLOW_EX, # Too light + ] + pid = data.get("pid", 0) + try: + i = int(pid) + except ValueError: + i = 0 + return colors[i % len(colors)] + else: + return colorama.Fore.CYAN + + if data.get("pid") == "autoscaler": + pid = "autoscaler +{}".format(time_string()) + lines = filter_autoscaler_events(data.get("lines", [])) + else: + pid = data.get("pid") + lines = data.get("lines", []) + + ip = data.get("ip") + ip_prefix = "" if ip == data.get("localhost") else f", ip={ip}" + for line in lines: + if RAY_TQDM_MAGIC in line: + process_tqdm(line) + else: + hide_tqdm() + # If RAY_COLOR_PREFIX=0, do not wrap with any color codes + if os.getenv("RAY_COLOR_PREFIX") == "0": + color_pre = "" + color_post = "" + else: + color_pre = color_for(data, line) + color_post = colorama.Style.RESET_ALL + + if ignore_prefix: + print( + f"{message_for(data, line)}", + file=print_file, + ) + else: + print( + f"{color_pre}({prefix_for(data)}{pid}{ip_prefix}){color_post} " + f"{message_for(data, line)}", + file=print_file, + ) + + # Restore once at end of batch to avoid excess hiding/unhiding of tqdm. + restore_tqdm() + + +def process_tqdm(line): + """Experimental distributed tqdm: see ray.experimental.tqdm_ray.""" + try: + data = json.loads(line) + tqdm_ray.instance().process_state_update(data) + except Exception: + if log_once("tqdm_corruption"): + logger.warning( + f"[tqdm_ray] Failed to decode {line}, this may be due to " + "logging too fast. This warning will not be printed again." + ) + + +def hide_tqdm(): + """Hide distributed tqdm bars temporarily to avoid conflicts with other logs.""" + tqdm_ray.instance().hide_bars() + + +def restore_tqdm(): + """Undo hide_tqdm().""" + tqdm_ray.instance().unhide_bars() + + +def listen_error_messages(worker, threads_stopped): + """Listen to error messages in the background on the driver. + + This runs in a separate thread on the driver and pushes (error, time) + tuples to be published. + + Args: + worker: The worker class that this thread belongs to. + threads_stopped (threading.Event): A threading event used to signal to + the thread that it should exit. + """ + + # TODO: we should just subscribe to the errors for this specific job. + worker.gcs_error_subscriber.subscribe() + + try: + if _internal_kv_initialized(): + # Get any autoscaler errors that occurred before the call to + # subscribe. + error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR) + if error_message is not None: + logger.warning(error_message.decode()) + while True: + # Exit if received a signal that the thread should stop. + if threads_stopped.is_set(): + return + + _, error_data = worker.gcs_error_subscriber.poll() + if error_data is None: + continue + if error_data["job_id"] not in [ + worker.current_job_id.binary(), + JobID.nil().binary(), + ]: + continue + + error_message = error_data["error_message"] + print_to_stdstream( + { + "lines": [error_message], + "pid": "raylet", + "is_err": False, + }, + ignore_prefix=False, + ) + except (OSError, ConnectionError) as e: + logger.error(f"listen_error_messages: {e}") + + +@PublicAPI +@client_mode_hook +def is_initialized() -> bool: + """Check if ray.init has been called yet. + + Returns: + True if ray.init has already been called and false otherwise. + """ + return ray._private.worker.global_worker.connected + + +def connect( + node, + session_name: str, + mode=WORKER_MODE, + log_to_driver: bool = False, + worker=global_worker, + driver_object_store_memory: Optional[int] = None, + job_id=None, + namespace: Optional[str] = None, + job_config=None, + runtime_env_hash: int = 0, + startup_token: int = 0, + ray_debugger_external: bool = False, + entrypoint: str = "", + worker_launch_time_ms: int = -1, + worker_launched_time_ms: int = -1, +): + """Connect this worker to the raylet, to Plasma, and to GCS. + + Args: + node (ray._private.node.Node): The node to connect. + session_name: The session name (cluster id) of this cluster. + mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. + log_to_driver: If true, then output from all of the worker + processes on all nodes will be directed to the driver. + worker: The ray.Worker instance. + driver_object_store_memory: Deprecated. + job_id: The ID of job. If it's None, then we will generate one. + namespace: Namespace to use. + job_config (ray.job_config.JobConfig): The job configuration. + runtime_env_hash: The hash of the runtime env for this worker. + startup_token: The startup token of the process assigned to + it during startup as a command line argument. + ray_debugger_external: If True, make the debugger external to the + node this worker is running on. + entrypoint: The name of the entrypoint script. Ignored if the + mode != SCRIPT_MODE + worker_launch_time_ms: The time when the worker process for this worker + is launched. If the worker is not launched by raylet (e.g., + driver), this must be -1 (default value). + worker_launched_time_ms: The time when the worker process for this worker + finshes launching. If the worker is not launched by raylet (e.g., + driver), this must be -1 (default value). + """ + # Do some basic checking to make sure we didn't call ray.init twice. + error_message = "Perhaps you called ray.init twice by accident?" + assert not worker.connected, error_message + + # Enable nice stack traces on SIGSEGV etc. + try: + if not faulthandler.is_enabled(): + faulthandler.enable(all_threads=False) + except io.UnsupportedOperation: + pass # ignore + + worker.gcs_client = node.get_gcs_client() + assert worker.gcs_client is not None + _initialize_internal_kv(worker.gcs_client) + ray._private.state.state._initialize_global_state( + ray._raylet.GcsClientOptions.create( + node.gcs_address, + node.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + ) + worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address) + # Initialize some fields. + if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE): + # We should not specify the job_id if it's `WORKER_MODE`. + assert job_id is None + job_id = JobID.nil() + else: + # This is the code path of driver mode. + if job_id is None: + job_id = ray._private.state.next_job_id() + + if mode is not SCRIPT_MODE and mode is not LOCAL_MODE and setproctitle: + process_name = ray_constants.WORKER_PROCESS_TYPE_IDLE_WORKER + if mode is SPILL_WORKER_MODE: + process_name = ray_constants.WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE + elif mode is RESTORE_WORKER_MODE: + process_name = ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE + setproctitle.setproctitle(process_name) + + if not isinstance(job_id, JobID): + raise TypeError("The type of given job id must be JobID.") + + # All workers start out as non-actors. A worker can be turned into an actor + # after it is created. + worker.node = node + worker.set_mode(mode) + + # For driver's check that the version information matches the version + # information that the Ray cluster was started with. + try: + node.check_version_info() + except Exception as e: + if mode == SCRIPT_MODE: + raise e + elif mode == WORKER_MODE: + traceback_str = traceback.format_exc() + ray._private.utils.publish_error_to_driver( + ray_constants.VERSION_MISMATCH_PUSH_ERROR, + traceback_str, + gcs_publisher=worker.gcs_publisher, + num_retries=1, + ) + + driver_name = "" + log_stdout_file_path = "" + log_stderr_file_path = "" + interactive_mode = False + if mode == SCRIPT_MODE: + import __main__ as main + + if hasattr(main, "__file__"): + driver_name = main.__file__ + else: + interactive_mode = True + driver_name = "INTERACTIVE MODE" + elif not LOCAL_MODE: + raise ValueError("Invalid worker mode. Expected DRIVER, WORKER or LOCAL.") + + gcs_options = ray._raylet.GcsClientOptions.create( + node.gcs_address, + node.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + if job_config is None: + job_config = ray.job_config.JobConfig() + + if namespace is not None: + ray._private.utils.validate_namespace(namespace) + + # The namespace field of job config may have already been set in code + # paths such as the client. + job_config.set_ray_namespace(namespace) + + # Make sure breakpoint() in the user's code will + # invoke the Ray debugger if we are in a worker or actor process + # (but not on the driver). + if mode == WORKER_MODE: + os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb.set_trace" + else: + # Add hook to suppress worker logs during breakpoint. + os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb._driver_set_trace" + + worker.ray_debugger_external = ray_debugger_external + + # If it's a driver and it's not coming from ray client, we'll prepare the + # environment here. If it's ray client, the environment will be prepared + # at the server side. + if mode == SCRIPT_MODE and not job_config._client_job and job_config.runtime_env: + scratch_dir: str = worker.node.get_runtime_env_dir_path() + runtime_env = job_config.runtime_env or {} + runtime_env = upload_py_modules_if_needed( + runtime_env, scratch_dir, logger=logger + ) + runtime_env = upload_working_dir_if_needed( + runtime_env, scratch_dir, logger=logger + ) + runtime_env = upload_worker_process_setup_hook_if_needed( + runtime_env, + worker, + ) + # Remove excludes, it isn't relevant after the upload step. + runtime_env.pop("excludes", None) + job_config.set_runtime_env(runtime_env) + + if mode == SCRIPT_MODE: + # Add the directory containing the script that is running to the Python + # paths of the workers. Also add the current directory. Note that this + # assumes that the directory structures on the machines in the clusters + # are the same. + # When using an interactive shell, there is no script directory. + # We also want to skip adding script directory when running from dashboard. + code_paths = [] + if not interactive_mode and not ( + namespace and namespace == ray_constants.RAY_INTERNAL_DASHBOARD_NAMESPACE + ): + script_directory = os.path.dirname(os.path.realpath(sys.argv[0])) + # If driver's sys.path doesn't include the script directory + # (e.g driver is started via `python -m`, + # see https://peps.python.org/pep-0338/), + # then we shouldn't add it to the workers. + if script_directory in sys.path: + code_paths.append(script_directory) + # In client mode, if we use runtime envs with "working_dir", then + # it'll be handled automatically. Otherwise, add the current dir. + if not job_config._client_job and not job_config._runtime_env_has_working_dir(): + current_directory = os.path.abspath(os.path.curdir) + code_paths.append(current_directory) + if len(code_paths) != 0: + job_config._py_driver_sys_path.extend(code_paths) + + serialized_job_config = job_config._serialize() + if not node.should_redirect_logs(): + # Logging to stderr, so give core worker empty logs directory. + logs_dir = "" + else: + logs_dir = node.get_logs_dir_path() + + worker.core_worker = ray._raylet.CoreWorker( + mode, + node.plasma_store_socket_name, + node.raylet_socket_name, + job_id, + gcs_options, + logs_dir, + node.node_ip_address, + node.node_manager_port, + node.raylet_ip_address, + (mode == LOCAL_MODE), + driver_name, + log_stdout_file_path, + log_stderr_file_path, + serialized_job_config, + node.metrics_agent_port, + runtime_env_hash, + startup_token, + session_name, + node.cluster_id.hex(), + "" if mode != SCRIPT_MODE else entrypoint, + worker_launch_time_ms, + worker_launched_time_ms, + ) + + if mode == SCRIPT_MODE: + worker_id = worker.worker_id + worker.gcs_error_subscriber = ray._raylet.GcsErrorSubscriber( + worker_id=worker_id, address=worker.gcs_client.address + ) + worker.gcs_log_subscriber = ray._raylet.GcsLogSubscriber( + worker_id=worker_id, address=worker.gcs_client.address + ) + + if driver_object_store_memory is not None: + logger.warning( + "`driver_object_store_memory` is deprecated" + " and will be removed in the future." + ) + + # If this is a driver running in SCRIPT_MODE, start a thread to print error + # messages asynchronously in the background. Ideally the scheduler would + # push messages to the driver's worker service, but we ran into bugs when + # trying to properly shutdown the driver's worker service, so we are + # temporarily using this implementation which constantly queries the + # scheduler for new error messages. + if mode == SCRIPT_MODE: + worker.listener_thread = threading.Thread( + target=listen_error_messages, + name="ray_listen_error_messages", + args=(worker, worker.threads_stopped), + ) + worker.listener_thread.daemon = True + worker.listener_thread.start() + # If the job's logging config is set, don't add the prefix + # (task/actor's name and its PID) to the logs. + ignore_prefix = global_worker.job_logging_config is not None + + if log_to_driver: + global_worker_stdstream_dispatcher.add_handler( + "ray_print_logs", + functools.partial(print_to_stdstream, ignore_prefix=ignore_prefix), + ) + worker.logger_thread = threading.Thread( + target=worker.print_logs, name="ray_print_logs" + ) + worker.logger_thread.daemon = True + worker.logger_thread.start() + + # Setup tracing here + tracing_hook_val = worker.gcs_client.internal_kv_get( + b"tracing_startup_hook", ray_constants.KV_NAMESPACE_TRACING + ) + if tracing_hook_val is not None: + ray.util.tracing.tracing_helper._enable_tracing() + if not getattr(ray, "__traced__", False): + _setup_tracing = _import_from_string(tracing_hook_val.decode("utf-8")) + _setup_tracing() + ray.__traced__ = True + + # Mark the worker as connected. + worker.set_is_connected(True) + + +def disconnect(exiting_interpreter=False): + """Disconnect this worker from the raylet and object store.""" + # Reset the list of cached remote functions and actors so that if more + # remote functions or actors are defined and then connect is called again, + # the remote functions will be exported. This is mostly relevant for the + # tests. + worker = global_worker + if worker.connected: + # Shutdown all of the threads that we've started. TODO(rkn): This + # should be handled cleanly in the worker object's destructor and not + # in this disconnect method. + worker.threads_stopped.set() + if hasattr(worker, "gcs_error_subscriber"): + worker.gcs_error_subscriber.close() + if hasattr(worker, "gcs_log_subscriber"): + worker.gcs_log_subscriber.close() + if hasattr(worker, "listener_thread"): + worker.listener_thread.join() + if hasattr(worker, "logger_thread"): + worker.logger_thread.join() + worker.threads_stopped.clear() + + # Ignore the prefix if the logging config is set. + ignore_prefix = worker.job_logging_config is not None + for leftover in stdout_deduplicator.flush(): + print_worker_logs(leftover, sys.stdout, ignore_prefix) + for leftover in stderr_deduplicator.flush(): + print_worker_logs(leftover, sys.stderr, ignore_prefix) + global_worker_stdstream_dispatcher.remove_handler("ray_print_logs") + + worker.node = None # Disconnect the worker from the node. + worker.serialization_context_map.clear() + try: + ray_actor = ray.actor + except AttributeError: + ray_actor = None # This can occur during program termination + if ray_actor is not None: + ray_actor._ActorClassMethodMetadata.reset_cache() + + # Mark the worker as disconnected. + worker.set_is_connected(False) + + +@contextmanager +def _changeproctitle(title, next_title): + if _mode() is not LOCAL_MODE: + setproctitle.setproctitle(title) + try: + yield + finally: + if _mode() is not LOCAL_MODE: + setproctitle.setproctitle(next_title) + + +@DeveloperAPI +def show_in_dashboard(message: str, key: str = "", dtype: str = "text"): + """Display message in dashboard. + + Display message for the current task or actor in the dashboard. + For example, this can be used to display the status of a long-running + computation. + + Args: + message: Message to be displayed. + key: The key name for the message. Multiple message under + different keys will be displayed at the same time. Messages + under the same key will be overridden. + dtype: The type of message for rendering. One of the + following: text, html. + """ + worker = global_worker + worker.check_connected() + + acceptable_dtypes = {"text", "html"} + assert dtype in acceptable_dtypes, f"dtype accepts only: {acceptable_dtypes}" + + message_wrapped = {"message": message, "dtype": dtype} + message_encoded = json.dumps(message_wrapped).encode() + + worker.core_worker.set_webui_display(key.encode(), message_encoded) + + +# Global variable to make sure we only send out the warning once. +blocking_get_inside_async_warned = False + + +@overload +def get( + object_refs: "Sequence[ObjectRef[Any]]", *, timeout: Optional[float] = None +) -> List[Any]: + ... + + +@overload +def get( + object_refs: "Sequence[ObjectRef[R]]", *, timeout: Optional[float] = None +) -> List[R]: + ... + + +@overload +def get(object_refs: "ObjectRef[R]", *, timeout: Optional[float] = None) -> R: + ... + + +@overload +def get( + object_refs: Sequence[CompiledDAGRef], *, timeout: Optional[float] = None +) -> List[Any]: + ... + + +@overload +def get(object_refs: CompiledDAGRef, *, timeout: Optional[float] = None) -> Any: + ... + + +@PublicAPI +@client_mode_hook +def get( + object_refs: Union[ + "ObjectRef[Any]", + Sequence["ObjectRef[Any]"], + CompiledDAGRef, + Sequence[CompiledDAGRef], + ], + *, + timeout: Optional[float] = None, +) -> Union[Any, List[Any]]: + """Get a remote object or a list of remote objects from the object store. + + This method blocks until the object corresponding to the object ref is + available in the local object store. If this object is not in the local + object store, it will be shipped from an object store that has it (once the + object has been created). If object_refs is a list, then the objects + corresponding to each object in the list will be returned. + + Ordering for an input list of object refs is preserved for each object + returned. That is, if an object ref to A precedes an object ref to B in the + input list, then A will precede B in the returned list. + + This method will issue a warning if it's running inside async context, + you can use ``await object_ref`` instead of ``ray.get(object_ref)``. For + a list of object refs, you can use ``await asyncio.gather(*object_refs)``. + + Passing :class:`~ObjectRefGenerator` is not allowed. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/ray-get-loop` + - :doc:`/ray-core/patterns/unnecessary-ray-get` + - :doc:`/ray-core/patterns/ray-get-submission-order` + - :doc:`/ray-core/patterns/ray-get-too-many-objects` + + + Args: + object_refs: Object ref of the object to get or a list of object refs + to get. + timeout (Optional[float]): The maximum amount of time in seconds to + wait before returning. Set this to None will block until the + corresponding object becomes available. Setting ``timeout=0`` will + return the object immediately if it's available, else raise + GetTimeoutError in accordance with the above docstring. + + Returns: + A Python object or a list of Python objects. + + Raises: + GetTimeoutError: A GetTimeoutError is raised if a timeout is set and + the get takes longer than timeout to return. + Exception: An exception is raised immediately if any task that created + the object or that created one of the objects raised an exception, + without waiting for the remaining ones to finish. + """ + worker = global_worker + worker.check_connected() + + if hasattr(worker, "core_worker") and worker.core_worker.current_actor_is_asyncio(): + global blocking_get_inside_async_warned + if not blocking_get_inside_async_warned: + logger.warning( + "Using blocking ray.get inside async actor. " + "This blocks the event loop. Please use `await` " + "on object ref with asyncio.gather if you want to " + "yield execution to the event loop instead." + ) + blocking_get_inside_async_warned = True + + with profiling.profile("ray.get"): + # TODO(sang): Should make ObjectRefGenerator + # compatible to ray.get for dataset. + if isinstance(object_refs, ObjectRefGenerator): + return object_refs + + if isinstance(object_refs, CompiledDAGRef): + return object_refs.get(timeout=timeout) + + if isinstance(object_refs, list): + all_compiled_dag_refs = True + any_compiled_dag_refs = False + for object_ref in object_refs: + is_dag_ref = isinstance(object_ref, CompiledDAGRef) + all_compiled_dag_refs = all_compiled_dag_refs and is_dag_ref + any_compiled_dag_refs = any_compiled_dag_refs or is_dag_ref + if all_compiled_dag_refs: + return [object_ref.get(timeout=timeout) for object_ref in object_refs] + elif any_compiled_dag_refs: + raise ValueError( + "Invalid type of object refs. 'object_refs' must be a list of " + "CompiledDAGRefs if there is any CompiledDAGRef within it. " + ) + + is_individual_id = isinstance(object_refs, ray.ObjectRef) + if is_individual_id: + object_refs = [object_refs] + + if not isinstance(object_refs, list): + raise ValueError( + f"Invalid type of object refs, {type(object_refs)}, is given. " + "'object_refs' must either be an ObjectRef or a list of ObjectRefs. " + ) + + # TODO(ujvl): Consider how to allow user to retrieve the ready objects. + values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) + for i, value in enumerate(values): + if isinstance(value, RayError): + if isinstance(value, ray.exceptions.ObjectLostError): + worker.core_worker.dump_object_store_memory_usage() + if isinstance(value, RayTaskError): + raise value.as_instanceof_cause() + else: + raise value + + if is_individual_id: + values = values[0] + + if debugger_breakpoint != b"": + frame = sys._getframe().f_back + rdb = ray.util.pdb._connect_ray_pdb( + host=None, + port=None, + patch_stdstreams=False, + quiet=None, + breakpoint_uuid=( + debugger_breakpoint.decode() if debugger_breakpoint else None + ), + debugger_external=worker.ray_debugger_external, + ) + rdb.set_trace(frame=frame) + + return values + + +@PublicAPI +@client_mode_hook +def put( + value: Any, + *, + _owner: Optional["ray.actor.ActorHandle"] = None, +) -> "ray.ObjectRef": + """Store an object in the object store. + + The object may not be evicted while a reference to the returned ID exists. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/return-ray-put` + - :doc:`/ray-core/patterns/pass-large-arg-by-value` + - :doc:`/ray-core/patterns/closure-capture-large-objects` + + Args: + value: The Python object to be stored. + _owner [Experimental]: The actor that should own this object. This + allows creating objects with lifetimes decoupled from that of the + creating process. The owner actor must be passed a reference to the + object prior to the object creator exiting, otherwise the reference + will still be lost. *Note that this argument is an experimental API + and should be avoided if possible.* + + Returns: + The object ref assigned to this value. + """ + worker = global_worker + worker.check_connected() + + if _owner is None: + serialize_owner_address = None + elif isinstance(_owner, ray.actor.ActorHandle): + # Ensure `ray._private.state.state.global_state_accessor` is not None + ray._private.state.state._check_connected() + serialize_owner_address = ( + ray._raylet._get_actor_serialized_owner_address_or_none( + ray._private.state.state.global_state_accessor.get_actor_info( + _owner._actor_id + ) + ) + ) + if not serialize_owner_address: + raise RuntimeError(f"{_owner} is not alive, it's worker_id is empty!") + else: + raise TypeError(f"Expect an `ray.actor.ActorHandle`, but got: {type(_owner)}") + + with profiling.profile("ray.put"): + try: + object_ref = worker.put_object(value, owner_address=serialize_owner_address) + except ObjectStoreFullError: + logger.info( + "Put failed since the value was either too large or the " + "store was full of pinned objects." + ) + raise + return object_ref + + +# Global variable to make sure we only send out the warning once. +blocking_wait_inside_async_warned = False + + +@PublicAPI +@client_mode_hook +def wait( + ray_waitables: List[Union[ObjectRef, ObjectRefGenerator]], + *, + num_returns: int = 1, + timeout: Optional[float] = None, + fetch_local: bool = True, +) -> Tuple[ + List[Union[ObjectRef, ObjectRefGenerator]], + List[Union[ObjectRef, ObjectRefGenerator]], +]: + """Return a list of IDs that are ready and a list of IDs that are not. + + If timeout is set, the function returns either when the requested number of + IDs are ready or when the timeout is reached, whichever occurs first. If it + is not set, the function simply waits until that number of objects is ready + and returns that exact number of object refs. + + `ray_waitables` is a list of :class:`~ray.ObjectRef` and + :class:`~ray.ObjectRefGenerator`. + + The method returns two lists, ready and unready `ray_waitables`. + + ObjectRef: + object refs that correspond to objects that are available + in the object store are in the first list. + The rest of the object refs are in the second list. + + ObjectRefGenerator: + Generators whose next reference (that will be obtained + via `next(generator)`) has a corresponding object available + in the object store are in the first list. + All other generators are placed in the second list. + + Ordering of the input list of ray_waitables is preserved. That is, if A + precedes B in the input list, and both are in the ready list, then A will + precede B in the ready list. This also holds true if A and B are both in + the remaining list. + + This method will issue a warning if it's running inside an async context. + Instead of ``ray.wait(ray_waitables)``, you can use + ``await asyncio.wait(ray_waitables)``. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/limit-pending-tasks` + - :doc:`/ray-core/patterns/ray-get-submission-order` + + Args: + ray_waitables: List of :class:`~ObjectRef` or + :class:`~ObjectRefGenerator` for objects that may or may + not be ready. Note that these must be unique. + num_returns: The number of ray_waitables that should be returned. + timeout: The maximum amount of time in seconds to wait before + returning. + fetch_local: If True, wait for the object to be downloaded onto + the local node before returning it as ready. If the `ray_waitable` + is a generator, it will wait until the next object in the generator + is downloaed. If False, ray.wait() will not trigger fetching of + objects to the local node and will return immediately once the + object is available anywhere in the cluster. + + Returns: + A list of object refs that are ready and a list of the remaining object + IDs. + """ + worker = global_worker + worker.check_connected() + + if ( + hasattr(worker, "core_worker") + and worker.core_worker.current_actor_is_asyncio() + and timeout != 0 + ): + global blocking_wait_inside_async_warned + if not blocking_wait_inside_async_warned: + logger.debug( + "Using blocking ray.wait inside async method. " + "This blocks the event loop. Please use `await` " + "on object ref with asyncio.wait. " + ) + blocking_wait_inside_async_warned = True + + if isinstance(ray_waitables, ObjectRef) or isinstance( + ray_waitables, ObjectRefGenerator + ): + raise TypeError( + "wait() expected a list of ray.ObjectRef or ray.ObjectRefGenerator" + ", got a single ray.ObjectRef or ray.ObjectRefGenerator " + f"{ray_waitables}" + ) + + if not isinstance(ray_waitables, list): + raise TypeError( + "wait() expected a list of ray.ObjectRef or " + "ray.ObjectRefGenerator, " + f"got {type(ray_waitables)}" + ) + + if timeout is not None and timeout < 0: + raise ValueError( + "The 'timeout' argument must be nonnegative. " f"Received {timeout}" + ) + + for ray_waitable in ray_waitables: + if not isinstance(ray_waitable, ObjectRef) and not isinstance( + ray_waitable, ObjectRefGenerator + ): + raise TypeError( + "wait() expected a list of ray.ObjectRef or " + "ray.ObjectRefGenerator, " + f"got list containing {type(ray_waitable)}" + ) + worker.check_connected() + + # TODO(swang): Check main thread. + with profiling.profile("ray.wait"): + # TODO(rkn): This is a temporary workaround for + # https://github.com/ray-project/ray/issues/997. However, it should be + # fixed in Arrow instead of here. + if len(ray_waitables) == 0: + return [], [] + + if len(ray_waitables) != len(set(ray_waitables)): + raise ValueError("Wait requires a list of unique ray_waitables.") + if num_returns <= 0: + raise ValueError("Invalid number of objects to return %d." % num_returns) + if num_returns > len(ray_waitables): + raise ValueError( + "num_returns cannot be greater than the number " + "of ray_waitables provided to ray.wait." + ) + + timeout = timeout if timeout is not None else 10**6 + timeout_milliseconds = int(timeout * 1000) + ready_ids, remaining_ids = worker.core_worker.wait( + ray_waitables, + num_returns, + timeout_milliseconds, + fetch_local, + ) + return ready_ids, remaining_ids + + +@PublicAPI +@client_mode_hook +def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle": + """Get a handle to a named actor. + + Gets a handle to an actor with the given name. The actor must + have been created with Actor.options(name="name").remote(). This + works for both detached & non-detached actors. + + This method is a sync call and it'll timeout after 60s. This can be modified + by setting OS env RAY_gcs_server_request_timeout_seconds before starting + the cluster. + + Args: + name: The name of the actor. + namespace: The namespace of the actor, or None to specify the current + namespace. + + Returns: + ActorHandle to the actor. + + Raises: + ValueError: if the named actor does not exist. + """ + if not name: + raise ValueError("Please supply a non-empty value to get_actor") + + if namespace is not None: + ray._private.utils.validate_namespace(namespace) + + worker = global_worker + worker.check_connected() + return worker.core_worker.get_named_actor_handle(name, namespace or "") + + +@PublicAPI +@client_mode_hook +def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): + """Kill an actor forcefully. + + This will interrupt any running tasks on the actor, causing them to fail + immediately. ``atexit`` handlers installed in the actor will not be run. + + If you want to kill the actor but let pending tasks finish, + you can call ``actor.__ray_terminate__.remote()`` instead to queue a + termination task. Any ``atexit`` handlers installed in the actor *will* + be run in this case. + + If the actor is a detached actor, subsequent calls to get its handle via + ray.get_actor will fail. + + Args: + actor: Handle to the actor to kill. + no_restart: Whether or not this actor should be restarted if + it's a restartable actor. + """ + worker = global_worker + worker.check_connected() + if not isinstance(actor, ray.actor.ActorHandle): + raise ValueError( + "ray.kill() only supported for actors. For tasks, try ray.cancel(). " + "Got: {}.".format(type(actor)) + ) + worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) + + +@PublicAPI +@client_mode_hook +def cancel( + ray_waitable: Union["ObjectRef[R]", "ObjectRefGenerator[R]"], + *, + force: bool = False, + recursive: bool = True, +) -> None: + """Cancels a task. + + Cancel API has a different behavior depending on if it is a remote function + (Task) or a remote Actor method (Actor Task). + + Task: + If the specified Task is pending execution, it is cancelled and not + executed. If the Task is currently executing, the behavior depends + on the `force` flag. When `force=False`, a KeyboardInterrupt is + raised in Python and when `force=True`, the executing Task + immediately exits. If the Task is already finished, nothing happens. + + Cancelled Tasks aren't retried. `max_task_retries` aren't respected. + + Calling ray.get on a cancelled Task raises a TaskCancelledError + if the Task has been scheduled or interrupted. + It raises a WorkerCrashedError if `force=True`. + + If `recursive=True`, all the child Tasks and Actor Tasks + are cancelled. If `force=True` and `recursive=True`, `force=True` + is ignored for child Actor Tasks. + + Actor Task: + If the specified Task is pending execution, it is cancelled and not + executed. If the Task is currently executing, the behavior depends + on the execution model of an Actor. If it is a regular Actor + or a threaded Actor, the execution isn't cancelled. + Actor Tasks cannot be interrupted because Actors have + states. If it is an async Actor, Ray cancels a `asyncio.Task`. + The semantic of cancellation is equivalent to asyncio's cancellation. + https://docs.python.org/3/library/asyncio-task.html#task-cancellation + If the Task has finished, nothing happens. + + Only `force=False` is allowed for an Actor Task. Otherwise, it raises + `ValueError`. Use `ray.kill(actor)` instead to kill an Actor. + + Cancelled Tasks aren't retried. `max_task_retries` aren't respected. + + Calling ray.get on a cancelled Task raises a TaskCancelledError + if the Task has been scheduled or interrupted. Also note that + only async actor tasks can be interrupted. + + If `recursive=True`, all the child Tasks and actor Tasks + are cancelled. + + Args: + ray_waitable: :class:`~ObjectRef` and + :class:`~ObjectRefGenerator` + returned by the task that should be canceled. + force: Whether to force-kill a running task by killing + the worker that is running the task. + recursive: Whether to try to cancel tasks submitted by the + task specified. + """ + worker = ray._private.worker.global_worker + worker.check_connected() + + if isinstance(ray_waitable, ray._raylet.ObjectRefGenerator): + assert hasattr(ray_waitable, "_generator_ref") + ray_waitable = ray_waitable._generator_ref + + if not isinstance(ray_waitable, ray.ObjectRef): + raise TypeError( + "ray.cancel() only supported for object refs. " + f"For actors, try ray.kill(). Got: {type(ray_waitable)}." + ) + return worker.core_worker.cancel_task(ray_waitable, force, recursive) + + +def _mode(worker=global_worker): + """This is a wrapper around worker.mode. + + We use this wrapper so that in the remote decorator, we can call _mode() + instead of worker.mode. The difference is that when we attempt to + serialize remote functions, we don't attempt to serialize the worker + object, which cannot be serialized. + """ + return worker.mode + + +def _make_remote(function_or_class, options): + if not function_or_class.__module__: + function_or_class.__module__ = "global" + + if inspect.isfunction(function_or_class) or is_cython(function_or_class): + ray_option_utils.validate_task_options(options, in_options=False) + return ray.remote_function.RemoteFunction( + Language.PYTHON, + function_or_class, + None, + options, + ) + + if inspect.isclass(function_or_class): + ray_option_utils.validate_actor_options(options, in_options=False) + return ray.actor._make_actor(function_or_class, options) + + raise TypeError( + "The @ray.remote decorator must be applied to either a function or a class." + ) + + +class RemoteDecorator(Protocol): + @overload + def __call__(self, __function: Callable[[], R]) -> RemoteFunctionNoArgs[R]: + ... + + @overload + def __call__(self, __function: Callable[[T0], R]) -> RemoteFunction0[R, T0]: + ... + + @overload + def __call__(self, __function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2], R] + ) -> RemoteFunction2[R, T0, T1, T2]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3], R] + ) -> RemoteFunction3[R, T0, T1, T2, T3]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4], R] + ) -> RemoteFunction4[R, T0, T1, T2, T3, T4]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5], R] + ) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R] + ) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R] + ) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] + ) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] + ) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: + ... + + # Pass on typing actors for now. The following makes it so no type errors + # are generated for actors. + @overload + def __call__(self, __t: type) -> Any: + ... + + +@overload +def remote(__function: Callable[[], R]) -> RemoteFunctionNoArgs[R]: + ... + + +@overload +def remote(__function: Callable[[T0], R]) -> RemoteFunction0[R, T0]: + ... + + +@overload +def remote(__function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]: + ... + + +@overload +def remote(__function: Callable[[T0, T1, T2], R]) -> RemoteFunction2[R, T0, T1, T2]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3], R] +) -> RemoteFunction3[R, T0, T1, T2, T3]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4], R] +) -> RemoteFunction4[R, T0, T1, T2, T3, T4]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5], R] +) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R] +) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R] +) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] +) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] +) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: + ... + + +# Pass on typing actors for now. The following makes it so no type errors +# are generated for actors. +@overload +def remote(__t: type) -> Any: + ... + + +# Passing options +@overload +def remote( + *, + num_returns: Union[int, Literal["streaming"]] = Undefined, + num_cpus: Union[int, float] = Undefined, + num_gpus: Union[int, float] = Undefined, + resources: Dict[str, float] = Undefined, + accelerator_type: str = Undefined, + memory: Union[int, float] = Undefined, + max_calls: int = Undefined, + max_restarts: int = Undefined, + max_task_retries: int = Undefined, + max_retries: int = Undefined, + runtime_env: Dict[str, Any] = Undefined, + retry_exceptions: bool = Undefined, + scheduling_strategy: Union[ + None, Literal["DEFAULT"], Literal["SPREAD"], PlacementGroupSchedulingStrategy + ] = Undefined, +) -> RemoteDecorator: + ... + + +@PublicAPI +def remote( + *args, **kwargs +) -> Union[ray.remote_function.RemoteFunction, ray.actor.ActorClass]: + """Defines a remote function or an actor class. + + This function can be used as a decorator with no arguments + to define a remote function or actor as follows: + + .. testcode:: + + import ray + + @ray.remote + def f(a, b, c): + return a + b + c + + object_ref = f.remote(1, 2, 3) + result = ray.get(object_ref) + assert result == (1 + 2 + 3) + + @ray.remote + class Foo: + def __init__(self, arg): + self.x = arg + + def method(self, a): + return self.x + a + + actor_handle = Foo.remote(123) + object_ref = actor_handle.method.remote(321) + result = ray.get(object_ref) + assert result == (123 + 321) + + Equivalently, use a function call to create a remote function or actor. + + .. testcode:: + + def g(a, b, c): + return a + b + c + + remote_g = ray.remote(g) + object_ref = remote_g.remote(1, 2, 3) + assert ray.get(object_ref) == (1 + 2 + 3) + + class Bar: + def __init__(self, arg): + self.x = arg + + def method(self, a): + return self.x + a + + RemoteBar = ray.remote(Bar) + actor_handle = RemoteBar.remote(123) + object_ref = actor_handle.method.remote(321) + result = ray.get(object_ref) + assert result == (123 + 321) + + + It can also be used with specific keyword arguments as follows: + + .. testcode:: + + @ray.remote(num_gpus=1, max_calls=1, num_returns=2) + def f(): + return 1, 2 + + @ray.remote(num_cpus=2, resources={"CustomResource": 1}) + class Foo: + def method(self): + return 1 + + Remote task and actor objects returned by @ray.remote can also be + dynamically modified with the same arguments as above using + ``.options()`` as follows: + + .. testcode:: + :hide: + + ray.shutdown() + + ray.init(num_cpus=5, num_gpus=5) + + .. testcode:: + + @ray.remote(num_gpus=1, max_calls=1, num_returns=2) + def f(): + return 1, 2 + + f_with_2_gpus = f.options(num_gpus=2) + object_refs = f_with_2_gpus.remote() + assert ray.get(object_refs) == [1, 2] + + @ray.remote(num_cpus=2, resources={"CustomResource": 1}) + class Foo: + def method(self): + return 1 + + Foo_with_no_resources = Foo.options(num_cpus=1, resources=None) + foo_actor = Foo_with_no_resources.remote() + assert ray.get(foo_actor.method.remote()) == 1 + + + A remote actor will be terminated when all actor handle to it + in Python is deleted, which will cause them to complete any outstanding + work and then shut down. If you only have 1 reference to an actor handle, + calling ``del actor`` *could* trigger actor deletion. Note that your program + may have multiple references to the same ActorHandle, and actor termination + will not occur until the reference count goes to 0. See the Python + documentation for more context about object deletion. + https://docs.python.org/3.9/reference/datamodel.html#object.__del__ + + If you want to kill actors immediately, you can also call ``ray.kill(actor)``. + + .. tip:: + Avoid repeatedly passing in large arguments to remote task or method calls. + + Instead, use ray.put to create a copy of the object in the object store. + + See :ref:`more info here `. + + Args: + num_returns: This is only for *remote functions*. It specifies + the number of object refs returned by the remote function + invocation. The default value is 1. + Pass "dynamic" to allow the task to decide how many + return values to return during execution, and the caller will + receive an ObjectRef[DynamicObjectRefGenerator]. + See :ref:`dynamic generators ` for more details. + num_cpus: The quantity of CPU resources to reserve + for this task or for the lifetime of the actor. + By default, tasks use 1 CPU resource and actors use 1 CPU + for scheduling and 0 CPU for running + (This means, by default, actors cannot get scheduled on a zero-cpu node, + but an infinite number of them can run on any non-zero cpu node. + The default value for actors was chosen for historical reasons. + It’s recommended to always explicitly set num_cpus for actors + to avoid any surprises. + If resources are specified explicitly, + they are required for both scheduling and running.) + See :ref:`specifying resource requirements ` + for more details. + num_gpus: The quantity of GPU resources to reserve + for this task or for the lifetime of the actor. + The default value is 0. + See :ref:`Ray GPU support ` for more details. + resources (Dict[str, float]): The quantity of various + :ref:`custom resources ` + to reserve for this task or for the lifetime of the actor. + This is a dictionary mapping strings (resource names) to floats. + By default it is empty. + accelerator_type: If specified, requires that the task or actor run + on a node with the specified type of accelerator. + See :ref:`accelerator types `. + memory: The heap memory request in bytes for this task/actor, + rounded down to the nearest integer. + max_calls: Only for *remote functions*. This specifies the + maximum number of times that a given worker can execute + the given remote function before it must exit + (this can be used to address :ref:`memory leaks ` in third-party + libraries or to reclaim resources that cannot easily be + released, e.g., GPU memory that was acquired by TensorFlow). + By default this is infinite for CPU tasks and 1 for GPU tasks + (to force GPU tasks to release resources after finishing). + max_restarts: Only for *actors*. This specifies the maximum + number of times that the actor should be restarted when it dies + unexpectedly. The minimum valid value is 0 (default), + which indicates that the actor doesn't need to be restarted. + A value of -1 indicates that an actor should be restarted + indefinitely. + See :ref:`actor fault tolerance ` for more details. + max_task_retries: Only for *actors*. How many times to + retry an actor task if the task fails due to a system error, + e.g., the actor has died. If set to -1, the system will + retry the failed task until the task succeeds, or the actor + has reached its max_restarts limit. If set to `n > 0`, the + system will retry the failed task up to n times, after which the + task will throw a `RayActorError` exception upon :obj:`ray.get`. + Note that Python exceptions are not considered system errors + and will not trigger retries. + The default value is 0. + See :ref:`actor fault tolerance ` for more details. + max_retries: Only for *remote functions*. This specifies + the maximum number of times that the remote function + should be rerun when the worker process executing it + crashes unexpectedly. The minimum valid value is 0, + the default value is 3, and a value of -1 indicates + infinite retries. + See :ref:`task fault tolerance ` for more details. + runtime_env (Dict[str, Any]): Specifies the runtime environment for + this actor or task and its children. See + :ref:`runtime-environments` for detailed documentation. + retry_exceptions: Only for *remote functions*. This specifies whether + application-level errors should be retried up to max_retries times. + This can be a boolean or a list of exceptions that should be retried. + See :ref:`task fault tolerance ` for more details. + scheduling_strategy: Strategy about how to + schedule a remote function or actor. Possible values are + None: ray will figure out the scheduling strategy to use, it + will either be the PlacementGroupSchedulingStrategy using parent's + placement group if parent has one and has + placement_group_capture_child_tasks set to true, + or "DEFAULT"; + "DEFAULT": default hybrid scheduling; + "SPREAD": best effort spread scheduling; + `PlacementGroupSchedulingStrategy`: + placement group based scheduling; + `NodeAffinitySchedulingStrategy`: + node id based affinity scheduling. + See :ref:`Ray scheduling strategies ` + for more details. + _metadata: Extended options for Ray libraries. For example, + _metadata={"workflows.io/options": } for Ray workflows. + _labels: The key-value labels of a task or actor. + """ + # "callable" returns true for both function and class. + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + # "args[0]" is the class or function under the decorator. + return _make_remote(args[0], {}) + assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string + return functools.partial(_make_remote, options=kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_operation_future.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_operation_future.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..635a3983c9441d62bffc4ebfc09f40b6af886fe8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/dag_operation_future.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/input_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/input_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b7085da556cd7418de38ad065e3c59876f32eb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dag/__pycache__/input_node.cpython-311.pyc differ