diff --git a/.gitattributes b/.gitattributes index 8ca6d23a54b6f255a9d4b53441db1226771b6152..834fddae6e7a4964c1b611810d924164907770dd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -150,3 +150,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/METADATA b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..3ac05cfd1077ba5664e98ecd1342f7c54360b936 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/METADATA @@ -0,0 +1,295 @@ +Metadata-Version: 2.3 +Name: annotated-types +Version: 0.7.0 +Summary: Reusable constraint types to use with typing.Annotated +Project-URL: Homepage, https://github.com/annotated-types/annotated-types +Project-URL: Source, https://github.com/annotated-types/annotated-types +Project-URL: Changelog, https://github.com/annotated-types/annotated-types/releases +Author-email: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>, Samuel Colvin , Zac Hatfield-Dodds +License-File: LICENSE +Classifier: Development Status :: 4 - Beta +Classifier: Environment :: Console +Classifier: Environment :: MacOS X +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Information Technology +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: Unix +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Typing :: Typed +Requires-Python: >=3.8 +Requires-Dist: typing-extensions>=4.0.0; python_version < '3.9' +Description-Content-Type: text/markdown + +# annotated-types + +[![CI](https://github.com/annotated-types/annotated-types/workflows/CI/badge.svg?event=push)](https://github.com/annotated-types/annotated-types/actions?query=event%3Apush+branch%3Amain+workflow%3ACI) +[![pypi](https://img.shields.io/pypi/v/annotated-types.svg)](https://pypi.python.org/pypi/annotated-types) +[![versions](https://img.shields.io/pypi/pyversions/annotated-types.svg)](https://github.com/annotated-types/annotated-types) +[![license](https://img.shields.io/github/license/annotated-types/annotated-types.svg)](https://github.com/annotated-types/annotated-types/blob/main/LICENSE) + +[PEP-593](https://peps.python.org/pep-0593/) added `typing.Annotated` as a way of +adding context-specific metadata to existing types, and specifies that +`Annotated[T, x]` _should_ be treated as `T` by any tool or library without special +logic for `x`. + +This package provides metadata objects which can be used to represent common +constraints such as upper and lower bounds on scalar values and collection sizes, +a `Predicate` marker for runtime checks, and +descriptions of how we intend these metadata to be interpreted. In some cases, +we also note alternative representations which do not require this package. + +## Install + +```bash +pip install annotated-types +``` + +## Examples + +```python +from typing import Annotated +from annotated_types import Gt, Len, Predicate + +class MyClass: + age: Annotated[int, Gt(18)] # Valid: 19, 20, ... + # Invalid: 17, 18, "19", 19.0, ... + factors: list[Annotated[int, Predicate(is_prime)]] # Valid: 2, 3, 5, 7, 11, ... + # Invalid: 4, 8, -2, 5.0, "prime", ... + + my_list: Annotated[list[int], Len(0, 10)] # Valid: [], [10, 20, 30, 40, 50] + # Invalid: (1, 2), ["abc"], [0] * 20 +``` + +## Documentation + +_While `annotated-types` avoids runtime checks for performance, users should not +construct invalid combinations such as `MultipleOf("non-numeric")` or `Annotated[int, Len(3)]`. +Downstream implementors may choose to raise an error, emit a warning, silently ignore +a metadata item, etc., if the metadata objects described below are used with an +incompatible type - or for any other reason!_ + +### Gt, Ge, Lt, Le + +Express inclusive and/or exclusive bounds on orderable values - which may be numbers, +dates, times, strings, sets, etc. Note that the boundary value need not be of the +same type that was annotated, so long as they can be compared: `Annotated[int, Gt(1.5)]` +is fine, for example, and implies that the value is an integer x such that `x > 1.5`. + +We suggest that implementors may also interpret `functools.partial(operator.le, 1.5)` +as being equivalent to `Gt(1.5)`, for users who wish to avoid a runtime dependency on +the `annotated-types` package. + +To be explicit, these types have the following meanings: + +* `Gt(x)` - value must be "Greater Than" `x` - equivalent to exclusive minimum +* `Ge(x)` - value must be "Greater than or Equal" to `x` - equivalent to inclusive minimum +* `Lt(x)` - value must be "Less Than" `x` - equivalent to exclusive maximum +* `Le(x)` - value must be "Less than or Equal" to `x` - equivalent to inclusive maximum + +### Interval + +`Interval(gt, ge, lt, le)` allows you to specify an upper and lower bound with a single +metadata object. `None` attributes should be ignored, and non-`None` attributes +treated as per the single bounds above. + +### MultipleOf + +`MultipleOf(multiple_of=x)` might be interpreted in two ways: + +1. Python semantics, implying `value % multiple_of == 0`, or +2. [JSONschema semantics](https://json-schema.org/draft/2020-12/json-schema-validation.html#rfc.section.6.2.1), + where `int(value / multiple_of) == value / multiple_of`. + +We encourage users to be aware of these two common interpretations and their +distinct behaviours, especially since very large or non-integer numbers make +it easy to cause silent data corruption due to floating-point imprecision. + +We encourage libraries to carefully document which interpretation they implement. + +### MinLen, MaxLen, Len + +`Len()` implies that `min_length <= len(value) <= max_length` - lower and upper bounds are inclusive. + +As well as `Len()` which can optionally include upper and lower bounds, we also +provide `MinLen(x)` and `MaxLen(y)` which are equivalent to `Len(min_length=x)` +and `Len(max_length=y)` respectively. + +`Len`, `MinLen`, and `MaxLen` may be used with any type which supports `len(value)`. + +Examples of usage: + +* `Annotated[list, MaxLen(10)]` (or `Annotated[list, Len(max_length=10))`) - list must have a length of 10 or less +* `Annotated[str, MaxLen(10)]` - string must have a length of 10 or less +* `Annotated[list, MinLen(3))` (or `Annotated[list, Len(min_length=3))`) - list must have a length of 3 or more +* `Annotated[list, Len(4, 6)]` - list must have a length of 4, 5, or 6 +* `Annotated[list, Len(8, 8)]` - list must have a length of exactly 8 + +#### Changed in v0.4.0 + +* `min_inclusive` has been renamed to `min_length`, no change in meaning +* `max_exclusive` has been renamed to `max_length`, upper bound is now **inclusive** instead of **exclusive** +* The recommendation that slices are interpreted as `Len` has been removed due to ambiguity and different semantic + meaning of the upper bound in slices vs. `Len` + +See [issue #23](https://github.com/annotated-types/annotated-types/issues/23) for discussion. + +### Timezone + +`Timezone` can be used with a `datetime` or a `time` to express which timezones +are allowed. `Annotated[datetime, Timezone(None)]` must be a naive datetime. +`Timezone[...]` ([literal ellipsis](https://docs.python.org/3/library/constants.html#Ellipsis)) +expresses that any timezone-aware datetime is allowed. You may also pass a specific +timezone string or [`tzinfo`](https://docs.python.org/3/library/datetime.html#tzinfo-objects) +object such as `Timezone(timezone.utc)` or `Timezone("Africa/Abidjan")` to express that you only +allow a specific timezone, though we note that this is often a symptom of fragile design. + +#### Changed in v0.x.x + +* `Timezone` accepts [`tzinfo`](https://docs.python.org/3/library/datetime.html#tzinfo-objects) objects instead of + `timezone`, extending compatibility to [`zoneinfo`](https://docs.python.org/3/library/zoneinfo.html) and third party libraries. + +### Unit + +`Unit(unit: str)` expresses that the annotated numeric value is the magnitude of +a quantity with the specified unit. For example, `Annotated[float, Unit("m/s")]` +would be a float representing a velocity in meters per second. + +Please note that `annotated_types` itself makes no attempt to parse or validate +the unit string in any way. That is left entirely to downstream libraries, +such as [`pint`](https://pint.readthedocs.io) or +[`astropy.units`](https://docs.astropy.org/en/stable/units/). + +An example of how a library might use this metadata: + +```python +from annotated_types import Unit +from typing import Annotated, TypeVar, Callable, Any, get_origin, get_args + +# given a type annotated with a unit: +Meters = Annotated[float, Unit("m")] + + +# you can cast the annotation to a specific unit type with any +# callable that accepts a string and returns the desired type +T = TypeVar("T") +def cast_unit(tp: Any, unit_cls: Callable[[str], T]) -> T | None: + if get_origin(tp) is Annotated: + for arg in get_args(tp): + if isinstance(arg, Unit): + return unit_cls(arg.unit) + return None + + +# using `pint` +import pint +pint_unit = cast_unit(Meters, pint.Unit) + + +# using `astropy.units` +import astropy.units as u +astropy_unit = cast_unit(Meters, u.Unit) +``` + +### Predicate + +`Predicate(func: Callable)` expresses that `func(value)` is truthy for valid values. +Users should prefer the statically inspectable metadata above, but if you need +the full power and flexibility of arbitrary runtime predicates... here it is. + +For some common constraints, we provide generic types: + +* `IsLower = Annotated[T, Predicate(str.islower)]` +* `IsUpper = Annotated[T, Predicate(str.isupper)]` +* `IsDigit = Annotated[T, Predicate(str.isdigit)]` +* `IsFinite = Annotated[T, Predicate(math.isfinite)]` +* `IsNotFinite = Annotated[T, Predicate(Not(math.isfinite))]` +* `IsNan = Annotated[T, Predicate(math.isnan)]` +* `IsNotNan = Annotated[T, Predicate(Not(math.isnan))]` +* `IsInfinite = Annotated[T, Predicate(math.isinf)]` +* `IsNotInfinite = Annotated[T, Predicate(Not(math.isinf))]` + +so that you can write e.g. `x: IsFinite[float] = 2.0` instead of the longer +(but exactly equivalent) `x: Annotated[float, Predicate(math.isfinite)] = 2.0`. + +Some libraries might have special logic to handle known or understandable predicates, +for example by checking for `str.isdigit` and using its presence to both call custom +logic to enforce digit-only strings, and customise some generated external schema. +Users are therefore encouraged to avoid indirection like `lambda s: s.lower()`, in +favor of introspectable methods such as `str.lower` or `re.compile("pattern").search`. + +To enable basic negation of commonly used predicates like `math.isnan` without introducing introspection that makes it impossible for implementers to introspect the predicate we provide a `Not` wrapper that simply negates the predicate in an introspectable manner. Several of the predicates listed above are created in this manner. + +We do not specify what behaviour should be expected for predicates that raise +an exception. For example `Annotated[int, Predicate(str.isdigit)]` might silently +skip invalid constraints, or statically raise an error; or it might try calling it +and then propagate or discard the resulting +`TypeError: descriptor 'isdigit' for 'str' objects doesn't apply to a 'int' object` +exception. We encourage libraries to document the behaviour they choose. + +### Doc + +`doc()` can be used to add documentation information in `Annotated`, for function and method parameters, variables, class attributes, return types, and any place where `Annotated` can be used. + +It expects a value that can be statically analyzed, as the main use case is for static analysis, editors, documentation generators, and similar tools. + +It returns a `DocInfo` class with a single attribute `documentation` containing the value passed to `doc()`. + +This is the early adopter's alternative form of the [`typing-doc` proposal](https://github.com/tiangolo/fastapi/blob/typing-doc/typing_doc.md). + +### Integrating downstream types with `GroupedMetadata` + +Implementers may choose to provide a convenience wrapper that groups multiple pieces of metadata. +This can help reduce verbosity and cognitive overhead for users. +For example, an implementer like Pydantic might provide a `Field` or `Meta` type that accepts keyword arguments and transforms these into low-level metadata: + +```python +from dataclasses import dataclass +from typing import Iterator +from annotated_types import GroupedMetadata, Ge + +@dataclass +class Field(GroupedMetadata): + ge: int | None = None + description: str | None = None + + def __iter__(self) -> Iterator[object]: + # Iterating over a GroupedMetadata object should yield annotated-types + # constraint metadata objects which describe it as fully as possible, + # and may include other unknown objects too. + if self.ge is not None: + yield Ge(self.ge) + if self.description is not None: + yield Description(self.description) +``` + +Libraries consuming annotated-types constraints should check for `GroupedMetadata` and unpack it by iterating over the object and treating the results as if they had been "unpacked" in the `Annotated` type. The same logic should be applied to the [PEP 646 `Unpack` type](https://peps.python.org/pep-0646/), so that `Annotated[T, Field(...)]`, `Annotated[T, Unpack[Field(...)]]` and `Annotated[T, *Field(...)]` are all treated consistently. + +Libraries consuming annotated-types should also ignore any metadata they do not recongize that came from unpacking a `GroupedMetadata`, just like they ignore unrecognized metadata in `Annotated` itself. + +Our own `annotated_types.Interval` class is a `GroupedMetadata` which unpacks itself into `Gt`, `Lt`, etc., so this is not an abstract concern. Similarly, `annotated_types.Len` is a `GroupedMetadata` which unpacks itself into `MinLen` (optionally) and `MaxLen`. + +### Consuming metadata + +We intend to not be prescriptive as to _how_ the metadata and constraints are used, but as an example of how one might parse constraints from types annotations see our [implementation in `test_main.py`](https://github.com/annotated-types/annotated-types/blob/f59cf6d1b5255a0fe359b93896759a180bec30ae/tests/test_main.py#L94-L103). + +It is up to the implementer to determine how this metadata is used. +You could use the metadata for runtime type checking, for generating schemas or to generate example data, amongst other use cases. + +## Design & History + +This package was designed at the PyCon 2022 sprints by the maintainers of Pydantic +and Hypothesis, with the goal of making it as easy as possible for end-users to +provide more informative annotations for use by runtime libraries. + +It is deliberately minimal, and following PEP-593 allows considerable downstream +discretion in what (if anything!) they choose to support. Nonetheless, we expect +that staying simple and covering _only_ the most common use-cases will give users +and maintainers the best experience we can. If you'd like more constraints for your +types - follow our lead, by defining them and documenting them downstream! diff --git a/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/RECORD b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..cb983b08665edc0419dd7b70bc9639baeeb25d80 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/RECORD @@ -0,0 +1,10 @@ +annotated_types-0.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +annotated_types-0.7.0.dist-info/METADATA,sha256=7ltqxksJJ0wCYFGBNIQCWTlWQGeAH0hRFdnK3CB895E,15046 +annotated_types-0.7.0.dist-info/RECORD,, +annotated_types-0.7.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87 +annotated_types-0.7.0.dist-info/licenses/LICENSE,sha256=_hBJiEsaDZNCkB6I4H8ykl0ksxIdmXK2poBfuYJLCV0,1083 +annotated_types/__init__.py,sha256=RynLsRKUEGI0KimXydlD1fZEfEzWwDo0Uon3zOKhG1Q,13819 +annotated_types/__pycache__/__init__.cpython-311.pyc,, +annotated_types/__pycache__/test_cases.cpython-311.pyc,, +annotated_types/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +annotated_types/test_cases.py,sha256=zHFX6EpcMbGJ8FzBYDbO56bPwx_DYIVSKbZM-4B3_lg,6421 diff --git a/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..516596c76787b10928cbab24f22c0ea00433b15d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.24.2 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/licenses/LICENSE b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d99323a9965f146d5b0888c4ca1bf0727e12b04f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2022 the contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/lib/python3.11/site-packages/ray/__init__.py b/.venv/lib/python3.11/site-packages/ray/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef737883c55996e81215adc49dbbc79113c1cb11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/__init__.py @@ -0,0 +1,294 @@ +# isort: skip_file +from ray._private import log # isort: skip # noqa: F401 +import logging +import os +import sys + +log.generate_logging_config() +logger = logging.getLogger(__name__) + + +def _configure_system(): + import os + import platform + import sys + + """Wraps system configuration to avoid 'leaking' variables into ray.""" + + # Sanity check pickle5 if it has been installed. + if "pickle5" in sys.modules: + if sys.version_info >= (3, 8): + logger.warning( + "Package pickle5 becomes unnecessary in Python 3.8 and above. " + "Its presence may confuse libraries including Ray. " + "Please uninstall the package." + ) + + import importlib.metadata + + try: + version_str = importlib.metadata.version("pickle5") + version = tuple(int(n) for n in version_str.split(".")) + if version < (0, 0, 10): + logger.warning( + "Although not used by Ray, a version of pickle5 that leaks memory " + "is found in the environment. Please run 'pip install pickle5 -U' " + "to upgrade." + ) + except importlib.metadata.PackageNotFoundError: + logger.warning( + "You are using the 'pickle5' module, but " + "the exact version is unknown (possibly carried as " + "an internal component by another module). Please " + "make sure you are using pickle5 >= 0.0.10 because " + "previous versions may leak memory." + ) + + # Importing psutil & setproctitle. Must be before ray._raylet is + # initialized. + thirdparty_files = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "thirdparty_files" + ) + sys.path.insert(0, thirdparty_files) + + if ( + platform.system() == "Linux" + and "Microsoft".lower() in platform.release().lower() + ): + from ray._private import compat # noqa: E402 + + compat.patch_psutil() + + # Expose ray ABI symbols which may be dependent by other shared + # libraries such as _streaming.so. See BUILD.bazel:_raylet + python_shared_lib_suffix = ".so" if sys.platform != "win32" else ".pyd" + so_path = os.path.join( + os.path.dirname(__file__), "_raylet" + python_shared_lib_suffix + ) + if os.path.exists(so_path): + import ctypes + from ctypes import CDLL + + CDLL(so_path, ctypes.RTLD_GLOBAL) + + +_configure_system() +# Delete configuration function. +del _configure_system + +from ray import _version # noqa: E402 + +__commit__ = _version.commit +__version__ = _version.version + +import ray._raylet # noqa: E402 + +from ray._raylet import ( # noqa: E402,F401 + ActorClassID, + ActorID, + NodeID, + Config as _Config, + JobID, + WorkerID, + FunctionID, + ObjectID, + ObjectRef, + ObjectRefGenerator, + DynamicObjectRefGenerator, + TaskID, + UniqueID, + Language, + PlacementGroupID, + ClusterID, +) + +_config = _Config() + +from ray._private.state import ( # noqa: E402,F401 + nodes, + timeline, + cluster_resources, + available_resources, +) +from ray._private.worker import ( # noqa: E402,F401 + LOCAL_MODE, + SCRIPT_MODE, + WORKER_MODE, + RESTORE_WORKER_MODE, + SPILL_WORKER_MODE, + cancel, + get, + get_actor, + get_gpu_ids, + init, + is_initialized, + put, + kill, + remote, + shutdown, + wait, +) + +from ray._private.ray_logging.logging_config import LoggingConfig # noqa: E402 + +# We import ray.actor because some code is run in actor.py which initializes +# some functions in the worker. +import ray.actor # noqa: E402,F401 +from ray.actor import method # noqa: E402,F401 + +# TODO(qwang): We should remove this exporting in Ray2.0. +from ray.cross_language import java_function, java_actor_class # noqa: E402,F401 +from ray.runtime_context import get_runtime_context # noqa: E402,F401 +from ray import internal # noqa: E402,F401 +from ray import util # noqa: E402,F401 +from ray import _private # noqa: E402,F401 + +# We import ClientBuilder so that modules can inherit from `ray.ClientBuilder`. +from ray.client_builder import client, ClientBuilder # noqa: E402,F401 + + +class _DeprecationWrapper: + def __init__(self, name, real_worker): + self._name = name + self._real_worker = real_worker + self._warned = set() + + def __getattr__(self, attr): + value = getattr(self._real_worker, attr) + if attr not in self._warned: + self._warned.add(attr) + logger.warning( + f"DeprecationWarning: `ray.{self._name}.{attr}` is a private " + "attribute and access will be removed in a future Ray version." + ) + return value + + +# TODO(ekl) remove this entirely after 3rd party libraries are all migrated. +worker = _DeprecationWrapper("worker", ray._private.worker) +ray_constants = _DeprecationWrapper("ray_constants", ray._private.ray_constants) +serialization = _DeprecationWrapper("serialization", ray._private.serialization) +state = _DeprecationWrapper("state", ray._private.state) + + +# Pulic Ray APIs +__all__ = [ + "__version__", + "_config", + "get_runtime_context", + "autoscaler", + "available_resources", + "cancel", + "client", + "ClientBuilder", + "cluster_resources", + "get", + "get_actor", + "get_gpu_ids", + "init", + "is_initialized", + "java_actor_class", + "java_function", + "cpp_function", + "kill", + "Language", + "method", + "nodes", + "put", + "remote", + "shutdown", + "show_in_dashboard", + "timeline", + "wait", + "LOCAL_MODE", + "SCRIPT_MODE", + "WORKER_MODE", + "LoggingConfig", +] + +# Public APIs that should automatically trigger ray.init(). +AUTO_INIT_APIS = { + "cancel", + "get", + "get_actor", + "get_gpu_ids", + "kill", + "put", + "wait", + "get_runtime_context", +} + +# Public APIs that should not automatically trigger ray.init(). +NON_AUTO_INIT_APIS = { + "ClientBuilder", + "LOCAL_MODE", + "Language", + "SCRIPT_MODE", + "WORKER_MODE", + "__version__", + "_config", + "autoscaler", + "available_resources", + "client", + "cluster_resources", + "cpp_function", + "init", + "is_initialized", + "java_actor_class", + "java_function", + "method", + "nodes", + "remote", + "show_in_dashboard", + "shutdown", + "timeline", + "LoggingConfig", +} + +assert set(__all__) == AUTO_INIT_APIS | NON_AUTO_INIT_APIS +from ray._private.auto_init_hook import wrap_auto_init_for_all_apis # noqa: E402 + +wrap_auto_init_for_all_apis(AUTO_INIT_APIS) +del wrap_auto_init_for_all_apis + +# Subpackages +__all__ += [ + "actor", + "autoscaler", + "data", + "internal", + "util", + "widgets", + "workflow", +] + +# ID types +__all__ += [ + "ActorClassID", + "ActorID", + "NodeID", + "JobID", + "WorkerID", + "FunctionID", + "ObjectID", + "ObjectRef", + "ObjectRefGenerator", + "DynamicObjectRefGenerator", + "TaskID", + "UniqueID", + "PlacementGroupID", +] + + +# Delay importing of expensive, isolated subpackages. +def __getattr__(name: str): + import importlib + + if name in ["data", "workflow", "autoscaler"]: + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +del os +del logging +del sys diff --git a/.venv/lib/python3.11/site-packages/ray/_raylet.pxd b/.venv/lib/python3.11/site-packages/ray/_raylet.pxd new file mode 100644 index 0000000000000000000000000000000000000000..acb025e43cc02af7fc00ede0f51df184b44f8ace --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_raylet.pxd @@ -0,0 +1,176 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cpython.pystate cimport PyThreadState_Get + +from libc.stdint cimport ( + int64_t, +) +from libcpp cimport bool as c_bool +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.unordered_map cimport unordered_map +from libcpp.memory cimport ( + shared_ptr, + unique_ptr +) +from libcpp.pair cimport pair as c_pair +from libcpp.utility cimport pair +from ray.includes.optional cimport ( + optional, + nullopt, + make_optional, +) + +from ray.includes.common cimport ( + CBuffer, + CRayObject, + CAddress, + CConcurrencyGroup, + CSchedulingStrategy, + CLabelMatchExpressions, +) +from ray.includes.libcoreworker cimport ( + ActorHandleSharedPtr, + CActorHandle, + CFiberEvent, +) + +from ray.includes.unique_ids cimport ( + CObjectID, + CActorID, + CTaskID, +) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) + +cdef extern from *: + """ + #if __OPTIMIZE__ && __OPTIMIZE__ == 1 + #undef __OPTIMIZE__ + int __OPTIMIZE__ = 1; + #define __OPTIMIZE__ 1 + #elif defined(BAZEL_OPT) + // For compilers that don't define __OPTIMIZE__ + int __OPTIMIZE__ = 1; + #else + int __OPTIMIZE__ = 0; + #endif + """ + int __OPTIMIZE__ + +cdef extern from "Python.h": + # Note(simon): This is used to configure asyncio actor stack size. + # Cython made PyThreadState an opaque types. Saying that if the user wants + # specific attributes, they can be declared manually. + + # You can find the cpython definition in Include/cpython/pystate.h#L59 + ctypedef struct CPyThreadState "PyThreadState": + int recursion_limit + int recursion_remaining + + # From Include/ceveal.h#67 + int Py_GetRecursionLimit() + void Py_SetRecursionLimit(int) + +cdef class Buffer: + cdef: + shared_ptr[CBuffer] buffer + Py_ssize_t shape + Py_ssize_t strides + + @staticmethod + cdef make(const shared_ptr[CBuffer]& buffer) + +cdef class BaseID: + # To avoid the error of "Python int too large to convert to C ssize_t", + # here `cdef size_t` is required. + cdef size_t hash(self) + +cdef class ObjectRef(BaseID): + cdef: + CObjectID data + c_string owner_addr + # Flag indicating whether or not this object ref was added to the set + # of active IDs in the core worker so we know whether we should clean + # it up. + c_bool in_core_worker + c_string call_site_data + + cdef CObjectID native(self) + +cdef class ActorID(BaseID): + cdef CActorID data + + cdef CActorID native(self) + + cdef size_t hash(self) + + +cdef class CoreWorker: + cdef: + c_bool is_driver + object async_thread + object async_event_loop + object plasma_event_handler + object job_config + object current_runtime_env + c_bool is_local_mode + + object cgname_to_eventloop_dict + object eventloop_for_default_cg + object thread_for_default_cg + object fd_to_cgname_dict + object _task_id_to_future_lock + dict _task_id_to_future + object event_loop_executor + + cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, + size_t data_size, ObjectRef object_ref, + c_vector[CObjectID] contained_ids, + CObjectID *c_object_id, shared_ptr[CBuffer] *data, + c_bool created_by_worker, + owner_address=*, + c_bool inline_small_object=*, + c_bool is_experimental_channel=*) + cdef unique_ptr[CAddress] _convert_python_address(self, address=*) + cdef store_task_output( + self, serialized_object, + const CObjectID &return_id, + const CObjectID &generator_id, + size_t data_size, shared_ptr[CBuffer] &metadata, const c_vector[CObjectID] + &contained_id, const CAddress &caller_address, + int64_t *task_output_inlined_bytes, + shared_ptr[CRayObject] *return_ptr) + cdef store_task_outputs( + self, + worker, outputs, + const CAddress &caller_address, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + CObjectID ref_generator_id=*) + cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle, + c_bool weak_ref) + cdef c_function_descriptors_to_python( + self, const c_vector[CFunctionDescriptor] &c_function_descriptors) + cdef initialize_eventloops_for_actor_concurrency_group( + self, const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups) + cdef python_scheduling_strategy_to_c( + self, python_scheduling_strategy, + CSchedulingStrategy *c_scheduling_strategy) + cdef python_label_match_expressions_to_c( + self, python_expressions, + CLabelMatchExpressions *c_expressions) + cdef CObjectID allocate_dynamic_return_id_for_generator( + self, + const CAddress &owner_address, + const CTaskID &task_id, + return_size, + generator_index, + is_async_actor) + +cdef class FunctionDescriptor: + cdef: + CFunctionDescriptor descriptor diff --git a/.venv/lib/python3.11/site-packages/ray/_raylet.pyi b/.venv/lib/python3.11/site-packages/ray/_raylet.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c2897640957876ff93365696198b22a8b0740e12 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_raylet.pyi @@ -0,0 +1,11 @@ +from typing import Awaitable, TypeVar + +R = TypeVar("R") + + +class ObjectRef(Awaitable[R]): # type: ignore + pass + + +class ObjectID(Awaitable[R]): # type: ignore + pass diff --git a/.venv/lib/python3.11/site-packages/ray/_version.py b/.venv/lib/python3.11/site-packages/ray/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb927fffd678baeb2d4cfbee5aeef9a096140f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_version.py @@ -0,0 +1,6 @@ +# Replaced with the current commit when building the wheels. +commit = "637116a090c052d061af5ba3bef8a467c8c3fc25" +version = "2.42.0" + +if __name__ == "__main__": + print("%s %s" % (version, commit)) diff --git a/.venv/lib/python3.11/site-packages/ray/actor.py b/.venv/lib/python3.11/site-packages/ray/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..450b1c620a1e3b09fb764b22c01e6b75e381ce3d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/actor.py @@ -0,0 +1,1790 @@ +import inspect +import logging +import weakref +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import ray._private.ray_constants as ray_constants +import ray._private.signature as signature +import ray._private.worker +import ray._raylet +from ray import ActorClassID, Language, cross_language +from ray._private import ray_option_utils +from ray._private.async_compat import has_async_methods +from ray._private.auto_init_hook import wrap_auto_init +from ray._private.client_mode_hook import ( + client_mode_convert_actor, + client_mode_hook, + client_mode_should_convert, +) +from ray._private.inspect_util import ( + is_class_method, + is_function_or_method, + is_static_method, +) +from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group +from ray._private.utils import get_runtime_env_info, parse_runtime_env +from ray._raylet import ( + STREAMING_GENERATOR_RETURN, + ObjectRefGenerator, + PythonFunctionDescriptor, + raise_sys_exit_with_custom_error_message, +) +from ray.exceptions import AsyncioActorExit +from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.placement_group import _configure_placement_group_based_on_context +from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy, + SchedulingStrategyT, +) +from ray.util.tracing.tracing_helper import ( + _inject_tracing_into_class, + _tracing_actor_creation, + _tracing_actor_method_invocation, +) + +logger = logging.getLogger(__name__) + +# Hook to call with (actor, resources, strategy) on each local actor creation. +_actor_launch_hook = None + + +@PublicAPI +@client_mode_hook +def method(*args, **kwargs): + """Annotate an actor method. + + .. code-block:: python + + @ray.remote + class Foo: + @ray.method(num_returns=2) + def bar(self): + return 1, 2 + + f = Foo.remote() + + _, _ = f.bar.remote() + + Args: + num_returns: The number of object refs that should be returned by + invocations of this actor method. + """ + valid_kwargs = [ + "num_returns", + "concurrency_group", + "max_task_retries", + "retry_exceptions", + "_generator_backpressure_num_objects", + "enable_task_events", + ] + error_string = ( + "The @ray.method decorator must be applied using at least one of " + f"the arguments in the list {valid_kwargs}, for example " + "'@ray.method(num_returns=2)'." + ) + assert len(args) == 0 and len(kwargs) > 0, error_string + for key in kwargs: + key_error_string = ( + f"Unexpected keyword argument to @ray.method: '{key}'. The " + f"supported keyword arguments are {valid_kwargs}" + ) + assert key in valid_kwargs, key_error_string + + def annotate_method(method): + if "num_returns" in kwargs: + method.__ray_num_returns__ = kwargs["num_returns"] + if "max_task_retries" in kwargs: + method.__ray_max_task_retries__ = kwargs["max_task_retries"] + if "retry_exceptions" in kwargs: + method.__ray_retry_exceptions__ = kwargs["retry_exceptions"] + if "concurrency_group" in kwargs: + method.__ray_concurrency_group__ = kwargs["concurrency_group"] + if "_generator_backpressure_num_objects" in kwargs: + method.__ray_generator_backpressure_num_objects__ = kwargs[ + "_generator_backpressure_num_objects" + ] + if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: + method.__ray_enable_task_events__ = kwargs["enable_task_events"] + return method + + return annotate_method + + +# Create objects to wrap method invocations. This is done so that we can +# invoke methods with actor.method.remote() instead of actor.method(). +@PublicAPI +class ActorMethod: + """A class used to invoke an actor method. + + Note: This class only keeps a weak ref to the actor, unless it has been + passed to a remote function. This avoids delays in GC of the actor. + + Attributes: + _actor_ref: A weakref handle to the actor. + _method_name: The name of the actor method. + _num_returns: The default number of return values that the method + invocation should return. If None is given, it uses + DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS for a normal actor task + and "streaming" for a generator task (when `is_generator` is True). + _max_task_retries: Number of retries on method failure. + _retry_exceptions: Boolean of whether you want to retry all user-raised + exceptions, or a list of allowlist exceptions to retry. + _is_generator: True if a given method is a Python generator. + _generator_backpressure_num_objects: Generator-only config. + If a number of unconsumed objects reach this threshold, + a actor task stop pausing. + enable_task_events: True if task events is enabled, i.e., task events from + the actor should be reported. Defaults to True. + _signature: The signature of the actor method. It is None only when cross + language feature is used. + _decorator: An optional decorator that should be applied to the actor + method invocation (as opposed to the actor method execution) before + invoking the method. The decorator must return a function that + takes in two arguments ("args" and "kwargs"). In most cases, it + should call the function that was passed into the decorator and + return the resulting ObjectRefs. For an example, see + "test_decorated_method" in "python/ray/tests/test_actor.py". + """ + + def __init__( + self, + actor, + method_name, + num_returns: Optional[Union[int, Literal["streaming"]]], + max_task_retries: int, + retry_exceptions: Union[bool, list, tuple], + is_generator: bool, + generator_backpressure_num_objects: int, + enable_task_events: bool, + decorator=None, + signature: Optional[List[inspect.Parameter]] = None, + hardref=False, + ): + self._actor_ref = weakref.ref(actor) + self._method_name = method_name + self._num_returns = num_returns + + # Default case. + if self._num_returns is None: + if is_generator: + self._num_returns = "streaming" + else: + self._num_returns = ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS + + self._max_task_retries = max_task_retries + self._retry_exceptions = retry_exceptions + self._is_generator = is_generator + self._generator_backpressure_num_objects = generator_backpressure_num_objects + self._enable_task_events = enable_task_events + self._signature = signature + # This is a decorator that is used to wrap the function invocation (as + # opposed to the function execution). The decorator must return a + # function that takes in two arguments ("args" and "kwargs"). In most + # cases, it should call the function that was passed into the decorator + # and return the resulting ObjectRefs. + self._decorator = decorator + + # Acquire a hard ref to the actor, this is useful mainly when passing + # actor method handles to remote functions. + if hardref: + self._actor_hard_ref = actor + else: + self._actor_hard_ref = None + + def __call__(self, *args, **kwargs): + raise TypeError( + "Actor methods cannot be called directly. Instead " + f"of running 'object.{self._method_name}()', try " + f"'object.{self._method_name}.remote()'." + ) + + @DeveloperAPI + def bind(self, *args, **kwargs): + return self._bind(args, kwargs) + + def remote(self, *args, **kwargs): + return self._remote(args, kwargs) + + def options(self, **options): + """Convenience method for executing an actor method call with options. + + Same arguments as func._remote(), but returns a wrapped function + that a non-underscore .remote() can be called on. + + Examples: + # The following two calls are equivalent. + >>> actor.my_method._remote(args=[x, y], name="foo", num_returns=2) + >>> actor.my_method.options(name="foo", num_returns=2).remote(x, y) + """ + + func_cls = self + + class FuncWrapper: + def remote(self, *args, **kwargs): + return func_cls._remote(args=args, kwargs=kwargs, **options) + + @DeveloperAPI + def bind(self, *args, **kwargs): + return func_cls._bind(args=args, kwargs=kwargs, **options) + + return FuncWrapper() + + @wrap_auto_init + @_tracing_actor_method_invocation + def _bind( + self, + args=None, + kwargs=None, + name="", + num_returns=None, + concurrency_group=None, + _generator_backpressure_num_objects=None, + ) -> Union["ray.dag.ClassMethodNode", Tuple["ray.dag.ClassMethodNode", ...]]: + from ray.dag.class_node import ( + BIND_INDEX_KEY, + IS_CLASS_METHOD_OUTPUT_KEY, + PARENT_CLASS_NODE_KEY, + PREV_CLASS_METHOD_CALL_KEY, + ClassMethodNode, + ) + + # TODO(sang): unify option passing + options = { + "name": name, + "num_returns": num_returns, + "concurrency_group": concurrency_group, + "_generator_backpressure_num_objects": _generator_backpressure_num_objects, + } + + actor = self._actor_ref() + if actor is None: + # Ref is GC'ed. It happens when the actor handle is GC'ed + # when bind is called. + raise RuntimeError("Lost reference to actor") + + other_args_to_resolve = { + PARENT_CLASS_NODE_KEY: actor, + PREV_CLASS_METHOD_CALL_KEY: None, + BIND_INDEX_KEY: actor._ray_dag_bind_index, + } + actor._ray_dag_bind_index += 1 + + assert ( + self._signature is not None + ), "self._signature should be set for .bind API." + try: + signature.validate_args(self._signature, args, kwargs) + except TypeError as e: + signature_copy = self._signature.copy() + if len(signature_copy) > 0 and signature_copy[-1].name == "_ray_trace_ctx": + # Remove the trace context arg for readability. + signature_copy.pop(-1) + signature_copy = inspect.Signature(parameters=signature_copy) + raise TypeError( + f"{str(e)}. The function `{self._method_name}` has a signature " + f"`{signature_copy}`, but the given arguments to `bind` doesn't " + f"match. args: {args}. kwargs: {kwargs}." + ) from None + + node = ClassMethodNode( + self._method_name, + args, + kwargs, + options, + other_args_to_resolve=other_args_to_resolve, + ) + + if node.num_returns > 1: + output_nodes: List[ClassMethodNode] = [] + for i in range(node.num_returns): + output_node = ClassMethodNode( + f"return_idx_{i}", + (node, i), + dict(), + dict(), + {IS_CLASS_METHOD_OUTPUT_KEY: True, PARENT_CLASS_NODE_KEY: actor}, + ) + output_nodes.append(output_node) + return tuple(output_nodes) + else: + return node + + @wrap_auto_init + @_tracing_actor_method_invocation + def _remote( + self, + args=None, + kwargs=None, + name="", + num_returns=None, + max_task_retries=None, + retry_exceptions=None, + concurrency_group=None, + _generator_backpressure_num_objects=None, + enable_task_events=None, + ): + if num_returns is None: + num_returns = self._num_returns + if max_task_retries is None: + max_task_retries = self._max_task_retries + if max_task_retries is None: + max_task_retries = 0 + if retry_exceptions is None: + retry_exceptions = self._retry_exceptions + if enable_task_events is None: + enable_task_events = self._enable_task_events + if _generator_backpressure_num_objects is None: + _generator_backpressure_num_objects = ( + self._generator_backpressure_num_objects + ) + + def invocation(args, kwargs): + actor = self._actor_hard_ref or self._actor_ref() + + if actor is None: + raise RuntimeError("Lost reference to actor") + + return actor._actor_method_call( + self._method_name, + args=args, + kwargs=kwargs, + name=name, + num_returns=num_returns, + max_task_retries=max_task_retries, + retry_exceptions=retry_exceptions, + concurrency_group_name=concurrency_group, + generator_backpressure_num_objects=( + _generator_backpressure_num_objects + ), + enable_task_events=enable_task_events, + ) + + # Apply the decorator if there is one. + if self._decorator is not None: + invocation = self._decorator(invocation) + + return invocation(args, kwargs) + + def __getstate__(self): + return { + "actor": self._actor_ref(), + "method_name": self._method_name, + "num_returns": self._num_returns, + "max_task_retries": self._max_task_retries, + "retry_exceptions": self._retry_exceptions, + "decorator": self._decorator, + "is_generator": self._is_generator, + "generator_backpressure_num_objects": self._generator_backpressure_num_objects, # noqa + "enable_task_events": self._enable_task_events, + } + + def __setstate__(self, state): + self.__init__( + state["actor"], + state["method_name"], + state["num_returns"], + state["max_task_retries"], + state["retry_exceptions"], + state["is_generator"], + state["generator_backpressure_num_objects"], + state["enable_task_events"], + state["decorator"], + hardref=True, + ) + + +class _ActorClassMethodMetadata(object): + """Metadata for all methods in an actor class. This data can be cached. + + Attributes: + methods: The actor methods. + decorators: Optional decorators that should be applied to the + method invocation function before invoking the actor methods. These + can be set by attaching the attribute + "__ray_invocation_decorator__" to the actor method. + signatures: The signatures of the methods. + num_returns: The default number of return values for + each actor method. + max_task_retries: Number of retries on method failure. + retry_exceptions: Boolean of whether you want to retry all user-raised + exceptions, or a list of allowlist exceptions to retry, for each method. + enable_task_events: True if tracing is enabled, i.e., task events from + the actor should be reported. Defaults to True. + """ + + _cache = {} # This cache will be cleared in ray._private.worker.disconnect() + + def __init__(self): + class_name = type(self).__name__ + raise TypeError( + f"{class_name} can not be constructed directly, " + f"instead of running '{class_name}()', " + f"try '{class_name}.create()'" + ) + + @classmethod + def reset_cache(cls): + cls._cache.clear() + + @classmethod + def create(cls, modified_class, actor_creation_function_descriptor): + # Try to create an instance from cache. + cached_meta = cls._cache.get(actor_creation_function_descriptor) + if cached_meta is not None: + return cached_meta + + # Create an instance without __init__ called. + self = cls.__new__(cls) + + actor_methods = inspect.getmembers(modified_class, is_function_or_method) + self.methods = dict(actor_methods) + + # Extract the signatures of each of the methods. This will be used + # to catch some errors if the methods are called with inappropriate + # arguments. + self.decorators = {} + self.signatures = {} + self.num_returns = {} + self.max_task_retries = {} + self.retry_exceptions = {} + self.method_is_generator = {} + self.enable_task_events = {} + self.generator_backpressure_num_objects = {} + self.concurrency_group_for_methods = {} + + for method_name, method in actor_methods: + # Whether or not this method requires binding of its first + # argument. For class and static methods, we do not want to bind + # the first argument, but we do for instance methods + method = inspect.unwrap(method) + is_bound = is_class_method(method) or is_static_method( + modified_class, method_name + ) + + # Print a warning message if the method signature is not + # supported. We don't raise an exception because if the actor + # inherits from a class that has a method whose signature we + # don't support, there may not be much the user can do about it. + self.signatures[method_name] = signature.extract_signature( + method, ignore_first=not is_bound + ) + # Set the default number of return values for this method. + if hasattr(method, "__ray_num_returns__"): + self.num_returns[method_name] = method.__ray_num_returns__ + else: + self.num_returns[method_name] = None + + # Only contains entries from `@ray.method(max_task_retries=...)` + # Ray may not populate the others with max_task_retries here because you may + # have set in `actor.method.options(max_task_retries=...)`. So Ray always + # stores max_task_retries both from the method and from the actor, and + # favors the former. + if hasattr(method, "__ray_max_task_retries__"): + self.max_task_retries[method_name] = method.__ray_max_task_retries__ + + if hasattr(method, "__ray_retry_exceptions__"): + self.retry_exceptions[method_name] = method.__ray_retry_exceptions__ + + if hasattr(method, "__ray_invocation_decorator__"): + self.decorators[method_name] = method.__ray_invocation_decorator__ + + if hasattr(method, "__ray_concurrency_group__"): + self.concurrency_group_for_methods[ + method_name + ] = method.__ray_concurrency_group__ + + if hasattr(method, "__ray_enable_task_events__"): + self.enable_task_events[method_name] = method.__ray_enable_task_events__ + + is_generator = inspect.isgeneratorfunction( + method + ) or inspect.isasyncgenfunction(method) + self.method_is_generator[method_name] = is_generator + + if hasattr(method, "__ray_generator_backpressure_num_objects__"): + self.generator_backpressure_num_objects[ + method_name + ] = method.__ray_generator_backpressure_num_objects__ + + # Update cache. + cls._cache[actor_creation_function_descriptor] = self + return self + + +class _ActorClassMetadata: + """Metadata for an actor class. + + Attributes: + language: The actor language, e.g. Python, Java. + modified_class: The original class that was decorated (with some + additional methods added like __ray_terminate__). + actor_creation_function_descriptor: The function descriptor for + the actor creation task. + class_id: The ID of this actor class. + class_name: The name of this class. + num_cpus: The default number of CPUs required by the actor creation + task. + num_gpus: The default number of GPUs required by the actor creation + task. + memory: The heap memory quota for this actor. + resources: The default resources required by the actor creation task. + accelerator_type: The specified type of accelerator required for the + node on which this actor runs. + See :ref:`accelerator types `. + runtime_env: The runtime environment for this actor. + scheduling_strategy: Strategy about how to schedule this actor. + last_export_cluster_and_job: A pair of the last exported cluster + and job to help us to know whether this function was exported. + This is an imperfect mechanism used to determine if we need to + export the remote function again. It is imperfect in the sense that + the actor class definition could be exported multiple times by + different workers. + method_meta: The actor method metadata. + """ + + def __init__( + self, + language, + modified_class, + actor_creation_function_descriptor, + class_id, + max_restarts, + max_task_retries, + num_cpus, + num_gpus, + memory, + object_store_memory, + resources, + accelerator_type, + runtime_env, + concurrency_groups, + scheduling_strategy: SchedulingStrategyT, + ): + self.language = language + self.modified_class = modified_class + self.actor_creation_function_descriptor = actor_creation_function_descriptor + self.class_name = actor_creation_function_descriptor.class_name + self.is_cross_language = language != Language.PYTHON + self.class_id = class_id + self.max_restarts = max_restarts + self.max_task_retries = max_task_retries + self.num_cpus = num_cpus + self.num_gpus = num_gpus + self.memory = memory + self.object_store_memory = object_store_memory + self.resources = resources + self.accelerator_type = accelerator_type + self.runtime_env = runtime_env + self.concurrency_groups = concurrency_groups + self.scheduling_strategy = scheduling_strategy + self.last_export_cluster_and_job = None + self.method_meta = _ActorClassMethodMetadata.create( + modified_class, actor_creation_function_descriptor + ) + + +@PublicAPI +class ActorClassInheritanceException(TypeError): + pass + + +def _process_option_dict(actor_options): + _filled_options = {} + arg_names = set(inspect.getfullargspec(_ActorClassMetadata.__init__)[0]) + for k, v in ray_option_utils.actor_options.items(): + if k in arg_names: + _filled_options[k] = actor_options.get(k, v.default_value) + _filled_options["runtime_env"] = parse_runtime_env(_filled_options["runtime_env"]) + return _filled_options + + +@PublicAPI +class ActorClass: + """An actor class. + + This is a decorated class. It can be used to create actors. + + Attributes: + __ray_metadata__: Contains metadata for the actor. + """ + + def __init__(cls, name, bases, attr): + """Prevents users from directly inheriting from an ActorClass. + + This will be called when a class is defined with an ActorClass object + as one of its base classes. To intentionally construct an ActorClass, + use the '_ray_from_modified_class' classmethod. + + Raises: + ActorClassInheritanceException: When ActorClass is inherited. + AssertionError: If ActorClassInheritanceException is not raised i.e., + conditions for raising it are not met in any + iteration of the loop. + TypeError: In all other cases. + """ + for base in bases: + if isinstance(base, ActorClass): + raise ActorClassInheritanceException( + f"Attempted to define subclass '{name}' of actor " + f"class '{base.__ray_metadata__.class_name}'. " + "Inheriting from actor classes is " + "not currently supported. You can instead " + "inherit from a non-actor base class and make " + "the derived class an actor class (with " + "@ray.remote)." + ) + + # This shouldn't be reached because one of the base classes must be + # an actor class if this was meant to be subclassed. + assert False, ( + "ActorClass.__init__ should not be called. Please use " + "the @ray.remote decorator instead." + ) + + def __call__(self, *args, **kwargs): + """Prevents users from directly instantiating an ActorClass. + + This will be called instead of __init__ when 'ActorClass()' is executed + because an is an object rather than a metaobject. To properly + instantiated a remote actor, use 'ActorClass.remote()'. + + Raises: + Exception: Always. + """ + raise TypeError( + "Actors cannot be instantiated directly. " + f"Instead of '{self.__ray_metadata__.class_name}()', " + f"use '{self.__ray_metadata__.class_name}.remote()'." + ) + + @classmethod + def _ray_from_modified_class( + cls, + modified_class, + class_id, + actor_options, + ): + for attribute in [ + "remote", + "_remote", + "_ray_from_modified_class", + "_ray_from_function_descriptor", + ]: + if hasattr(modified_class, attribute): + logger.warning( + "Creating an actor from class " + f"{modified_class.__name__} overwrites " + f"attribute {attribute} of that class" + ) + + # Make sure the actor class we are constructing inherits from the + # original class so it retains all class properties. + class DerivedActorClass(cls, modified_class): + def __init__(self, *args, **kwargs): + try: + cls.__init__(self, *args, **kwargs) + except Exception as e: + # Delegate call to modified_class.__init__ only + # if the exception raised by cls.__init__ is + # TypeError and not ActorClassInheritanceException(TypeError). + # In all other cases proceed with raise e. + if isinstance(e, TypeError) and not isinstance( + e, ActorClassInheritanceException + ): + modified_class.__init__(self, *args, **kwargs) + else: + raise e + + name = f"ActorClass({modified_class.__name__})" + DerivedActorClass.__module__ = modified_class.__module__ + DerivedActorClass.__name__ = name + DerivedActorClass.__qualname__ = name + # Construct the base object. + self = DerivedActorClass.__new__(DerivedActorClass) + # Actor creation function descriptor. + actor_creation_function_descriptor = PythonFunctionDescriptor.from_class( + modified_class.__ray_actor_class__ + ) + + self.__ray_metadata__ = _ActorClassMetadata( + Language.PYTHON, + modified_class, + actor_creation_function_descriptor, + class_id, + **_process_option_dict(actor_options), + ) + self._default_options = actor_options + if "runtime_env" in self._default_options: + self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env + + return self + + @classmethod + def _ray_from_function_descriptor( + cls, + language, + actor_creation_function_descriptor, + actor_options, + ): + self = ActorClass.__new__(ActorClass) + self.__ray_metadata__ = _ActorClassMetadata( + language, + None, + actor_creation_function_descriptor, + None, + **_process_option_dict(actor_options), + ) + self._default_options = actor_options + if "runtime_env" in self._default_options: + self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env + return self + + def remote(self, *args, **kwargs): + """Create an actor. + + Args: + args: These arguments are forwarded directly to the actor + constructor. + kwargs: These arguments are forwarded directly to the actor + constructor. + + Returns: + A handle to the newly created actor. + """ + return self._remote(args=args, kwargs=kwargs, **self._default_options) + + def options(self, **actor_options): + """Configures and overrides the actor instantiation parameters. + + The arguments are the same as those that can be passed + to :obj:`ray.remote`. + + Args: + num_cpus: The quantity of CPU cores to reserve + for this task or for the lifetime of the actor. + num_gpus: The quantity of GPUs to reserve + for this task or for the lifetime of the actor. + resources (Dict[str, float]): The quantity of various custom resources + to reserve for this task or for the lifetime of the actor. + This is a dictionary mapping strings (resource names) to floats. + accelerator_type: If specified, requires that the task or actor run + on a node with the specified type of accelerator. + See :ref:`accelerator types `. + memory: The heap memory request in bytes for this task/actor, + rounded down to the nearest integer. + object_store_memory: The object store memory request for actors only. + max_restarts: This specifies the maximum + number of times that the actor should be restarted when it dies + unexpectedly. The minimum valid value is 0 (default), + which indicates that the actor doesn't need to be restarted. + A value of -1 indicates that an actor should be restarted + indefinitely. + max_task_retries: How many times to + retry an actor task if the task fails due to a runtime error, + e.g., the actor has died. If set to -1, the system will + retry the failed task until the task succeeds, or the actor + has reached its max_restarts limit. If set to `n > 0`, the + system will retry the failed task up to n times, after which the + task will throw a `RayActorError` exception upon :obj:`ray.get`. + Note that Python exceptions may trigger retries *only if* + `retry_exceptions` is set for the method, in that case when + `max_task_retries` runs out the task will rethrow the exception from + the task. You can override this number with the method's + `max_task_retries` option in `@ray.method` decorator or in `.option()`. + max_pending_calls: Set the max number of pending calls + allowed on the actor handle. When this value is exceeded, + PendingCallsLimitExceeded will be raised for further tasks. + Note that this limit is counted per handle. -1 means that the + number of pending calls is unlimited. + max_concurrency: The max number of concurrent calls to allow for + this actor. This only works with direct actor calls. The max + concurrency defaults to 1 for threaded execution, and 1000 for + asyncio execution. Note that the execution order is not + guaranteed when max_concurrency > 1. + name: The globally unique name for the actor, which can be used + to retrieve the actor via ray.get_actor(name) as long as the + actor is still alive. + namespace: Override the namespace to use for the actor. By default, + actors are created in an anonymous namespace. The actor can + be retrieved via ray.get_actor(name=name, namespace=namespace). + lifetime: Either `None`, which defaults to the actor will fate + share with its creator and will be deleted once its refcount + drops to zero, or "detached", which means the actor will live + as a global object independent of the creator. + runtime_env (Dict[str, Any]): Specifies the runtime environment for + this actor or task and its children. See + :ref:`runtime-environments` for detailed documentation. + scheduling_strategy: Strategy about how to + schedule a remote function or actor. Possible values are + None: ray will figure out the scheduling strategy to use, it + will either be the PlacementGroupSchedulingStrategy using parent's + placement group if parent has one and has + placement_group_capture_child_tasks set to true, + or "DEFAULT"; + "DEFAULT": default hybrid scheduling; + "SPREAD": best effort spread scheduling; + `PlacementGroupSchedulingStrategy`: + placement group based scheduling; + `NodeAffinitySchedulingStrategy`: + node id based affinity scheduling. + _metadata: Extended options for Ray libraries. For example, + _metadata={"workflows.io/options": } for + Ray workflows. + enable_task_events: True if tracing is enabled, i.e., task events from + the actor should be reported. Defaults to True. + + Examples: + + .. code-block:: python + + @ray.remote(num_cpus=2, resources={"CustomResource": 1}) + class Foo: + def method(self): + return 1 + # Class Bar will require 1 cpu instead of 2. + # It will also require no custom resources. + Bar = Foo.options(num_cpus=1, resources=None) + """ + + actor_cls = self + + # override original options + default_options = self._default_options.copy() + # "concurrency_groups" could not be used in ".options()", + # we should remove it before merging options from '@ray.remote'. + default_options.pop("concurrency_groups", None) + updated_options = ray_option_utils.update_options( + default_options, actor_options + ) + ray_option_utils.validate_actor_options(updated_options, in_options=True) + + # only update runtime_env when ".options()" specifies new runtime_env + if "runtime_env" in actor_options: + updated_options["runtime_env"] = parse_runtime_env( + updated_options["runtime_env"] + ) + + class ActorOptionWrapper: + def remote(self, *args, **kwargs): + return actor_cls._remote(args=args, kwargs=kwargs, **updated_options) + + @DeveloperAPI + def bind(self, *args, **kwargs): + """ + For Ray DAG building that creates static graph from decorated + class or functions. + """ + from ray.dag.class_node import ClassNode + + return ClassNode( + actor_cls.__ray_metadata__.modified_class, + args, + kwargs, + updated_options, + ) + + return ActorOptionWrapper() + + @wrap_auto_init + @_tracing_actor_creation + def _remote(self, args=None, kwargs=None, **actor_options): + """Create an actor. + + This method allows more flexibility than the remote method because + resource requirements can be specified and override the defaults in the + decorator. + + Args: + args: The arguments to forward to the actor constructor. + kwargs: The keyword arguments to forward to the actor constructor. + num_cpus: The number of CPUs required by the actor creation task. + num_gpus: The number of GPUs required by the actor creation task. + memory: Restrict the heap memory usage of this actor. + resources: The custom resources required by the actor creation + task. + max_concurrency: The max number of concurrent calls to allow for + this actor. This only works with direct actor calls. The max + concurrency defaults to 1 for threaded execution, and 1000 for + asyncio execution. Note that the execution order is not + guaranteed when max_concurrency > 1. + name: The globally unique name for the actor, which can be used + to retrieve the actor via ray.get_actor(name) as long as the + actor is still alive. + namespace: Override the namespace to use for the actor. By default, + actors are created in an anonymous namespace. The actor can + be retrieved via ray.get_actor(name=name, namespace=namespace). + lifetime: Either `None`, which defaults to the actor will fate + share with its creator and will be deleted once its refcount + drops to zero, or "detached", which means the actor will live + as a global object independent of the creator. + placement_group: (This has been deprecated, please use + `PlacementGroupSchedulingStrategy` scheduling_strategy) + the placement group this actor belongs to, + or None if it doesn't belong to any group. Setting to "default" + autodetects the placement group based on the current setting of + placement_group_capture_child_tasks. + placement_group_bundle_index: (This has been deprecated, please use + `PlacementGroupSchedulingStrategy` scheduling_strategy) + the index of the bundle + if the actor belongs to a placement group, which may be -1 to + specify any available bundle. + placement_group_capture_child_tasks: (This has been deprecated, + please use `PlacementGroupSchedulingStrategy` + scheduling_strategy) + Whether or not children tasks + of this actor should implicitly use the same placement group + as its parent. It is False by default. + runtime_env (Dict[str, Any]): Specifies the runtime environment for + this actor or task and its children (see + :ref:`runtime-environments` for details). + max_pending_calls: Set the max number of pending calls + allowed on the actor handle. When this value is exceeded, + PendingCallsLimitExceeded will be raised for further tasks. + Note that this limit is counted per handle. -1 means that the + number of pending calls is unlimited. + scheduling_strategy: Strategy about how to schedule this actor. + enable_task_events: True if tracing is enabled, i.e., task events from + the actor should be reported. Defaults to True. + _labels: The key-value labels of the actor. + + Returns: + A handle to the newly created actor. + """ + name = actor_options.get("name") + namespace = actor_options.get("namespace") + if name is not None: + if not isinstance(name, str): + raise TypeError(f"name must be None or a string, got: '{type(name)}'.") + elif name == "": + raise ValueError("Actor name cannot be an empty string.") + if namespace is not None: + ray._private.utils.validate_namespace(namespace) + + # Handle the get-or-create case. + if actor_options.get("get_if_exists"): + try: + return ray.get_actor(name, namespace=namespace) + except ValueError: + # Attempt to create it (may race with other attempts). + updated_options = actor_options.copy() + updated_options["get_if_exists"] = False # prevent infinite loop + try: + return self._remote(args, kwargs, **updated_options) + except ValueError: + # We lost the creation race, ignore. + pass + return ray.get_actor(name, namespace=namespace) + + # We pop the "concurrency_groups" coming from "@ray.remote" here. We no longer + # need it in "_remote()". + actor_options.pop("concurrency_groups", None) + + if args is None: + args = [] + if kwargs is None: + kwargs = {} + meta = self.__ray_metadata__ + is_asyncio = has_async_methods(meta.modified_class) + + if actor_options.get("max_concurrency") is None: + actor_options["max_concurrency"] = ( + ray_constants.DEFAULT_MAX_CONCURRENCY_ASYNC + if is_asyncio + else ray_constants.DEFAULT_MAX_CONCURRENCY_THREADED + ) + + if client_mode_should_convert(): + return client_mode_convert_actor(self, args, kwargs, **actor_options) + + # fill actor required options + for k, v in ray_option_utils.actor_options.items(): + actor_options[k] = actor_options.get(k, v.default_value) + # "concurrency_groups" already takes effects and should not apply again. + # Remove the default value here. + actor_options.pop("concurrency_groups", None) + + # TODO(suquark): cleanup these fields + max_concurrency = actor_options["max_concurrency"] + lifetime = actor_options["lifetime"] + runtime_env = actor_options["runtime_env"] + placement_group = actor_options["placement_group"] + placement_group_bundle_index = actor_options["placement_group_bundle_index"] + placement_group_capture_child_tasks = actor_options[ + "placement_group_capture_child_tasks" + ] + scheduling_strategy = actor_options["scheduling_strategy"] + max_restarts = actor_options["max_restarts"] + max_task_retries = actor_options["max_task_retries"] + max_pending_calls = actor_options["max_pending_calls"] + + # Override enable_task_events to default for actor if not specified (i.e. None) + enable_task_events = actor_options.get("enable_task_events") + + if scheduling_strategy is None or not isinstance( + scheduling_strategy, PlacementGroupSchedulingStrategy + ): + _warn_if_using_deprecated_placement_group(actor_options, 3) + + worker = ray._private.worker.global_worker + worker.check_connected() + + # Check whether the name is already taken. + # TODO(edoakes): this check has a race condition because two drivers + # could pass the check and then create the same named actor. We should + # instead check this when we create the actor, but that's currently an + # async call. + if name is not None: + try: + ray.get_actor(name, namespace=namespace) + except ValueError: # Name is not taken. + pass + else: + raise ValueError( + f"The name {name} (namespace={namespace}) is already " + "taken. Please use " + "a different name or get the existing actor using " + f"ray.get_actor('{name}', namespace='{namespace}')" + ) + + if lifetime is None: + detached = None + elif lifetime == "detached": + detached = True + elif lifetime == "non_detached": + detached = False + else: + raise ValueError( + "actor `lifetime` argument must be one of 'detached', " + "'non_detached' and 'None'." + ) + + # LOCAL_MODE cannot handle cross_language + if worker.mode == ray.LOCAL_MODE: + assert ( + not meta.is_cross_language + ), "Cross language ActorClass cannot be executed locally." + + # Export the actor. + if not meta.is_cross_language and ( + meta.last_export_cluster_and_job != worker.current_cluster_and_job + ): + # If this actor class was not exported in this cluster and job, + # we need to export this function again, because current GCS + # doesn't have it. + + # After serialize / deserialize modified class, the __module__ + # of modified class will be ray.cloudpickle.cloudpickle. + # So, here pass actor_creation_function_descriptor to make + # sure export actor class correct. + worker.function_actor_manager.export_actor_class( + meta.modified_class, + meta.actor_creation_function_descriptor, + meta.method_meta.methods.keys(), + ) + meta.last_export_cluster_and_job = worker.current_cluster_and_job + + resources = ray._private.utils.resources_from_ray_options(actor_options) + # Set the actor's default resources if not already set. First three + # conditions are to check that no resources were specified in the + # decorator. Last three conditions are to check that no resources were + # specified when _remote() was called. + # TODO(suquark): In the original code, memory is not considered as resources, + # when deciding the default CPUs. It is strange, but we keep the original + # semantics in case that it breaks user applications & tests. + if not set(resources.keys()).difference({"memory", "object_store_memory"}): + # In the default case, actors acquire no resources for + # their lifetime, and actor methods will require 1 CPU. + resources.setdefault("CPU", ray_constants.DEFAULT_ACTOR_CREATION_CPU_SIMPLE) + actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SIMPLE + else: + # If any resources are specified (here or in decorator), then + # all resources are acquired for the actor's lifetime and no + # resources are associated with methods. + resources.setdefault( + "CPU", ray_constants.DEFAULT_ACTOR_CREATION_CPU_SPECIFIED + ) + actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED + + # If the actor methods require CPU resources, then set the required + # placement resources. If actor_placement_resources is empty, then + # the required placement resources will be the same as resources. + actor_placement_resources = {} + assert actor_method_cpu in [0, 1] + if actor_method_cpu == 1: + actor_placement_resources = resources.copy() + actor_placement_resources["CPU"] += 1 + if meta.is_cross_language: + creation_args = cross_language._format_args(worker, args, kwargs) + else: + function_signature = meta.method_meta.signatures["__init__"] + creation_args = signature.flatten_args(function_signature, args, kwargs) + + if scheduling_strategy is None or isinstance( + scheduling_strategy, PlacementGroupSchedulingStrategy + ): + # TODO(jjyao) Clean this up once the + # placement_group option is removed. + # We should also consider pushing this logic down to c++ + # so that it can be reused by all languages. + if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy): + placement_group = scheduling_strategy.placement_group + placement_group_bundle_index = ( + scheduling_strategy.placement_group_bundle_index + ) + placement_group_capture_child_tasks = ( + scheduling_strategy.placement_group_capture_child_tasks + ) + + if placement_group_capture_child_tasks is None: + placement_group_capture_child_tasks = ( + worker.should_capture_child_tasks_in_placement_group + ) + placement_group = _configure_placement_group_based_on_context( + placement_group_capture_child_tasks, + placement_group_bundle_index, + resources, + actor_placement_resources, + meta.class_name, + placement_group=placement_group, + ) + if not placement_group.is_empty: + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group, + placement_group_bundle_index, + placement_group_capture_child_tasks, + ) + else: + scheduling_strategy = "DEFAULT" + + serialized_runtime_env_info = None + if runtime_env is not None: + serialized_runtime_env_info = get_runtime_env_info( + runtime_env, + is_job_runtime_env=False, + serialize=True, + ) + + concurrency_groups_dict = {} + if meta.concurrency_groups is None: + meta.concurrency_groups = [] + for cg_name in meta.concurrency_groups: + concurrency_groups_dict[cg_name] = { + "name": cg_name, + "max_concurrency": meta.concurrency_groups[cg_name], + "function_descriptors": [], + } + + # Update methods + for method_name in meta.method_meta.concurrency_group_for_methods: + cg_name = meta.method_meta.concurrency_group_for_methods[method_name] + assert cg_name in concurrency_groups_dict + + module_name = meta.actor_creation_function_descriptor.module_name + class_name = meta.actor_creation_function_descriptor.class_name + concurrency_groups_dict[cg_name]["function_descriptors"].append( + PythonFunctionDescriptor(module_name, method_name, class_name) + ) + + # Update the creation descriptor based on number of arguments + if meta.is_cross_language: + func_name = "" + if meta.language == Language.CPP: + func_name = meta.actor_creation_function_descriptor.function_name + meta.actor_creation_function_descriptor = ( + cross_language._get_function_descriptor_for_actor_method( + meta.language, + meta.actor_creation_function_descriptor, + func_name, + str(len(args) + len(kwargs)), + ) + ) + + actor_id = worker.core_worker.create_actor( + meta.language, + meta.actor_creation_function_descriptor, + creation_args, + max_restarts, + max_task_retries, + resources, + actor_placement_resources, + max_concurrency, + detached, + name if name is not None else "", + namespace if namespace is not None else "", + is_asyncio, + # Store actor_method_cpu in actor handle's extension data. + extension_data=str(actor_method_cpu), + serialized_runtime_env_info=serialized_runtime_env_info or "{}", + concurrency_groups_dict=concurrency_groups_dict or dict(), + max_pending_calls=max_pending_calls, + scheduling_strategy=scheduling_strategy, + enable_task_events=enable_task_events, + labels=actor_options.get("_labels"), + ) + + if _actor_launch_hook: + _actor_launch_hook( + meta.actor_creation_function_descriptor, resources, scheduling_strategy + ) + + actor_handle = ActorHandle( + meta.language, + actor_id, + max_task_retries, + enable_task_events, + meta.method_meta.method_is_generator, + meta.method_meta.decorators, + meta.method_meta.signatures, + meta.method_meta.num_returns, + meta.method_meta.max_task_retries, + meta.method_meta.retry_exceptions, + meta.method_meta.generator_backpressure_num_objects, + meta.method_meta.enable_task_events, + actor_method_cpu, + meta.actor_creation_function_descriptor, + worker.current_cluster_and_job, + original_handle=True, + ) + + return actor_handle + + @DeveloperAPI + def bind(self, *args, **kwargs): + """ + For Ray DAG building that creates static graph from decorated + class or functions. + """ + from ray.dag.class_node import ClassNode + + return ClassNode( + self.__ray_metadata__.modified_class, args, kwargs, self._default_options + ) + + +@PublicAPI +class ActorHandle: + """A handle to an actor. + + The fields in this class are prefixed with _ray_ to hide them from the user + and to avoid collision with actor method names. + + An ActorHandle can be created in three ways. First, by calling .remote() on + an ActorClass. Second, by passing an actor handle into a task (forking the + ActorHandle). Third, by directly serializing the ActorHandle (e.g., with + cloudpickle). + + Attributes: + _ray_actor_language: The actor language. + _ray_actor_id: Actor ID. + _ray_enable_task_events: The default value of whether task events is + enabled, i.e., task events from the actor should be reported. + _ray_method_is_generator: Map of method name -> if it is a generator + method. + _ray_method_decorators: Optional decorators for the function + invocation. This can be used to change the behavior on the + invocation side, whereas a regular decorator can be used to change + the behavior on the execution side. + _ray_method_signatures: The signatures of the actor methods. + _ray_method_max_task_retries: Max number of retries on method failure. + _ray_method_num_returns: The default number of return values for + each method. + _ray_method_retry_exceptions: The default value of boolean of whether you want + to retry all user-raised exceptions, or a list of allowlist exceptions to + retry. + _ray_method_generator_backpressure_num_objects: Generator-only + config. The max number of objects to generate before it + starts pausing a generator. + _ray_method_enable_task_events: The value of whether task + tracing is enabled for the actor methods. This overrides the + actor's default value (`_ray_enable_task_events`). + _ray_actor_method_cpus: The number of CPUs required by actor methods. + _ray_original_handle: True if this is the original actor handle for a + given actor. If this is true, then the actor will be destroyed when + this handle goes out of scope. + _ray_weak_ref: True means that this handle does not count towards the + distributed ref count for the actor, i.e. the actor may be GCed + while this handle is still in scope. This is set to True if the + handle was created by getting an actor by name or by getting the + self handle. It is set to False if this is the original handle or + if it was created by passing the original handle through task args + and returns. + _ray_is_cross_language: Whether this actor is cross language. + _ray_actor_creation_function_descriptor: The function descriptor + of the actor creation task. + """ + + def __init__( + self, + language, + actor_id, + max_task_retries: Optional[int], + enable_task_events: bool, + method_is_generator: Dict[str, bool], + method_decorators, + method_signatures, + method_num_returns: Dict[str, Union[int, Literal["streaming"]]], + method_max_task_retries: Dict[str, int], + method_retry_exceptions: Dict[str, Union[bool, list, tuple]], + method_generator_backpressure_num_objects: Dict[str, int], + method_enable_task_events: Dict[str, bool], + actor_method_cpus: int, + actor_creation_function_descriptor, + cluster_and_job, + original_handle=False, + weak_ref: bool = False, + ): + self._ray_actor_language = language + self._ray_actor_id = actor_id + self._ray_max_task_retries = max_task_retries + self._ray_original_handle = original_handle + self._ray_weak_ref = weak_ref + self._ray_enable_task_events = enable_task_events + + self._ray_method_is_generator = method_is_generator + self._ray_method_decorators = method_decorators + self._ray_method_signatures = method_signatures + self._ray_method_num_returns = method_num_returns + self._ray_method_max_task_retries = method_max_task_retries + self._ray_method_retry_exceptions = method_retry_exceptions + self._ray_method_generator_backpressure_num_objects = ( + method_generator_backpressure_num_objects + ) + self._ray_method_enable_task_events = method_enable_task_events + self._ray_actor_method_cpus = actor_method_cpus + self._ray_cluster_and_job = cluster_and_job + self._ray_is_cross_language = language != Language.PYTHON + self._ray_actor_creation_function_descriptor = ( + actor_creation_function_descriptor + ) + self._ray_function_descriptor = {} + # This is incremented each time `bind()` is called on an actor handle + # (in Ray DAGs), therefore capturing the bind order of the actor methods. + # TODO: this does not work properly if the caller has two copies of the + # same actor handle, and needs to be fixed. + self._ray_dag_bind_index = 0 + + if not self._ray_is_cross_language: + assert isinstance( + actor_creation_function_descriptor, PythonFunctionDescriptor + ) + module_name = actor_creation_function_descriptor.module_name + class_name = actor_creation_function_descriptor.class_name + for method_name in self._ray_method_signatures.keys(): + function_descriptor = PythonFunctionDescriptor( + module_name, method_name, class_name + ) + self._ray_function_descriptor[method_name] = function_descriptor + method = ActorMethod( + self, + method_name, + self._ray_method_num_returns[method_name], + self._ray_method_max_task_retries.get( + method_name, self._ray_max_task_retries + ) + or 0, # never None + self._ray_method_retry_exceptions.get(method_name), + self._ray_method_is_generator[method_name], + self._ray_method_generator_backpressure_num_objects.get( + method_name + ), # noqa + self._ray_method_enable_task_events.get( + method_name, + self._ray_enable_task_events, # Use actor's default value + ), + decorator=self._ray_method_decorators.get(method_name), + signature=self._ray_method_signatures[method_name], + ) + setattr(self, method_name, method) + + def __del__(self): + # Weak references don't count towards the distributed ref count, so no + # need to decrement the ref count. + if self._ray_weak_ref: + return + + try: + # Mark that this actor handle has gone out of scope. Once all actor + # handles are out of scope, the actor will exit. + if ray._private.worker: + worker = ray._private.worker.global_worker + if worker.connected and hasattr(worker, "core_worker"): + worker.core_worker.remove_actor_handle_reference(self._ray_actor_id) + except AttributeError: + # Suppress the attribtue error which is caused by + # python destruction ordering issue. + # It only happen when python exits. + pass + + def _actor_method_call( + self, + method_name: str, + args: List[Any] = None, + kwargs: Dict[str, Any] = None, + name: str = "", + num_returns: Optional[Union[int, Literal["streaming"]]] = None, + max_task_retries: int = None, + retry_exceptions: Union[bool, list, tuple] = None, + concurrency_group_name: Optional[str] = None, + generator_backpressure_num_objects: Optional[int] = None, + enable_task_events: Optional[bool] = None, + ): + """Method execution stub for an actor handle. + + This is the function that executes when + `actor.method_name.remote(*args, **kwargs)` is called. Instead of + executing locally, the method is packaged as a task and scheduled + to the remote actor instance. + + Args: + method_name: The name of the actor method to execute. + args: A list of arguments for the actor method. + kwargs: A dictionary of keyword arguments for the actor method. + name: The name to give the actor method call task. + num_returns: The number of return values for the method. + max_task_retries: Number of retries when method fails. + retry_exceptions: Boolean of whether you want to retry all user-raised + exceptions, or a list of allowlist exceptions to retry. + enable_task_events: True if tracing is enabled, i.e., task events from + the actor should be reported. + + Returns: + object_refs: A list of object refs returned by the remote actor + method. + """ + worker = ray._private.worker.global_worker + + args = args or [] + kwargs = kwargs or {} + if self._ray_is_cross_language: + list_args = cross_language._format_args(worker, args, kwargs) + function_descriptor = cross_language._get_function_descriptor_for_actor_method( # noqa: E501 + self._ray_actor_language, + self._ray_actor_creation_function_descriptor, + method_name, + # The signature for xlang should be "{length_of_arguments}" to handle + # overloaded methods. + signature=str(len(args) + len(kwargs)), + ) + else: + function_signature = self._ray_method_signatures[method_name] + + if not args and not kwargs and not function_signature: + list_args = [] + else: + list_args = signature.flatten_args(function_signature, args, kwargs) + function_descriptor = self._ray_function_descriptor[method_name] + + if worker.mode == ray.LOCAL_MODE: + assert ( + not self._ray_is_cross_language + ), "Cross language remote actor method cannot be executed locally." + + if num_returns == "dynamic": + num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = ray._raylet.STREAMING_GENERATOR_RETURN + + retry_exception_allowlist = None + if retry_exceptions is None: + retry_exceptions = False + elif isinstance(retry_exceptions, (list, tuple)): + retry_exception_allowlist = tuple(retry_exceptions) + retry_exceptions = True + assert isinstance( + retry_exceptions, bool + ), "retry_exceptions can either be \ + boolean or list/tuple of exception types." + + if generator_backpressure_num_objects is None: + generator_backpressure_num_objects = -1 + + object_refs = worker.core_worker.submit_actor_task( + self._ray_actor_language, + self._ray_actor_id, + function_descriptor, + list_args, + name, + num_returns, + max_task_retries, + retry_exceptions, + retry_exception_allowlist, + self._ray_actor_method_cpus, + concurrency_group_name if concurrency_group_name is not None else b"", + generator_backpressure_num_objects, + enable_task_events, + ) + + if num_returns == STREAMING_GENERATOR_RETURN: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return ObjectRefGenerator(generator_ref, worker) + if len(object_refs) == 1: + object_refs = object_refs[0] + elif len(object_refs) == 0: + object_refs = None + + return object_refs + + def __getattr__(self, item): + if not self._ray_is_cross_language: + raise AttributeError( + f"'{type(self).__name__}' object has " f"no attribute '{item}'" + ) + if item in ["__ray_terminate__"]: + + class FakeActorMethod(object): + def __call__(self, *args, **kwargs): + raise TypeError( + "Actor methods cannot be called directly. Instead " + "of running 'object.{}()', try 'object.{}.remote()'.".format( + item, item + ) + ) + + def remote(self, *args, **kwargs): + logger.warning( + f"Actor method {item} is not supported by cross language." + ) + + return FakeActorMethod() + + return ActorMethod( + self, # actor + item, # method_name + ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS, + 0, # max_task_retries + False, # retry_exceptions + False, # is_generator + self._ray_method_generator_backpressure_num_objects.get(item, -1), + self._ray_enable_task_events, # enable_task_events + # Currently, cross-lang actor method not support decorator + decorator=None, + signature=None, + ) + + # Make tab completion work. + def __dir__(self): + return self._ray_method_signatures.keys() + + def __repr__(self): + return ( + "Actor(" + f"{self._ray_actor_creation_function_descriptor.class_name}, " + f"{self._actor_id.hex()})" + ) + + def __hash__(self): + return hash(self._actor_id) + + def __eq__(self, __value): + return hash(self) == hash(__value) + + @property + def _actor_id(self): + return self._ray_actor_id + + def _get_local_state(self): + """Get the local actor state. + + NOTE: this method only returns accurate actor state + after a first actor method call is made against + this actor handle due to https://github.com/ray-project/ray/pull/24600. + + Returns: + ActorTableData.ActorState or None if the state is unknown. + """ + worker = ray._private.worker.global_worker + worker.check_connected() + return worker.core_worker.get_local_actor_state(self._ray_actor_id) + + def _serialization_helper(self): + """This is defined in order to make pickling work. + + Returns: + A dictionary of the information needed to reconstruct the object. + """ + worker = ray._private.worker.global_worker + worker.check_connected() + + if hasattr(worker, "core_worker"): + # Non-local mode + state = worker.core_worker.serialize_actor_handle(self._ray_actor_id) + else: + # Local mode + state = ( + { + "actor_language": self._ray_actor_language, + "actor_id": self._ray_actor_id, + "max_task_retries": self._ray_max_task_retries, + "enable_task_events": self._enable_task_events, + "method_is_generator": self._ray_method_is_generator, + "method_decorators": self._ray_method_decorators, + "method_signatures": self._ray_method_signatures, + "method_num_returns": self._ray_method_num_returns, + "method_max_task_retries": self._ray_method_max_task_retries, + "method_retry_exceptions": self._ray_method_retry_exceptions, + "method_generator_backpressure_num_objects": ( + self._ray_method_generator_backpressure_num_objects + ), + "method_enable_task_events": self._ray_method_enable_task_events, + "actor_method_cpus": self._ray_actor_method_cpus, + "actor_creation_function_descriptor": self._ray_actor_creation_function_descriptor, # noqa: E501 + }, + None, + ) + + return (*state, self._ray_weak_ref) + + @classmethod + def _deserialization_helper(cls, state, weak_ref: bool, outer_object_ref=None): + """This is defined in order to make pickling work. + + Args: + state: The serialized state of the actor handle. + outer_object_ref: The ObjectRef that the serialized actor handle + was contained in, if any. This is used for counting references + to the actor handle. + weak_ref: Whether this was serialized from an actor handle with a + weak ref to the actor. + + """ + worker = ray._private.worker.global_worker + worker.check_connected() + + if hasattr(worker, "core_worker"): + # Non-local mode + return worker.core_worker.deserialize_and_register_actor_handle( + state, + outer_object_ref, + weak_ref, + ) + else: + # Local mode + assert worker.current_cluster_and_job == state["current_cluster_and_job"] + return cls( + # TODO(swang): Accessing the worker's current task ID is not + # thread-safe. + state["actor_language"], + state["actor_id"], + state["max_task_retries"], + state["enable_task_events"], + state["method_is_generator"], + state["method_decorators"], + state["method_signatures"], + state["method_num_returns"], + state["method_max_task_retries"], + state["method_retry_exceptions"], + state["method_generator_backpressure_num_objects"], + state["method_enable_task_events"], + state["actor_method_cpus"], + state["actor_creation_function_descriptor"], + state["current_cluster_and_job"], + ) + + def __reduce__(self): + """This code path is used by pickling but not by Ray forking.""" + (serialized, _, weak_ref) = self._serialization_helper() + # There is no outer object ref when the actor handle is + # deserialized out-of-band using pickle. + return ActorHandle._deserialization_helper, (serialized, weak_ref, None) + + +def _modify_class(cls): + # cls has been modified. + if hasattr(cls, "__ray_actor_class__"): + return cls + + # Give an error if cls is an old-style class. + if not issubclass(cls, object): + raise TypeError( + "The @ray.remote decorator cannot be applied to old-style " + "classes. In Python 2, you must declare the class with " + "'class ClassName(object):' instead of 'class ClassName:'." + ) + + # Modify the class to have additional default methods. + class Class(cls): + __ray_actor_class__ = cls # The original actor class + + def __ray_ready__(self): + return True + + def __ray_call__(self, fn, *args, **kwargs): + return fn(self, *args, **kwargs) + + def __ray_terminate__(self): + worker = ray._private.worker.global_worker + if worker.mode != ray.LOCAL_MODE: + ray.actor.exit_actor() + + Class.__module__ = cls.__module__ + Class.__name__ = cls.__name__ + + if not is_function_or_method(getattr(Class, "__init__", None)): + # Add __init__ if it does not exist. + # Actor creation will be executed with __init__ together. + + # Assign an __init__ function will avoid many checks later on. + def __init__(self): + pass + + Class.__init__ = __init__ + + return Class + + +def _make_actor(cls, actor_options): + Class = _modify_class(cls) + _inject_tracing_into_class(Class) + + if "max_restarts" in actor_options: + if actor_options["max_restarts"] != -1: # -1 represents infinite restart + # Make sure we don't pass too big of an int to C++, causing + # an overflow. + actor_options["max_restarts"] = min( + actor_options["max_restarts"], ray_constants.MAX_INT64_VALUE + ) + + return ActorClass._ray_from_modified_class( + Class, + ActorClassID.from_random(), + actor_options, + ) + + +@PublicAPI +def exit_actor(): + """Intentionally exit the current actor. + + This API can be used only inside an actor. Use ray.kill + API if you'd like to kill an actor using actor handle. + + When the API is called, the actor raises an exception and exits. + Any queued methods will fail. Any ``atexit`` + handlers installed in the actor will be run. + + Raises: + TypeError: An exception is raised if this is a driver or this + worker is not an actor. + """ + worker = ray._private.worker.global_worker + if worker.mode == ray.WORKER_MODE and not worker.actor_id.is_nil(): + # In asyncio actor mode, we can't raise SystemExit because it will just + # quit the asycnio event loop thread, not the main thread. Instead, we + # raise a custom error to the main thread to tell it to exit. + if worker.core_worker.current_actor_is_asyncio(): + raise AsyncioActorExit() + + # Set a flag to indicate this is an intentional actor exit. This + # reduces log verbosity. + raise_sys_exit_with_custom_error_message("exit_actor() is called.") + else: + raise TypeError( + "exit_actor API is called on a non-actor worker, " + f"{worker.mode}. Call this API inside an actor methods" + "if you'd like to exit the actor gracefully." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/client_builder.py b/.venv/lib/python3.11/site-packages/ray/client_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..59b555fba2c3df3ed8d66d30f68b3a22a1f8a2ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/client_builder.py @@ -0,0 +1,379 @@ +import importlib +import inspect +import json +import logging +import os +import sys +import warnings +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import ray.util.client_connect +from ray._private.ray_constants import ( + RAY_ADDRESS_ENVIRONMENT_VARIABLE, + RAY_NAMESPACE_ENVIRONMENT_VARIABLE, + RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE, +) +from ray._private.utils import check_ray_client_dependencies_installed, split_address +from ray._private.worker import BaseContext +from ray._private.worker import init as ray_driver_init +from ray.job_config import JobConfig +from ray.util.annotations import Deprecated, PublicAPI + +logger = logging.getLogger(__name__) + +CLIENT_DOCS_URL = ( + "https://docs.ray.io/en/latest/cluster/running-applications/" + "job-submission/ray-client.html" +) + + +@dataclass +@PublicAPI +class ClientContext(BaseContext): + """ + Basic context manager for a ClientBuilder connection. + + `protocol_version` is no longer used. + """ + + dashboard_url: Optional[str] + python_version: str + ray_version: str + ray_commit: str + _num_clients: int + _context_to_restore: Optional[ray.util.client.RayAPIStub] + protocol_version: Optional[str] = None # Deprecated + + def __enter__(self) -> "ClientContext": + self._swap_context() + return self + + def __exit__(self, *exc) -> None: + self._disconnect_with_context(False) + self._swap_context() + + def disconnect(self) -> None: + self._swap_context() + self._disconnect_with_context(True) + self._swap_context() + + def _swap_context(self): + if self._context_to_restore is not None: + self._context_to_restore = ray.util.client.ray.set_context( + self._context_to_restore + ) + + def _disconnect_with_context(self, force_disconnect: bool) -> None: + """ + Disconnect Ray. If it's a ray client and created with `allow_multiple`, + it will do nothing. For other cases this either disconnects from the + remote Client Server or shuts the current driver down. + """ + if ray.util.client.ray.is_connected(): + if ray.util.client.ray.is_default() or force_disconnect: + # This is the only client connection + ray.util.client_connect.disconnect() + elif ray._private.worker.global_worker.node is None: + # Already disconnected. + return + elif ray._private.worker.global_worker.node.is_head(): + logger.debug( + "The current Ray Cluster is scoped to this process. " + "Disconnecting is not possible as it will shutdown the " + "cluster." + ) + else: + # This is only a driver connected to an existing cluster. + ray.shutdown() + + +@Deprecated +class ClientBuilder: + """ + Builder for a Ray Client connection. This class can be subclassed by + custom builder classes to modify connection behavior to include additional + features or altered semantics. One example is the ``_LocalClientBuilder``. + """ + + def __init__(self, address: Optional[str]) -> None: + if not check_ray_client_dependencies_installed(): + raise ValueError( + "Ray Client requires pip package `ray[client]`. " + "If you installed the minimal Ray (e.g. `pip install ray`), " + "please reinstall by executing `pip install ray[client]`." + ) + self.address = address + self._job_config = JobConfig() + self._remote_init_kwargs = {} + # Whether to allow connections to multiple clusters" + # " (allow_multiple=True). + self._allow_multiple_connections = False + self._credentials = None + self._metadata = None + # Set to False if ClientBuilder is being constructed by internal + # methods + self._deprecation_warn_enabled = True + + def env(self, env: Dict[str, Any]) -> "ClientBuilder": + """ + Set an environment for the session. + Args: + env (Dict[st, Any]): A runtime environment to use for this + connection. See :ref:`runtime-environments` for what values are + accepted in this dict. + """ + self._job_config.set_runtime_env(env) + return self + + def namespace(self, namespace: str) -> "ClientBuilder": + """ + Sets the namespace for the session. + Args: + namespace: Namespace to use. + """ + self._job_config.set_ray_namespace(namespace) + return self + + def connect(self) -> ClientContext: + """ + Begin a connection to the address passed in via ray.client(...). + + Returns: + ClientInfo: Dataclass with information about the setting. This + includes the server's version of Python & Ray as well as the + dashboard_url. + """ + if self._deprecation_warn_enabled: + self._client_deprecation_warn() + # Fill runtime env/namespace from environment if not already set. + # Should be done *after* the deprecation warning, since warning will + # check if those values are already set. + self._fill_defaults_from_env() + + # If it has already connected to the cluster with allow_multiple=True, + # connect to the default one is not allowed. + # But if it has connected to the default one, connect to other clients + # with allow_multiple=True is allowed + default_cli_connected = ray.util.client.ray.is_connected() + has_cli_connected = ray.util.client.num_connected_contexts() > 0 + if ( + not self._allow_multiple_connections + and not default_cli_connected + and has_cli_connected + ): + raise ValueError( + "The client has already connected to the cluster " + "with allow_multiple=True. Please set allow_multiple=True" + " to proceed" + ) + + old_ray_cxt = None + if self._allow_multiple_connections: + old_ray_cxt = ray.util.client.ray.set_context(None) + + client_info_dict = ray.util.client_connect.connect( + self.address, + job_config=self._job_config, + _credentials=self._credentials, + ray_init_kwargs=self._remote_init_kwargs, + metadata=self._metadata, + ) + + dashboard_url = ray.util.client.ray._get_dashboard_url() + + cxt = ClientContext( + dashboard_url=dashboard_url, + python_version=client_info_dict["python_version"], + ray_version=client_info_dict["ray_version"], + ray_commit=client_info_dict["ray_commit"], + _num_clients=client_info_dict["num_clients"], + _context_to_restore=ray.util.client.ray.get_context(), + ) + if self._allow_multiple_connections: + ray.util.client.ray.set_context(old_ray_cxt) + return cxt + + def _fill_defaults_from_env(self): + # Check environment variables for default values + namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE) + if namespace_env_var and self._job_config.ray_namespace is None: + self.namespace(namespace_env_var) + + runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE) + if runtime_env_var and self._job_config.runtime_env is None: + self.env(json.loads(runtime_env_var)) + + def _init_args(self, **kwargs) -> "ClientBuilder": + """ + When a client builder is constructed through ray.init, for example + `ray.init(ray://..., namespace=...)`, all of the + arguments passed into ray.init with non-default values are passed + again into this method. Custom client builders can override this method + to do their own handling/validation of arguments. + """ + # Use namespace and runtime_env from ray.init call + if kwargs.get("namespace") is not None: + self.namespace(kwargs["namespace"]) + del kwargs["namespace"] + if kwargs.get("runtime_env") is not None: + self.env(kwargs["runtime_env"]) + del kwargs["runtime_env"] + + if kwargs.get("allow_multiple") is True: + self._allow_multiple_connections = True + del kwargs["allow_multiple"] + + if "_credentials" in kwargs.keys(): + self._credentials = kwargs["_credentials"] + del kwargs["_credentials"] + + if "_metadata" in kwargs.keys(): + self._metadata = kwargs["_metadata"] + del kwargs["_metadata"] + + if kwargs: + expected_sig = inspect.signature(ray_driver_init) + extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys()) + if len(extra_args) > 0: + raise RuntimeError( + "Got unexpected kwargs: {}".format(", ".join(extra_args)) + ) + self._remote_init_kwargs = kwargs + unknown = ", ".join(kwargs) + logger.info( + "Passing the following kwargs to ray.init() " + f"on the server: {unknown}" + ) + return self + + def _client_deprecation_warn(self) -> None: + """ + Generates a warning for user's if this ClientBuilder instance was + created directly or through ray.client, instead of relying on + internal methods (ray.init, or auto init) + """ + namespace = self._job_config.ray_namespace + runtime_env = self._job_config.runtime_env + replacement_args = [] + if self.address: + if isinstance(self, _LocalClientBuilder): + # Address might be set for LocalClientBuilder if ray.client() + # is called while ray_current_cluster is set + # (see _get_builder_from_address). In this case, + # leave off the ray:// so the user attaches the driver directly + replacement_args.append(f'"{self.address}"') + else: + replacement_args.append(f'"ray://{self.address}"') + if namespace: + replacement_args.append(f'namespace="{namespace}"') + if runtime_env: + # Use a placeholder here, since the real runtime_env would be + # difficult to read if formatted in directly + replacement_args.append("runtime_env=") + args_str = ", ".join(replacement_args) + replacement_call = f"ray.init({args_str})" + + # Note: stack level is set to 3 since we want the warning to reach the + # call to ray.client(...).connect(). The intervening frames are + # connect() -> client_deprecation_warn() -> warnings.warn() + # https://docs.python.org/3/library/warnings.html#available-functions + warnings.warn( + "Starting a connection through `ray.client` will be deprecated " + "in future ray versions in favor of `ray.init`. See the docs for " + f"more details: {CLIENT_DOCS_URL}. You can replace your call to " + "`ray.client().connect()` with the following:\n" + f" {replacement_call}\n", + DeprecationWarning, + stacklevel=3, + ) + + +class _LocalClientBuilder(ClientBuilder): + def connect(self) -> ClientContext: + """ + Begin a connection to the address passed in via ray.client(...) + """ + if self._deprecation_warn_enabled: + self._client_deprecation_warn() + # Fill runtime env/namespace from environment if not already set. + # Should be done *after* the deprecation warning, since warning will + # check if those values are already set. + self._fill_defaults_from_env() + + connection_dict = ray.init(address=self.address, job_config=self._job_config) + return ClientContext( + dashboard_url=connection_dict["webui_url"], + python_version="{}.{}.{}".format( + sys.version_info[0], sys.version_info[1], sys.version_info[2] + ), + ray_version=ray.__version__, + ray_commit=ray.__commit__, + _num_clients=1, + _context_to_restore=None, + ) + + +def _split_address(address: str) -> Tuple[str, str]: + """ + Splits address into a module string (scheme) and an inner_address. + + If the scheme is not present, then "ray://" is prepended to the address. + """ + if "://" not in address: + address = "ray://" + address + return split_address(address) + + +def _get_builder_from_address(address: Optional[str]) -> ClientBuilder: + if address == "local": + return _LocalClientBuilder("local") + if address is None: + # NOTE: This is not placed in `Node::get_temp_dir_path`, because + # this file is accessed before the `Node` object is created. + address = ray._private.services.canonicalize_bootstrap_address(address) + return _LocalClientBuilder(address) + module_string, inner_address = _split_address(address) + try: + module = importlib.import_module(module_string) + except Exception as e: + raise RuntimeError( + f"Module: {module_string} does not exist.\n" + f"This module was parsed from Address: {address}" + ) from e + assert "ClientBuilder" in dir( + module + ), f"Module: {module_string} does not have ClientBuilder." + return module.ClientBuilder(inner_address) + + +@Deprecated +def client( + address: Optional[str] = None, _deprecation_warn_enabled: bool = True +) -> ClientBuilder: + """ + Creates a ClientBuilder based on the provided address. The address can be + of the following forms: + + * None: Connects to or creates a local cluster and connects to it. + * ``"local"``: Creates a new cluster locally and connects to it. + * ``"IP:Port"``: Connects to a Ray Client Server at the given address. + * ``"module://inner_address"``: load module.ClientBuilder & pass + inner_address + + The _deprecation_warn_enabled flag enables deprecation warnings, and is + for internal use only. Set it to False to suppress client deprecation + warnings. + """ + env_address = os.environ.get(RAY_ADDRESS_ENVIRONMENT_VARIABLE) + if env_address and address is None: + logger.debug( + f"Using address ({env_address}) instead of auto-detection " + f"because {RAY_ADDRESS_ENVIRONMENT_VARIABLE} is set." + ) + address = env_address + + builder = _get_builder_from_address(address) + # Disable client deprecation warn when ray.client is used internally + builder._deprecation_warn_enabled = _deprecation_warn_enabled + return builder diff --git a/.venv/lib/python3.11/site-packages/ray/cluster_utils.py b/.venv/lib/python3.11/site-packages/ray/cluster_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5eff92e9e019b9d5e9740b655d4eb2764da78c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/cluster_utils.py @@ -0,0 +1,415 @@ +import copy +import json +import logging +import os +import subprocess +import tempfile +import time +from typing import Dict, Optional + +import yaml + +import ray +import ray._private.services +from ray._private import ray_constants +from ray._private.client_mode_hook import disable_client_hook +from ray._raylet import GcsClientOptions +from ray.autoscaler._private.fake_multi_node.node_provider import FAKE_HEAD_NODE_ID +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + +cluster_not_supported = os.name == "nt" + + +@DeveloperAPI +class AutoscalingCluster: + """Create a local autoscaling cluster for testing. + + See test_autoscaler_fake_multinode.py for an end-to-end example. + """ + + def __init__( + self, + head_resources: dict, + worker_node_types: dict, + autoscaler_v2: bool = False, + **config_kwargs, + ): + """Create the cluster. + + Args: + head_resources: resources of the head node, including CPU. + worker_node_types: autoscaler node types config for worker nodes. + """ + self._head_resources = head_resources + self._config = self._generate_config( + head_resources, + worker_node_types, + autoscaler_v2=autoscaler_v2, + **config_kwargs, + ) + self._autoscaler_v2 = autoscaler_v2 + + def _generate_config( + self, head_resources, worker_node_types, autoscaler_v2=False, **config_kwargs + ): + base_config = yaml.safe_load( + open( + os.path.join( + os.path.dirname(ray.__file__), + "autoscaler/_private/fake_multi_node/example.yaml", + ) + ) + ) + custom_config = copy.deepcopy(base_config) + custom_config["available_node_types"] = worker_node_types + custom_config["available_node_types"]["ray.head.default"] = { + "resources": head_resources, + "node_config": {}, + "max_workers": 0, + } + + # Autoscaler v2 specific configs + if autoscaler_v2: + custom_config["provider"]["launch_multiple"] = True + custom_config["provider"]["head_node_id"] = FAKE_HEAD_NODE_ID + custom_config.update(config_kwargs) + return custom_config + + def start(self, _system_config=None, override_env: Optional[Dict] = None): + """Start the cluster. + + After this call returns, you can connect to the cluster with + ray.init("auto"). + """ + subprocess.check_call(["ray", "stop", "--force"]) + _, fake_config = tempfile.mkstemp() + with open(fake_config, "w") as f: + f.write(json.dumps(self._config)) + cmd = [ + "ray", + "start", + "--autoscaling-config={}".format(fake_config), + "--head", + ] + if "CPU" in self._head_resources: + cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU"))) + if "GPU" in self._head_resources: + cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU"))) + if "object_store_memory" in self._head_resources: + cmd.append( + "--object-store-memory={}".format( + self._head_resources.pop("object_store_memory") + ) + ) + if self._head_resources: + cmd.append("--resources='{}'".format(json.dumps(self._head_resources))) + if _system_config is not None: + cmd.append( + "--system-config={}".format( + json.dumps(_system_config, separators=(",", ":")) + ) + ) + env = os.environ.copy() + env.update({"AUTOSCALER_UPDATE_INTERVAL_S": "1", "RAY_FAKE_CLUSTER": "1"}) + if self._autoscaler_v2: + # Set the necessary environment variables for autoscaler v2. + env.update( + { + "RAY_enable_autoscaler_v2": "1", + "RAY_CLOUD_INSTANCE_ID": FAKE_HEAD_NODE_ID, + "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID, + } + ) + if override_env: + env.update(override_env) + subprocess.check_call(cmd, env=env) + + def shutdown(self): + """Terminate the cluster.""" + subprocess.check_call(["ray", "stop", "--force"]) + + +@DeveloperAPI +class Cluster: + def __init__( + self, + initialize_head: bool = False, + connect: bool = False, + head_node_args: dict = None, + shutdown_at_exit: bool = True, + ): + """Initializes all services of a Ray cluster. + + Args: + initialize_head: Automatically start a Ray cluster + by initializing the head node. Defaults to False. + connect: If `initialize_head=True` and `connect=True`, + ray.init will be called with the address of this cluster + passed in. + head_node_args: Arguments to be passed into + `start_ray_head` via `self.add_node`. + shutdown_at_exit: If True, registers an exit hook + for shutting down all started processes. + """ + if cluster_not_supported: + logger.warning( + "Ray cluster mode is currently experimental and untested on " + "Windows. If you are using it and running into issues please " + "file a report at https://github.com/ray-project/ray/issues." + ) + self.head_node = None + self.worker_nodes = set() + self.redis_address = None + self.connected = False + # Create a new global state accessor for fetching GCS table. + self.global_state = ray._private.state.GlobalState() + self._shutdown_at_exit = shutdown_at_exit + if not initialize_head and connect: + raise RuntimeError("Cannot connect to uninitialized cluster.") + + if initialize_head: + head_node_args = head_node_args or {} + self.add_node(**head_node_args) + if connect: + self.connect() + + @property + def gcs_address(self): + if self.head_node is None: + return None + return self.head_node.gcs_address + + @property + def address(self): + return self.gcs_address + + def connect(self, namespace=None): + """Connect the driver to the cluster.""" + assert self.address is not None + assert not self.connected + output_info = ray.init( + namespace=namespace, + ignore_reinit_error=True, + address=self.address, + _redis_username=self.redis_username, + _redis_password=self.redis_password, + ) + logger.info(output_info) + self.connected = True + + def add_node(self, wait: bool = True, **node_args): + """Adds a node to the local Ray Cluster. + + All nodes are by default started with the following settings: + cleanup=True, + num_cpus=1, + object_store_memory=150 * 1024 * 1024 # 150 MiB + + Args: + wait: Whether to wait until the node is alive. + node_args: Keyword arguments used in `start_ray_head` and + `start_ray_node`. Overrides defaults. + + Returns: + Node object of the added Ray node. + """ + default_kwargs = { + "num_cpus": 1, + "num_gpus": 0, + "object_store_memory": 150 * 1024 * 1024, # 150 MiB + "min_worker_port": 0, + "max_worker_port": 0, + } + ray_params = ray._private.parameter.RayParams(**node_args) + ray_params.update_if_absent(**default_kwargs) + with disable_client_hook(): + if self.head_node is None: + node = ray._private.node.Node( + ray_params, + head=True, + shutdown_at_exit=self._shutdown_at_exit, + spawn_reaper=self._shutdown_at_exit, + ) + self.head_node = node + self.redis_address = self.head_node.redis_address + self.redis_username = node_args.get( + "redis_username", ray_constants.REDIS_DEFAULT_USERNAME + ) + self.redis_password = node_args.get( + "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD + ) + self.webui_url = self.head_node.webui_url + # Init global state accessor when creating head node. + gcs_options = GcsClientOptions.create( + node.gcs_address, + None, + allow_cluster_id_nil=True, + fetch_cluster_id_if_nil=False, + ) + self.global_state._initialize_global_state(gcs_options) + # Write the Ray cluster address for convenience in unit + # testing. ray.init() and ray.init(address="auto") will connect + # to the local cluster. + ray._private.utils.write_ray_address(self.head_node.gcs_address) + else: + ray_params.update_if_absent(redis_address=self.redis_address) + ray_params.update_if_absent(gcs_address=self.gcs_address) + # We only need one log monitor per physical node. + ray_params.update_if_absent(include_log_monitor=False) + # Let grpc pick a port. + ray_params.update_if_absent(node_manager_port=0) + if "dashboard_agent_listen_port" not in node_args: + # Pick a random one to not conflict + # with the head node dashboard agent + ray_params.dashboard_agent_listen_port = None + + node = ray._private.node.Node( + ray_params, + head=False, + shutdown_at_exit=self._shutdown_at_exit, + spawn_reaper=self._shutdown_at_exit, + ) + self.worker_nodes.add(node) + + if wait: + # Wait for the node to appear in the client table. We do this + # so that the nodes appears in the client table in the order + # that the corresponding calls to add_node were made. We do + # this because in the tests we assume that the driver is + # connected to the first node that is added. + self._wait_for_node(node) + + return node + + def remove_node(self, node, allow_graceful=True): + """Kills all processes associated with worker node. + + Args: + node: Worker node of which all associated processes + will be removed. + """ + global_node = ray._private.worker._global_node + if global_node is not None: + if node._raylet_socket_name == global_node._raylet_socket_name: + ray.shutdown() + raise ValueError( + "Removing a node that is connected to this Ray client " + "is not allowed because it will break the driver." + "You can use the get_other_node utility to avoid removing" + "a node that the Ray client is connected." + ) + + node.destroy_external_storage() + if self.head_node == node: + # We have to wait to prevent the raylet becomes a zombie which will prevent + # worker from exiting + self.head_node.kill_all_processes( + check_alive=False, allow_graceful=allow_graceful, wait=True + ) + self.head_node = None + # TODO(rliaw): Do we need to kill all worker processes? + else: + # We have to wait to prevent the raylet becomes a zombie which will prevent + # worker from exiting + node.kill_all_processes( + check_alive=False, allow_graceful=allow_graceful, wait=True + ) + self.worker_nodes.remove(node) + + assert ( + not node.any_processes_alive() + ), "There are zombie processes left over after killing." + + def _wait_for_node(self, node, timeout: float = 30): + """Wait until this node has appeared in the client table. + + Args: + node (ray._private.node.Node): The node to wait for. + timeout: The amount of time in seconds to wait before raising an + exception. + + Raises: + TimeoutError: An exception is raised if the timeout expires before + the node appears in the client table. + """ + ray._private.services.wait_for_node( + node.gcs_address, + node.plasma_store_socket_name, + timeout, + ) + + def wait_for_nodes(self, timeout: float = 30): + """Waits for correct number of nodes to be registered. + + This will wait until the number of live nodes in the client table + exactly matches the number of "add_node" calls minus the number of + "remove_node" calls that have been made on this cluster. This means + that if a node dies without "remove_node" having been called, this will + raise an exception. + + Args: + timeout: The number of seconds to wait for nodes to join + before failing. + + Raises: + TimeoutError: An exception is raised if we time out while waiting + for nodes to join. + """ + start_time = time.time() + while time.time() - start_time < timeout: + live_clients = self.global_state._live_node_ids() + + expected = len(self.list_all_nodes()) + if len(live_clients) == expected: + logger.debug("All nodes registered as expected.") + return + else: + logger.debug( + f"{len(live_clients)} nodes are currently registered, " + f"but we are expecting {expected}" + ) + time.sleep(0.1) + raise TimeoutError("Timed out while waiting for nodes to join.") + + def list_all_nodes(self): + """Lists all nodes. + + TODO(rliaw): What is the desired behavior if a head node + dies before worker nodes die? + + Returns: + List of all nodes, including the head node. + """ + nodes = list(self.worker_nodes) + if self.head_node: + nodes = [self.head_node] + nodes + return nodes + + def remaining_processes_alive(self): + """Returns a bool indicating whether all processes are alive or not. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if all processes are alive and false otherwise. + """ + return all(node.remaining_processes_alive() for node in self.list_all_nodes()) + + def shutdown(self): + """Removes all nodes.""" + + # We create a list here as a copy because `remove_node` + # modifies `self.worker_nodes`. + all_nodes = list(self.worker_nodes) + for node in all_nodes: + self.remove_node(node) + + if self.head_node is not None: + self.remove_node(self.head_node) + # need to reset internal kv since gcs is down + ray.experimental.internal_kv._internal_kv_reset() + # Delete the cluster address. + ray._private.utils.reset_ray_address() diff --git a/.venv/lib/python3.11/site-packages/ray/cross_language.py b/.venv/lib/python3.11/site-packages/ray/cross_language.py new file mode 100644 index 0000000000000000000000000000000000000000..1539954e56f0f45a9f2fbc89b99e72153db0b603 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/cross_language.py @@ -0,0 +1,137 @@ +from __future__ import absolute_import, division, print_function + +from ray import Language +from ray._raylet import CppFunctionDescriptor, JavaFunctionDescriptor +from ray.util.annotations import PublicAPI + +__all__ = [ + "java_function", + "java_actor_class", + "cpp_function", +] + + +@PublicAPI(stability="beta") +def java_function(class_name: str, function_name: str): + """Define a Java function. + + Args: + class_name: Java class name. + function_name: Java function name. + """ + from ray.remote_function import RemoteFunction + + return RemoteFunction( + Language.JAVA, + lambda *args, **kwargs: None, + JavaFunctionDescriptor(class_name, function_name, ""), + {}, + ) + + +@PublicAPI(stability="beta") +def cpp_function(function_name: str): + """Define a Cpp function. + + Args: + function_name: Cpp function name. + """ + from ray.remote_function import RemoteFunction + + return RemoteFunction( + Language.CPP, + lambda *args, **kwargs: None, + CppFunctionDescriptor(function_name, "PYTHON"), + {}, + ) + + +@PublicAPI(stability="beta") +def java_actor_class(class_name: str): + """Define a Java actor class. + + Args: + class_name: Java class name. + """ + from ray.actor import ActorClass + + return ActorClass._ray_from_function_descriptor( + Language.JAVA, + JavaFunctionDescriptor(class_name, "", ""), + {}, + ) + + +@PublicAPI(stability="beta") +def cpp_actor_class(create_function_name: str, class_name: str): + """Define a Cpp actor class. + + Args: + create_function_name: Create cpp class function name. + class_name: Cpp class name. + """ + from ray.actor import ActorClass + + print("create func=", create_function_name, "class_name=", class_name) + return ActorClass._ray_from_function_descriptor( + Language.CPP, + CppFunctionDescriptor(create_function_name, "PYTHON", class_name), + {}, + ) + + +def _format_args(worker, args, kwargs): + """Format args for various languages. + + Args: + worker: The global worker instance. + args: The arguments for cross language. + kwargs: The keyword arguments for cross language. + + Returns: + List of args and kwargs (if supported). + """ + if not worker.load_code_from_local: + raise ValueError( + "Cross language feature needs --load-code-from-local to be set." + ) + if kwargs: + raise TypeError( + f"Cross language remote functions does not support kwargs, " + f"kwargs:{str(kwargs)}." + ) + return args + + +def _get_function_descriptor_for_actor_method( + language: str, actor_creation_function_descriptor, method_name: str, signature: str +): + """Get function descriptor for cross language actor method call. + + Args: + language: Target language. + actor_creation_function_descriptor: + The function signature for actor creation. + method_name: The name of actor method. + signature: The signature for the actor method. When calling Java from Python, + it should be string in the form of "{length_of_args}". + + Returns: + Function descriptor for cross language actor method call. + """ + if language == Language.JAVA: + return JavaFunctionDescriptor( + actor_creation_function_descriptor.class_name, + method_name, + signature, + ) + elif language == Language.CPP: + return CppFunctionDescriptor( + method_name, + "PYTHON", + actor_creation_function_descriptor.class_name, + ) + else: + raise NotImplementedError( + "Cross language remote actor method " f"not support language {language}" + ) diff --git a/.venv/lib/python3.11/site-packages/ray/exceptions.py b/.venv/lib/python3.11/site-packages/ray/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..48ea0baa5e968770293d7841652305eb45258647 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/exceptions.py @@ -0,0 +1,933 @@ +import logging +import os +import sys +from traceback import format_exception +from typing import Optional, Union + +import colorama + +import ray._private.ray_constants as ray_constants +import ray.cloudpickle as pickle +from ray._raylet import ActorID, TaskID, WorkerID +from ray.core.generated.common_pb2 import ( + PYTHON, + ActorDiedErrorContext, + Address, + Language, + NodeDeathInfo, + RayException, +) +from ray.util.annotations import DeveloperAPI, PublicAPI + +import setproctitle + +logger = logging.getLogger(__name__) + + +@PublicAPI +class RayError(Exception): + """Super class of all ray exception types.""" + + def to_bytes(self): + # Extract exc_info from exception object. + exc_info = (type(self), self, self.__traceback__) + formatted_exception_string = "\n".join(format_exception(*exc_info)) + return RayException( + language=PYTHON, + serialized_exception=pickle.dumps(self), + formatted_exception_string=formatted_exception_string, + ).SerializeToString() + + @staticmethod + def from_bytes(b): + ray_exception = RayException() + ray_exception.ParseFromString(b) + return RayError.from_ray_exception(ray_exception) + + @staticmethod + def from_ray_exception(ray_exception): + if ray_exception.language == PYTHON: + try: + return pickle.loads(ray_exception.serialized_exception) + except Exception as e: + msg = "Failed to unpickle serialized exception" + raise RuntimeError(msg) from e + else: + return CrossLanguageError(ray_exception) + + +@PublicAPI +class CrossLanguageError(RayError): + """Raised from another language.""" + + def __init__(self, ray_exception): + super().__init__( + "An exception raised from {}:\n{}".format( + Language.Name(ray_exception.language), + ray_exception.formatted_exception_string, + ) + ) + + +@PublicAPI +class TaskCancelledError(RayError): + """Raised when this task is cancelled. + + Args: + task_id: The TaskID of the function that was directly + cancelled. + """ + + def __init__( + self, task_id: Optional[TaskID] = None, error_message: Optional[str] = None + ): + self.task_id = task_id + self.error_message = error_message + + def __str__(self): + msg = "" + if self.task_id: + msg = "Task: " + str(self.task_id) + " was cancelled. " + if self.error_message: + msg += self.error_message + return msg + + +@PublicAPI +class RayTaskError(RayError): + """Indicates that a task threw an exception during execution. + + If a task throws an exception during execution, a RayTaskError is stored in + the object store for each of the task's outputs. When an object is + retrieved from the object store, the Python method that retrieved it checks + to see if the object is a RayTaskError and if it is then an exception is + thrown propagating the error message. + """ + + def __init__( + self, + function_name, + traceback_str, + cause, + proctitle=None, + pid=None, + ip=None, + actor_repr=None, + actor_id=None, + ): + """Initialize a RayTaskError.""" + import ray + + if proctitle: + self.proctitle = proctitle + else: + self.proctitle = setproctitle.getproctitle() + self.pid = pid or os.getpid() + self.ip = ip or ray.util.get_node_ip_address() + self.function_name = function_name + self.traceback_str = traceback_str + self.actor_repr = actor_repr + self._actor_id = actor_id + self.cause = cause + + try: + pickle.dumps(cause) + except (pickle.PicklingError, TypeError) as e: + err_msg = ( + "The original cause of the RayTaskError" + f" ({self.cause.__class__}) isn't serializable: {e}." + " Overwriting the cause to a RayError." + ) + logger.warning(err_msg) + self.cause = RayError(err_msg) + + # BaseException implements a __reduce__ method that returns + # a tuple with the type and the value of self.args. + # https://stackoverflow.com/a/49715949/2213289 + self.args = (function_name, traceback_str, self.cause, proctitle, pid, ip) + + assert traceback_str is not None + + def make_dual_exception_instance(self) -> "RayTaskError": + """Makes a object instance that inherits from both RayTaskError and the type of + `self.cause`. Raises TypeError if the cause class can't be subclassed""" + # For normal user Exceptions, we subclass from both + # RayTaskError and the user exception. For ExceptionGroup, + # we special handle it because it has a different __new__() + # signature from Exception. + # Ref: https://docs.python.org/3/library/exceptions.html#exception-groups + if sys.version_info >= (3, 11) and isinstance( + self.cause, ExceptionGroup # noqa: F821 + ): + return self._make_exceptiongroup_dual_exception_instance() + return self._make_normal_dual_exception_instance() + + def _make_normal_dual_exception_instance(self) -> "RayTaskError": + cause_cls = self.cause.__class__ + error_msg = str(self) + + class cls(RayTaskError, cause_cls): + def __init__(self, cause): + self.cause = cause + # BaseException implements a __reduce__ method that returns + # a tuple with the type and the value of self.args. + # https://stackoverflow.com/a/49715949/2213289 + self.args = (cause,) + + def __getattr__(self, name): + return getattr(self.cause, name) + + def __str__(self): + return error_msg + + name = f"RayTaskError({cause_cls.__name__})" + cls.__name__ = name + cls.__qualname__ = name + + return cls(self.cause) + + def _make_exceptiongroup_dual_exception_instance(self) -> "RayTaskError": + cause_cls = self.cause.__class__ + error_msg = str(self) + + class cls(RayTaskError, cause_cls): + def __new__(cls, cause): + self = super().__new__(cls, cause.message, cause.exceptions) + return self + + def __init__(self, cause): + self.cause = cause + # BaseException implements a __reduce__ method that returns + # a tuple with the type and the value of self.args. + # https://stackoverflow.com/a/49715949/2213289 + self.args = (cause,) + + def __getattr__(self, name): + return getattr(self.cause, name) + + def __str__(self): + return error_msg + + name = f"RayTaskError({cause_cls.__name__})" + cls.__name__ = name + cls.__qualname__ = name + + return cls(self.cause) + + def as_instanceof_cause(self): + """Returns an exception that's an instance of the cause's class. + + The returned exception inherits from both RayTaskError and the + cause class and contains all of the attributes of the cause + exception. + + If the cause class can't be subclassed, issues a warning and returns `self`. + """ + cause_cls = self.cause.__class__ + if issubclass(RayTaskError, cause_cls): + return self # already satisfied + + try: + return self.make_dual_exception_instance() + except TypeError as e: + logger.warning( + f"User exception type {type(self.cause)} in RayTaskError can't" + " be subclassed! This exception is raised as" + " RayTaskError only. You can use `ray_task_error.cause` to" + f" access the user exception. Failure in subclassing: {e}" + ) + return self + + def __str__(self): + """Format a RayTaskError as a string.""" + lines = self.traceback_str.strip().split("\n") + out = [] + code_from_internal_file = False + + # Format tracebacks. + # Python stacktrace consists of + # Traceback...: Indicate the next line will be a traceback. + # File [file_name + line number] + # code + # XError: [message] + # NOTE: For _raylet.pyx (Cython), the code is not always included. + for i, line in enumerate(lines): + # Convert traceback to the readable information. + if line.startswith("Traceback "): + traceback_line = ( + f"{colorama.Fore.CYAN}" + f"{self.proctitle}()" + f"{colorama.Fore.RESET} " + f"(pid={self.pid}, ip={self.ip}" + ) + if self.actor_repr: + traceback_line += ( + f", actor_id={self._actor_id}, repr={self.actor_repr})" + ) + else: + traceback_line += ")" + code_from_internal_file = False + out.append(traceback_line) + elif line.startswith(" File ") and ( + "ray/worker.py" in line + or "ray/_private/" in line + or "ray/util/tracing/" in line + or "ray/_raylet.pyx" in line + ): + # TODO(windows) + # Process the internal file line. + # The file line always starts with 2 space and File. + # https://github.com/python/cpython/blob/0a0a135bae2692d069b18d2d590397fbe0a0d39a/Lib/traceback.py#L421 # noqa + if "ray._raylet.raise_if_dependency_failed" in line: + # It means the current task is failed + # due to the dependency failure. + # Print out an user-friendly + # message to explain that.. + out.append( + " At least one of the input arguments for " + "this task could not be computed:" + ) + if i + 1 < len(lines) and lines[i + 1].startswith(" "): + # If the next line is indented with 2 space, + # that means it contains internal code information. + # For example, + # File [file_name] [line] + # [code] # if the next line is indented, it is code. + # Note there there are 4 spaces in the code line. + code_from_internal_file = True + elif code_from_internal_file: + # If the current line is internal file's code, + # the next line is not code anymore. + code_from_internal_file = False + else: + out.append(line) + return "\n".join(out) + + +@PublicAPI +class LocalRayletDiedError(RayError): + """Indicates that the task's local raylet died.""" + + def __str__(self): + return "The task's local raylet died. Check raylet.out for more information." + + +@PublicAPI +class WorkerCrashedError(RayError): + """Indicates that the worker died unexpectedly while executing a task.""" + + def __str__(self): + return ( + "The worker died unexpectedly while executing this task. " + "Check python-core-worker-*.log files for more information." + ) + + +@PublicAPI +class RayActorError(RayError): + """Indicates that the actor has outages unexpectedly before finishing a task. + + This exception could happen because the actor process is dead, or is unavailable for + the moment. Ray raises subclasses `ActorDiedError` and `ActorUnavailableError` + respectively. + """ + + BASE_ERROR_MSG = "The actor experienced an error before finishing this task." + + def __init__( + self, + actor_id: str = None, + error_msg: str = BASE_ERROR_MSG, + actor_init_failed: bool = False, + preempted: bool = False, + ): + #: The actor ID in hex string. + self.actor_id = actor_id + #: Whether the actor failed in the middle of __init__. + self.error_msg = error_msg + #: The full error message. + self._actor_init_failed = actor_init_failed + #: Whether the actor died because the node was preempted. + self._preempted = preempted + + def __str__(self) -> str: + return self.error_msg + + @property + def preempted(self) -> bool: + return self._preempted + + @property + def actor_init_failed(self) -> bool: + return self._actor_init_failed + + +@DeveloperAPI +class ActorDiedError(RayActorError): + """Indicates that the actor died unexpectedly before finishing a task. + + This exception could happen either because the actor process dies while + executing a task, or because a task is submitted to a dead actor. + + Args: + cause: The cause of the actor error. `RayTaskError` type means + the actor has died because of an exception within `__init__`. + `ActorDiedErrorContext` means the actor has died because of + an unexpected system error. None means the cause isn't known. + Theoretically, this shouldn't happen, + but it's there as a safety check. + """ + + BASE_ERROR_MSG = "The actor died unexpectedly before finishing this task." + + def __init__( + self, cause: Optional[Union[RayTaskError, ActorDiedErrorContext]] = None + ): + """ + Construct a RayActorError by building the arguments. + """ + + actor_id = None + error_msg = ActorDiedError.BASE_ERROR_MSG + actor_init_failed = False + preempted = False + + if not cause: + # Use the defaults above. + pass + elif isinstance(cause, RayTaskError): + actor_init_failed = True + actor_id = cause._actor_id + error_msg = ( + "The actor died because of an error" + " raised in its creation task, " + f"{cause.__str__()}" + ) + else: + # Inidicating system-level actor failures. + assert isinstance(cause, ActorDiedErrorContext) + error_msg_lines = [ActorDiedError.BASE_ERROR_MSG] + error_msg_lines.append(f"\tclass_name: {cause.class_name}") + error_msg_lines.append(f"\tactor_id: {ActorID(cause.actor_id).hex()}") + # Below items are optional fields. + if cause.pid != 0: + error_msg_lines.append(f"\tpid: {cause.pid}") + if cause.name != "": + error_msg_lines.append(f"\tname: {cause.name}") + if cause.ray_namespace != "": + error_msg_lines.append(f"\tnamespace: {cause.ray_namespace}") + if cause.node_ip_address != "": + error_msg_lines.append(f"\tip: {cause.node_ip_address}") + error_msg_lines.append(cause.error_message) + if cause.never_started: + error_msg_lines.append( + "The actor never ran - it was cancelled before it started running." + ) + if ( + cause.node_death_info + and cause.node_death_info.reason + == NodeDeathInfo.AUTOSCALER_DRAIN_PREEMPTED + ): + preempted = True + error_msg = "\n".join(error_msg_lines) + actor_id = ActorID(cause.actor_id).hex() + super().__init__(actor_id, error_msg, actor_init_failed, preempted) + + @staticmethod + def from_task_error(task_error: RayTaskError): + return ActorDiedError(task_error) + + +@DeveloperAPI +class ActorUnavailableError(RayActorError): + """Raised when the actor is temporarily unavailable but may be available later.""" + + def __init__(self, error_message: str, actor_id: Optional[bytes]): + actor_id = ActorID(actor_id).hex() if actor_id is not None else None + error_msg = ( + f"The actor {actor_id} is unavailable: {error_message}. The task may or may" + "not have been executed on the actor." + ) + actor_init_failed = False + preempted = False + + super().__init__(actor_id, error_msg, actor_init_failed, preempted) + + +@PublicAPI +class RaySystemError(RayError): + """Indicates that Ray encountered a system error. + + This exception can be thrown when the raylet is killed. + """ + + def __init__(self, client_exc, traceback_str=None): + self.client_exc = client_exc + self.traceback_str = traceback_str + + def __str__(self): + error_msg = f"System error: {self.client_exc}" + if self.traceback_str: + error_msg += f"\ntraceback: {self.traceback_str}" + return error_msg + + +@DeveloperAPI +class UserCodeException(RayError): + """Indicates that an exception occurred while executing user code. + For example, this exception can be used to wrap user code exceptions + from a remote task or actor. The `retry_exceptions` parameter will + still respect the underlying cause of this exception.""" + + pass + + +@PublicAPI +class ObjectStoreFullError(RayError): + """Indicates that the object store is full. + + This is raised if the attempt to store the object fails + because the object store is full even after multiple retries. + """ + + def __str__(self): + return super(ObjectStoreFullError, self).__str__() + ( + "\n" + "The local object store is full of objects that are still in " + "scope and cannot be evicted. Tip: Use the `ray memory` command " + "to list active objects in the cluster." + ) + + +@PublicAPI +class OutOfDiskError(RayError): + """Indicates that the local disk is full. + + This is raised if the attempt to store the object fails + because both the object store and disk are full. + """ + + def __str__(self): + # TODO(scv119): expose more disk usage information and link to a doc. + return super(OutOfDiskError, self).__str__() + ( + "\n" + "The object cannot be created because the local object store" + " is full and the local disk's utilization is over capacity" + " (95% by default)." + "Tip: Use `df` on this node to check disk usage and " + "`ray memory` to check object store memory usage." + ) + + +@PublicAPI +class OutOfMemoryError(RayError): + """Indicates that the node is running out of memory and is close to full. + + This is raised if the node is low on memory and tasks or actors are being + evicted to free up memory. + """ + + # TODO: (clarng) expose the error message string here and format it with proto + def __init__(self, message): + self.message = message + + def __str__(self): + return self.message + + +@PublicAPI +class NodeDiedError(RayError): + """Indicates that the node is either dead or unreachable.""" + + # TODO: (clarng) expose the error message string here and format it with proto + def __init__(self, message): + self.message = message + + def __str__(self): + return self.message + + +@PublicAPI +class ObjectLostError(RayError): + """Indicates that the object is lost from distributed memory, due to + node failure or system error. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __init__(self, object_ref_hex, owner_address, call_site): + self.object_ref_hex = object_ref_hex + self.owner_address = owner_address + self.call_site = call_site.replace( + ray_constants.CALL_STACK_LINE_DELIMITER, "\n " + ) + + def _base_str(self): + msg = f"Failed to retrieve object {self.object_ref_hex}. " + if self.call_site: + msg += f"The ObjectRef was created at: {self.call_site}" + else: + msg += ( + "To see information about where this ObjectRef was created " + "in Python, set the environment variable " + "RAY_record_ref_creation_sites=1 during `ray start` and " + "`ray.init()`." + ) + return msg + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + f"All copies of {self.object_ref_hex} have been lost due to node " + "failure. Check cluster logs (`/tmp/ray/session_latest/logs`) for " + "more information about the failure." + ) + ) + + +@PublicAPI +class ObjectFetchTimedOutError(ObjectLostError): + """Indicates that an object fetch timed out. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + f"Fetch for object {self.object_ref_hex} timed out because no " + "locations were found for the object. This may indicate a " + "system-level bug." + ) + ) + + +@DeveloperAPI +class RpcError(RayError): + """Indicates an error in the underlying RPC system.""" + + def __init__(self, message, rpc_code=None): + self.message = message + self.rpc_code = rpc_code + + def __str__(self): + return self.message + + +@DeveloperAPI +class ReferenceCountingAssertionError(ObjectLostError, AssertionError): + """Indicates that an object has been deleted while there was still a + reference to it. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + "The object has already been deleted by the reference counting " + "protocol. This should not happen." + ) + ) + + +@DeveloperAPI +class ObjectFreedError(ObjectLostError): + """Indicates that an object was manually freed by the application. + + Attributes: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + "The object was manually freed using the internal `free` call. " + "Please ensure that `free` is only called once the object is no " + "longer needed." + ) + ) + + +@PublicAPI +class OwnerDiedError(ObjectLostError): + """Indicates that the owner of the object has died while there is still a + reference to the object. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + log_loc = "`/tmp/ray/session_latest/logs`" + if self.owner_address: + try: + addr = Address() + addr.ParseFromString(self.owner_address) + ip_addr = addr.ip_address + worker_id = WorkerID(addr.worker_id) + log_loc = ( + f"`/tmp/ray/session_latest/logs/*{worker_id.hex()}*`" + f" at IP address {ip_addr}" + ) + except Exception: + # Catch all to make sure we always at least print the default + # message. + pass + + return ( + self._base_str() + + "\n\n" + + ( + "The object's owner has exited. This is the Python " + "worker that first created the ObjectRef via `.remote()` or " + "`ray.put()`. " + f"Check cluster logs ({log_loc}) for more " + "information about the Python worker failure." + ) + ) + + +@PublicAPI +class ObjectReconstructionFailedError(ObjectLostError): + """Indicates that the object cannot be reconstructed. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + "The object cannot be reconstructed " + "because it was created by an actor, ray.put() call, or its " + "ObjectRef was created by a different worker." + ) + ) + + +@PublicAPI +class ObjectReconstructionFailedMaxAttemptsExceededError(ObjectLostError): + """Indicates that the object cannot be reconstructed because the maximum + number of task retries has been exceeded. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + "The object cannot be reconstructed " + "because the maximum number of task retries has been exceeded. " + "To prevent this error, set " + "`@ray.remote(max_retries=)` (default 3)." + ) + ) + + +@PublicAPI +class ObjectReconstructionFailedLineageEvictedError(ObjectLostError): + """Indicates that the object cannot be reconstructed because its lineage + was evicted due to memory pressure. + + Args: + object_ref_hex: Hex ID of the object. + """ + + def __str__(self): + return ( + self._base_str() + + "\n\n" + + ( + "The object cannot be reconstructed because its lineage has been " + "evicted to reduce memory pressure. " + "To prevent this error, set the environment variable " + "RAY_max_lineage_bytes= (default 1GB) during `ray start`." + ) + ) + + +@PublicAPI +class GetTimeoutError(RayError, TimeoutError): + """Indicates that a call to the worker timed out.""" + + pass + + +@PublicAPI +class PlasmaObjectNotAvailable(RayError): + """Called when an object was not available within the given timeout.""" + + pass + + +@PublicAPI +class AsyncioActorExit(RayError): + """Raised when an asyncio actor intentionally exits via exit_actor().""" + + pass + + +@PublicAPI +class RuntimeEnvSetupError(RayError): + """Raised when a runtime environment fails to be set up. + + Args: + error_message: The error message that explains + why runtime env setup has failed. + """ + + def __init__(self, error_message: str = None): + self.error_message = error_message + + def __str__(self): + msgs = ["Failed to set up runtime environment."] + if self.error_message: + msgs.append(self.error_message) + return "\n".join(msgs) + + +@PublicAPI +class TaskPlacementGroupRemoved(RayError): + """Raised when the corresponding placement group was removed.""" + + def __str__(self): + return "The placement group corresponding to this task has been removed." + + +@PublicAPI +class ActorPlacementGroupRemoved(RayError): + """Raised when the corresponding placement group was removed.""" + + def __str__(self): + return "The placement group corresponding to this Actor has been removed." + + +@PublicAPI +class PendingCallsLimitExceeded(RayError): + """Raised when the pending actor calls exceeds `max_pending_calls` option. + + This exception could happen probably because the caller calls the callee + too frequently. + """ + + pass + + +@PublicAPI +class TaskUnschedulableError(RayError): + """Raised when the task cannot be scheduled. + + One example is that the node specified through + NodeAffinitySchedulingStrategy is dead. + """ + + def __init__(self, error_message: str): + self.error_message = error_message + + def __str__(self): + return f"The task is not schedulable: {self.error_message}" + + +@PublicAPI +class ActorUnschedulableError(RayError): + """Raised when the actor cannot be scheduled. + + One example is that the node specified through + NodeAffinitySchedulingStrategy is dead. + """ + + def __init__(self, error_message: str): + self.error_message = error_message + + def __str__(self): + return f"The actor is not schedulable: {self.error_message}" + + +@DeveloperAPI +class ObjectRefStreamEndOfStreamError(RayError): + """Raised by streaming generator tasks when there are no more ObjectRefs to + read. + """ + + pass + + +@DeveloperAPI +class OufOfBandObjectRefSerializationException(RayError): + """Raised when an `ray.ObjectRef` is out of band serialized by + `ray.cloudpickle`. It is an anti pattern. + """ + + pass + + +@PublicAPI(stability="alpha") +class RayChannelError(RaySystemError): + """Indicates that Ray encountered a system error related + to ray.experimental.channel. + """ + + pass + + +@PublicAPI(stability="alpha") +class RayChannelTimeoutError(RayChannelError, TimeoutError): + """Raised when the Compiled Graph channel operation times out.""" + + pass + + +@PublicAPI(stability="alpha") +class RayCgraphCapacityExceeded(RaySystemError): + """Raised when the Compiled Graph channel's buffer is at max capacity""" + + pass + + +RAY_EXCEPTION_TYPES = [ + PlasmaObjectNotAvailable, + RayError, + RayTaskError, + WorkerCrashedError, + RayActorError, + ObjectStoreFullError, + ObjectLostError, + ObjectFetchTimedOutError, + ReferenceCountingAssertionError, + ObjectReconstructionFailedError, + ObjectReconstructionFailedMaxAttemptsExceededError, + ObjectReconstructionFailedLineageEvictedError, + OwnerDiedError, + GetTimeoutError, + AsyncioActorExit, + RuntimeEnvSetupError, + TaskPlacementGroupRemoved, + ActorPlacementGroupRemoved, + PendingCallsLimitExceeded, + LocalRayletDiedError, + TaskUnschedulableError, + ActorDiedError, + ActorUnschedulableError, + ActorUnavailableError, + RayChannelError, + RayChannelTimeoutError, + OufOfBandObjectRefSerializationException, + RayCgraphCapacityExceeded, +] diff --git a/.venv/lib/python3.11/site-packages/ray/job_config.py b/.venv/lib/python3.11/site-packages/ray/job_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ca00b39fecf3908f0c7fe394080e8f0f3868b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/job_config.py @@ -0,0 +1,249 @@ +import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import ray.cloudpickle as pickle +from ray._private.ray_logging.logging_config import LoggingConfig +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.runtime_env import RuntimeEnv + + +@PublicAPI +class JobConfig: + """A class used to store the configurations of a job. + + Examples: + .. testcode:: + :hide: + + import ray + ray.shutdown() + + .. testcode:: + + import ray + from ray.job_config import JobConfig + + ray.init(job_config=JobConfig(default_actor_lifetime="non_detached")) + + Args: + jvm_options: The jvm options for java workers of the job. + code_search_path: A list of directories or jar files that + specify the search path for user code. This will be used as + `CLASSPATH` in Java and `PYTHONPATH` in Python. + See :ref:`Ray cross-language programming ` for more details. + runtime_env: A :ref:`runtime environment ` dictionary. + metadata: An opaque metadata dictionary. + ray_namespace: A :ref:`namespace ` + is a logical grouping of jobs and named actors. + default_actor_lifetime: The default value of actor lifetime, + can be "detached" or "non_detached". + See :ref:`actor lifetimes ` for more details. + """ + + def __init__( + self, + jvm_options: Optional[List[str]] = None, + code_search_path: Optional[List[str]] = None, + runtime_env: Optional[dict] = None, + _client_job: bool = False, + metadata: Optional[dict] = None, + ray_namespace: Optional[str] = None, + default_actor_lifetime: str = "non_detached", + _py_driver_sys_path: Optional[List[str]] = None, + ): + #: The jvm options for java workers of the job. + self.jvm_options = jvm_options or [] + #: A list of directories or jar files that + #: specify the search path for user code. + self.code_search_path = code_search_path or [] + # It's difficult to find the error that caused by the + # code_search_path is a string. So we assert here. + assert isinstance(self.code_search_path, (list, tuple)), ( + f"The type of code search path is incorrect: " f"{type(code_search_path)}" + ) + self._client_job = _client_job + #: An opaque metadata dictionary. + self.metadata = metadata or {} + #: A namespace is a logical grouping of jobs and named actors. + self.ray_namespace = ray_namespace + self.set_runtime_env(runtime_env) + self.set_default_actor_lifetime(default_actor_lifetime) + # A list of directories that specify the search path for python workers. + self._py_driver_sys_path = _py_driver_sys_path or [] + # Python logging configurations that will be passed to Ray tasks/actors. + self.py_logging_config = None + + def set_metadata(self, key: str, value: str) -> None: + """Add key-value pair to the metadata dictionary. + + If the key already exists, the value is overwritten to the new value. + + Examples: + .. testcode:: + + import ray + from ray.job_config import JobConfig + + job_config = JobConfig() + job_config.set_metadata("submitter", "foo") + + Args: + key: The key of the metadata. + value: The value of the metadata. + """ + self.metadata[key] = value + + def _serialize(self) -> str: + """Serialize the struct into protobuf string""" + return self._get_proto_job_config().SerializeToString() + + def set_runtime_env( + self, + runtime_env: Optional[Union[Dict[str, Any], "RuntimeEnv"]], + validate: bool = False, + ) -> None: + """Modify the runtime_env of the JobConfig. + + We don't validate the runtime_env by default here because it may go + through some translation before actually being passed to C++ (e.g., + working_dir translated from a local directory to a URI). + + Args: + runtime_env: A :ref:`runtime environment ` dictionary. + validate: Whether to validate the runtime env. + """ + self.runtime_env = runtime_env if runtime_env is not None else {} + if validate: + self.runtime_env = self._validate_runtime_env() + self._cached_pb = None + + def set_py_logging_config( + self, + logging_config: Optional[LoggingConfig] = None, + ): + """Set the logging configuration for the job. + + The logging configuration will be applied to the root loggers of + all Ray task and actor processes that belong to this job. + + Args: + logging_config: The logging configuration to set. + """ + self.py_logging_config = logging_config + + def set_ray_namespace(self, ray_namespace: str) -> None: + """Set Ray :ref:`namespace `. + + Args: + ray_namespace: The namespace to set. + """ + + if ray_namespace != self.ray_namespace: + self.ray_namespace = ray_namespace + self._cached_pb = None + + def set_default_actor_lifetime(self, default_actor_lifetime: str) -> None: + """Set the default actor lifetime, which can be "detached" or "non_detached". + + See :ref:`actor lifetimes ` for more details. + + Args: + default_actor_lifetime: The default actor lifetime to set. + """ + import ray.core.generated.common_pb2 as common_pb2 + + if default_actor_lifetime == "detached": + self._default_actor_lifetime = common_pb2.JobConfig.ActorLifetime.DETACHED + elif default_actor_lifetime == "non_detached": + self._default_actor_lifetime = ( + common_pb2.JobConfig.ActorLifetime.NON_DETACHED + ) + else: + raise ValueError( + "Default actor lifetime must be one of `detached`, `non_detached`" + ) + + def _validate_runtime_env(self): + # TODO(edoakes): this is really unfortunate, but JobConfig is imported + # all over the place so this causes circular imports. We should remove + # this dependency and pass in a validated runtime_env instead. + from ray.runtime_env import RuntimeEnv + + if isinstance(self.runtime_env, RuntimeEnv): + return self.runtime_env + return RuntimeEnv(**self.runtime_env) + + def _get_proto_job_config(self): + """Return the protobuf structure of JobConfig.""" + # TODO(edoakes): this is really unfortunate, but JobConfig is imported + # all over the place so this causes circular imports. We should remove + # this dependency and pass in a validated runtime_env instead. + import ray.core.generated.common_pb2 as common_pb2 + from ray._private.utils import get_runtime_env_info + + if self._cached_pb is None: + pb = common_pb2.JobConfig() + if self.ray_namespace is None: + pb.ray_namespace = str(uuid.uuid4()) + else: + pb.ray_namespace = self.ray_namespace + pb.jvm_options.extend(self.jvm_options) + pb.code_search_path.extend(self.code_search_path) + pb.py_driver_sys_path.extend(self._py_driver_sys_path) + for k, v in self.metadata.items(): + pb.metadata[k] = v + + parsed_env = self._validate_runtime_env() + pb.runtime_env_info.CopyFrom( + get_runtime_env_info( + parsed_env, + is_job_runtime_env=True, + serialize=False, + ) + ) + + if self._default_actor_lifetime is not None: + pb.default_actor_lifetime = self._default_actor_lifetime + if self.py_logging_config: + pb.serialized_py_logging_config = pickle.dumps(self.py_logging_config) + self._cached_pb = pb + + return self._cached_pb + + def _runtime_env_has_working_dir(self): + return self._validate_runtime_env().has_working_dir() + + def _get_serialized_runtime_env(self) -> str: + """Return the JSON-serialized parsed runtime env dict""" + return self._validate_runtime_env().serialize() + + def _get_proto_runtime_env_config(self) -> str: + """Return the JSON-serialized parsed runtime env info""" + return self._get_proto_job_config().runtime_env_info.runtime_env_config + + @classmethod + def from_json(cls, job_config_json): + """Generates a JobConfig object from json. + + Examples: + .. testcode:: + + from ray.job_config import JobConfig + + job_config = JobConfig.from_json( + {"runtime_env": {"working_dir": "uri://abc"}}) + + Args: + job_config_json: The job config json dictionary. + """ + return cls( + jvm_options=job_config_json.get("jvm_options", None), + code_search_path=job_config_json.get("code_search_path", None), + runtime_env=job_config_json.get("runtime_env", None), + metadata=job_config_json.get("metadata", None), + ray_namespace=job_config_json.get("ray_namespace", None), + _client_job=job_config_json.get("client_job", False), + _py_driver_sys_path=job_config_json.get("py_driver_sys_path", None), + ) diff --git a/.venv/lib/python3.11/site-packages/ray/nightly-wheels.yaml b/.venv/lib/python3.11/site-packages/ray/nightly-wheels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1648a1d01400826ac652f1cb6fb7cf9a38c6aa7a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/nightly-wheels.yaml @@ -0,0 +1,11 @@ +linux: + "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl + "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl + +darwin: + "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-macosx_10_15_x86_64.whl + "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-macosx_10_15_intel.whl + +win32: + "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-win_amd64.whl + "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-win_amd64.whl diff --git a/.venv/lib/python3.11/site-packages/ray/py.typed b/.venv/lib/python3.11/site-packages/ray/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/remote_function.py b/.venv/lib/python3.11/site-packages/ray/remote_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b44eae3d84ce6480868591c4b3e5315dd92629da --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/remote_function.py @@ -0,0 +1,515 @@ +import inspect +import logging +import os +import uuid +from functools import wraps +from threading import Lock +from typing import Optional + +import ray._private.signature +from ray import Language, cross_language +from ray._private import ray_option_utils +from ray._private.auto_init_hook import wrap_auto_init +from ray._private.client_mode_hook import ( + client_mode_convert_function, + client_mode_should_convert, +) +from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group +from ray._private.serialization import pickle_dumps +from ray._private.utils import get_runtime_env_info, parse_runtime_env +from ray._raylet import ( + STREAMING_GENERATOR_RETURN, + ObjectRefGenerator, + PythonFunctionDescriptor, +) +from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.placement_group import _configure_placement_group_based_on_context +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from ray.util.tracing.tracing_helper import ( + _inject_tracing_into_function, + _tracing_task_invocation, +) + +logger = logging.getLogger(__name__) + + +# Hook to call with (fn, resources, strategy) on each local task submission. +_task_launch_hook = None + + +@PublicAPI +class RemoteFunction: + """A remote function. + + This is a decorated function. It can be used to spawn tasks. + + Attributes: + _language: The target language. + _function: The original function. + _function_descriptor: The function descriptor. This is not defined + until the remote function is first invoked because that is when the + function is pickled, and the pickled function is used to compute + the function descriptor. + _function_name: The module and function name. + _num_cpus: The default number of CPUs to use for invocations of this + remote function. + _num_gpus: The default number of GPUs to use for invocations of this + remote function. + _memory: The heap memory request in bytes for this task/actor, + rounded down to the nearest integer. + _resources: The default custom resource requirements for invocations of + this remote function. + _num_returns: The default number of return values for invocations + of this remote function. + _max_calls: The number of times a worker can execute this function + before exiting. + _max_retries: The number of times this task may be retried + on worker failure. + _retry_exceptions: Whether application-level errors should be retried. + This can be a boolean or a list/tuple of exceptions that should be retried. + _runtime_env: The runtime environment for this task. + _decorator: An optional decorator that should be applied to the remote + function invocation (as opposed to the function execution) before + invoking the function. The decorator must return a function that + takes in two arguments ("args" and "kwargs"). In most cases, it + should call the function that was passed into the decorator and + return the resulting ObjectRefs. For an example, see + "test_decorated_function" in "python/ray/tests/test_basic.py". + _function_signature: The function signature. + _last_export_cluster_and_job: A pair of the last exported cluster + and job to help us to know whether this function was exported. + This is an imperfect mechanism used to determine if we need to + export the remote function again. It is imperfect in the sense that + the actor class definition could be exported multiple times by + different workers. + _scheduling_strategy: Strategy about how to schedule + this remote function. + """ + + def __init__( + self, + language, + function, + function_descriptor, + task_options, + ): + if inspect.iscoroutinefunction(function): + raise ValueError( + "'async def' should not be used for remote tasks. You can wrap the " + "async function with `asyncio.run(f())`. See more at:" + "https://docs.ray.io/en/latest/ray-core/actors/async_api.html " + ) + self._default_options = task_options + + # When gpu is used, set the task non-recyclable by default. + # https://github.com/ray-project/ray/issues/29624 for more context. + # Note: Ray task worker process is not being reused when nsight + # profiler is running, as nsight generate report once the process exit. + num_gpus = self._default_options.get("num_gpus") or 0 + if ( + num_gpus > 0 and self._default_options.get("max_calls", None) is None + ) or "nsight" in (self._default_options.get("runtime_env") or {}): + self._default_options["max_calls"] = 1 + + # TODO(suquark): This is a workaround for class attributes of options. + # They are being used in some other places, mostly tests. Need cleanup later. + # E.g., actors uses "__ray_metadata__" to collect options, we can so something + # similar for remote functions. + for k, v in ray_option_utils.task_options.items(): + setattr(self, "_" + k, task_options.get(k, v.default_value)) + self._runtime_env = parse_runtime_env(self._runtime_env) + if "runtime_env" in self._default_options: + self._default_options["runtime_env"] = self._runtime_env + + # Pre-calculate runtime env info, to avoid re-calculation at `remote` + # invocation. When `remote` call has specified extra `option` field, + # runtime env will be overwritten and re-serialized. + # + # Caveat: To support dynamic runtime envs in + # `func.option(runtime_env={...}).remote()`, we recalculate the serialized + # runtime env info in the `option` call. But it's acceptable since + # pre-calculation here only happens once at `RemoteFunction` initialization. + self._serialized_base_runtime_env_info = "" + if self._runtime_env: + self._serialized_base_runtime_env_info = get_runtime_env_info( + self._runtime_env, + is_job_runtime_env=False, + serialize=True, + ) + + self._language = language + self._is_generator = inspect.isgeneratorfunction(function) + self._function = function + self._function_signature = None + # Guards trace injection to enforce exactly once semantics + self._inject_lock = Lock() + self._function_name = function.__module__ + "." + function.__name__ + self._function_descriptor = function_descriptor + self._is_cross_language = language != Language.PYTHON + self._decorator = getattr(function, "__ray_invocation_decorator__", None) + self._last_export_cluster_and_job = None + self._uuid = uuid.uuid4() + + # Override task.remote's signature and docstring + @wraps(function) + def _remote_proxy(*args, **kwargs): + return self._remote( + serialized_runtime_env_info=self._serialized_base_runtime_env_info, + args=args, + kwargs=kwargs, + **self._default_options, + ) + + self.remote = _remote_proxy + + def __call__(self, *args, **kwargs): + raise TypeError( + "Remote functions cannot be called directly. Instead " + f"of running '{self._function_name}()', " + f"try '{self._function_name}.remote()'." + ) + + # Lock is not picklable + def __getstate__(self): + attrs = self.__dict__.copy() + del attrs["_inject_lock"] + return attrs + + def __setstate__(self, state): + self.__dict__.update(state) + self.__dict__["_inject_lock"] = Lock() + + def options(self, **task_options): + """Configures and overrides the task invocation parameters. + + The arguments are the same as those that can be passed to :obj:`ray.remote`. + Overriding `max_calls` is not supported. + + Args: + num_returns: It specifies the number of object refs returned by + the remote function invocation. + num_cpus: The quantity of CPU cores to reserve + for this task or for the lifetime of the actor. + num_gpus: The quantity of GPUs to reserve + for this task or for the lifetime of the actor. + resources (Dict[str, float]): The quantity of various custom resources + to reserve for this task or for the lifetime of the actor. + This is a dictionary mapping strings (resource names) to floats. + accelerator_type: If specified, requires that the task or actor run + on a node with the specified type of accelerator. + See :ref:`accelerator types `. + memory: The heap memory request in bytes for this task/actor, + rounded down to the nearest integer. + object_store_memory: The object store memory request for actors only. + max_calls: This specifies the + maximum number of times that a given worker can execute + the given remote function before it must exit + (this can be used to address memory leaks in third-party + libraries or to reclaim resources that cannot easily be + released, e.g., GPU memory that was acquired by TensorFlow). + By default this is infinite for CPU tasks and 1 for GPU tasks + (to force GPU tasks to release resources after finishing). + max_retries: This specifies the maximum number of times that the remote + function should be rerun when the worker process executing it + crashes unexpectedly. The minimum valid value is 0, + the default is 3 (default), and a value of -1 indicates + infinite retries. + runtime_env (Dict[str, Any]): Specifies the runtime environment for + this actor or task and its children. See + :ref:`runtime-environments` for detailed documentation. + retry_exceptions: This specifies whether application-level errors + should be retried up to max_retries times. + scheduling_strategy: Strategy about how to + schedule a remote function or actor. Possible values are + None: ray will figure out the scheduling strategy to use, it + will either be the PlacementGroupSchedulingStrategy using parent's + placement group if parent has one and has + placement_group_capture_child_tasks set to true, + or "DEFAULT"; + "DEFAULT": default hybrid scheduling; + "SPREAD": best effort spread scheduling; + `PlacementGroupSchedulingStrategy`: + placement group based scheduling; + `NodeAffinitySchedulingStrategy`: + node id based affinity scheduling. + enable_task_events: This specifies whether to enable task events for this + task. If set to True, task events such as (task running, finished) + are emitted, and available to Ray Dashboard and State API. + See :ref:`state-api-overview-ref` for more details. + _metadata: Extended options for Ray libraries. For example, + _metadata={"workflows.io/options": } for + Ray workflows. + _labels: The key-value labels of a task. + + Examples: + + .. code-block:: python + + @ray.remote(num_gpus=1, max_calls=1, num_returns=2) + def f(): + return 1, 2 + # Task g will require 2 gpus instead of 1. + g = f.options(num_gpus=2) + """ + + func_cls = self + + # override original options + default_options = self._default_options.copy() + # max_calls could not be used in ".options()", we should remove it before + # merging options from '@ray.remote'. + default_options.pop("max_calls", None) + updated_options = ray_option_utils.update_options(default_options, task_options) + ray_option_utils.validate_task_options(updated_options, in_options=True) + + # Only update runtime_env and re-calculate serialized runtime env info when + # ".options()" specifies new runtime_env. + serialized_runtime_env_info = self._serialized_base_runtime_env_info + if "runtime_env" in task_options: + updated_options["runtime_env"] = parse_runtime_env( + updated_options["runtime_env"] + ) + # Re-calculate runtime env info based on updated runtime env. + if updated_options["runtime_env"]: + serialized_runtime_env_info = get_runtime_env_info( + updated_options["runtime_env"], + is_job_runtime_env=False, + serialize=True, + ) + + class FuncWrapper: + def remote(self, *args, **kwargs): + return func_cls._remote( + args=args, + kwargs=kwargs, + serialized_runtime_env_info=serialized_runtime_env_info, + **updated_options, + ) + + @DeveloperAPI + def bind(self, *args, **kwargs): + """ + For Ray DAG building that creates static graph from decorated + class or functions. + """ + from ray.dag.function_node import FunctionNode + + return FunctionNode(func_cls._function, args, kwargs, updated_options) + + return FuncWrapper() + + @wrap_auto_init + @_tracing_task_invocation + def _remote( + self, + args=None, + kwargs=None, + serialized_runtime_env_info: Optional[str] = None, + **task_options, + ): + """Submit the remote function for execution.""" + # We pop the "max_calls" coming from "@ray.remote" here. We no longer need + # it in "_remote()". + task_options.pop("max_calls", None) + if client_mode_should_convert(): + return client_mode_convert_function(self, args, kwargs, **task_options) + + worker = ray._private.worker.global_worker + worker.check_connected() + + # We cannot do this when the function is first defined, because we need + # ray.init() to have been called when this executes + with self._inject_lock: + if self._function_signature is None: + self._function = _inject_tracing_into_function(self._function) + self._function_signature = ray._private.signature.extract_signature( + self._function + ) + + # If this function was not exported in this cluster and job, we need to + # export this function again, because the current GCS doesn't have it. + if ( + not self._is_cross_language + and self._last_export_cluster_and_job != worker.current_cluster_and_job + ): + self._function_descriptor = PythonFunctionDescriptor.from_function( + self._function, self._uuid + ) + # There is an interesting question here. If the remote function is + # used by a subsequent driver (in the same script), should the + # second driver pickle the function again? If yes, then the remote + # function definition can differ in the second driver (e.g., if + # variables in its closure have changed). We probably want the + # behavior of the remote function in the second driver to be + # independent of whether or not the function was invoked by the + # first driver. This is an argument for repickling the function, + # which we do here. + self._pickled_function = pickle_dumps( + self._function, + f"Could not serialize the function {self._function_descriptor.repr}", + ) + + self._last_export_cluster_and_job = worker.current_cluster_and_job + worker.function_actor_manager.export(self) + + kwargs = {} if kwargs is None else kwargs + args = [] if args is None else args + + # fill task required options + for k, v in ray_option_utils.task_options.items(): + if k == "max_retries": + # TODO(swang): We need to override max_retries here because the default + # value gets set at Ray import time. Ideally, we should allow setting + # default values from env vars for other options too. + v.default_value = os.environ.get( + "RAY_TASK_MAX_RETRIES", v.default_value + ) + v.default_value = int(v.default_value) + task_options[k] = task_options.get(k, v.default_value) + # "max_calls" already takes effects and should not apply again. + # Remove the default value here. + task_options.pop("max_calls", None) + + # TODO(suquark): cleanup these fields + name = task_options["name"] + placement_group = task_options["placement_group"] + placement_group_bundle_index = task_options["placement_group_bundle_index"] + placement_group_capture_child_tasks = task_options[ + "placement_group_capture_child_tasks" + ] + scheduling_strategy = task_options["scheduling_strategy"] + + num_returns = task_options["num_returns"] + if num_returns is None: + if self._is_generator: + num_returns = "streaming" + else: + num_returns = 1 + + if num_returns == "dynamic": + num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = ray._raylet.STREAMING_GENERATOR_RETURN + generator_backpressure_num_objects = task_options[ + "_generator_backpressure_num_objects" + ] + if generator_backpressure_num_objects is None: + generator_backpressure_num_objects = -1 + + max_retries = task_options["max_retries"] + retry_exceptions = task_options["retry_exceptions"] + if isinstance(retry_exceptions, (list, tuple)): + retry_exception_allowlist = tuple(retry_exceptions) + retry_exceptions = True + else: + retry_exception_allowlist = None + + if scheduling_strategy is None or not isinstance( + scheduling_strategy, PlacementGroupSchedulingStrategy + ): + _warn_if_using_deprecated_placement_group(task_options, 4) + + resources = ray._private.utils.resources_from_ray_options(task_options) + + if scheduling_strategy is None or isinstance( + scheduling_strategy, PlacementGroupSchedulingStrategy + ): + if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy): + placement_group = scheduling_strategy.placement_group + placement_group_bundle_index = ( + scheduling_strategy.placement_group_bundle_index + ) + placement_group_capture_child_tasks = ( + scheduling_strategy.placement_group_capture_child_tasks + ) + + if placement_group_capture_child_tasks is None: + placement_group_capture_child_tasks = ( + worker.should_capture_child_tasks_in_placement_group + ) + placement_group = _configure_placement_group_based_on_context( + placement_group_capture_child_tasks, + placement_group_bundle_index, + resources, + {}, # no placement_resources for tasks + self._function_descriptor.function_name, + placement_group=placement_group, + ) + if not placement_group.is_empty: + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group, + placement_group_bundle_index, + placement_group_capture_child_tasks, + ) + else: + scheduling_strategy = "DEFAULT" + + if _task_launch_hook: + _task_launch_hook(self._function_descriptor, resources, scheduling_strategy) + + # Override enable_task_events to default for actor if not specified (i.e. None) + enable_task_events = task_options.get("enable_task_events") + labels = task_options.get("_labels") + + def invocation(args, kwargs): + if self._is_cross_language: + list_args = cross_language._format_args(worker, args, kwargs) + elif not args and not kwargs and not self._function_signature: + list_args = [] + else: + list_args = ray._private.signature.flatten_args( + self._function_signature, args, kwargs + ) + + if worker.mode == ray._private.worker.LOCAL_MODE: + assert ( + not self._is_cross_language + ), "Cross language remote function cannot be executed locally." + object_refs = worker.core_worker.submit_task( + self._language, + self._function_descriptor, + list_args, + name if name is not None else "", + num_returns, + resources, + max_retries, + retry_exceptions, + retry_exception_allowlist, + scheduling_strategy, + worker.debugger_breakpoint, + serialized_runtime_env_info or "{}", + generator_backpressure_num_objects, + enable_task_events, + labels, + ) + # Reset worker's debug context from the last "remote" command + # (which applies only to this .remote call). + worker.debugger_breakpoint = b"" + if num_returns == STREAMING_GENERATOR_RETURN: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return ObjectRefGenerator(generator_ref, worker) + if len(object_refs) == 1: + return object_refs[0] + elif len(object_refs) > 1: + return object_refs + + if self._decorator is not None: + invocation = self._decorator(invocation) + + return invocation(args, kwargs) + + @DeveloperAPI + def bind(self, *args, **kwargs): + """ + For Ray DAG building that creates static graph from decorated + class or functions. + """ + + from ray.dag.function_node import FunctionNode + + return FunctionNode(self._function, args, kwargs, self._default_options) diff --git a/.venv/lib/python3.11/site-packages/ray/runtime_context.py b/.venv/lib/python3.11/site-packages/ray/runtime_context.py new file mode 100644 index 0000000000000000000000000000000000000000..5cacae69371b7dc3ed194a8cedee52b06bdfeac1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/runtime_context.py @@ -0,0 +1,564 @@ +import logging +from typing import Any, Dict, List, Optional + +import ray._private.worker +from ray._private.client_mode_hook import client_mode_hook +from ray._private.utils import parse_pg_formatted_resources_to_original +from ray._raylet import TaskID +from ray.runtime_env import RuntimeEnv +from ray.util.annotations import Deprecated, PublicAPI + +logger = logging.getLogger(__name__) + + +@PublicAPI +class RuntimeContext(object): + """A class used for getting runtime context.""" + + def __init__(self, worker): + assert worker is not None + self.worker = worker + + @Deprecated( + message="Use get_xxx_id() methods to get relevant ids instead", warning=True + ) + def get(self) -> Dict[str, Any]: + """Get a dictionary of the current context. + + Returns: + dict: Dictionary of the current context. + """ + context = { + "job_id": self.job_id, + "node_id": self.node_id, + "namespace": self.namespace, + } + if self.worker.mode == ray._private.worker.WORKER_MODE: + if self.task_id is not None: + context["task_id"] = self.task_id + if self.actor_id is not None: + context["actor_id"] = self.actor_id + + return context + + @property + @Deprecated(message="Use get_job_id() instead", warning=True) + def job_id(self): + """Get current job ID for this worker or driver. + + Job ID is the id of your Ray drivers that create tasks or actors. + + Returns: + If called by a driver, this returns the job ID. If called in + a task, return the job ID of the associated driver. + + """ + job_id = self.worker.current_job_id + assert not job_id.is_nil() + return job_id + + def get_job_id(self) -> str: + """Get current job ID for this worker or driver. + + Job ID is the id of your Ray drivers that create tasks or actors. + + Returns: + If called by a driver, this returns the job ID. If called in + a task, return the job ID of the associated driver. The + job ID will be hex format. + + Raises: + AssertionError: If not called in a driver or worker. Generally, + this means that ray.init() was not called. + """ + assert ray.is_initialized(), ( + "Job ID is not available because " "Ray has not been initialized." + ) + job_id = self.worker.current_job_id + return job_id.hex() + + @property + @Deprecated(message="Use get_node_id() instead", warning=True) + def node_id(self): + """Get current node ID for this worker or driver. + + Node ID is the id of a node that your driver, task, or actor runs. + + Returns: + A node id for this worker or driver. + """ + node_id = self.worker.current_node_id + assert not node_id.is_nil() + return node_id + + def get_node_id(self) -> str: + """Get current node ID for this worker or driver. + + Node ID is the id of a node that your driver, task, or actor runs. + The ID will be in hex format. + + Returns: + A node id in hex format for this worker or driver. + + Raises: + AssertionError: If not called in a driver or worker. Generally, + this means that ray.init() was not called. + """ + assert ray.is_initialized(), ( + "Node ID is not available because " "Ray has not been initialized." + ) + node_id = self.worker.current_node_id + return node_id.hex() + + def get_worker_id(self) -> str: + """Get current worker ID for this worker or driver process. + + Returns: + A worker id in hex format for this worker or driver process. + """ + assert ( + ray.is_initialized() + ), "Worker ID is not available because Ray has not been initialized." + return self.worker.worker_id.hex() + + @property + @Deprecated(message="Use get_task_id() instead", warning=True) + def task_id(self): + """Get current task ID for this worker. + + Task ID is the id of a Ray task. + This shouldn't be used in a driver process. + + Example: + + .. testcode:: + + import ray + + @ray.remote + class Actor: + def ready(self): + return True + + @ray.remote + def f(): + return True + + # All the below code generates different task ids. + # Task ids are available for actor creation. + a = Actor.remote() + # Task ids are available for actor tasks. + a.ready.remote() + # Task ids are available for normal tasks. + f.remote() + + Returns: + The current worker's task id. None if there's no task id. + """ + # only worker mode has task_id + assert ( + self.worker.mode == ray._private.worker.WORKER_MODE + ), f"This method is only available when the process is a\ + worker. Current mode: {self.worker.mode}" + + task_id = self._get_current_task_id() + return task_id if not task_id.is_nil() else None + + def get_task_id(self) -> Optional[str]: + """Get current task ID for this worker. + + Task ID is the id of a Ray task. The ID will be in hex format. + This shouldn't be used in a driver process. + + Example: + + .. testcode:: + + import ray + + @ray.remote + class Actor: + def get_task_id(self): + return ray.get_runtime_context().get_task_id() + + @ray.remote + def get_task_id(): + return ray.get_runtime_context().get_task_id() + + # All the below code generates different task ids. + a = Actor.remote() + # Task ids are available for actor tasks. + print(ray.get(a.get_task_id.remote())) + # Task ids are available for normal tasks. + print(ray.get(get_task_id.remote())) + + .. testoutput:: + :options: +MOCK + + 16310a0f0a45af5c2746a0e6efb235c0962896a201000000 + c2668a65bda616c1ffffffffffffffffffffffff01000000 + + Returns: + The current worker's task id in hex. None if there's no task id. + """ + # only worker mode has task_id + if self.worker.mode != ray._private.worker.WORKER_MODE: + logger.warning( + "This method is only available when the process is a " + f"worker. Current mode: {self.worker.mode}" + ) + return None + task_id = self._get_current_task_id() + return task_id.hex() if not task_id.is_nil() else None + + def _get_current_task_id(self) -> TaskID: + return self.worker.current_task_id + + def get_task_name(self) -> Optional[str]: + """Get current task name for this worker. + + Task name by default is the task's funciton call string. It can also be + specified in options when triggering a task. + + Example: + + .. testcode:: + + import ray + + @ray.remote + class Actor: + def get_task_name(self): + return ray.get_runtime_context().get_task_name() + + @ray.remote + class AsyncActor: + async def get_task_name(self): + return ray.get_runtime_context().get_task_name() + + @ray.remote + def get_task_name(): + return ray.get_runtime_context().get_task_name() + + a = Actor.remote() + b = AsyncActor.remote() + # Task names are available for actor tasks. + print(ray.get(a.get_task_name.remote())) + # Task names are avaiable for async actor tasks. + print(ray.get(b.get_task_name.remote())) + # Task names are available for normal tasks. + # Get default task name + print(ray.get(get_task_name.remote())) + # Get specified task name + print(ray.get(get_task_name.options(name="task_name").remote())) + + .. testoutput:: + :options: +MOCK + + Actor.get_task_name + AsyncActor.get_task_name + get_task_name + task_nams + + Returns: + The current worker's task name + """ + # only worker mode has task_name + if self.worker.mode != ray._private.worker.WORKER_MODE: + logger.warning( + "This method is only available when the process is a " + f"worker. Current mode: {self.worker.mode}" + ) + return None + return self.worker.current_task_name + + def get_task_function_name(self) -> Optional[str]: + """Get current task function name string for this worker. + + Example: + + .. testcode:: + + import ray + + @ray.remote + class Actor: + def get_task_function_name(self): + return ray.get_runtime_context().get_task_function_name() + + @ray.remote + class AsyncActor: + async def get_task_function_name(self): + return ray.get_runtime_context().get_task_function_name() + + @ray.remote + def get_task_function_name(): + return ray.get_runtime_context().get_task_function_name() + + a = Actor.remote() + b = AsyncActor.remote() + # Task functions are available for actor tasks. + print(ray.get(a.get_task_function_name.remote())) + # Task functions are available for async actor tasks. + print(ray.get(b.get_task_function_name.remote())) + # Task functions are available for normal tasks. + print(ray.get(get_task_function_name.remote())) + + .. testoutput:: + :options: +MOCK + + [python modual name].Actor.get_task_function_name + [python modual name].AsyncActor.get_task_function_name + [python modual name].get_task_function_name + + Returns: + The current worker's task function call string + """ + # only worker mode has task_function_name + if self.worker.mode != ray._private.worker.WORKER_MODE: + logger.warning( + "This method is only available when the process is a " + f"worker. Current mode: {self.worker.mode}" + ) + return None + return self.worker.current_task_function_name + + @property + @Deprecated(message="Use get_actor_id() instead", warning=True) + def actor_id(self): + """Get the current actor ID in this worker. + + ID of the actor of the current process. + This shouldn't be used in a driver process. + + Returns: + The current actor id in this worker. None if there's no actor id. + """ + # only worker mode has actor_id + assert ( + self.worker.mode == ray._private.worker.WORKER_MODE + ), f"This method is only available when the process is a\ + worker. Current mode: {self.worker.mode}" + actor_id = self.worker.actor_id + return actor_id if not actor_id.is_nil() else None + + def get_actor_id(self) -> Optional[str]: + """Get the current actor ID in this worker. + + ID of the actor of the current process. + This shouldn't be used in a driver process. + The ID will be in hex format. + + Returns: + The current actor id in hex format in this worker. None if there's no + actor id. + """ + # only worker mode has actor_id + if self.worker.mode != ray._private.worker.WORKER_MODE: + logger.debug( + "This method is only available when the process is a " + f"worker. Current mode: {self.worker.mode}" + ) + return None + actor_id = self.worker.actor_id + return actor_id.hex() if not actor_id.is_nil() else None + + def get_actor_name(self) -> Optional[str]: + """Get the current actor name of this worker. + + This shouldn't be used in a driver process. + The name is in string format. + + Returns: + The current actor name of this worker. + If a current worker is an actor, and + if actor name doesn't exist, it returns an empty string. + If a current worker is not an actor, it returns None. + """ + # only worker mode has actor_id + if self.worker.mode != ray._private.worker.WORKER_MODE: + logger.warning( + "This method is only available when the process is a " + f"worker. Current mode: {self.worker.mode}" + ) + return None + actor_id = self.worker.actor_id + return self.worker.actor_name if not actor_id.is_nil() else None + + @property + def namespace(self): + """Get the current namespace of this worker. + + Returns: + The current namespace of this worker. + """ + return self.worker.namespace + + @property + def was_current_actor_reconstructed(self): + """Check whether this actor has been restarted. + + Returns: + Whether this actor has been ever restarted. + """ + assert ( + not self.actor_id.is_nil() + ), "This method should't be called inside Ray tasks." + actor_info = ray._private.state.actors(self.actor_id.hex()) + return actor_info and actor_info["NumRestarts"] != 0 + + @property + @Deprecated(message="Use get_placement_group_id() instead", warning=True) + def current_placement_group_id(self): + """Get the current Placement group ID of this worker. + + Returns: + The current placement group id of this worker. + """ + return self.worker.placement_group_id + + def get_placement_group_id(self) -> Optional[str]: + """Get the current Placement group ID of this worker. + + Returns: + The current placement group id in hex format of this worker. + """ + pg_id = self.worker.placement_group_id + return pg_id.hex() if not pg_id.is_nil() else None + + @property + def should_capture_child_tasks_in_placement_group(self): + """Get if the current task should capture parent's placement group. + + This returns True if it is called inside a driver. + + Returns: + Return True if the current task should implicitly + capture the parent placement group. + """ + return self.worker.should_capture_child_tasks_in_placement_group + + def get_assigned_resources(self): + """Get the assigned resources to this worker. + + By default for tasks, this will return {"CPU": 1}. + By default for actors, this will return {}. This is because + actors do not have CPUs assigned to them by default. + + Returns: + A dictionary mapping the name of a resource to a float, where + the float represents the amount of that resource reserved + for this worker. + """ + assert ( + self.worker.mode == ray._private.worker.WORKER_MODE + ), f"This method is only available when the process is a\ + worker. Current mode: {self.worker.mode}" + self.worker.check_connected() + resource_id_map = self.worker.core_worker.resource_ids() + resource_map = { + res: sum(amt for _, amt in mapping) + for res, mapping in resource_id_map.items() + } + result = parse_pg_formatted_resources_to_original(resource_map) + return result + + def get_runtime_env_string(self): + """Get the runtime env string used for the current driver or worker. + + Returns: + The runtime env string currently using by this worker. + """ + return self.worker.runtime_env + + @property + def runtime_env(self): + """Get the runtime env used for the current driver or worker. + + Returns: + The runtime env currently using by this worker. The type of + return value is ray.runtime_env.RuntimeEnv. + """ + + return RuntimeEnv.deserialize(self.get_runtime_env_string()) + + @property + def current_actor(self): + """Get the current actor handle of this actor itself. + + Returns: + The handle of current actor. + """ + worker = self.worker + worker.check_connected() + actor_id = worker.actor_id + if actor_id.is_nil(): + raise RuntimeError("This method is only available in an actor.") + + return worker.core_worker.get_actor_handle(actor_id) + + @property + def gcs_address(self): + """Get the GCS address of the ray cluster. + Returns: + The GCS address of the cluster. + """ + self.worker.check_connected() + return self.worker.gcs_client.address + + @Deprecated(message="Use get_accelerator_ids() instead", warning=True) + def get_resource_ids(self) -> Dict[str, List[str]]: + return self.get_accelerator_ids() + + def get_accelerator_ids(self) -> Dict[str, List[str]]: + """ + Get the current worker's visible accelerator ids. + + Returns: + A dictionary keyed by the accelerator resource name. The values are a list + of ids `{'GPU': ['0', '1'], 'neuron_cores': ['0', '1'], + 'TPU': ['0', '1']}`. + """ + worker = self.worker + worker.check_connected() + ids_dict: Dict[str, List[str]] = {} + for ( + accelerator_resource_name + ) in ray._private.accelerators.get_all_accelerator_resource_names(): + accelerator_ids = worker.get_accelerator_ids_for_accelerator_resource( + accelerator_resource_name, + f"^{accelerator_resource_name}_group_[0-9A-Za-z]+$", + ) + ids_dict[accelerator_resource_name] = [str(id) for id in accelerator_ids] + return ids_dict + + +_runtime_context = None + + +@PublicAPI +@client_mode_hook +def get_runtime_context() -> RuntimeContext: + """Get the runtime context of the current driver/worker. + + The obtained runtime context can be used to get the metadata + of the current task and actor. + + Example: + + .. testcode:: + + import ray + # Get the job id. + ray.get_runtime_context().get_job_id() + # Get the actor id. + ray.get_runtime_context().get_actor_id() + # Get the task id. + ray.get_runtime_context().get_task_id() + + """ + global _runtime_context + if _runtime_context is None: + _runtime_context = RuntimeContext(ray._private.worker.global_worker) + + return _runtime_context diff --git a/.venv/lib/python3.11/site-packages/ray/setup-dev.py b/.venv/lib/python3.11/site-packages/ray/setup-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..75c7bae79e1a2422e06bbdc03cbab7ddefdce5fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/setup-dev.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# ruff: noqa: E402 +"""This script allows you to develop Ray Python code without needing to compile +Ray. +See https://docs.ray.io/en/master/development.html#building-ray-python-only""" + +import os +import sys + +# types.py can conflict with stdlib's types.py in some python versions, +# see https://github.com/python/cpython/issues/101210. +# To avoid import errors, we move the current working dir to the end of sys.path. +this_dir = os.path.dirname(__file__) +if this_dir in sys.path: + cur = sys.path.remove(this_dir) + sys.path.append(this_dir) + +import argparse +import click +import shutil +import subprocess + +import ray + + +def do_link(package, force=False, skip_list=None, local_path=None): + if skip_list and package in skip_list: + print(f"Skip creating symbolic link for {package}") + return + package_home = os.path.abspath(os.path.join(ray.__file__, f"../{package}")) + # Infer local_path automatically. + if local_path is None: + local_path = f"../{package}" + local_home = os.path.abspath(os.path.join(__file__, local_path)) + # If installed package dir does not exist, continue either way. We'll + # remove it/create a link from there anyways. + if not os.path.isdir(package_home) and not os.path.isfile(package_home): + print(f"{package_home} does not exist. Continuing to link.") + # Make sure the path we are linking to does exist. + assert os.path.exists(local_home), local_home + # Confirm with user. + if not force and not click.confirm( + f"This will replace:\n {package_home}\nwith " f"a symlink to:\n {local_home}", + default=True, + ): + return + + # Windows: Create directory junction. + if os.name == "nt": + try: + shutil.rmtree(package_home) + except FileNotFoundError: + pass + except OSError: + os.remove(package_home) + + # create symlink for directory or file + if os.path.isdir(local_home): + subprocess.check_call( + ["mklink", "/J", package_home, local_home], shell=True + ) + elif os.path.isfile(local_home): + subprocess.check_call( + ["mklink", "/H", package_home, local_home], shell=True + ) + else: + print(f"{local_home} is neither directory nor file. Link failed.") + + # Posix: Use `ln -s` to create softlink. + else: + sudo = [] + if not os.access(os.path.dirname(package_home), os.W_OK): + print("You don't have write permission " f"to {package_home}, using sudo:") + sudo = ["sudo"] + print(f"Creating symbolic link from \n {local_home} to \n {package_home}") + + # Preserve ray/serve/generated + if package == "serve": + # Copy generated folder to a temp dir + generated_folder = os.path.join(package_home, "generated") + temp_dir = "/tmp/ray/_serve/" + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + subprocess.check_call(["cp", "-r", generated_folder, temp_dir]) + + subprocess.check_call(sudo + ["rm", "-rf", package_home]) + subprocess.check_call(sudo + ["ln", "-s", local_home, package_home]) + + # Move generated folder to local_home + if package == "serve": + tmp_generated_folder = os.path.join(temp_dir, "generated") + package_generated_folder = os.path.join(package_home, "generated") + subprocess.check_call( + ["mv", tmp_generated_folder, package_generated_folder] + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, description="Setup dev." + ) + parser.add_argument( + "--yes", "-y", action="store_true", help="Don't ask for confirmation." + ) + parser.add_argument( + "--skip", + "-s", + nargs="*", + help="List of folders to skip linking to facilitate workspace dev", + required=False, + ) + parser.add_argument( + "--extras", + "-e", + nargs="*", + help="List of extra folders to link to facilitate workspace dev", + required=False, + ) + + args = parser.parse_args() + if not args.yes: + print("NOTE: Use '-y' to override all python files without confirmation.") + + do_link("rllib", force=args.yes, skip_list=args.skip, local_path="../../../rllib") + do_link("air", force=args.yes, skip_list=args.skip) + do_link("tune", force=args.yes, skip_list=args.skip) + do_link("train", force=args.yes, skip_list=args.skip) + do_link("autoscaler", force=args.yes, skip_list=args.skip) + do_link("cloudpickle", force=args.yes, skip_list=args.skip) + do_link("data", force=args.yes, skip_list=args.skip) + do_link("scripts", force=args.yes, skip_list=args.skip) + do_link("internal", force=args.yes, skip_list=args.skip) + do_link("tests", force=args.yes, skip_list=args.skip) + do_link("experimental", force=args.yes, skip_list=args.skip) + do_link("util", force=args.yes, skip_list=args.skip) + do_link("workflow", force=args.yes, skip_list=args.skip) + do_link("serve", force=args.yes, skip_list=args.skip) + do_link("dag", force=args.yes, skip_list=args.skip) + do_link("widgets", force=args.yes, skip_list=args.skip) + do_link("cluster_utils.py", force=args.yes, skip_list=args.skip) + do_link("_private", force=args.yes, skip_list=args.skip) + do_link("dashboard", force=args.yes, skip_list=args.skip) + + if args.extras is not None: + for package in args.extras: + do_link(package, force=args.yes, skip_list=args.skip) + + print( + "Created links.\n\nIf you run into issues initializing Ray, please " + "ensure that your local repo and the installed Ray are in sync " + "(pip install -U the latest wheels at " + "https://docs.ray.io/en/master/installation.html, " + "and ensure you are up-to-date on the master branch on git).\n\n" + "Note that you may need to delete the package symlinks when pip " + "installing new Ray versions to prevent pip from overwriting files " + "in your git repo." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/types.py b/.venv/lib/python3.11/site-packages/ray/types.py new file mode 100644 index 0000000000000000000000000000000000000000..a4733f44e98eb8660449dd46bf0c45f81856d8a6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/types.py @@ -0,0 +1,14 @@ +from typing import Generic, TypeVar + +from ray.util.annotations import PublicAPI + +T = TypeVar("T") + + +# TODO(ekl) this is a dummy generic ref type for documentation purposes only. +# We should try to make the Cython ray.ObjectRef properly generic. +# NOTE(sang): Looks like using Generic in Cython is not currently possible. +# We should update Cython > 3.0 for this. +@PublicAPI +class ObjectRef(Generic[T]): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__init__.py b/.venv/lib/python3.11/site-packages/ray/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c2e703afc929801fc03f68bf9d884a84f4e11d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/__init__.py @@ -0,0 +1,55 @@ +from ray.workflow.api import ( + init, + run, + run_async, + resume, + resume_all, + resume_async, + cancel, + list_all, + delete, + get_output, + get_output_async, + get_status, + get_metadata, + sleep, + wait_for_event, + continuation, + options, +) +from ray.workflow.exceptions import ( + WorkflowError, + WorkflowExecutionError, + WorkflowCancellationError, +) +from ray.workflow.common import WorkflowStatus +from ray.workflow.event_listener import EventListener + +globals().update(WorkflowStatus.__members__) + + +__all__ = [ + "init", + "run", + "run_async", + "resume", + "resume_async", + "resume_all", + "cancel", + "list_all", + "delete", + "get_output", + "get_output_async", + "get_status", + "get_metadata", + "sleep", + "wait_for_event", + "options", + "continuation", + # events + "EventListener", + # exceptions + "WorkflowError", + "WorkflowExecutionError", + "WorkflowCancellationError", +] diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/debug_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/debug_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..839fe9eed93c82b28e30e537da042d48d57f3c47 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/debug_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5e00577cfa881e35a45b91762476b0866363c0c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_access.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_access.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ee502f1afcafabfa3cf52baebfe2b677dcd6605 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_access.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e2fb6d28d698e9f9965a8fde7de0a8dae20c52c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_executor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_executor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26db30c7d3df646589a101db2fedf8590f1032b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_executor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e2113e1386720290edf75e14bec89aabe81a968 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_dag.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_dag.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4044af2ebe796381497dc531210327f8233fe1f8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_dag.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_storage.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_storage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7719c404df54f1c85231ade7962af5fddbf206 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_storage.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/api.py b/.venv/lib/python3.11/site-packages/ray/workflow/api.py new file mode 100644 index 0000000000000000000000000000000000000000..664f7b54f32ba7edec3827a9e183568487d15197 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/api.py @@ -0,0 +1,869 @@ +import functools +import logging +import tempfile +from typing import Dict, Set, List, Tuple, Union, Optional, Any +import time +import uuid +from pathlib import Path + +import ray +from ray.dag import DAGNode +from ray.dag.input_node import DAGInputData +from ray.remote_function import RemoteFunction + +# avoid collision with arguments & APIs + +from ray.workflow.common import ( + WorkflowStatus, + Event, + asyncio_run, + validate_user_metadata, +) +from ray.workflow import serialization, workflow_access, workflow_context +from ray.workflow.event_listener import EventListener, EventListenerType, TimerListener +from ray.workflow.workflow_storage import WorkflowStorage +from ray.workflow.workflow_state_from_dag import workflow_state_from_dag + +from ray.util.annotations import PublicAPI +from ray._private.usage import usage_lib + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +def init( + *, + max_running_workflows: Optional[int] = None, + max_pending_workflows: Optional[int] = None, +) -> None: + """Initialize workflow. + + If Ray is not initialized, we will initialize Ray and + use ``/tmp/ray/workflow_data`` as the default storage. + + Args: + max_running_workflows: The maximum number of concurrently running workflows. + Use -1 as infinity. 'None' means preserving previous setting or initialize + the setting with infinity. + max_pending_workflows: The maximum number of queued workflows. + Use -1 as infinity. 'None' means preserving previous setting or initialize + the setting with infinity. + """ + usage_lib.record_library_usage("workflow") + + if max_running_workflows is not None: + if not isinstance(max_running_workflows, int): + raise TypeError("'max_running_workflows' must be None or an integer.") + if max_running_workflows < -1 or max_running_workflows == 0: + raise ValueError( + "'max_running_workflows' must be a positive integer " + "or use -1 as infinity." + ) + if max_pending_workflows is not None: + if not isinstance(max_pending_workflows, int): + raise TypeError("'max_pending_workflows' must be None or an integer.") + if max_pending_workflows < -1: + raise ValueError( + "'max_pending_workflows' must be a non-negative integer " + "or use -1 as infinity." + ) + + if not ray.is_initialized(): + # We should use get_temp_dir_path, but for ray client, we don't + # have this one. We need a flag to tell whether it's a client + # or a driver to use the right dir. + # For now, just use $TMP/ray/workflow_data + workflow_dir = Path(tempfile.gettempdir()) / "ray" / "workflow_data" + ray.init(storage=workflow_dir.as_uri()) + workflow_access.init_management_actor(max_running_workflows, max_pending_workflows) + serialization.init_manager() + + +def _ensure_workflow_initialized() -> None: + # NOTE: Trying to get the actor has a side effect: it initializes Ray with + # default arguments. This is different in "init()": it assigns a temporary + # storage. This is why we need to check "ray.is_initialized()" first. + if not ray.is_initialized(): + init() + else: + try: + workflow_access.get_management_actor() + except ValueError: + init() + + +def client_mode_wrap(func): + """Wraps a function called during client mode for execution as a remote task. + + Adopted from "ray._private.client_mode_hook.client_mode_wrap". Some changes are made + (e.g., init the workflow instead of init Ray; the latter does not specify a storage + during Ray init and will result in workflow failures). + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + from ray._private.client_mode_hook import client_mode_should_convert + from ray._private.auto_init_hook import enable_auto_connect + + if enable_auto_connect: + _ensure_workflow_initialized() + + # `is_client_mode_enabled_by_default` is used for testing with + # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode. + if client_mode_should_convert(): + f = ray.remote(num_cpus=0)(func) + ref = f.remote(*args, **kwargs) + return ray.get(ref) + return func(*args, **kwargs) + + return wrapper + + +@PublicAPI(stability="alpha") +def run( + dag: DAGNode, + *args, + workflow_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, +) -> Any: + """Run a workflow. + + If the workflow with the given id already exists, it will be resumed. + + Examples: + .. testcode:: + + import ray + from ray import workflow + + @ray.remote + def book_flight(origin: str, dest: str): + return f"Flight: {origin}->{dest}" + + @ray.remote + def book_hotel(location: str): + return f"Hotel: {location}" + + @ray.remote + def finalize_trip(bookings: List[Any]): + return ' | '.join(ray.get(bookings)) + + flight1 = book_flight.bind("OAK", "SAN") + flight2 = book_flight.bind("SAN", "OAK") + hotel = book_hotel.bind("SAN") + trip = finalize_trip.bind([flight1, flight2, hotel]) + print(workflow.run(trip)) + + .. testoutput:: + + Flight: OAK->SAN | Flight: SAN->OAK | Hotel: SAN + + Args: + workflow_id: A unique identifier that can be used to resume the + workflow. If not specified, a random id will be generated. + metadata: The metadata to add to the workflow. It has to be able + to serialize to json. + + Returns: + The running result. + """ + return ray.get( + run_async(dag, *args, workflow_id=workflow_id, metadata=metadata, **kwargs) + ) + + +@PublicAPI(stability="alpha") +def run_async( + dag: DAGNode, + *args, + workflow_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, +) -> ray.ObjectRef: + """Run a workflow asynchronously. + + If the workflow with the given id already exists, it will be resumed. + + Args: + workflow_id: A unique identifier that can be used to resume the + workflow. If not specified, a random id will be generated. + metadata: The metadata to add to the workflow. It has to be able + to serialize to json. + + Returns: + The running result as ray.ObjectRef. + + """ + _ensure_workflow_initialized() + if not isinstance(dag, DAGNode): + raise TypeError("Input should be a DAG.") + input_data = DAGInputData(*args, **kwargs) + validate_user_metadata(metadata) + metadata = metadata or {} + + if workflow_id is None: + # Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds} + workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}" + + workflow_manager = workflow_access.get_management_actor() + if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)): + raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.") + + state = workflow_state_from_dag(dag, input_data, workflow_id) + logger.info(f'Workflow job created. [id="{workflow_id}"].') + context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id) + with workflow_context.workflow_task_context(context): + # checkpoint the workflow + @client_mode_wrap + def _try_checkpoint_workflow(workflow_state) -> bool: + ws = WorkflowStorage(workflow_id) + ws.save_workflow_user_metadata(metadata) + try: + ws.get_entrypoint_task_id() + return True + except Exception: + # The workflow does not exist. We must checkpoint entry workflow. + ws.save_workflow_execution_state("", workflow_state) + return False + + wf_exists = _try_checkpoint_workflow(state) + if wf_exists: + return resume_async(workflow_id) + ray.get( + workflow_manager.submit_workflow.remote( + workflow_id, state, ignore_existing=False + ) + ) + job_id = ray.get_runtime_context().get_job_id() + return workflow_manager.execute_workflow.remote(job_id, context) + + +@PublicAPI(stability="alpha") +def resume(workflow_id: str) -> Any: + """Resume a workflow. + + Resume a workflow and retrieve its output. If the workflow was incomplete, + it will be re-executed from its checkpointed outputs. If the workflow was + complete, returns the result immediately. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def start_trip(): + return 3 + + trip = start_trip.bind() + res1 = workflow.run_async(trip, workflow_id="trip1") + res2 = workflow.resume("trip1") + assert ray.get(res1) == res2 + + Args: + workflow_id: The id of the workflow to resume. + + Returns: + The output of the workflow. + """ + return ray.get(resume_async(workflow_id)) + + +@PublicAPI(stability="alpha") +def resume_async(workflow_id: str) -> ray.ObjectRef: + """Resume a workflow asynchronously. + + Resume a workflow and retrieve its output. If the workflow was incomplete, + it will be re-executed from its checkpointed outputs. If the workflow was + complete, returns the result immediately. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def start_trip(): + return 3 + + trip = start_trip.bind() + res1 = workflow.run_async(trip, workflow_id="trip1") + res2 = workflow.resume_async("trip1") + assert ray.get(res1) == ray.get(res2) + + Args: + workflow_id: The id of the workflow to resume. + + Returns: + An object reference that can be used to retrieve the workflow result. + """ + _ensure_workflow_initialized() + logger.info(f'Resuming workflow [id="{workflow_id}"].') + workflow_manager = workflow_access.get_management_actor() + if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)): + raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.") + # NOTE: It is important to 'ray.get' the returned output. This + # ensures caller of 'run()' holds the reference to the workflow + # result. Otherwise if the actor removes the reference of the + # workflow output, the caller may fail to resolve the result. + job_id = ray.get_runtime_context().get_job_id() + + context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id) + ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context)) + result = workflow_manager.execute_workflow.remote(job_id, context) + logger.info(f"Workflow job {workflow_id} resumed.") + return result + + +@PublicAPI(stability="alpha") +def get_output(workflow_id: str, *, task_id: Optional[str] = None) -> Any: + """Get the output of a running workflow. + + Args: + workflow_id: The workflow to get the output of. + task_id: If set, fetch the specific task instead of the output of the + workflow. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def start_trip(): + return 1 + + trip = start_trip.options(**workflow.options(task_id="trip")).bind() + res1 = workflow.run_async(trip, workflow_id="trip1") + # you could "get_output()" in another machine + res2 = workflow.get_output("trip1") + assert ray.get(res1) == res2 + task_output = workflow.get_output_async("trip1", task_id="trip") + assert ray.get(task_output) == ray.get(res1) + + Returns: + The output of the workflow task. + """ + return ray.get(get_output_async(workflow_id, task_id=task_id)) + + +@PublicAPI(stability="alpha") +@client_mode_wrap +def get_output_async( + workflow_id: str, *, task_id: Optional[str] = None +) -> ray.ObjectRef: + """Get the output of a running workflow asynchronously. + + Args: + workflow_id: The workflow to get the output of. + task_id: If set, fetch the specific task output instead of the output + of the workflow. + + Returns: + An object reference that can be used to retrieve the workflow task result. + """ + _ensure_workflow_initialized() + try: + workflow_manager = workflow_access.get_management_actor() + except ValueError as e: + raise ValueError( + "Failed to connect to the workflow management " + "actor. The workflow could have already failed. You can use " + "workflow.resume() or workflow.resume_async() to resume the " + "workflow." + ) from e + return workflow_manager.get_output.remote(workflow_id, task_id) + + +@PublicAPI(stability="alpha") +@client_mode_wrap +def list_all( + status_filter: Optional[ + Union[Union[WorkflowStatus, str], Set[Union[WorkflowStatus, str]]] + ] = None +) -> List[Tuple[str, WorkflowStatus]]: + """List all workflows matching a given status filter. When returning "RESUMEABLE" + workflows, the workflows that was running ranks before the workflow that was pending + in the result list. + + Args: + status_filter: If given, only returns workflow with that status. This can + be a single status or set of statuses. The string form of the + status is also acceptable, i.e., + "RUNNING"/"FAILED"/"SUCCESSFUL"/"CANCELED"/"RESUMABLE"/"PENDING". + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def long_running_job(): + import time + time.sleep(2) + + workflow_task = long_running_job.bind() + wf = workflow.run_async(workflow_task, + workflow_id="long_running_job") + jobs = workflow.list_all(workflow.RUNNING) + assert jobs == [ ("long_running_job", workflow.RUNNING) ] + ray.get(wf) + jobs = workflow.list_all({workflow.RUNNING}) + assert jobs == [] + + Returns: + A list of tuple with workflow id and workflow status + """ + _ensure_workflow_initialized() + if isinstance(status_filter, str): + status_filter = set({WorkflowStatus(status_filter)}) + elif isinstance(status_filter, WorkflowStatus): + status_filter = set({status_filter}) + elif isinstance(status_filter, set): + if all(isinstance(s, str) for s in status_filter): + status_filter = {WorkflowStatus(s) for s in status_filter} + elif not all(isinstance(s, WorkflowStatus) for s in status_filter): + raise TypeError( + "status_filter contains element which is not" + " a type of `WorkflowStatus or str`." + f" {status_filter}" + ) + elif status_filter is None: + status_filter = set(WorkflowStatus) + status_filter.discard(WorkflowStatus.NONE) + else: + raise TypeError( + "status_filter must be WorkflowStatus or a set of WorkflowStatus." + ) + + try: + workflow_manager = workflow_access.get_management_actor() + except ValueError: + workflow_manager = None + + if workflow_manager is None: + non_terminating_workflows = {} + else: + non_terminating_workflows = ray.get( + workflow_manager.list_non_terminating_workflows.remote() + ) + + ret = [] + if set(non_terminating_workflows.keys()).issuperset(status_filter): + for status, workflows in non_terminating_workflows.items(): + if status in status_filter: + for w in workflows: + ret.append((w, status)) + return ret + + ret = [] + # Here we don't have workflow id, so use empty one instead + store = WorkflowStorage("") + modified_status_filter = status_filter.copy() + # Here we have to add non-terminating status to the status filter, because some + # "RESUMABLE" workflows are converted from non-terminating workflows below. + # This is the tricky part: the status "RESUMABLE" neither come from + # the workflow management actor nor the storage. It is the status where + # the storage says it is non-terminating but the workflow management actor + # is not running it. This usually happened when there was a sudden crash + # of the whole Ray runtime or the workflow management actor + # (due to cluster etc.). So we includes non terminating status in the storage + # filter to get "RESUMABLE" candidates. + modified_status_filter.update(WorkflowStatus.non_terminating_status()) + status_from_storage = store.list_workflow(modified_status_filter) + non_terminating_workflows = { + k: set(v) for k, v in non_terminating_workflows.items() + } + resume_running = [] + resume_pending = [] + for (k, s) in status_from_storage: + if s in non_terminating_workflows and k not in non_terminating_workflows[s]: + if s == WorkflowStatus.RUNNING: + resume_running.append(k) + elif s == WorkflowStatus.PENDING: + resume_pending.append(k) + else: + assert False, "This line of code should not be reachable." + continue + if s in status_filter: + ret.append((k, s)) + if WorkflowStatus.RESUMABLE in status_filter: + # The running workflows ranks before the pending workflows. + for w in resume_running: + ret.append((w, WorkflowStatus.RESUMABLE)) + for w in resume_pending: + ret.append((w, WorkflowStatus.RESUMABLE)) + return ret + + +@PublicAPI(stability="alpha") +@client_mode_wrap +def resume_all(include_failed: bool = False) -> List[Tuple[str, ray.ObjectRef]]: + """Resume all resumable workflow jobs. + + This can be used after cluster restart to resume all tasks. + + Args: + include_failed: Whether to resume FAILED workflows. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def failed_job(): + raise ValueError() + + workflow_task = failed_job.bind() + output = workflow.run_async( + workflow_task, workflow_id="failed_job") + try: + ray.get(output) + except Exception: + print("JobFailed") + + assert workflow.get_status("failed_job") == workflow.FAILED + print(workflow.resume_all(include_failed=True)) + + .. testoutput:: + + JobFailed + [('failed_job', ObjectRef(...))] + + Returns: + A list of (workflow_id, returned_obj_ref) resumed. + """ + _ensure_workflow_initialized() + filter_set = {WorkflowStatus.RESUMABLE} + if include_failed: + filter_set.add(WorkflowStatus.FAILED) + all_failed = list_all(filter_set) + + try: + workflow_manager = workflow_access.get_management_actor() + except Exception as e: + raise RuntimeError("Failed to get management actor") from e + + job_id = ray.get_runtime_context().get_job_id() + reconstructed_workflows = [] + for wid, _ in all_failed: + context = workflow_context.WorkflowTaskContext(workflow_id=wid) + # TODO(suquark): This is not very efficient, but it makes sure + # running workflows has higher priority when getting reconstructed. + try: + ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context)) + except Exception as e: + # TODO(suquark): Here some workflows got resumed successfully but some + # failed and the user has no idea about this, which is very wired. + # Maybe we should raise an exception here instead? + logger.error(f"Failed to resume workflow {context.workflow_id}", exc_info=e) + raise + reconstructed_workflows.append(context) + + results = [] + for context in reconstructed_workflows: + results.append( + ( + context.workflow_id, + workflow_manager.execute_workflow.remote(job_id, context), + ) + ) + return results + + +@PublicAPI(stability="alpha") +def get_status(workflow_id: str) -> WorkflowStatus: + """Get the status for a given workflow. + + Args: + workflow_id: The workflow to query. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def trip(): + pass + + workflow_task = trip.bind() + output = workflow.run(workflow_task, workflow_id="local_trip") + assert workflow.SUCCESSFUL == workflow.get_status("local_trip") + + Returns: + The status of that workflow + """ + _ensure_workflow_initialized() + if not isinstance(workflow_id, str): + raise TypeError("workflow_id has to be a string type.") + workflow_manager = workflow_access.get_management_actor() + return ray.get(workflow_manager.get_workflow_status.remote(workflow_id)) + + +@PublicAPI(stability="alpha") +def wait_for_event( + event_listener_type: EventListenerType, *args, **kwargs +) -> "DAGNode[Event]": + if not issubclass(event_listener_type, EventListener): + raise TypeError( + f"Event listener type is {event_listener_type.__name__}" + ", which is not a subclass of workflow.EventListener" + ) + + @ray.remote + def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event: + event_listener = event_listener_type() + return asyncio_run(event_listener.poll_for_event(*args, **kwargs)) + + @ray.remote + def message_committed( + event_listener_type: EventListenerType, event: Event + ) -> Event: + event_listener = event_listener_type() + asyncio_run(event_listener.event_checkpointed(event)) + return event + + return message_committed.bind( + event_listener_type, get_message.bind(event_listener_type, *args, **kwargs) + ) + + +@PublicAPI(stability="alpha") +def sleep(duration: float) -> "DAGNode[Event]": + """ + A workfow that resolves after sleeping for a given duration. + """ + + @ray.remote + def end_time(): + return time.time() + duration + + return wait_for_event(TimerListener, end_time.bind()) + + +@PublicAPI(stability="alpha") +@client_mode_wrap +def get_metadata(workflow_id: str, task_id: Optional[str] = None) -> Dict[str, Any]: + """Get the metadata of the workflow. + + This will return a dict of metadata of either the workflow ( + if only workflow_id is given) or a specific workflow task (if + both workflow_id and task id are given). Exception will be + raised if the given workflow id or task id does not exist. + + If only workflow id is given, this will return metadata on + workflow level, which includes running status, workflow-level + user metadata and workflow-level running stats (e.g. the + start time and end time of the workflow). + + If both workflow id and task id are given, this will return + metadata on workflow task level, which includes task inputs, + task-level user metadata and task-level running stats (e.g. + the start time and end time of the task). + + + Args: + workflow_id: The workflow to get the metadata of. + task_id: If set, fetch the metadata of the specific task instead of + the metadata of the workflow. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def trip(): + pass + + workflow_task = trip.options( + **workflow.options(task_id="trip", metadata={"k1": "v1"})).bind() + workflow.run(workflow_task, + workflow_id="trip1", metadata={"k2": "v2"}) + workflow_metadata = workflow.get_metadata("trip1") + print(workflow_metadata) + + task_metadata = workflow.get_metadata("trip1", "trip") + print(task_metadata) + + .. testoutput:: + + {'status': 'SUCCESSFUL', 'user_metadata': {'k2': 'v2'}, 'stats': {'start_time': ..., 'end_time': ...}} + {'task_id': 'trip', 'task_options': {'task_type': 'FUNCTION', 'max_retries': 3, 'catch_exceptions': False, 'retry_exceptions': False, 'checkpoint': True, 'ray_options': {'_metadata': {'workflow.io/options': {'task_id': 'trip', 'metadata': {'k1': 'v1'}}}}}, 'user_metadata': {'k1': 'v1'}, 'workflow_refs': [], 'stats': {'start_time': ..., 'end_time': ...}} + + Returns: + A dictionary containing the metadata of the workflow. + + Raises: + ValueError: if given workflow or workflow task does not exist. + """ # noqa: E501 + _ensure_workflow_initialized() + store = WorkflowStorage(workflow_id) + if task_id is None: + return store.load_workflow_metadata() + else: + return store.load_task_metadata(task_id) + + +@PublicAPI(stability="alpha") +def cancel(workflow_id: str) -> None: + """Cancel a workflow. Workflow checkpoints will still be saved in storage. To + clean up saved checkpoints, see `workflow.delete()`. + + Args: + workflow_id: The workflow to cancel. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def some_job(): + return 1 + + workflow_task = some_job.bind() + workflow.run(workflow_task, workflow_id="some_job") + workflow.cancel(workflow_id="some_job") + assert workflow.get_status("some_job") == workflow.CANCELED + + Returns: + None + + """ + _ensure_workflow_initialized() + if not isinstance(workflow_id, str): + raise TypeError("workflow_id has to be a string type.") + workflow_manager = workflow_access.get_management_actor() + ray.get(workflow_manager.cancel_workflow.remote(workflow_id)) + + +@PublicAPI(stability="alpha") +def delete(workflow_id: str) -> None: + """Delete a workflow, its checkpoints, and other information it may have + persisted to storage. To stop a running workflow, see + `workflow.cancel()`. + + Args: + workflow_id: The workflow to delete. + + Raises: + WorkflowStillActiveError: The workflow is still active. + WorkflowNotFoundError: The workflow does not exist. + + Examples: + .. testcode:: + + from ray import workflow + + @ray.remote + def some_job(): + pass + + workflow_task = some_job.bind() + workflow.run(workflow_task, workflow_id="some_job") + workflow.delete(workflow_id="some_job") + """ + _ensure_workflow_initialized() + workflow_manager = workflow_access.get_management_actor() + ray.get(workflow_manager.delete_workflow.remote(workflow_id)) + + +@PublicAPI(stability="alpha") +def continuation(dag_node: "DAGNode") -> Union["DAGNode", Any]: + """Converts a DAG into a continuation. + + The result depends on the context. If it is inside a workflow, it + returns a workflow; otherwise it executes and get the result of + the DAG. + + Args: + dag_node: The DAG to be converted. + """ + from ray.workflow.workflow_context import in_workflow_execution + + if not isinstance(dag_node, DAGNode): + raise TypeError("Input should be a DAG.") + + if in_workflow_execution(): + return dag_node + return ray.get(dag_node.execute()) + + +@PublicAPI(stability="alpha") +class options: + """This class serves both as a decorator and options for workflow. + + Examples: + + .. testcode:: + + import ray + from ray import workflow + + # specify workflow options with a decorator + @workflow.options(catch_exceptions=True) + @ray.remote + def foo(): + return 1 + + # specify workflow options in ".options" + foo_new = foo.options(**workflow.options(catch_exceptions=False)) + """ + + def __init__(self, **workflow_options: Dict[str, Any]): + # TODO(suquark): More rigid arguments check like @ray.remote arguments. This is + # fairly complex, but we should enable it later. + valid_options = { + "task_id", + "metadata", + "catch_exceptions", + "checkpoint", + } + invalid_keywords = set(workflow_options.keys()) - valid_options + if invalid_keywords: + raise ValueError( + f"Invalid option keywords {invalid_keywords} for workflow tasks. " + f"Valid ones are {valid_options}." + ) + from ray.workflow.common import WORKFLOW_OPTIONS + + validate_user_metadata(workflow_options.get("metadata")) + + self.options = {"_metadata": {WORKFLOW_OPTIONS: workflow_options}} + + def keys(self): + return ("_metadata",) + + def __getitem__(self, key): + return self.options[key] + + def __call__(self, f: RemoteFunction) -> RemoteFunction: + if not isinstance(f, RemoteFunction): + raise ValueError("Only apply 'workflow.options' to Ray remote functions.") + f._default_options.update(self.options) + return f + + +__all__ = ( + "init", + "run", + "run_async", + "resume", + "resume_async", + "resume_all", + "cancel", + "list_all", + "delete", + "get_output", + "get_output_async", + "get_status", + "get_metadata", + "sleep", + "wait_for_event", + "options", + "continuation", +) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/common.py b/.venv/lib/python3.11/site-packages/ray/workflow/common.py new file mode 100644 index 0000000000000000000000000000000000000000..37888d7b634972ce981d1d18c7368d30a2d1e6e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/common.py @@ -0,0 +1,199 @@ +import base64 +import json + +from ray import cloudpickle +from enum import Enum, unique +import hashlib +from typing import Dict, Optional, Any, Tuple + +from dataclasses import dataclass + +import ray +from ray import ObjectRef +from ray._private.utils import get_or_create_event_loop +from ray.util.annotations import PublicAPI + +# Alias types +Event = Any +TaskID = str +WorkflowOutputType = ObjectRef + +MANAGEMENT_ACTOR_NAMESPACE = "workflow" +MANAGEMENT_ACTOR_NAME = "WorkflowManagementActor" +HTTP_EVENT_PROVIDER_NAME = "WorkflowHttpEventProvider" +STORAGE_ACTOR_NAME = "StorageManagementActor" +WORKFLOW_OPTIONS = "workflow.io/options" + + +def asyncio_run(coro): + return get_or_create_event_loop().run_until_complete(coro) + + +def validate_user_metadata(metadata): + if metadata is not None: + if not isinstance(metadata, dict): + raise ValueError("metadata must be a dict.") + try: + json.dumps(metadata) + except TypeError as e: + raise ValueError( + "metadata must be JSON serializable, instead, " + "we got 'TypeError: {}'".format(e) + ) + + +@dataclass +class WorkflowRef: + """This class represents a reference of a workflow output. + + A reference means the workflow has already been executed, + and we have both the workflow task ID and the object ref to it + living outputs. + + This could be used when you want to return a running workflow + from a workflow task. For example, the remaining workflows + returned by 'workflow.wait' contains a static ref to these + pending workflows. + """ + + # The ID of the task that produces the output of the workflow. + task_id: TaskID + # The ObjectRef of the output. If it is "None", then the output has been + # saved in the storage, and we need to check the workflow management actor + # for the object ref. + ref: Optional[ObjectRef] = None + + @classmethod + def from_output(cls, task_id: str, output: Any): + """Create static ref from given output.""" + if not isinstance(output, cls): + if not isinstance(output, ray.ObjectRef): + output = ray.put(output) + output = cls(task_id=task_id, ref=output) + return output + + def __hash__(self): + return hash(self.task_id) + + +@PublicAPI(stability="alpha") +@unique +class WorkflowStatus(str, Enum): + # No status is set for this workflow. + NONE = "NONE" + # There is at least a remote task running in ray cluster + RUNNING = "RUNNING" + # It got canceled and can't be resumed later. + CANCELED = "CANCELED" + # The workflow runs successfully. + SUCCESSFUL = "SUCCESSFUL" + # The workflow failed with an application error. + # It can be resumed. + FAILED = "FAILED" + # The workflow failed with a system error, i.e., ray shutdown. + # It can be resumed. + RESUMABLE = "RESUMABLE" + # The workflow is queued and waited to be executed. + PENDING = "PENDING" + + @classmethod + def non_terminating_status(cls) -> "Tuple[WorkflowStatus, ...]": + return cls.RUNNING, cls.PENDING + + +@unique +class TaskType(str, Enum): + """All task types.""" + + FUNCTION = "FUNCTION" + WAIT = "WAIT" + + +CheckpointModeType = bool + + +@unique +class CheckpointMode(Enum): + """All checkpoint modes.""" + + # Keep the checkpoint of the workflow task. + SYNC = True + # Skip the checkpoint of the workflow task. + SKIP = False + + +@ray.remote +def _hash(obj: Any) -> bytes: + m = hashlib.sha256() + m.update(cloudpickle.dumps(obj)) + return m.digest() + + +@ray.remote +def calculate_identifier(obj: Any) -> str: + """Calculate a url-safe identifier for an object.""" + + # Task 1: Serialize the object. + # Task 2: Calculate its sha256 hash. + # Task 3: Get the url safe, base64 representation of it. + + # TODO (Alex): Ideally we should use the existing ObjectRef serializer to + # avoid duplicate serialization passes and support nested object refs. + m = hashlib.sha256() + m.update(cloudpickle.dumps(obj)) + hash = m.digest() + encoded = base64.urlsafe_b64encode(hash).decode("ascii") + return encoded + + +@dataclass +class WorkflowTaskRuntimeOptions: + """Options that will affect a workflow task at runtime.""" + + # Type of the task. + task_type: "TaskType" + # Whether the user want to handle the exception manually. + catch_exceptions: bool + # Whether application-level errors should be retried. + retry_exceptions: bool + # The num of retry for application exceptions & system failures. + max_retries: int + # Checkpoint mode. + checkpoint: CheckpointModeType + # ray_remote options + ray_options: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + return { + "task_type": self.task_type, + "max_retries": self.max_retries, + "catch_exceptions": self.catch_exceptions, + "retry_exceptions": self.retry_exceptions, + "checkpoint": self.checkpoint, + "ray_options": self.ray_options, + } + + @classmethod + def from_dict(cls, value: Dict[str, Any]): + return cls( + task_type=TaskType[value["task_type"]], + max_retries=value["max_retries"], + catch_exceptions=value["catch_exceptions"], + retry_exceptions=value["retry_exceptions"], + checkpoint=value["checkpoint"], + ray_options=value["ray_options"], + ) + + +@dataclass +class WorkflowExecutionMetadata: + """Dataclass for the metadata of the workflow execution.""" + + # True if the workflow task returns a workflow DAG. + is_output_workflow: bool = False + + +@dataclass +class WorkflowMetaData: + # The current status of the workflow + status: WorkflowStatus diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/debug_utils.py b/.venv/lib/python3.11/site-packages/ray/workflow/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d95bb59d4132d92805450378cb5da1e6b12f57 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/debug_utils.py @@ -0,0 +1,40 @@ +"""Utils for debugging purpose.""" +import ray +from ray.dag import DAGNode, DAGInputData + +from ray.workflow.common import asyncio_run +from ray.workflow.workflow_executor import WorkflowExecutor +from ray.workflow.workflow_context import workflow_task_context, WorkflowTaskContext +from ray.workflow.workflow_storage import get_workflow_storage + + +def execute_workflow_local(dag: DAGNode, workflow_id: str, *args, **kwargs): + """Execute the workflow locally.""" + from ray.workflow.workflow_state_from_dag import workflow_state_from_dag + + job_id = ray.get_runtime_context().get_job_id() + context = WorkflowTaskContext(workflow_id=workflow_id) + with workflow_task_context(context): + wf_store = get_workflow_storage() + state = workflow_state_from_dag( + dag, DAGInputData(*args, **kwargs), workflow_id=workflow_id + ) + executor = WorkflowExecutor(state) + fut = executor.get_task_output_async(state.output_task_id) + asyncio_run(executor.run_until_complete(job_id, context, wf_store)) + return asyncio_run(fut) + + +def resume_workflow_local(workflow_id: str): + """Resume the workflow locally.""" + from ray.workflow.workflow_state_from_storage import workflow_state_from_storage + + job_id = ray.get_runtime_context().get_job_id() + context = WorkflowTaskContext(workflow_id=workflow_id) + with workflow_task_context(context): + wf_store = get_workflow_storage() + state = workflow_state_from_storage(workflow_id, None) + executor = WorkflowExecutor(state) + fut = executor.get_task_output_async(state.output_task_id) + asyncio_run(executor.run_until_complete(job_id, context, wf_store)) + return asyncio_run(fut) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/event_listener.py b/.venv/lib/python3.11/site-packages/ray/workflow/event_listener.py new file mode 100644 index 0000000000000000000000000000000000000000..03babc47b711a84dc34e86629265b2f02dc343af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/event_listener.py @@ -0,0 +1,70 @@ +import asyncio +from ray.util.annotations import PublicAPI +from ray.workflow.common import Event +import time +from typing import Callable + +EventListenerType = Callable[[], "EventListener"] + + +@PublicAPI(stability="alpha") +class EventListener: + """Defining a custom event listener. Event listeners provide an efficient way + to listen for a custom event. + + Event listeners should be stateless. They will be instantiated from a + coordinator actor. + + Example definition + ================== + + ``` + class CustomEventListener: + + def __init__(self): + self.event_provider = ... + + async def poll_for_event(self, topic, partition): + return await self.event_provider.poll(topic, partition) + + async def event_checkpointed(self, event: Event): + self.event_provider.commit(event.offset) + ``` + + Example Usage + ============= + .. testcode:: + :skipif: True + + from ray import workflow + CustomEventListener = ... + event_task = workflow.wait_for_event( + CustomEventListener, "topic1", "partition2") + handle_event = ... + workflow.run(handle_event.task(event_task)) + + """ + + def __init__(self): + """Optional constructor. Only the constructor with now arguments will be + called.""" + pass + + async def poll_for_event(self, *args, **kwargs) -> Event: + """Should return only when the event is received.""" + raise NotImplementedError + + async def event_checkpointed(self, event: Event) -> None: + """Optional. Called after an event has been checkpointed and a transaction can + be safely committed.""" + pass + + +@PublicAPI(stability="alpha") +class TimerListener(EventListener): + """ + A listener that produces an event at a given timestamp. + """ + + async def poll_for_event(self, timestamp): + await asyncio.sleep(timestamp - time.time()) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/exceptions.py b/.venv/lib/python3.11/site-packages/ray/workflow/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..dca4b4ab171744408a2ea94735911d45e520a7ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/exceptions.py @@ -0,0 +1,57 @@ +from ray.util.annotations import PublicAPI +from ray.workflow.common import TaskID + + +@PublicAPI(stability="alpha") +class WorkflowError(Exception): + """Workflow error base class.""" + + +@PublicAPI(stability="alpha") +class WorkflowExecutionError(WorkflowError): + def __init__(self, workflow_id: str): + self.message = f"Workflow[id={workflow_id}] failed during execution." + super().__init__(self.message) + + +@PublicAPI(stability="alpha") +class WorkflowCancellationError(WorkflowError): + def __init__(self, workflow_id: str): + self.message = f"Workflow[id={workflow_id}] is cancelled during execution." + super().__init__(self.message) + + +@PublicAPI(stability="alpha") +class WorkflowNotResumableError(WorkflowError): + """Raise the exception when we cannot resume from a workflow.""" + + def __init__(self, workflow_id: str): + self.message = f"Workflow[id={workflow_id}] is not resumable." + super().__init__(self.message) + + +@PublicAPI(stability="alpha") +class WorkflowTaskNotRecoverableError(WorkflowNotResumableError): + """Raise the exception when we find a workflow task cannot be recovered + using the checkpointed inputs.""" + + def __init__(self, task_id: TaskID): + self.message = f"Workflow task[id={task_id}] is not recoverable" + super(WorkflowError, self).__init__(self.message) + + +@PublicAPI(stability="alpha") +class WorkflowNotFoundError(WorkflowError): + def __init__(self, workflow_id: str): + self.message = f"Workflow[id={workflow_id}] was referenced but doesn't exist." + super().__init__(self.message) + + +@PublicAPI(stability="alpha") +class WorkflowStillActiveError(WorkflowError): + def __init__(self, operation: str, workflow_id: str): + self.message = ( + f"{operation} couldn't be completed because " + f"Workflow[id={workflow_id}] is still running or pending." + ) + super().__init__(self.message) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/http_event_provider.py b/.venv/lib/python3.11/site-packages/ray/workflow/http_event_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..5a25fc97c1ee91353213325ed01aeb66e0c37a27 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/http_event_provider.py @@ -0,0 +1,272 @@ +import asyncio +from typing import Dict +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +import ray +from ray import serve +from ray.workflow import common, workflow_context, workflow_access +from ray.workflow.event_listener import EventListener +from ray.workflow.common import Event + + +import logging + +logger = logging.getLogger(__name__) + + +class WorkflowEventHandleError(Exception): + """Raise when event processing failed""" + + def __init__(self, workflow_id: str, what_happened: str): + self.message = ( + f"Workflow[id={workflow_id}] HTTP event handle failed: {what_happened}" + ) + super().__init__(self.message) + + +app = FastAPI() + + +@serve.deployment(num_replicas=1) +@serve.ingress(app) +class HTTPEventProvider: + """HTTPEventProvider is defined to be a Ray Serve deployment with route_prefix='/event', + which will receive external events via an HTTP endpoint. It supports FastAPI, + e.g. post. It responds to both poll_for_event() and event_checkpointed() from + an HTTPListener instance. + + HTTPListener class is designed to work with the current workflow.wait_for_event() + implementation, where an HTTPListener instance will be initiated by the + get_message() and message_committed() of the workflow.wait_for_event(). + + HTTPEventProvider requires an event to arrive after HTTPListner registers + its event_key. If an event arrived before the registration, it returns HTTP + error code 404 with the error "workflow_id and event_key need to be registered + to receive event. Please make sure they are registered before resending." + + Example definition + ================== + + ``` + class HTTPEventProvider: + + def __init__(self): + + @app.post("/send_event/{workflow_id}") + async def send_event(self, workflow_id: str, req: Request): + Receive an external event message and acknowledge if it was processed + by the workflow + async def get_event_payload(self, workflow_id, event_key): + Internal method used by HTTPListner to subscribe to an event matched by + workflow_id and event_key + async def report_checkpointed(self, workflow_id, event, confirmation): + Internal method used by HTTPListner to confirm the received event has been + checkpointed by workflow + ``` + + Example Usage + ============= + .. testcode:: + :skipif: True + + from ray.workflow.http_event_provider import HTTPEventProvider, HTTPListener + ray.init(address='auto', namespace='serve') + serve.start(detached=True) + event_node = workflow.wait_for_event( + HTTPListener, event_key='') + handle_event = ... + workflow.run_aync(handle_event.bind(event_node)) + + On a separate python process, it sends an event to the HTTPEventProvider: + + .. testcode:: + :skipif: True + + import requests + resp = requests.post('http://127.0.0.1:8000/event/send_event/{workflow_id}', + json={'event_key':'my_key','event_payload':'testMessage'}) + + """ + + def __init__(self): + """Maintain two data structures to track pending events and confirmations + event_key_payload: for each registered workflow_id and event_key, + keep the Future to be set after an event is received. + event_checkpoint_pending: for each received event_key, keep its Future + after checkpointing is confirmed so HTTP 200 can be returned. + """ + self.event_key_payload: Dict[str, Dict[str, asyncio.Future]] = {} + self.event_checkpoint_pending: Dict[str, asyncio.Future] = {} + + @app.post("/send_event/{workflow_id}") + async def send_event(self, workflow_id: str, req: Request) -> JSONResponse: + """Receive an external event message and acknowledge if it was processed + by the workflow + Args: + workflow_id: the workflow that this event is submitted for + req: the JSON formatted request that contains two string fields: ' + event_key' and 'event_payload' + 'event_key' uniquely identifies a node in the receiving workflow; + 'event_payload' refers to the event's content + Example: + JSON formatted request {"event_key":"node_event","event_payload":"approved"} + Returns: + if the event was received and processed, HTTP response status 200 + if the event was not expected or the workflow_id did not exist, HTTP + response status 404 + if the event was received but failed at checkpointing, HTTP response 500 + + """ + req_json = await req.json() + try: + event_key = req_json["event_key"] + event_payload = req_json["event_payload"] + except KeyError as e: + return JSONResponse( + status_code=404, + content={ + "error": { + "code": 404, + "message": f"{e} field is not found in the request JSON", + } + }, + ) + try: + self.event_key_payload[workflow_id][event_key].set_result( + (event_key, event_payload) + ) + except KeyError: + return JSONResponse( + status_code=404, + content={ + "error": { + "code": 404, + "message": "workflow_id and event_key need to be registered " + "to receive event. Please make sure they are " + "registered before resending.", + } + }, + ) + + self.event_checkpoint_pending[event_key] = asyncio.Future() + confirmed = await self.event_checkpoint_pending[event_key] + self.event_checkpoint_pending.pop(event_key) + if confirmed: + return JSONResponse(status_code=200, content={}) + return JSONResponse( + status_code=500, + content={"error": {"code": 500, "message": "event processing failed"}}, + ) + + async def get_event_payload(self, workflow_id: str, event_key: str) -> Event: + """Internal method used by HTTPListener to subscribe to an event matched + by workflow_id and event_key""" + if workflow_id not in self.event_key_payload: + self.event_key_payload[workflow_id] = {} + + if event_key in self.event_key_payload[workflow_id]: + raise WorkflowEventHandleError( + workflow_id, f"The same {event_key} is used to get payload again." + ) + + self.event_key_payload[workflow_id][event_key] = asyncio.Future() + return await self.event_key_payload[workflow_id][event_key] + + async def report_checkpointed( + self, workflow_id: str, event_key: str, confirmation: bool + ) -> str: + """Internal method used by HTTPListner to confirm the received event has + been checkpointed by workflow""" + try: + self.event_checkpoint_pending[event_key].set_result(confirmation) + except KeyError: + logger.error( + f"{event_key} cannot be found to acknowledge request. " + f"The event provider may have been restarted." + ) + raise WorkflowEventHandleError( + workflow_id, f"{event_key} cannot be found to acknowledge request." + ) + return "OK" + + +class HTTPListener(EventListener): + """HTTPLister is defined to work with the HTTPEventProvider. It implements two + APIs, poll_for_event() and event_checkpointed(). An instance of HTTPListener will + be started by the get_message() of the workflow.wait_for_event() to listen for + an event from the HTTPEventProvider instance (a Ray Serve deployment). Another + instance of HTTPListener will be started by the message_committed() of the + workflow.wait_for_event() to confirm that the event has been checkpointed. + + + Example definition + ================== + + ``` + class HTTPListener: + + def __init__(self): + + async def poll_for_event(self, event_key) -> Event: + + async def event_checkpointed(self, event) -> None: + + ``` + + Example Usage + ============= + + .. testcode:: + + import tempfile + from ray import workflow + from ray.workflow.http_event_provider import HTTPListener + + temp_dir = tempfile.TemporaryDirectory() + ray.init(storage=f"file://{temp_dir.name}") + + serve.start(detached=True) + event_node = workflow.wait_for_event(HTTPListener, event_key='') + + @ray.remote + def handle_event(arg): + return arg + + workflow.run_async(handle_event.bind(event_node), workflow_id="http_listener") + """ + + def __init__(self): + super().__init__() + try: + self.handle = ray.serve.get_app_handle(common.HTTP_EVENT_PROVIDER_NAME) + except ray.serve.exceptions.RayServeException: + mgr = workflow_access.get_management_actor() + ray.get(mgr.create_http_event_provider.remote()) + self.handle = ray.serve.get_app_handle(common.HTTP_EVENT_PROVIDER_NAME) + + async def poll_for_event(self, event_key: str = None) -> Event: + """workflow.wait_for_event calls this method to subscribe to the + HTTPEventProvider and return the received external event + Args: + event_key: a unique identifier to the receiving node in a workflow; + if missing, default to current workflow task id + Returns: + tuple(event_key, event_payload) + """ + workflow_id = workflow_context.get_current_workflow_id() + if event_key is None: + event_key = workflow_context.get_current_task_id() + + event_key_payload = await self.handle.get_event_payload.remote( + workflow_id, event_key + ) + return event_key_payload + + async def event_checkpointed(self, event: Event) -> None: + """workflow.wait_for_event calls this method after the event has + been checkpointed and a transaction can be safely committed.""" + (event_key, _) = event + await self.handle.report_checkpointed.remote( + workflow_context.get_current_workflow_id(), event_key, True + ) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/serialization.py b/.venv/lib/python3.11/site-packages/ray/workflow/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..f858577bc9df875abb3dde26828eb0aab21047b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/serialization.py @@ -0,0 +1,235 @@ +import contextlib +from dataclasses import dataclass +import logging +import os + +import ray +from ray import cloudpickle +from ray.types import ObjectRef +from ray.workflow import common, workflow_storage +from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING + +from collections import ChainMap +import io + +if TYPE_CHECKING: + from ray.actor import ActorHandle + +logger = logging.getLogger(__name__) + + +def init_manager() -> None: + get_or_create_manager(warn_on_creation=False) + + +def get_or_create_manager(warn_on_creation: bool = True) -> "ActorHandle": + """Get or create the storage manager.""" + # TODO(suquark): We should not get the actor everytime. We also need to + # resume the actor if it failed. Using a global variable to cache the + # actor seems not enough to resume the actor, because there is no + # aliveness detection for an actor. + try: + return ray.get_actor( + common.STORAGE_ACTOR_NAME, namespace=common.MANAGEMENT_ACTOR_NAMESPACE + ) + except ValueError: + if warn_on_creation: + logger.warning( + "Cannot access workflow serialization manager. It " + "could be because " + "the workflow manager exited unexpectedly. A new " + "workflow manager is being created. " + ) + handle = Manager.options( + name=common.STORAGE_ACTOR_NAME, + namespace=common.MANAGEMENT_ACTOR_NAMESPACE, + lifetime="detached", + ).remote() + ray.get(handle.ping.remote()) + return handle + + +@dataclass +class Upload: + identifier_ref: ObjectRef[str] + upload_task: ObjectRef[None] + + +@ray.remote(num_cpus=0) +class Manager: + """ + Responsible for deduping the serialization/upload of object references. + """ + + def __init__(self): + self._uploads: Dict[ray.ObjectRef, Upload] = {} + self._num_uploads = 0 + + def ping(self) -> None: + """ + Trivial function to ensure actor creation is successful. + """ + return None + + async def save_objectref( + self, ref_tuple: Tuple[ray.ObjectRef], workflow_id: "str" + ) -> Tuple[List[str], ray.ObjectRef]: + """Serialize and upload an object reference exactly once. + + Args: + ref_tuple: A 1-element tuple which wraps the reference. + + Returns: + A pair. The first element is the paths the ref will be uploaded to. + The second is an object reference to the upload task. + """ + (ref,) = ref_tuple + # Use the hex as the key to avoid holding a reference to the object. + key = (ref.hex(), workflow_id) + + if key not in self._uploads: + # TODO(Alex): We should probably eventually free these refs. + identifier_ref = common.calculate_identifier.remote(ref) + upload_task = _put_helper.remote(identifier_ref, ref, workflow_id) + self._uploads[key] = Upload( + identifier_ref=identifier_ref, upload_task=upload_task + ) + self._num_uploads += 1 + + info = self._uploads[key] + identifer = await info.identifier_ref + key = _obj_id_to_key(identifer) + return key, info.upload_task + + async def export_stats(self) -> Dict[str, Any]: + return {"num_uploads": self._num_uploads} + + +OBJECTS_DIR = "objects" + + +def _obj_id_to_key(object_id: str) -> str: + return os.path.join(OBJECTS_DIR, object_id) + + +@ray.remote(num_cpus=0) +def _put_helper(identifier: str, obj: Any, workflow_id: str) -> None: + # TODO (Alex): This check isn't sufficient, it only works for directly + # nested object refs. + if isinstance(obj, ray.ObjectRef): + raise NotImplementedError( + "Workflow does not support checkpointing nested object references yet." + ) + key = _obj_id_to_key(identifier) + + dump_to_storage( + key, + obj, + workflow_id, + workflow_storage.WorkflowStorage(workflow_id), + update_existing=False, + ) + + +def _reduce_objectref( + workflow_id: str, + obj_ref: ObjectRef, + tasks: List[ObjectRef], +): + manager = get_or_create_manager() + key, task = ray.get(manager.save_objectref.remote((obj_ref,), workflow_id)) + + assert task + tasks.append(task) + + return _load_object_ref, (key, workflow_id) + + +def dump_to_storage( + key: str, + obj: Any, + workflow_id: str, + storage: "workflow_storage.WorkflowStorage", + update_existing=True, +) -> None: + """Serializes and puts arbitrary object, handling references. The object will + be uploaded at `paths`. Any object references will be uploaded to their + global, remote storage. + + Args: + key: The key of the object. + obj: The object to serialize. If it contains object references, those + will be serialized too. + workflow_id: The workflow id. + storage: The storage to use. If obj contains object references, + `storage.put` will be called on them individually. + update_existing: If False, the object will not be uploaded if the path + exists. + """ + if not update_existing: + if storage._exists(key): + return + + tasks = [] + + # NOTE: Cloudpickle doesn't support private dispatch tables, so we extend + # the cloudpickler instead to avoid changing cloudpickle's global dispatch + # table which is shared with `ray.put`. See + # https://github.com/cloudpipe/cloudpickle/issues/437 + class ObjectRefPickler(cloudpickle.CloudPickler): + _object_ref_reducer = { + ray.ObjectRef: lambda ref: _reduce_objectref(workflow_id, ref, tasks) + } + dispatch_table = ChainMap( + _object_ref_reducer, cloudpickle.CloudPickler.dispatch_table + ) + dispatch = dispatch_table + + ray.get(tasks) + + # TODO(Alex): We should be able to do this without the extra buffer. + with io.BytesIO() as f: + pickler = ObjectRefPickler(f) + pickler.dump(obj) + f.seek(0) + # use the underlying storage to avoid cyclic calls of "dump_to_storage" + storage._storage.put(key, f.read()) + + +@ray.remote +def _load_ref_helper(key: str, workflow_id: str): + # TODO(Alex): We should stream the data directly into `cloudpickle.load`. + storage = workflow_storage.WorkflowStorage(workflow_id) + return storage._get(key) + + +# TODO (Alex): We should use weakrefs here instead requiring a context manager. +_object_cache: Optional[Dict[str, ray.ObjectRef]] = None + + +def _load_object_ref(key: str, workflow_id: str) -> ray.ObjectRef: + global _object_cache + if _object_cache is None: + return _load_ref_helper.remote(key, workflow_id) + + if _object_cache is None: + return _load_ref_helper.remote(key, workflow_id) + + if key not in _object_cache: + _object_cache[key] = _load_ref_helper.remote(key, workflow_id) + + return _object_cache[key] + + +@contextlib.contextmanager +def objectref_cache() -> Generator: + """A reentrant caching context for object refs.""" + global _object_cache + clear_cache = _object_cache is None + if clear_cache: + _object_cache = {} + try: + yield + finally: + if clear_cache: + _object_cache = None diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/serialization_context.py b/.venv/lib/python3.11/site-packages/ray/workflow/serialization_context.py new file mode 100644 index 0000000000000000000000000000000000000000..1c70ae5f4527398567a92d6a55a675c8747490cb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/serialization_context.py @@ -0,0 +1,112 @@ +import contextlib +from typing import List, Any, Dict + +from ray.util.serialization import register_serializer, deregister_serializer +from ray.workflow.common import WorkflowRef + + +def _resolve_workflow_refs(index: int) -> Any: + raise ValueError("There is no context for resolving workflow refs.") + + +@contextlib.contextmanager +def workflow_args_serialization_context(workflow_refs: List[WorkflowRef]) -> None: + """ + This serialization context reduces workflow input arguments to three + parts: + + 1. A workflow input placeholder. It is an object without 'Workflow' and + 'ObjectRef' object. They are replaced with integer indices. During + deserialization, we can refill the placeholder with a list of + 'Workflow' and a list of 'ObjectRef'. This provides us great + flexibility, for example, during recovery we can plug an alternative + list of 'Workflow' and 'ObjectRef', since we lose the original ones. + 2. A list of 'Workflow'. There is no duplication in it. + 3. A list of 'ObjectRef'. There is no duplication in it. + + We do not allow duplication because in the arguments duplicated workflows + and object refs are shared by reference. So when deserialized, we also + want them to be shared by reference. See + "tests/test_object_deref.py:deref_shared" as an example. + + The deduplication works like this: + Inputs: [A B A B C C A] + Output List: [A B C] + Index in placeholder: [0 1 0 1 2 2 0] + + Args: + workflow_refs: Output list of workflows or references to workflows. + """ + deduplicator: Dict[WorkflowRef, int] = {} + + def serializer(w): + if w in deduplicator: + return deduplicator[w] + if isinstance(w, WorkflowRef): + # The ref should be resolved by the workflow management actor + # when treated as the input of a workflow, so we remove the ref here. + w.ref = None + i = len(workflow_refs) + workflow_refs.append(w) + deduplicator[w] = i + return i + + register_serializer( + WorkflowRef, + serializer=serializer, + deserializer=_resolve_workflow_refs, + ) + + try: + yield + finally: + # we do not want to serialize Workflow objects in other places. + deregister_serializer(WorkflowRef) + + +@contextlib.contextmanager +def workflow_args_resolving_context(workflow_ref_mapping: List[Any]) -> None: + """ + This context resolves workflows and object refs inside workflow + arguments into correct values. + + Args: + workflow_ref_mapping: List of workflow refs. + """ + global _resolve_workflow_refs + _resolve_workflow_refs_bak = _resolve_workflow_refs + _resolve_workflow_refs = workflow_ref_mapping.__getitem__ + + try: + yield + finally: + _resolve_workflow_refs = _resolve_workflow_refs_bak + + +class _KeepWorkflowRefs: + def __init__(self, index: int): + self._index = index + + def __reduce__(self): + return _resolve_workflow_refs, (self._index,) + + +@contextlib.contextmanager +def workflow_args_keeping_context() -> None: + """ + This context only read workflow arguments. Workflows inside + are untouched and can be serialized again properly. + """ + global _resolve_workflow_refs + _resolve_workflow_refs_bak = _resolve_workflow_refs + + # we must capture the old functions to prevent self-referencing. + def _keep_workflow_refs(index: int): + return _KeepWorkflowRefs(index) + + _resolve_workflow_refs = _keep_workflow_refs + + try: + yield + finally: + _resolve_workflow_refs = _resolve_workflow_refs_bak diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/task_executor.py b/.venv/lib/python3.11/site-packages/ray/workflow/task_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c921622d0cc8cd4c9582d7ab4c3f64581245e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/task_executor.py @@ -0,0 +1,163 @@ +import time +from dataclasses import dataclass +import logging +from typing import List, Tuple, Any, Dict, Callable, TYPE_CHECKING +import ray +from ray import ObjectRef +from ray._private import signature + +from ray.dag import DAGNode +from ray.workflow import workflow_context +from ray.workflow.workflow_context import get_task_status_info +from ray.workflow import serialization_context +from ray.workflow import workflow_storage + +from ray.workflow.common import ( + WorkflowStatus, + WorkflowExecutionMetadata, + TaskType, + TaskID, + WorkflowRef, + CheckpointMode, +) +from ray.workflow.workflow_state import WorkflowExecutionState +from ray.workflow.workflow_state_from_dag import workflow_state_from_dag + +if TYPE_CHECKING: + from ray.workflow.common import ( + WorkflowTaskRuntimeOptions, + ) + from ray.workflow.workflow_context import WorkflowTaskContext + + +logger = logging.getLogger(__name__) + + +def get_task_executor(task_options: "WorkflowTaskRuntimeOptions"): + if task_options.task_type == TaskType.FUNCTION: + # prevent automatic lineage reconstruction + task_options.ray_options["max_retries"] = 0 + # prevent retrying exception by Ray + task_options.ray_options["retry_exceptions"] = False + executor = _workflow_task_executor_remote.options( + **task_options.ray_options + ).remote + else: + raise ValueError(f"Invalid task type {task_options.task_type}") + return executor + + +def _workflow_task_executor( + func: Callable, + context: "WorkflowTaskContext", + task_id: "TaskID", + baked_inputs: "_BakedWorkflowInputs", + runtime_options: "WorkflowTaskRuntimeOptions", +) -> Tuple[Any, Any]: + """Executor function for workflow task. + + Args: + task_id: ID of the task. + func: The workflow task function. + baked_inputs: The processed inputs for the task. + context: Workflow task context. Used to access correct storage etc. + runtime_options: Parameters for workflow task execution. + + Returns: + Workflow task output. + """ + with workflow_context.workflow_task_context(context): + store = workflow_storage.get_workflow_storage() + # Part 1: resolve inputs + args, kwargs = baked_inputs.resolve(store) + + # Part 2: execute the task + try: + store.save_task_prerun_metadata(task_id, {"start_time": time.time()}) + with workflow_context.workflow_execution(): + logger.info(f"{get_task_status_info(WorkflowStatus.RUNNING)}") + output = func(*args, **kwargs) + store.save_task_postrun_metadata(task_id, {"end_time": time.time()}) + except Exception as e: + # Always checkpoint the exception. + store.save_task_output(task_id, None, exception=e) + raise e + + if isinstance(output, DAGNode): + output = workflow_state_from_dag(output, None, context.workflow_id) + execution_metadata = WorkflowExecutionMetadata(is_output_workflow=True) + else: + execution_metadata = WorkflowExecutionMetadata() + if runtime_options.catch_exceptions: + output = (output, None) + + # Part 3: save outputs + # TODO(suquark): Validate checkpoint options before commit the task. + if CheckpointMode(runtime_options.checkpoint) == CheckpointMode.SYNC: + if isinstance(output, WorkflowExecutionState): + store.save_workflow_execution_state(task_id, output) + else: + store.save_task_output(task_id, output, exception=None) + return execution_metadata, output + + +@ray.remote(num_returns=2) +def _workflow_task_executor_remote( + func: Callable, + context: "WorkflowTaskContext", + job_id: str, + task_id: "TaskID", + baked_inputs: "_BakedWorkflowInputs", + runtime_options: "WorkflowTaskRuntimeOptions", +) -> Any: + """The remote version of '_workflow_task_executor'.""" + with workflow_context.workflow_logging_context(job_id): + return _workflow_task_executor( + func, context, task_id, baked_inputs, runtime_options + ) + + +@dataclass +class _BakedWorkflowInputs: + """This class stores pre-processed inputs for workflow task execution. + Especially, all input workflows to the workflow task will be scheduled, + and their outputs (ObjectRefs) replace the original workflows.""" + + args: "ObjectRef" + workflow_refs: "List[WorkflowRef]" + + def resolve(self, store: workflow_storage.WorkflowStorage) -> Tuple[List, Dict]: + """ + This function resolves the inputs for the code inside + a workflow task (works on the callee side). For outputs from other + workflows, we resolve them into object instances inplace. + + For each ObjectRef argument, the function returns both the ObjectRef + and the object instance. If the ObjectRef is a chain of nested + ObjectRefs, then we resolve it recursively until we get the + object instance, and we return the *direct* ObjectRef of the + instance. This function does not resolve ObjectRef + inside another object (e.g. list of ObjectRefs) to give users some + flexibility. + + Returns: + Instances of arguments. + """ + workflow_ref_mapping = [] + for r in self.workflow_refs: + if r.ref is None: + workflow_ref_mapping.append(store.load_task_output(r.task_id)) + else: + workflow_ref_mapping.append(r.ref) + + with serialization_context.workflow_args_resolving_context( + workflow_ref_mapping + ): + # reconstruct input arguments under correct serialization context + flattened_args: List[Any] = ray.get(self.args) + + # dereference arguments like Ray remote functions + flattened_args = [ + ray.get(a) if isinstance(a, ObjectRef) else a for a in flattened_args + ] + return signature.recover_args(flattened_args) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_access.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_access.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ec59e8052ffcbe7ec0c03699b24f9a9f28be6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_access.py @@ -0,0 +1,379 @@ +import asyncio +import logging +import queue +from typing import Dict, List, Set, Optional, TYPE_CHECKING + +import ray + +from ray.workflow import common +from ray.workflow.common import WorkflowStatus, TaskID +from ray.workflow import workflow_state_from_storage +from ray.workflow import workflow_context +from ray.workflow import workflow_storage +from ray.workflow.exceptions import ( + WorkflowCancellationError, + WorkflowNotFoundError, + WorkflowNotResumableError, + WorkflowStillActiveError, +) +from ray.workflow.workflow_executor import WorkflowExecutor +from ray.workflow.workflow_state import WorkflowExecutionState +from ray.workflow.workflow_context import WorkflowTaskContext + +if TYPE_CHECKING: + from ray.actor import ActorHandle + +logger = logging.getLogger(__name__) + + +class SelfResolvingObject: + def __init__(self, x): + self.x = x + + def __reduce__(self): + return ray.get, (self.x,) + + +@ray.remote(num_cpus=0) +def load_task_output_from_storage(workflow_id: str, task_id: Optional[TaskID]): + wf_store = workflow_storage.WorkflowStorage(workflow_id) + tid = wf_store.inspect_output(task_id) + if tid is not None: + return wf_store.load_task_output(tid) + # TODO(suquark): Unify the error from "workflow.get_output" & "workflow.run_async". + # Currently they could be different, because "workflow.get_output" could + # get the output from a stopped workflow, it does not may sense to raise + # "WorkflowExecutionError" as the workflow is not running. + if task_id is not None: + raise ValueError( + f"Cannot load output from task id '{task_id}' in workflow '{workflow_id}'" + ) + else: + raise ValueError(f"Cannot load output from workflow '{workflow_id}'") + + +@ray.remote(num_cpus=0) +def resume_workflow_task( + job_id: str, + workflow_id: str, + task_id: Optional[TaskID] = None, +) -> WorkflowExecutionState: + """Resume a task of a workflow. + + Args: + job_id: The ID of the job that submits the workflow execution. The ID + is used to identify the submitter of the workflow. + workflow_id: The ID of the workflow job. The ID is used to identify + the workflow. + task_id: The task to resume in the workflow. + + Raises: + WorkflowNotResumableException: fail to resume the workflow. + + Returns: + The execution result of the workflow, represented by Ray ObjectRef. + """ + with workflow_context.workflow_logging_context(job_id): + try: + return workflow_state_from_storage.workflow_state_from_storage( + workflow_id, task_id + ) + except Exception as e: + raise WorkflowNotResumableError(workflow_id) from e + + +# TODO(suquark): we may use an actor pool in the future if too much +# concurrent workflow access blocks the actor. +@ray.remote(num_cpus=0) +class WorkflowManagementActor: + """Keep the ownership and manage the workflow output.""" + + def __init__(self, max_running_workflows: int, max_pending_workflows: int): + self._workflow_executors: Dict[str, WorkflowExecutor] = {} + + self._max_running_workflows: int = max_running_workflows + self._max_pending_workflows: int = max_pending_workflows + + # 0 means infinite for queue + self._workflow_queue = queue.Queue( + max_pending_workflows if max_pending_workflows != -1 else 0 + ) + + self._running_workflows: Set[str] = set() + self._queued_workflows: Dict[str, asyncio.Future] = {} + # TODO(suquark): We do not cleanup "_executed_workflows" because we need to + # know if users are running the same workflow again long after a workflow + # completes. One possible alternative solution is to check the workflow + # status in the storage. + self._executed_workflows: Set[str] = set() + + def validate_init_options( + self, max_running_workflows: Optional[int], max_pending_workflows: Optional[int] + ): + if ( + max_running_workflows is not None + and max_running_workflows != self._max_running_workflows + ) or ( + max_pending_workflows is not None + and max_pending_workflows != self._max_pending_workflows + ): + raise ValueError( + "The workflow init is called again but the init options" + "does not match the original ones. Original options: " + f"max_running_workflows={self._max_running_workflows} " + f"max_pending_workflows={self._max_pending_workflows}; " + f"New options: max_running_workflows={max_running_workflows} " + f"max_pending_workflows={max_pending_workflows}." + ) + + def gen_task_id(self, workflow_id: str, task_name: str) -> str: + wf_store = workflow_storage.WorkflowStorage(workflow_id) + idx = wf_store.gen_task_id(task_name) + if idx == 0: + return task_name + else: + return f"{task_name}_{idx}" + + def submit_workflow( + self, + workflow_id: str, + state: WorkflowExecutionState, + ignore_existing: bool = False, + ): + """Submit workflow. A submitted workflow can be executed later. + + Args: + workflow_id: ID of the workflow. + state: The initial state of the workflow. + ignore_existing: Ignore existing executed workflows. + """ + if workflow_id in self._workflow_executors: + raise RuntimeError(f"Workflow[id={workflow_id}] is being executed.") + if workflow_id in self._executed_workflows and not ignore_existing: + raise RuntimeError(f"Workflow[id={workflow_id}] has been executed.") + + if state.output_task_id is None: + raise ValueError( + "No root DAG specified that generates output for the workflow." + ) + + wf_store = workflow_storage.WorkflowStorage(workflow_id) + if ( + self._max_running_workflows != -1 + and len(self._running_workflows) >= self._max_running_workflows + ): + try: + self._workflow_queue.put_nowait(workflow_id) + self._queued_workflows[workflow_id] = asyncio.Future() + wf_store.update_workflow_status(WorkflowStatus.PENDING) + except queue.Full: + # override with our error message + raise queue.Full("Workflow queue has been full") from None + else: + self._running_workflows.add(workflow_id) + wf_store.update_workflow_status(WorkflowStatus.RUNNING) + # initialize executor + self._workflow_executors[workflow_id] = WorkflowExecutor(state) + + async def reconstruct_workflow( + self, job_id: str, context: WorkflowTaskContext + ) -> None: + """Reconstruct a (failed) workflow and submit it.""" + state = await resume_workflow_task.remote(job_id, context.workflow_id) + self.submit_workflow(context.workflow_id, state, ignore_existing=True) + + async def execute_workflow( + self, + job_id: str, + context: WorkflowTaskContext, + ) -> ray.ObjectRef: + """Execute a submitted workflow. + + Args: + job_id: The ID of the job for logging. + context: The execution context. + Returns: + An object ref that represent the result. + """ + workflow_id = context.workflow_id + if workflow_id not in self._workflow_executors: + raise RuntimeError(f"Workflow '{workflow_id}' has not been submitted.") + + pending_fut = self._queued_workflows.get(workflow_id) + if pending_fut is not None: + await pending_fut # wait until this workflow is ready to go + + wf_store = workflow_storage.WorkflowStorage(workflow_id) + executor = self._workflow_executors[workflow_id] + try: + await executor.run_until_complete(job_id, context, wf_store) + return await self.get_output(workflow_id, executor.output_task_id) + finally: + self._workflow_executors.pop(workflow_id) + self._running_workflows.remove(workflow_id) + self._executed_workflows.add(workflow_id) + if not self._workflow_queue.empty(): + # schedule another workflow from the pending queue + next_workflow_id = self._workflow_queue.get_nowait() + self._running_workflows.add(next_workflow_id) + fut = self._queued_workflows.pop(next_workflow_id) + fut.set_result(None) + + async def cancel_workflow(self, workflow_id: str) -> None: + """Cancel workflow execution.""" + if workflow_id in self._workflow_executors: + executor = self._workflow_executors[workflow_id] + fut = executor.get_task_output_async(executor.output_task_id) + executor.cancel() + try: + # Wait until cancelled, otherwise workflow status may not + # get updated after "workflow.cancel()" is called. + await fut + except WorkflowCancellationError: + pass + else: + wf_store = workflow_storage.WorkflowStorage(workflow_id) + wf_store.update_workflow_status(WorkflowStatus.CANCELED) + + def get_workflow_status(self, workflow_id: str) -> WorkflowStatus: + """Get the status of the workflow.""" + if workflow_id in self._workflow_executors: + if workflow_id in self._queued_workflows: + return WorkflowStatus.PENDING + return WorkflowStatus.RUNNING + store = workflow_storage.get_workflow_storage(workflow_id) + status = store.load_workflow_status() + if status == WorkflowStatus.NONE: + raise WorkflowNotFoundError(workflow_id) + elif status in WorkflowStatus.non_terminating_status(): + return WorkflowStatus.RESUMABLE + return status + + def is_workflow_non_terminating(self, workflow_id: str) -> bool: + """True if the workflow is still running or pending.""" + return workflow_id in self._workflow_executors + + def list_non_terminating_workflows(self) -> Dict[WorkflowStatus, List[str]]: + """List workflows whose status are not of terminated status.""" + result = {WorkflowStatus.RUNNING: [], WorkflowStatus.PENDING: []} + for wf in self._workflow_executors.keys(): + if wf in self._running_workflows: + result[WorkflowStatus.RUNNING].append(wf) + else: + result[WorkflowStatus.PENDING].append(wf) + return result + + async def get_output( + self, workflow_id: str, task_id: Optional[TaskID] + ) -> ray.ObjectRef: + """Get the output of a running workflow. + + Args: + workflow_id: The ID of a workflow job. + task_id: If set, fetch the specific task output instead of the output + of the workflow. + + Returns: + An object reference that can be used to retrieve the workflow result. + """ + ref = None + if self.is_workflow_non_terminating(workflow_id): + executor = self._workflow_executors[workflow_id] + if task_id is None: + task_id = executor.output_task_id + workflow_ref = await executor.get_task_output_async(task_id) + task_id, ref = workflow_ref.task_id, workflow_ref.ref + if ref is None: + wf_store = workflow_storage.WorkflowStorage(workflow_id) + tid = wf_store.inspect_output(task_id) + if tid is not None: + ref = load_task_output_from_storage.remote(workflow_id, task_id) + elif task_id is not None: + raise ValueError( + f"Cannot load output from task id '{task_id}' in workflow " + f"'{workflow_id}'" + ) + else: + raise ValueError(f"Cannot load output from workflow '{workflow_id}'") + return SelfResolvingObject(ref) + + def delete_workflow(self, workflow_id: str) -> None: + """Delete a workflow, its checkpoints, and other information it may have + persisted to storage. + + Args: + workflow_id: The workflow to delete. + + Raises: + WorkflowStillActiveError: The workflow is still active. + WorkflowNotFoundError: The workflow does not exist. + """ + if self.is_workflow_non_terminating(workflow_id): + raise WorkflowStillActiveError("DELETE", workflow_id) + wf_storage = workflow_storage.WorkflowStorage(workflow_id) + wf_storage.delete_workflow() + self._executed_workflows.discard(workflow_id) + + def create_http_event_provider(self) -> None: + """Deploy an HTTPEventProvider as a Serve deployment with + name = common.HTTP_EVENT_PROVIDER_NAME, if one doesn't exist + """ + ray.serve.start(detached=True) + provider_exists = ( + common.HTTP_EVENT_PROVIDER_NAME in ray.serve.status().applications + ) + if not provider_exists: + from ray.workflow.http_event_provider import HTTPEventProvider + + ray.serve.run( + HTTPEventProvider.bind(), + name=common.HTTP_EVENT_PROVIDER_NAME, + route_prefix="/event", + ) + + def ready(self) -> None: + """A no-op to make sure the actor is ready.""" + + +def init_management_actor( + max_running_workflows: Optional[int], max_pending_workflows: Optional[int] +) -> None: + """Initialize WorkflowManagementActor. + + Args: + max_running_workflows: The maximum number of concurrently running workflows. + Use -1 as infinity. Use 'None' for keeping the original value if the actor + exists, or it is equivalent to infinity if the actor does not exist. + max_pending_workflows: The maximum number of queued workflows. + Use -1 as infinity. Use 'None' for keeping the original value if the actor + exists, or it is equivalent to infinity if the actor does not exist. + """ + try: + actor = get_management_actor() + # Check if max_running_workflows/max_pending_workflows + # matches the previous settings. + ray.get( + actor.validate_init_options.remote( + max_running_workflows, max_pending_workflows + ) + ) + except ValueError: + logger.info("Initializing workflow manager...") + if max_running_workflows is None: + max_running_workflows = -1 + if max_pending_workflows is None: + max_pending_workflows = -1 + # the actor does not exist + actor = WorkflowManagementActor.options( + name=common.MANAGEMENT_ACTOR_NAME, + namespace=common.MANAGEMENT_ACTOR_NAMESPACE, + lifetime="detached", + ).remote(max_running_workflows, max_pending_workflows) + # No-op to ensure the actor is created before the driver exits. + ray.get(actor.ready.remote()) + + +def get_management_actor() -> "ActorHandle": + return ray.get_actor( + common.MANAGEMENT_ACTOR_NAME, namespace=common.MANAGEMENT_ACTOR_NAMESPACE + ) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_context.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_context.py new file mode 100644 index 0000000000000000000000000000000000000000..7c797e44679431f9c0642b1f67984a4ac788c03b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_context.py @@ -0,0 +1,123 @@ +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import ray +from ray._private.ray_logging import configure_log_file, get_worker_log_file_name +from ray.workflow.common import CheckpointModeType, WorkflowStatus + +logger = logging.getLogger(__name__) + + +@dataclass +class WorkflowTaskContext: + """ + The structure for saving workflow task context. The context provides + critical info (e.g. where to checkpoint, which is its parent task) + for the task to execute correctly. + """ + + # ID of the workflow. + workflow_id: Optional[str] = None + # ID of the current task. + task_id: str = "" + # ID of the task that creates the current task. + creator_task_id: str = "" + # The checkpoint context of parent workflow tasks. + checkpoint: CheckpointModeType = True + # The context of catching exceptions. + catch_exceptions: bool = False + + +_context: Optional[WorkflowTaskContext] = None + + +@contextmanager +def workflow_task_context(context) -> None: + """Initialize the workflow task context. + + Args: + context: The new context. + """ + global _context + original_context = _context + try: + _context = context + yield + finally: + _context = original_context + + +def get_workflow_task_context() -> Optional[WorkflowTaskContext]: + return _context + + +def get_current_task_id() -> str: + """Get the current workflow task ID. Empty means we are in + the workflow job driver.""" + return get_workflow_task_context().task_id + + +def get_current_workflow_id() -> str: + assert _context is not None + return _context.workflow_id + + +def get_name() -> str: + return f"{get_current_workflow_id()}@{get_current_task_id()}" + + +def get_task_status_info(status: WorkflowStatus) -> str: + assert _context is not None + return f"Task status [{status.value}]\t[{get_name()}]" + + +_in_workflow_execution = False + + +@contextmanager +def workflow_execution() -> None: + """Scope for workflow task execution.""" + global _in_workflow_execution + try: + _in_workflow_execution = True + yield + finally: + _in_workflow_execution = False + + +def in_workflow_execution() -> bool: + """Whether we are in workflow task execution.""" + global _in_workflow_execution + return _in_workflow_execution + + +@contextmanager +def workflow_logging_context(job_id) -> None: + """Initialize the workflow logging context. + + Workflow executions are running as remote functions from + WorkflowManagementActor. Without logging redirection, workflow + inner execution logs will be pushed to the driver that initially + created WorkflowManagementActor rather than the driver that + actually submits the current workflow execution. + We use this conext manager to re-configure the log files to send + the logs to the correct driver, and to restore the log files once + the execution is done. + + Args: + job_id: The ID of the job that submits the workflow execution. + """ + node = ray._private.worker._global_node + original_out_file, original_err_file = node.get_log_file_handles( + get_worker_log_file_name("WORKER") + ) + out_file, err_file = node.get_log_file_handles( + get_worker_log_file_name("WORKER", job_id) + ) + try: + configure_log_file(out_file, err_file) + yield + finally: + configure_log_file(original_out_file, original_err_file) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_executor.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0637e06881520346af76f4f386287f0f986012 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_executor.py @@ -0,0 +1,433 @@ +from typing import Dict, List, Iterator, Optional, Tuple, TYPE_CHECKING + +import asyncio +import logging +import time +from collections import defaultdict + +import ray +from ray.exceptions import RayTaskError, RayError + +from ray.workflow.common import ( + WorkflowRef, + WorkflowExecutionMetadata, + WorkflowStatus, + TaskID, +) +from ray.workflow.exceptions import WorkflowCancellationError, WorkflowExecutionError +from ray.workflow.task_executor import get_task_executor, _BakedWorkflowInputs +from ray.workflow.workflow_state import ( + WorkflowExecutionState, + TaskExecutionMetadata, + Task, +) + +if TYPE_CHECKING: + from ray.workflow.workflow_context import WorkflowTaskContext + from ray.workflow.workflow_storage import WorkflowStorage + +logger = logging.getLogger(__name__) + + +class WorkflowExecutor: + def __init__( + self, + state: WorkflowExecutionState, + ): + """The core logic of executing a workflow. + + This class is responsible for: + + - Dependency resolving. + - Task scheduling. + - Reference counting. + - Garbage collection. + - Continuation handling and scheduling. + - Error handling. + - Responding callbacks. + + It borrows some design of event loop in asyncio, + e.g., 'run_until_complete'. + + Args: + state: The initial state of the workflow. + """ + self._state = state + self._completion_queue = asyncio.Queue() + self._task_done_callbacks: Dict[TaskID, List[asyncio.Future]] = defaultdict( + list + ) + + def is_running(self) -> bool: + """The state is running, if there are tasks to be run or running tasks.""" + return bool(self._state.frontier_to_run or self._state.running_frontier) + + def get_state(self) -> WorkflowExecutionState: + return self._state + + @property + def output_task_id(self) -> TaskID: + return self._state.output_task_id + + async def run_until_complete( + self, job_id: str, context: "WorkflowTaskContext", wf_store: "WorkflowStorage" + ): + """Drive the state util it completes. + + Args: + job_id: The Ray JobID for logging properly. + context: The context of workflow execution. + wf_store: The store for the workflow. + + # TODO(suquark): move job_id inside context + """ + workflow_id = context.workflow_id + wf_store.update_workflow_status(WorkflowStatus.RUNNING) + logger.info(f"Workflow job [id={workflow_id}] started.") + + self._state.construct_scheduling_plan(self._state.output_task_id) + self._state.init_context(context) + + while self.is_running(): + # ------------ poll queued tasks ------------ + queued_tasks = self._poll_queued_tasks() + + # --------------- submit task --------------- + for task_id in queued_tasks: + # '_submit_ray_task' submit a Ray task based on the workflow task. + self._submit_ray_task(task_id, job_id=job_id) + # '_post_process_submit_task' updates the state related to task + # submission. + self._post_process_submit_task(task_id, wf_store) + + self._garbage_collect() + + # ------------ poll ready tasks ------------ + ready_futures = await self._poll_ready_tasks() + + # ----------- handle ready tasks ----------- + await asyncio.gather( + *[ + self._handle_ready_task( + fut, workflow_id=workflow_id, wf_store=wf_store + ) + for fut in ready_futures + ] + ) + + # prevent leaking ObjectRefs into the next iteration + del ready_futures + + wf_store.update_workflow_status(WorkflowStatus.SUCCESSFUL) + logger.info(f"Workflow '{workflow_id}' completes successfully.") + + # set errors for pending workflow outputs + for task_id, futures in self._task_done_callbacks.items(): + err = ValueError( + f"The workflow haven't yet produced output of task '{task_id}' " + f"after workflow execution completes." + ) + for fut in futures: + if not fut.done(): + fut.set_exception(err) + + def cancel(self) -> None: + """Cancel the running workflow.""" + for fut, workflow_ref in self._state.running_frontier.items(): + fut.cancel() + try: + ray.cancel(workflow_ref.ref, force=True) + except Exception: + pass + + def _poll_queued_tasks(self) -> List[TaskID]: + tasks = [] + while True: + task_id = self._state.pop_frontier_to_run() + if task_id is None: + break + tasks.append(task_id) + return tasks + + def _submit_ray_task(self, task_id: TaskID, job_id: str) -> None: + """Submit a workflow task as a Ray task.""" + state = self._state + baked_inputs = _BakedWorkflowInputs( + args=state.task_input_args[task_id], + workflow_refs=[ + state.get_input(d) for d in state.upstream_dependencies[task_id] + ], + ) + task = state.tasks[task_id] + executor = get_task_executor(task.options) + metadata_ref, output_ref = executor( + task.func_body, + state.task_context[task_id], + job_id, + task_id, + baked_inputs, + task.options, + ) + # The input workflow is not a reference to an executed workflow. + future = asyncio.wrap_future(metadata_ref.future()) + future.add_done_callback(self._completion_queue.put_nowait) + + state.insert_running_frontier(future, WorkflowRef(task_id, ref=output_ref)) + state.task_execution_metadata[task_id] = TaskExecutionMetadata( + submit_time=time.time() + ) + + def _post_process_submit_task( + self, task_id: TaskID, store: "WorkflowStorage" + ) -> None: + """Update dependencies and reference count etc. after task submission.""" + state = self._state + if task_id in state.continuation_root: + if state.tasks[task_id].options.checkpoint: + store.update_continuation_output_link( + state.continuation_root[task_id], task_id + ) + else: + # update reference counting if the task is not a continuation + for c in state.upstream_dependencies[task_id]: + state.reference_set[c].remove(task_id) + if not state.reference_set[c]: + del state.reference_set[c] + state.free_outputs.add(c) + + def _garbage_collect(self) -> None: + """Garbage collect the output refs of tasks. + + Currently, this is done after task submission, because when a task + starts, we no longer needs its inputs (i.e. outputs from other tasks). + + # TODO(suquark): We may need to improve garbage collection + # when taking more fault tolerant cases into consideration. + """ + state = self._state + while state.free_outputs: + # garbage collect all free outputs immediately + gc_task_id = state.free_outputs.pop() + assert state.get_input(gc_task_id) is not None + state.output_map.pop(gc_task_id, None) + + async def _poll_ready_tasks(self) -> List[asyncio.Future]: + cq = self._completion_queue + ready_futures = [] + rf = await cq.get() + ready_futures.append(rf) + # get all remaining futures in the queue + while not cq.empty(): + ready_futures.append(cq.get_nowait()) + return ready_futures + + def _iter_callstack(self, task_id: TaskID) -> Iterator[Tuple[TaskID, Task]]: + state = self._state + while task_id in state.task_context and task_id in state.tasks: + yield task_id, state.tasks[task_id] + task_id = state.task_context[task_id].creator_task_id + + def _retry_failed_task( + self, workflow_id: str, failed_task_id: TaskID, exc: Exception + ) -> bool: + state = self._state + is_application_error = isinstance(exc, RayTaskError) + options = state.tasks[failed_task_id].options + if not is_application_error or options.retry_exceptions: + if state.task_retries[failed_task_id] < options.max_retries: + state.task_retries[failed_task_id] += 1 + logger.info( + f"Retry [{workflow_id}@{failed_task_id}] " + f"({state.task_retries[failed_task_id]}/{options.max_retries})" + ) + state.construct_scheduling_plan(failed_task_id) + return True + return False + + async def _catch_failed_task( + self, workflow_id: str, failed_task_id: TaskID, exc: Exception + ) -> bool: + # lookup a creator task that catches the exception + is_application_error = isinstance(exc, RayTaskError) + exception_catcher = None + if is_application_error: + for t, task in self._iter_callstack(failed_task_id): + if task.options.catch_exceptions: + exception_catcher = t + break + if exception_catcher is not None: + logger.info( + f"Exception raised by '{workflow_id}@{failed_task_id}' is caught by " + f"'{workflow_id}@{exception_catcher}'" + ) + # assign output to exception catching task; + # compose output with caught exception + await self._post_process_ready_task( + exception_catcher, + metadata=WorkflowExecutionMetadata(), + output_ref=WorkflowRef(failed_task_id, ray.put((None, exc))), + ) + # TODO(suquark): cancel other running tasks? + return True + return False + + async def _handle_ready_task( + self, fut: asyncio.Future, workflow_id: str, wf_store: "WorkflowStorage" + ) -> None: + """Handle ready task, especially about its exception.""" + state = self._state + output_ref = state.pop_running_frontier(fut) + task_id = output_ref.task_id + try: + metadata: WorkflowExecutionMetadata = fut.result() + state.task_execution_metadata[task_id].finish_time = time.time() + logger.info( + f"Task status [{WorkflowStatus.SUCCESSFUL.value}]\t" + f"[{workflow_id}@{task_id}]" + ) + await self._post_process_ready_task(task_id, metadata, output_ref) + except asyncio.CancelledError: + # NOTE: We must update the workflow status before broadcasting + # the exception. Otherwise, the workflow status would still be + # 'RUNNING' if check the status immediately after cancellation. + wf_store.update_workflow_status(WorkflowStatus.CANCELED) + logger.warning(f"Workflow '{workflow_id}' is cancelled.") + # broadcasting cancellation to all outputs + err = WorkflowCancellationError(workflow_id) + self._broadcast_exception(err) + raise err from None + except Exception as e: + if isinstance(e, RayTaskError): + reason = "an exception raised by the task" + elif isinstance(e, RayError): + reason = "a system error" + else: + reason = "an unknown error" + logger.error( + f"Task status [{WorkflowStatus.FAILED.value}] due to {reason}.\t" + f"[{workflow_id}@{task_id}]" + ) + + is_application_error = isinstance(e, RayTaskError) + options = state.tasks[task_id].options + + # ---------------------- retry the task ---------------------- + if not is_application_error or options.retry_exceptions: + if state.task_retries[task_id] < options.max_retries: + state.task_retries[task_id] += 1 + logger.info( + f"Retry [{workflow_id}@{task_id}] " + f"({state.task_retries[task_id]}/{options.max_retries})" + ) + state.construct_scheduling_plan(task_id) + return + + # ----------- retry used up, handle the task error ----------- + exception_catcher = None + if is_application_error: + for t, task in self._iter_callstack(task_id): + if task.options.catch_exceptions: + exception_catcher = t + break + if exception_catcher is not None: + logger.info( + f"Exception raised by '{workflow_id}@{task_id}' is caught by " + f"'{workflow_id}@{exception_catcher}'" + ) + # assign output to exception catching task; + # compose output with caught exception + await self._post_process_ready_task( + exception_catcher, + metadata=WorkflowExecutionMetadata(), + output_ref=WorkflowRef(task_id, ray.put((None, e))), + ) + # TODO(suquark): cancel other running tasks? + return + + # ------------------- raise the task error ------------------- + # NOTE: We must update the workflow status before broadcasting + # the exception. Otherwise, the workflow status would still be + # 'RUNNING' if check the status immediately after the exception. + wf_store.update_workflow_status(WorkflowStatus.FAILED) + logger.error(f"Workflow '{workflow_id}' failed due to {e}") + err = WorkflowExecutionError(workflow_id) + err.__cause__ = e # chain exceptions + self._broadcast_exception(err) + raise err + + async def _post_process_ready_task( + self, + task_id: TaskID, + metadata: WorkflowExecutionMetadata, + output_ref: WorkflowRef, + ) -> None: + state = self._state + state.task_retries.pop(task_id, None) + if metadata.is_output_workflow: # The task returns a continuation + sub_workflow_state: WorkflowExecutionState = await output_ref.ref + # init the context just for "sub_workflow_state" + sub_workflow_state.init_context(state.task_context[task_id]) + state.merge_state(sub_workflow_state) + # build up runtime dependency + continuation_task_id = sub_workflow_state.output_task_id + state.append_continuation(task_id, continuation_task_id) + # Migrate callbacks - all continuation callbacks are moved + # under the root of continuation, so when the continuation + # completes, all callbacks in the continuation can be triggered. + if continuation_task_id in self._task_done_callbacks: + self._task_done_callbacks[ + state.continuation_root[continuation_task_id] + ].extend(self._task_done_callbacks.pop(continuation_task_id)) + state.construct_scheduling_plan(sub_workflow_state.output_task_id) + else: # The task returns a normal object + target_task_id = state.continuation_root.get(task_id, task_id) + state.output_map[target_task_id] = output_ref + if state.tasks[task_id].options.checkpoint: + state.checkpoint_map[target_task_id] = WorkflowRef(task_id) + state.done_tasks.add(target_task_id) + # TODO(suquark): cleanup callbacks when a result is set? + if target_task_id in self._task_done_callbacks: + for callback in self._task_done_callbacks[target_task_id]: + callback.set_result(output_ref) + for m in state.reference_set[target_task_id]: + # we ensure that each reference corresponds to a pending input + state.pending_input_set[m].remove(target_task_id) + if not state.pending_input_set[m]: + state.append_frontier_to_run(m) + + def _broadcast_exception(self, err: Exception): + for _, futures in self._task_done_callbacks.items(): + for fut in futures: + if not fut.done(): + fut.set_exception(err) + + def get_task_output_async(self, task_id: Optional[TaskID]) -> asyncio.Future: + """Get the output of a task asynchronously. + + Args: + task_id: The ID of task the callback associates with. + + Returns: + A callback in the form of a future that associates with the task. + """ + state = self._state + if self._task_done_callbacks[task_id]: + return self._task_done_callbacks[task_id][0] + + fut = asyncio.Future() + task_id = state.continuation_root.get(task_id, task_id) + output = state.get_input(task_id) + if output is not None: + fut.set_result(output) + elif task_id in state.done_tasks: + fut.set_exception( + ValueError( + f"Task '{task_id}' is done but neither in memory or in storage " + "could we find its output. It could because its in memory " + "output has been garbage collected and the task did not" + "checkpoint its output." + ) + ) + else: + self._task_done_callbacks[task_id].append(fut) + return fut diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state.py new file mode 100644 index 0000000000000000000000000000000000000000..19a7cfad3bfef91ff37e8c38caea602fe323a56e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state.py @@ -0,0 +1,251 @@ +import asyncio + +from collections import deque, defaultdict +import dataclasses +from dataclasses import field +import logging +from typing import List, Dict, Optional, Set, Deque, Callable + +import ray +from ray.workflow.common import ( + TaskID, + WorkflowRef, + WorkflowTaskRuntimeOptions, +) +from ray.workflow.workflow_context import WorkflowTaskContext + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TaskExecutionMetadata: + submit_time: Optional[float] = None + finish_time: Optional[float] = None + output_size: Optional[int] = None + + @property + def duration(self): + return self.finish_time - self.submit_time + + +@dataclasses.dataclass +class Task: + """Data class for a workflow task.""" + + task_id: str + options: WorkflowTaskRuntimeOptions + user_metadata: Dict + func_body: Optional[Callable] + + def to_dict(self) -> Dict: + return { + "task_id": self.task_id, + "task_options": self.options.to_dict(), + "user_metadata": self.user_metadata, + } + + +@dataclasses.dataclass +class WorkflowExecutionState: + """The execution state of a workflow. This dataclass helps with observation + and debugging.""" + + # -------------------------------- dependencies -------------------------------- # + + # The mapping from all tasks to immediately upstream tasks. + upstream_dependencies: Dict[TaskID, List[TaskID]] = field(default_factory=dict) + # A reverse mapping of the above. The dependency mapping from tasks to + # immediately downstream tasks. + downstream_dependencies: Dict[TaskID, List[TaskID]] = field( + default_factory=lambda: defaultdict(list) + ) + # The mapping from a task to its immediate continuation. + next_continuation: Dict[TaskID, TaskID] = field(default_factory=dict) + # The reversed mapping from continuation to its immediate task. + prev_continuation: Dict[TaskID, TaskID] = field(default_factory=dict) + # The mapping from a task to its latest continuation. The latest continuation is + # a task that returns a value instead of a continuation. + latest_continuation: Dict[TaskID, TaskID] = field(default_factory=dict) + # The mapping from a task to the root of the continuation, i.e. the initial task + # that generates the lineage of continuation. + continuation_root: Dict[TaskID, TaskID] = field(default_factory=dict) + + # ------------------------------- task properties ------------------------------- # + + # Workflow tasks. + tasks: Dict[TaskID, Task] = field(default_factory=dict) + + # The arguments for the task. + task_input_args: Dict[TaskID, ray.ObjectRef] = field(default_factory=dict) + # The context of the task. + task_context: Dict[TaskID, WorkflowTaskContext] = field(default_factory=dict) + # The execution metadata of a task. + task_execution_metadata: Dict[TaskID, TaskExecutionMetadata] = field( + default_factory=dict + ) + task_retries: Dict[TaskID, int] = field(default_factory=lambda: defaultdict(int)) + + # ------------------------------ object management ------------------------------ # + + # Set of references to upstream outputs. + reference_set: Dict[TaskID, Set[TaskID]] = field( + default_factory=lambda: defaultdict(set) + ) + # The set of pending inputs of a task. We are able to run the task + # when it becomes empty. + pending_input_set: Dict[TaskID, Set[TaskID]] = field(default_factory=dict) + # The map from a task to its in-memory outputs. Normally it is the ObjectRef + # returned by the underlying Ray task. Things are different for continuation: + # because the true output of a continuation is created by the last task in + # the continuation lineage, so all other tasks in the continuation points + # to the output of the last task instead of the output of themselves. + output_map: Dict[TaskID, WorkflowRef] = field(default_factory=dict) + # The map from a task to its in-storage checkpoints. Normally it is the checkpoint + # created by the underlying Ray task. For continuations, the semantics is similar + # to 'output_map'. + checkpoint_map: Dict[TaskID, WorkflowRef] = field(default_factory=dict) + # Outputs that are free (no reference to this output in the workflow) and + # can be garbage collected. + free_outputs: Set[TaskID] = field(default_factory=set) + + # -------------------------------- scheduling -------------------------------- # + + # The frontier that is ready to run. + frontier_to_run: Deque[TaskID] = field(default_factory=deque) + # The set of frontier tasks to run. This field helps deduplicate tasks or + # look up task quickly. It contains the same elements as 'frontier_to_run', + # they act like a 'DequeSet' when combined. + frontier_to_run_set: Set[TaskID] = field(default_factory=set) + # The frontier that is running. + running_frontier: Dict[asyncio.Future, WorkflowRef] = field(default_factory=dict) + # The set of running frontier. This field helps deduplicate tasks or + # look up task quickly. It contains the same elements as 'running_frontier', + # they act like a dict but its values are in a set when combined. + running_frontier_set: Set[TaskID] = field(default_factory=set) + # The set of completed tasks. They are tasks are actually executed with the state, + # so inspected during recovery does not count. + # + # Normally, a task will be added in 'done_tasks' immediately after its completion. + # However, a task that is the root of continuations (i.e. it returns a continuation + # but itself is not a continuation) is only added to 'done_tasks' when all its + # continuation completes. We do not add its continuations in 'done_tasks' because + # we indicate their completion from the continuation structure - if a continuation + # is appended to a previous continuation, then the previous continuation must + # already complete; if the task that is the root of all continuation completes, + # then all its continuations would complete. + done_tasks: Set[TaskID] = field(default_factory=set) + + # -------------------------------- external -------------------------------- # + + # The ID of the output task. + output_task_id: Optional[TaskID] = None + + def get_input(self, task_id: TaskID) -> Optional[WorkflowRef]: + """Get the input. It checks memory first and storage later. It returns None if + the input does not exist. + """ + return self.output_map.get(task_id, self.checkpoint_map.get(task_id)) + + def pop_frontier_to_run(self) -> Optional[TaskID]: + """Pop one task to run from the frontier queue.""" + try: + t = self.frontier_to_run.popleft() + self.frontier_to_run_set.remove(t) + return t + except IndexError: + return None + + def append_frontier_to_run(self, task_id: TaskID) -> None: + """Insert one task to the frontier queue.""" + if ( + task_id not in self.frontier_to_run_set + and task_id not in self.running_frontier_set + ): + self.frontier_to_run.append(task_id) + self.frontier_to_run_set.add(task_id) + + def add_dependencies(self, task_id: TaskID, in_dependencies: List[TaskID]) -> None: + """Add dependencies between a task and it input dependencies.""" + self.upstream_dependencies[task_id] = in_dependencies + for in_task_id in in_dependencies: + self.downstream_dependencies[in_task_id].append(task_id) + + def pop_running_frontier(self, fut: asyncio.Future) -> WorkflowRef: + """Pop a task from the running frontier.""" + ref = self.running_frontier.pop(fut) + self.running_frontier_set.remove(ref.task_id) + return ref + + def insert_running_frontier(self, fut: asyncio.Future, ref: WorkflowRef) -> None: + """Insert a task to the running frontier.""" + self.running_frontier[fut] = ref + self.running_frontier_set.add(ref.task_id) + + def append_continuation( + self, task_id: TaskID, continuation_task_id: TaskID + ) -> None: + """Append continuation to a task.""" + continuation_root = self.continuation_root.get(task_id, task_id) + self.prev_continuation[continuation_task_id] = task_id + self.next_continuation[task_id] = continuation_task_id + self.continuation_root[continuation_task_id] = continuation_root + self.latest_continuation[continuation_root] = continuation_task_id + + def merge_state(self, state: "WorkflowExecutionState") -> None: + """Merge with another execution state.""" + self.upstream_dependencies.update(state.upstream_dependencies) + self.downstream_dependencies.update(state.downstream_dependencies) + self.task_input_args.update(state.task_input_args) + self.tasks.update(state.tasks) + self.task_context.update(state.task_context) + self.output_map.update(state.output_map) + self.checkpoint_map.update(state.checkpoint_map) + + def construct_scheduling_plan(self, task_id: TaskID) -> None: + """Analyze upstream dependencies of a task to construct the scheduling plan.""" + if self.get_input(task_id) is not None: + # This case corresponds to the scenario that the task is a + # checkpoint or ref. + return + + visited_nodes = set() + dag_visit_queue = deque([task_id]) + while dag_visit_queue: + tid = dag_visit_queue.popleft() + if tid in visited_nodes: + continue + visited_nodes.add(tid) + self.pending_input_set[tid] = set() + for in_task_id in self.upstream_dependencies[tid]: + self.reference_set[in_task_id].add(tid) + # All upstream deps should already complete here, + # so we just check their checkpoints. + task_input = self.get_input(in_task_id) + if task_input is None: + self.pending_input_set[tid].add(in_task_id) + dag_visit_queue.append(in_task_id) + if tid in self.latest_continuation: + if self.pending_input_set[tid]: + raise ValueError( + "A task that already returns a continuation cannot be pending." + ) + # construct continuations, as they are not directly connected to + # the DAG dependency + self.construct_scheduling_plan(self.latest_continuation[tid]) + elif not self.pending_input_set[tid]: + self.append_frontier_to_run(tid) + + def init_context(self, context: WorkflowTaskContext) -> None: + """Initialize the context of all tasks.""" + for task_id, task in self.tasks.items(): + options = task.options + self.task_context.setdefault( + task_id, + dataclasses.replace( + context, + task_id=task_id, + creator_task_id=context.task_id, + checkpoint=options.checkpoint, + catch_exceptions=options.catch_exceptions, + ), + ) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_dag.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_dag.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f39ad6dc9d8ae64b97a4b739258e146da61e0f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_dag.py @@ -0,0 +1,205 @@ +from typing import Any, List, Optional +import re +import unicodedata + +import ray +from ray.workflow.common import WORKFLOW_OPTIONS +from ray.dag import DAGNode, FunctionNode, InputNode +from ray.dag.input_node import InputAttributeNode, DAGInputData +from ray import cloudpickle +from ray._private import signature +from ray._private.client_mode_hook import client_mode_should_convert +from ray.workflow import serialization_context +from ray.workflow.common import ( + TaskType, + WorkflowTaskRuntimeOptions, + WorkflowRef, + validate_user_metadata, +) +from ray.workflow import workflow_context +from ray.workflow.workflow_state import WorkflowExecutionState, Task + + +def get_module(f): + return f.__module__ if hasattr(f, "__module__") else "__anonymous_module__" + + +def get_qualname(f): + return f.__qualname__ if hasattr(f, "__qualname__") else "__anonymous_func__" + + +def slugify(value: str, allow_unicode=False) -> str: + """Adopted from + https://github.com/django/django/blob/master/django/utils/text.py + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, dots or hyphens. Also strip leading and + trailing whitespace. + """ + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^\w.\-]", "", value).strip() + return re.sub(r"[-\s]+", "-", value) + + +class _DelayedDeserialization: + def __init__(self, serialized: bytes): + self._serialized = serialized + + def __reduce__(self): + return cloudpickle.loads, (self._serialized,) + + +class _SerializationContextPreservingWrapper: + """This class is a workaround for preserving serialization context + in client mode.""" + + def __init__(self, obj: Any): + self._serialized = cloudpickle.dumps(obj) + + def __reduce__(self): + # This delays the deserialization to the actual worker + # instead of the Ray client server. + return _DelayedDeserialization, (self._serialized,) + + +def workflow_state_from_dag( + dag_node: DAGNode, input_context: Optional[DAGInputData], workflow_id: str +): + """ + Transform a Ray DAG to a workflow. Map FunctionNode to workflow task with + the workflow decorator. + + Args: + dag_node: The DAG to be converted to a workflow. + input_context: The input data that wraps varibles for the input node of the DAG. + workflow_id: The ID of the workflow. + """ + if not isinstance(dag_node, FunctionNode): + raise TypeError("Currently workflow does not support classes as DAG inputs.") + + state = WorkflowExecutionState() + + # TODO(suquark): remove this cyclic importing later by changing the way of + # task ID assignment. + from ray.workflow.workflow_access import get_management_actor + + mgr = get_management_actor() + context = workflow_context.get_workflow_task_context() + + def _node_visitor(node: Any) -> Any: + if isinstance(node, FunctionNode): + bound_options = node._bound_options.copy() + num_returns = bound_options.get("num_returns", 1) + if num_returns is None: # ray could use `None` as default value + num_returns = 1 + if num_returns > 1: + raise ValueError("Workflow task can only have one return.") + + workflow_options = bound_options.get("_metadata", {}).get( + WORKFLOW_OPTIONS, {} + ) + + # If checkpoint option is not specified, inherit checkpoint + # options from context (i.e. checkpoint options of the outer + # task). If it is still not specified, it's True by default. + checkpoint = workflow_options.get("checkpoint", None) + if checkpoint is None: + checkpoint = context.checkpoint if context is not None else True + # When it returns a nested workflow, catch_exception + # should be passed recursively. + catch_exceptions = workflow_options.get("catch_exceptions", None) + if catch_exceptions is None: + if node.get_stable_uuid() == dag_node.get_stable_uuid(): + # 'catch_exception' context should be passed down to + # its direct continuation task. + # In this case, the direct continuation is the output node. + catch_exceptions = ( + context.catch_exceptions if context is not None else False + ) + else: + catch_exceptions = False + + # We do not need to check the validness of bound options, because + # Ray option has already checked them for us. + max_retries = bound_options.get("max_retries", 3) + retry_exceptions = bound_options.get("retry_exceptions", False) + + task_options = WorkflowTaskRuntimeOptions( + task_type=TaskType.FUNCTION, + catch_exceptions=catch_exceptions, + retry_exceptions=retry_exceptions, + max_retries=max_retries, + checkpoint=checkpoint, + ray_options=bound_options, + ) + + workflow_refs: List[WorkflowRef] = [] + with serialization_context.workflow_args_serialization_context( + workflow_refs + ): + _func_signature = signature.extract_signature(node._body) + flattened_args = signature.flatten_args( + _func_signature, node._bound_args, node._bound_kwargs + ) + # NOTE: When calling 'ray.put', we trigger python object + # serialization. Under our serialization context, + # Workflows are separated from the arguments, + # leaving a placeholder object with all other python objects. + # Then we put the placeholder object to object store, + # so it won't be mutated later. This guarantees correct + # semantics. See "tests/test_variable_mutable.py" as + # an example. + if client_mode_should_convert(): + # Handle client mode. The Ray client would serialize and + # then deserialize objects in the Ray client server. When + # the object is being deserialized, the serialization context + # will be missing, resulting in failures. Here we protect the + # object from deserialization in client server, and we make sure + # the 'real' deserialization happens under the serialization + # context later. + flattened_args = _SerializationContextPreservingWrapper( + flattened_args + ) + # Set the owner of the objects to the actor so that even the driver + # exits, these objects are still available. + input_placeholder: ray.ObjectRef = ray.put(flattened_args, _owner=mgr) + + orig_task_id = workflow_options.get("task_id", None) + if orig_task_id is None: + orig_task_id = ( + f"{get_module(node._body)}.{slugify(get_qualname(node._body))}" + ) + + task_id = ray.get(mgr.gen_task_id.remote(workflow_id, orig_task_id)) + state.add_dependencies(task_id, [s.task_id for s in workflow_refs]) + state.task_input_args[task_id] = input_placeholder + + user_metadata = workflow_options.get("metadata", {}) + + validate_user_metadata(user_metadata) + state.tasks[task_id] = Task( + task_id=task_id, + options=task_options, + user_metadata=user_metadata, + func_body=node._body, + ) + return WorkflowRef(task_id) + + if isinstance(node, InputAttributeNode): + return node._execute_impl() # get data from input node + if isinstance(node, InputNode): + return input_context # replace input node with input data + if not isinstance(node, DAGNode): + return node # return normal objects + raise TypeError(f"Unsupported DAG node: {node}") + + output_workflow_ref = dag_node.apply_recursive(_node_visitor) + state.output_task_id = output_workflow_ref.task_id + return state diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_storage.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..a13e31283202f238cf9afc910762cfc819f89a1b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_storage.py @@ -0,0 +1,71 @@ +from typing import Optional +from collections import deque + +from ray.workflow import serialization +from ray.workflow.common import TaskID, WorkflowRef +from ray.workflow.exceptions import WorkflowTaskNotRecoverableError +from ray.workflow import workflow_storage +from ray.workflow.workflow_state import WorkflowExecutionState, Task + + +def workflow_state_from_storage( + workflow_id: str, task_id: Optional[TaskID] +) -> WorkflowExecutionState: + """Try to construct a workflow (task) that recovers the workflow task. + If the workflow task already has an output checkpointing file, we return + the workflow task id instead. + + Args: + workflow_id: The ID of the workflow. + task_id: The ID of the output task. If None, it will be the entrypoint of + the workflow. + + Returns: + A workflow that recovers the task, or the output of the task + if it has been checkpointed. + """ + reader = workflow_storage.WorkflowStorage(workflow_id) + if task_id is None: + task_id = reader.get_entrypoint_task_id() + + # Construct the workflow execution state. + state = WorkflowExecutionState(output_task_id=task_id) + state.output_task_id = task_id + + visited_tasks = set() + dag_visit_queue = deque([task_id]) + with serialization.objectref_cache(): + while dag_visit_queue: + task_id: TaskID = dag_visit_queue.popleft() + if task_id in visited_tasks: + continue + visited_tasks.add(task_id) + r = reader.inspect_task(task_id) + if not r.is_recoverable(): + raise WorkflowTaskNotRecoverableError(task_id) + if r.output_object_valid: + target = state.continuation_root.get(task_id, task_id) + state.checkpoint_map[target] = WorkflowRef(task_id) + continue + if isinstance(r.output_task_id, str): + # no input dependencies here because the task has already + # returned a continuation + state.upstream_dependencies[task_id] = [] + state.append_continuation(task_id, r.output_task_id) + dag_visit_queue.append(r.output_task_id) + continue + # transfer task info to state + state.add_dependencies(task_id, r.workflow_refs) + state.task_input_args[task_id] = reader.load_task_args(task_id) + # TODO(suquark): although not necessary, but for completeness, + # we may also load name and metadata. + state.tasks[task_id] = Task( + task_id="", + options=r.task_options, + user_metadata={}, + func_body=reader.load_task_func_body(task_id), + ) + + dag_visit_queue.extend(r.workflow_refs) + + return state diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/workflow_storage.py b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..ff73d17c47e2696134e3c0842a570d5f013eb8bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/workflow_storage.py @@ -0,0 +1,880 @@ +""" +This module is higher-level abstraction of storage directly used by +workflows. +""" + +import json +import logging +import os +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +import ray +from ray import cloudpickle +from ray._private import storage +from ray.types import ObjectRef +from ray.workflow.common import ( + TaskID, + WorkflowStatus, + WorkflowTaskRuntimeOptions, +) +from ray.workflow.exceptions import WorkflowNotFoundError +from ray.workflow import workflow_context +from ray.workflow import serialization +from ray.workflow import serialization_context +from ray.workflow.workflow_state import WorkflowExecutionState +from ray.workflow.storage import DataLoadError, DataSaveError, KeyNotFoundError + +logger = logging.getLogger(__name__) + +ArgsType = Tuple[List[Any], Dict[str, Any]] # args and kwargs + +# constants used for keys +WORKFLOW_ROOT = "workflows" # The workflow root directory under global Ray storage. +OBJECTS_DIR = "objects" +STEPS_DIR = "tasks" +STEP_INPUTS_METADATA = "inputs.json" +STEP_USER_METADATA = "user_task_metadata.json" +STEP_PRERUN_METADATA = "pre_task_metadata.json" +STEP_POSTRUN_METADATA = "post_task_metadata.json" +STEP_OUTPUTS_METADATA = "outputs.json" +STEP_ARGS = "args.pkl" +STEP_OUTPUT = "output.pkl" +STEP_EXCEPTION = "exception.pkl" +STEP_FUNC_BODY = "func_body.pkl" +CLASS_BODY = "class_body.pkl" +WORKFLOW_META = "workflow_meta.json" +WORKFLOW_USER_METADATA = "user_run_metadata.json" +WORKFLOW_PRERUN_METADATA = "pre_run_metadata.json" +WORKFLOW_POSTRUN_METADATA = "post_run_metadata.json" +WORKFLOW_PROGRESS = "progress.json" +WORKFLOW_STATUS_DIR = "__status__" +WORKFLOW_STATUS_DIRTY_DIR = "dirty" +# Without this counter, we're going to scan all tasks to get the number of +# tasks with a given name. This can be very expensive if there are too +# many duplicates. +DUPLICATE_NAME_COUNTER = "duplicate_name_counter" + + +@dataclass +class TaskInspectResult: + # The task output checkpoint exists and valid. If this field + # is set, we do not set all other fields below. + output_object_valid: bool = False + # The ID of the task that could contain the output checkpoint of this + # task. If this field is set, we do not set all other fields below. + output_task_id: Optional[TaskID] = None + # The task input arguments checkpoint exists and valid. + args_valid: bool = False + # The task function body checkpoint exists and valid. + func_body_valid: bool = False + # The dynamically referenced workflows in the input of the workflow. + workflow_refs: Optional[List[str]] = None + # The options of the workflow task. + task_options: Optional[WorkflowTaskRuntimeOptions] = None + # task throw exception + task_raised_exception: bool = False + + def is_recoverable(self) -> bool: + return ( + self.output_object_valid + or self.output_task_id + or ( + self.args_valid + and self.workflow_refs is not None + and self.func_body_valid + ) + ) + + +class WorkflowIndexingStorage: + """Access and maintenance the indexing of workflow status. + + It runs a protocol that guarantees we can recover from any interrupted + status updating. This protocol is **not thread-safe** for updating the + status of the same workflow, currently it is executed by workflow management + actor with a single thread. + + Here is how the protocol works: + + Update the status of a workflow + 1. Load workflow status from workflow data. If it is the same as the new status, + return. + 2. Check if the workflow status updating is dirty. If it is, fix the + workflow status; otherwise, mark the workflow status updating dirty. + 3. Update status in the workflow metadata. + 4. Insert the workflow ID key in the status indexing directory of the new status. + 5. Delete the workflow ID key in the status indexing directory of + the previous status. + 6. Remove the workflow status updating dirty mark. + + Load a status of a workflow + 1. Read the status of the workflow from the workflow metadata. + 2. Return the status. + + List the status of all workflows + 1. Get status of all workflows by listing workflow ID keys in each workflow + status indexing directory. + 2. List all workflows with dirty updating status. Get their status from + workflow data. Override the status of the corresponding workflow. + 3. Return all the status. + """ + + def __init__(self): + self._storage = storage.get_client(WORKFLOW_ROOT) + + def update_workflow_status(self, workflow_id: str, status: WorkflowStatus): + """Update the status of the workflow. + Try fixing indexing if workflow status updating was marked dirty. + + This method is NOT thread-safe. It is handled by the workflow management actor. + """ + prev_status = self.load_workflow_status(workflow_id) + if prev_status != status: + # Try fixing indexing if workflow status updating was marked dirty. + if ( + self._storage.get_info(self._key_workflow_status_dirty(workflow_id)) + is not None + ): + # This means the previous status update failed. Fix it. + self._storage.put( + self._key_workflow_with_status(workflow_id, prev_status), b"" + ) + for s in WorkflowStatus: + if s != prev_status: + self._storage.delete( + self._key_workflow_with_status(workflow_id, s) + ) + else: + self._storage.put(self._key_workflow_status_dirty(workflow_id), b"") + # Transactional update of workflow status + self._storage.put( + self._key_workflow_metadata(workflow_id), + json.dumps({"status": status.value}).encode(), + ) + self._storage.put(self._key_workflow_with_status(workflow_id, status), b"") + if prev_status is not WorkflowStatus.NONE: + self._storage.delete( + self._key_workflow_with_status(workflow_id, prev_status) + ) + self._storage.delete(self._key_workflow_status_dirty(workflow_id)) + + def load_workflow_status(self, workflow_id: str): + """Load the committed workflow status.""" + raw_data = self._storage.get(self._key_workflow_metadata(workflow_id)) + if raw_data is not None: + metadata = json.loads(raw_data) + return WorkflowStatus(metadata["status"]) + return WorkflowStatus.NONE + + def list_workflow( + self, status_filter: Optional[Set[WorkflowStatus]] = None + ) -> List[Tuple[str, WorkflowStatus]]: + """List workflow status. Override status of the workflows whose status updating + were marked dirty with the workflow status from workflow metadata. + + Args: + status_filter: If given, only returns workflow with that status. This can + be a single status or set of statuses. + """ + if status_filter is None: + status_filter = set(WorkflowStatus) + status_filter.discard(WorkflowStatus.NONE) + elif not isinstance(status_filter, set): + raise TypeError("'status_filter' should either be 'None' or a set.") + elif WorkflowStatus.NONE in status_filter: + raise ValueError("'WorkflowStatus.NONE' is not a valid filter value.") + + results = {} + for status in status_filter: + try: + # empty string points the key to the dir + for p in self._storage.list(self._key_workflow_with_status("", status)): + workflow_id = p.base_name + results[workflow_id] = status + except FileNotFoundError: + pass + # Get "correct" status of workflows + try: + for p in self._storage.list(self._key_workflow_status_dirty("")): + workflow_id = p.base_name + # overwrite status + results.pop(workflow_id, None) + status = self.load_workflow_status(workflow_id) + if status in status_filter: + results[workflow_id] = status + except FileNotFoundError: + pass + return list(results.items()) + + def delete_workflow_status(self, workflow_id: str): + """Delete status indexing for the workflow.""" + for status in WorkflowStatus: + self._storage.delete(self._key_workflow_with_status(workflow_id, status)) + self._storage.delete(self._key_workflow_status_dirty(workflow_id)) + + def _key_workflow_with_status(self, workflow_id: str, status: WorkflowStatus): + """A key whose existence marks the status of the workflow.""" + return os.path.join(WORKFLOW_STATUS_DIR, status.value, workflow_id) + + def _key_workflow_status_dirty(self, workflow_id: str): + """A key marks the workflow status dirty, because it is under change.""" + return os.path.join(WORKFLOW_STATUS_DIR, WORKFLOW_STATUS_DIRTY_DIR, workflow_id) + + def _key_workflow_metadata(self, workflow_id: str): + return os.path.join(workflow_id, WORKFLOW_META) + + +class WorkflowStorage: + """Access workflow in storage. This is a higher-level abstraction, + which does not care about the underlining storage implementation.""" + + def __init__(self, workflow_id: str): + self._storage = storage.get_client(os.path.join(WORKFLOW_ROOT, workflow_id)) + self._status_storage = WorkflowIndexingStorage() + self._workflow_id = workflow_id + + def load_task_output(self, task_id: TaskID) -> Any: + """Load the output of the workflow task from checkpoint. + + Args: + task_id: ID of the workflow task. + + Returns: + Output of the workflow task. + """ + + tasks = [ + self._get(self._key_task_output(task_id), no_exception=True), + self._get(self._key_task_exception(task_id), no_exception=True), + ] + (output_ret, output_err), (exception_ret, exception_err) = tasks + # When we have output, always return output first + if output_err is None: + return output_ret + + # When we don't have output, check exception + if exception_err is None: + raise exception_ret + + # In this case, there is no such task + raise output_err + + def save_workflow_execution_state( + self, creator_task_id: TaskID, state: WorkflowExecutionState + ) -> None: + """Save a workflow execution state. + Typically, the state is translated from a Ray DAG. + + Args: + creator_task_id: The ID of the task that creates the state. + state: The state converted from the DAG. + """ + assert creator_task_id != state.output_task_id + + for task_id, task in state.tasks.items(): + # TODO (Alex): Handle the json case better? + metadata = { + **task.to_dict(), + "workflow_refs": state.upstream_dependencies[task_id], + } + self._put(self._key_task_input_metadata(task_id), metadata, True) + # TODO(suquark): The task user metadata duplicates. + self._put( + self._key_task_user_metadata(task_id), + task.user_metadata, + True, + ) + workflow_id = self._workflow_id + serialization.dump_to_storage( + self._key_task_function_body(task_id), + task.func_body, + workflow_id, + self, + ) + with serialization_context.workflow_args_keeping_context(): + # TODO(suquark): in the future we should write to storage directly + # with plasma store object in memory. + args_obj = ray.get(state.task_input_args[task_id]) + serialization.dump_to_storage( + self._key_task_args(task_id), + args_obj, + workflow_id, + self, + ) + + # Finally, point to the output ID of the DAG. The DAG is a continuation + # of the creator task. + self._put( + self._key_task_output_metadata(creator_task_id), + {"output_task_id": state.output_task_id}, + True, + ) + + def save_task_output( + self, + task_id: TaskID, + ret: Any, + *, + exception: Optional[Exception], + ) -> None: + """When a workflow task returns, + 1. If the returned object is a workflow, this means we are a nested + workflow. We save the output metadata that points to the workflow. + 2. Otherwise, checkpoint the output. + + Args: + task_id: The ID of the workflow task. If it is an empty string, + it means we are in the workflow job driver process. + ret: The returned object from a workflow task. + exception: This task should throw exception. + """ + if exception is None: + # This workflow task returns a object. + ret = ray.get(ret) if isinstance(ret, ray.ObjectRef) else ret + serialization.dump_to_storage( + self._key_task_output(task_id), + ret, + self._workflow_id, + storage=self, + ) + # tasks.append(self._put(self._key_task_output(task_id), ret)) + # TODO (yic): Delete exception file + else: + assert ret is None + serialization.dump_to_storage( + self._key_task_exception(task_id), + exception, + self._workflow_id, + storage=self, + ) + # tasks.append( + # self._put(self._key_task_exception(task_id), exception)) + + # Finish checkpointing. + # TODO(suquark): batching all tasks above. + + def load_task_func_body(self, task_id: TaskID) -> Callable: + """Load the function body of the workflow task. + + Args: + task_id: ID of the workflow task. + + Returns: + A callable function. + """ + return self._get(self._key_task_function_body(task_id)) + + def gen_task_id(self, task_name: str) -> int: + def _gen_task_id(): + key = self._key_num_tasks_with_name(task_name) + try: + val = self._get(key, True) + self._put(key, val + 1, True) + return val + 1 + except KeyNotFoundError: + self._put(key, 0, True) + return 0 + + return _gen_task_id() + + def load_task_args(self, task_id: TaskID) -> ray.ObjectRef: + """Load the input arguments of the workflow task. This must be + done under a serialization context, otherwise the arguments would + not be reconstructed successfully. + + Args: + task_id: ID of the workflow task. + + Returns: + An object ref of the input args. + """ + with serialization_context.workflow_args_keeping_context(): + x = self._get(self._key_task_args(task_id)) + return ray.put(x) + + def save_object_ref(self, obj_ref: ray.ObjectRef) -> None: + """Save the object ref. + + Args: + obj_ref: The object reference + + Returns: + None + """ + return self._save_object_ref(obj_ref) + + def load_object_ref(self, object_id: str) -> ray.ObjectRef: + """Load the input object ref. + + Args: + object_id: The hex ObjectID. + + Returns: + The object ref. + """ + + def _load_obj_ref() -> ray.ObjectRef: + data = self._get(self._key_obj_id(object_id)) + ref = _put_obj_ref.remote((data,)) + return ref + + return _load_obj_ref() + + def update_continuation_output_link( + self, continuation_root_id: TaskID, latest_continuation_task_id: TaskID + ) -> None: + """Update the link of the continuation output. The link points + to the ID of the latest finished continuation task. + + Args: + continuation_root_id: The ID of the task that returns all later + continuations. + latest_continuation_task_id: The ID of the latest finished + continuation task. + """ + try: + metadata = self._get( + self._key_task_output_metadata(continuation_root_id), True + ) + except KeyNotFoundError: + # This is because we skipped checkpointing of the + # task [id=continuation_root_id]. Return a dummy + # metadata instead. + metadata = {} + if latest_continuation_task_id != metadata.get( + "output_task_id" + ) and latest_continuation_task_id != metadata.get("dynamic_output_task_id"): + metadata["dynamic_output_task_id"] = latest_continuation_task_id + self._put( + self._key_task_output_metadata(continuation_root_id), metadata, True + ) + + def _locate_output_task_id(self, task_id: TaskID) -> str: + metadata = self._get(self._key_task_output_metadata(task_id), True) + return metadata.get("dynamic_output_task_id") or metadata["output_task_id"] + + def get_entrypoint_task_id(self) -> TaskID: + """Load the entrypoint task ID of the workflow. + + Returns: + The ID of the entrypoint task. + """ + # empty TaskID represents the workflow driver + try: + return self._locate_output_task_id("") + except Exception as e: + raise ValueError( + "Fail to get entrypoint task ID from workflow" + f"[id={self._workflow_id}]" + ) from e + + def _locate_output_in_storage(self, task_id: TaskID) -> Optional[TaskID]: + result = self.inspect_task(task_id) + while isinstance(result.output_task_id, str): + task_id = result.output_task_id + result = self.inspect_task(result.output_task_id) + if result.output_object_valid: + return task_id + return None + + def inspect_output(self, task_id: TaskID) -> Optional[TaskID]: + """Get the actual checkpointed output for a task, represented by the ID of + the task that actually keeps the checkpoint. + + Raises: + ValueError: The workflow does not exist or the workflow state is not valid. + + Args: + task_id: The ID of the task we are looking for its checkpoint. + + Returns: + The ID of the task that actually keeps the checkpoint. + 'None' if the checkpoint does not exist. + """ + status = self.load_workflow_status() + if status == WorkflowStatus.NONE: + raise ValueError(f"No such workflow '{self._workflow_id}'") + if status == WorkflowStatus.CANCELED: + raise ValueError(f"Workflow {self._workflow_id} is canceled") + # For resumable workflow, the workflow result is not ready. + # It has to be resumed first. + if status == WorkflowStatus.RESUMABLE: + raise ValueError( + f"Workflow {self._workflow_id} is in resumable status, please resume it" + ) + if task_id is None: + task_id = self.get_entrypoint_task_id() + return self._locate_output_in_storage(task_id) + + def inspect_task(self, task_id: TaskID) -> TaskInspectResult: + """ + Get the status of a workflow task. The status indicates whether + the workflow task can be recovered etc. + + Args: + task_id: The ID of a workflow task + + Returns: + The status of the task. + """ + return self._inspect_task(task_id) + + def _inspect_task(self, task_id: TaskID) -> TaskInspectResult: + items = self._scan(self._key_task_prefix(task_id), ignore_errors=True) + keys = set(items) + # does this task contains output checkpoint file? + if STEP_OUTPUT in keys: + return TaskInspectResult(output_object_valid=True) + # do we know where the output comes from? + if STEP_OUTPUTS_METADATA in keys: + output_task_id = self._locate_output_task_id(task_id) + return TaskInspectResult(output_task_id=output_task_id) + + # read inputs metadata + try: + metadata = self._get(self._key_task_input_metadata(task_id), True) + return TaskInspectResult( + args_valid=(STEP_ARGS in keys), + func_body_valid=(STEP_FUNC_BODY in keys), + workflow_refs=metadata["workflow_refs"], + task_options=WorkflowTaskRuntimeOptions.from_dict( + metadata["task_options"] + ), + task_raised_exception=(STEP_EXCEPTION in keys), + ) + except Exception: + return TaskInspectResult( + args_valid=(STEP_ARGS in keys), + func_body_valid=(STEP_FUNC_BODY in keys), + task_raised_exception=(STEP_EXCEPTION in keys), + ) + + def _save_object_ref(self, identifier: str, obj_ref: ray.ObjectRef): + data = ray.get(obj_ref) + self._put(self._key_obj_id(identifier), data) + + def load_actor_class_body(self) -> type: + """Load the class body of the virtual actor. + + Raises: + DataLoadError: if we fail to load the class body. + """ + return self._get(self._key_class_body()) + + def save_actor_class_body(self, cls: type) -> None: + """Save the class body of the virtual actor. + + Args: + cls: The class body used by the virtual actor. + + Raises: + DataSaveError: if we fail to save the class body. + """ + self._put(self._key_class_body(), cls) + + def save_task_prerun_metadata(self, task_id: TaskID, metadata: Dict[str, Any]): + """Save pre-run metadata of the current task. + + Args: + task_id: ID of the workflow task. + metadata: pre-run metadata of the current task. + + Raises: + DataSaveError: if we fail to save the pre-run metadata. + """ + + self._put(self._key_task_prerun_metadata(task_id), metadata, True) + + def save_task_postrun_metadata(self, task_id: TaskID, metadata: Dict[str, Any]): + """Save post-run metadata of the current task. + + Args: + task_id: ID of the workflow task. + metadata: post-run metadata of the current task. + + Raises: + DataSaveError: if we fail to save the post-run metadata. + """ + + self._put(self._key_task_postrun_metadata(task_id), metadata, True) + + def save_workflow_user_metadata(self, metadata: Dict[str, Any]): + """Save user metadata of the current workflow. + + Args: + metadata: user metadata of the current workflow. + + Raises: + DataSaveError: if we fail to save the user metadata. + """ + + self._put(self._key_workflow_user_metadata(), metadata, True) + + def load_task_metadata(self, task_id: TaskID) -> Dict[str, Any]: + """Load the metadata of the given task. + + Returns: + The metadata of the given task. + """ + + def _load_task_metadata(): + if not self._scan(self._key_task_prefix(task_id), ignore_errors=True): + if not self._scan("", ignore_errors=True): + raise ValueError( + "No such workflow_id '{}'".format(self._workflow_id) + ) + else: + raise ValueError( + "No such task_id '{}' in workflow '{}'".format( + task_id, self._workflow_id + ) + ) + + tasks = [ + self._get(self._key_task_input_metadata(task_id), True, True), + self._get(self._key_task_prerun_metadata(task_id), True, True), + self._get(self._key_task_postrun_metadata(task_id), True, True), + ] + + ( + (input_metadata, _), + (prerun_metadata, _), + (postrun_metadata, _), + ) = tasks + + input_metadata = input_metadata or {} + prerun_metadata = prerun_metadata or {} + postrun_metadata = postrun_metadata or {} + + metadata = input_metadata + metadata["stats"] = {**prerun_metadata, **postrun_metadata} + + return metadata + + return _load_task_metadata() + + def load_workflow_metadata(self) -> Dict[str, Any]: + """Load the metadata of the current workflow. + + Returns: + The metadata of the current workflow. + """ + + def _load_workflow_metadata(): + if not self._scan("", ignore_errors=True): + raise ValueError("No such workflow_id '{}'".format(self._workflow_id)) + + tasks = [ + self._get(self._key_workflow_metadata(), True, True), + self._get(self._key_workflow_user_metadata(), True, True), + self._get(self._key_workflow_prerun_metadata(), True, True), + self._get(self._key_workflow_postrun_metadata(), True, True), + ] + + ( + (status_metadata, _), + (user_metadata, _), + (prerun_metadata, _), + (postrun_metadata, _), + ) = tasks + + status_metadata = status_metadata or {} + user_metadata = user_metadata or {} + prerun_metadata = prerun_metadata or {} + postrun_metadata = postrun_metadata or {} + + metadata = status_metadata + metadata["user_metadata"] = user_metadata + metadata["stats"] = {**prerun_metadata, **postrun_metadata} + + return metadata + + return _load_workflow_metadata() + + def list_workflow( + self, status_filter: Optional[Set[WorkflowStatus]] = None + ) -> List[Tuple[str, WorkflowStatus]]: + """List all workflows matching a given status filter. + + Args: + status_filter: If given, only returns workflow with that status. This can + be a single status or set of statuses. + """ + return self._status_storage.list_workflow(status_filter) + + def delete_workflow(self) -> None: + # TODO (Alex): There's a race condition here if someone tries to + # start the workflow between these ops. + self._status_storage.delete_workflow_status(self._workflow_id) + found = self._storage.delete_dir("") + # TODO (Alex): Different file systems seem to have different + # behavior when deleting a prefix that doesn't exist, so we may + # need to catch a broader class of exceptions. + + if not found: + raise WorkflowNotFoundError(self._workflow_id) + + def update_workflow_status(self, status: WorkflowStatus): + """Update the status of the workflow. + This method is NOT thread-safe. It is handled by the workflow management actor. + """ + self._status_storage.update_workflow_status(self._workflow_id, status) + if status == WorkflowStatus.RUNNING: + self._put( + self._key_workflow_prerun_metadata(), {"start_time": time.time()}, True + ) + elif status in (WorkflowStatus.SUCCESSFUL, WorkflowStatus.FAILED): + self._put( + self._key_workflow_postrun_metadata(), {"end_time": time.time()}, True + ) + + def load_workflow_status(self): + """Load workflow status. If we find the previous status updating failed, + fix it with redo-log transaction recovery.""" + return self._status_storage.load_workflow_status(self._workflow_id) + + def _put(self, key: str, data: Any, is_json: bool = False) -> str: + """Serialize and put an object in the object store. + + Args: + key: The key of the object. + data: The data to be stored. + is_json: If true, json encode the data, otherwise pickle it. + """ + # TODO(suquark): Currently put to file is not atomic -- you can get a partial + # file. This could fail workflow recovery. + try: + if not is_json: + serialization.dump_to_storage( + key, data, self._workflow_id, storage=self + ) + else: + serialized_data = json.dumps(data).encode() + self._storage.put(key, serialized_data) + except Exception as e: + raise DataSaveError from e + + return key + + def _get(self, key: str, is_json: bool = False, no_exception: bool = False) -> Any: + err = None + ret = None + try: + unmarshaled = self._storage.get(key) + if unmarshaled is None: + raise KeyNotFoundError + if is_json: + ret = json.loads(unmarshaled.decode()) + else: + ret = cloudpickle.loads(unmarshaled) + except KeyNotFoundError as e: + err = e + except Exception as e: + err = DataLoadError() + err.__cause__ = e + + if no_exception: + return (ret, err) + elif err is None: + return ret + else: + raise err + + def _scan(self, prefix: str, ignore_errors: bool = False) -> List[str]: + try: + return [p.base_name for p in self._storage.list(prefix)] + except Exception as e: + if ignore_errors: + return [] + raise e + + def _exists(self, key: str) -> bool: + return self._storage.get_info(key) is not None + + # The following functions are helper functions to get the key + # for a specific fields + + def _key_task_input_metadata(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_INPUTS_METADATA) + + def _key_task_user_metadata(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_USER_METADATA) + + def _key_task_prerun_metadata(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_PRERUN_METADATA) + + def _key_task_postrun_metadata(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_POSTRUN_METADATA) + + def _key_task_output(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_OUTPUT) + + def _key_task_exception(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_EXCEPTION) + + def _key_task_output_metadata(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_OUTPUTS_METADATA) + + def _key_task_function_body(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_FUNC_BODY) + + def _key_task_args(self, task_id): + return os.path.join(STEPS_DIR, task_id, STEP_ARGS) + + def _key_obj_id(self, object_id): + return os.path.join(OBJECTS_DIR, object_id) + + def _key_task_prefix(self, task_id): + return os.path.join(STEPS_DIR, task_id, "") + + def _key_class_body(self): + return os.path.join(CLASS_BODY) + + def _key_workflow_metadata(self): + return os.path.join(WORKFLOW_META) + + def _key_workflow_user_metadata(self): + return os.path.join(WORKFLOW_USER_METADATA) + + def _key_workflow_prerun_metadata(self): + return os.path.join(WORKFLOW_PRERUN_METADATA) + + def _key_workflow_postrun_metadata(self): + return os.path.join(WORKFLOW_POSTRUN_METADATA) + + def _key_num_tasks_with_name(self, task_name): + return os.path.join(DUPLICATE_NAME_COUNTER, task_name) + + +def get_workflow_storage(workflow_id: Optional[str] = None) -> WorkflowStorage: + """Get the storage for the workflow. + + Args: + workflow_id: The ID of the storage. + + Returns: + A workflow storage. + """ + if workflow_id is None: + workflow_id = workflow_context.get_workflow_task_context().workflow_id + return WorkflowStorage(workflow_id) + + +def _load_object_ref(paths: List[str], wf_storage: WorkflowStorage) -> ObjectRef: + @ray.remote(num_cpus=0) + def load_ref(paths: List[str], wf_storage: WorkflowStorage): + return wf_storage._get(paths) + + return load_ref.remote(paths, wf_storage) + + +@ray.remote(num_cpus=0) +def _put_obj_ref(ref: Tuple[ObjectRef]): + """ + Return a ref to an object ref. (This can't be done with + `ray.put(obj_ref)`). + + """ + return ref[0] diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb931da262f28e12d0d311580ccdb16be9a9b0c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37a50a2f6a1b9d5981382db3935480f1eb48babfd975e7cc5008363f657ed26e +size 120846 diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53b47df9e82a19b1e01c996d39566c10a4dafa73 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/autograd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/autograd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c91245deea44bd68d88d4c4794024c7fae639c1d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/autograd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/cpp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/cpp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8006fc34ad1ff1ad47ebc5cdfa5253d3ff833b9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/cpp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/dispatcher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/dispatcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca769445bf7c266f63f335ec9b646023de9d4ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/dispatcher.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/functionalization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/functionalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccdb40ee5c3d6f243716498742b3108ac015072a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/functionalization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/lazy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/lazy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b1908dbc35be77c2789aae294a127f2f1548a85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/lazy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/meta.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/meta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1569ffd7f06197a51e01dbc4ba3a4dfb605155 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/meta.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/native.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/native.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b2e4c60ceaa29382aba6d462cbfc37a02c543ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/native.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/python.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/python.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a952e58d06628f69577283ea9cad8cf1a8b615c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/python.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/structured.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/structured.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28c97d46b3574800ec405032f2226ee2b4770009 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/structured.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/translate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/translate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..504ea0519d5e16c303eaa4b6d6e392e8db724529 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/translate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/ufunc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/ufunc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a8a15cff7d61d532a6e45dcb3eaf651639c52a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/ufunc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/unboxing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/unboxing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..056b9bf08464d5c9a4eb8ee4feec1b83d7fe826d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/unboxing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/cpp.py b/.venv/lib/python3.11/site-packages/torchgen/api/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..c657570ee3e2494053e0a618d5707bcb5def0d19 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/cpp.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +from typing import Sequence + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + ArrayRefCType, + BaseCType, + BaseTypeToCppMapping, + Binding, + boolT, + ConstRefCType, + CType, + dimnameListT, + intArrayRefT, + iTensorListRefT, + ListCType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalSymIntArrayRefT, + scalarT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + TupleCType, + VectorCType, + voidT, +) +from torchgen.model import ( + Argument, + Arguments, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the public C++ +# API, which is what people use when they call functions like at::add. +# +# Prominent characteristics of the C++ API: +# +# - dtype, layout, device and pin_memory are collected into +# a single C++ type TensorOptions (the native functions API +# also has this, but tensor options is really most relevant +# for the C++ API; it makes calling kwarg factory functions +# pleasant) +# +# - defaulting lives here (in fact, the dispatcher is completely +# oblivious of defaults!) +# +# BTW: policy on name collisions: we try not to have types with +# collisions, but functions are fair game to collide + + +def name( + func: FunctionSchema, + *, + faithful_name_for_out_overloads: bool = False, + symint_overload: bool = False, +) -> str: + name = str(func.name.name) + if symint_overload: + name += "_symint" + if func.is_out_fn(): + if faithful_name_for_out_overloads: + name += "_outf" + else: + name += "_out" + + return name + + +# Translation of "value types" in JIT schema to C++ API type. Value +# types look the same no matter if they are argument types or return +# types. Returns None if the type in question is not a value type. +def valuetype_type( + t: Type, + *, + binds: ArgName, + mutable: bool = True, + remove_non_owning_ref_types: bool = False, + symint: bool = False, +) -> NamedCType | None: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: + return None + elif str(t) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(SymIntT)) + else: + return NamedCType(binds, BaseCType(longT)) + if remove_non_owning_ref_types: + if t.name == BaseTy.str: + raise AssertionError( + "string ref->value conversion: not implemented yet" + ) + # All other BaseType currently map directly to BaseCppTypes. + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) + elif isinstance(t, OptionalType): + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) + if elem is None: + return None + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if str(t.elem) == "bool": + assert t.size is not None + return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) + else: + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. +# For example, we'll return std::vector instead of IntArrayRef. +# See Note [translation from C++ reference to value types] +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = False, +) -> NamedCType: + # If it's a value type, do the value type translation + r = valuetype_type( + t, + binds=binds, + mutable=mutable, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(BaseCType(tensorT))) + else: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType( + binds, MutRefCType(BaseCType(tensorT)) + ) # TODO: fix this discrepancy + else: + return NamedCType( + binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) + ) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(optionalSymIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + # TODO: remove these special cases, ArrayRef fallthrough works fine + if str(t.elem) == "int": + if remove_non_owning_ref_types: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "SymInt": + if remove_non_owning_ref_types: + if symint: + return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + else: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + if symint: + return NamedCType(binds, BaseCType(symIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "Tensor": + if local.use_ilistref_for_tensor_lists(): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + else: + return NamedCType(binds, BaseCType(tensorListT)) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elif str(t.elem) == "Tensor?": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translate a JIT argument into its C++ type +def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) + + +# Translation of a (non-multi) return type from JIT to C++ +# N.B: returntype_type returns a CType, not a NamedCType. +# This is mostly because of the mismatch between return types and return names. +# e.g. a function with a return type of 'void' has 0 return names, +# and a function with a return type of 'std::tuple' has >1 return name. +def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: + # placeholder is ignored + # NB: symint is ALWAYS respected for return types. So symint argument + # here is IGNORED + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) + if r is not None: + return r.type + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable: + if local.use_const_ref_for_mutable_tensors(): + return ConstRefCType(BaseCType(tensorT)) + else: + return MutRefCType(BaseCType(tensorT)) + else: + # Note [Tensor Copy Returns] + # Currently, we use "Argument.is_write" to determine + # whether or not Tensor return types should be copies or references. + # If that ever changes, take a look at other locations of this note! + return BaseCType(tensorT) + elif t.name == BaseTy.Scalar: + return BaseCType(scalarT) + elif isinstance(t, ListType): + assert ( + not mutable + ), "Native functions should never return a mutable tensor list. They should return void." + elem = returntype_type(t.elem, mutable=False) + assert t.size is None, f"fixed size list returns not supported: {t}" + return VectorCType(elem) + elif isinstance(t, OptionalType): + elem = returntype_type(t.elem, mutable=mutable) + if str(t.elem) == "Tensor": + return OptionalCType(elem) + + raise AssertionError(f"unrecognized return type {t}") + + +# Translation of a single return to its C++ type +def return_type(r: Return, *, symint: bool = False) -> CType: + return returntype_type(r.type, mutable=r.is_write, symint=symint) + + +# Translation of a full (possibly multi) return from JIT to its C++ type +def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: + if len(rs) == 0: + return BaseCType(voidT) + elif len(rs) == 1: + return return_type(rs[0], symint=symint) + else: + return TupleCType([return_type(r, symint=symint) for r in rs]) + + +def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: + returns: list[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = "self" + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any( + r.name == a.name for a in f.func.schema_order_arguments() + ) + if name_conflict and not f.func.is_out_fn(): + name = f"{r.name}_return" + else: + name = r.name + # If there is no explicit name and no fallback name was passed in, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" + returns.append(name) + return returns + + +JIT_TO_CPP_DEFAULT = { + "False": "false", + "True": "true", + "None": "::std::nullopt", # UGH this one is type directed + "Mean": "at::Reduction::Mean", + "[]": "{}", + "contiguous_format": "c10::MemoryFormat::Contiguous", + "long": "at::kLong", +} + + +# Convert a JIT default into C++ expression representing the default +def default_expr(d: str, t: Type, *, symint: bool) -> str: + if d == "None" and str(t) == "Tensor?": + return "{}" + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = "" + i = 1 + while i + 1 < len(d): + if d[i] != "\\": + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i : i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == "None": + return "::std::nullopt" + + return default_expr(d, t.elem, symint=symint) + + if isinstance(t, ListType): + if d.startswith("[") and d.endswith("]"): + return "{" + d[1:-1] + "}" + elif symint and d.isdigit() and str(t.elem) == "SymInt": + return f"c10::SymInt({d})" + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + + return JIT_TO_CPP_DEFAULT.get(d, d) + + +# Convert an argument into its C++ API form + + +def argument( + a: Argument | TensorOptionsArguments | SelfArgument, + *, + cpp_no_default_args: set[str], + method: bool, + faithful: bool, + symint: bool = False, + has_tensor_options: bool, +) -> list[Binding]: + def sub_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Binding]: + return argument( + a, + cpp_no_default_args=cpp_no_default_args, + method=method, + faithful=faithful, + symint=symint, + has_tensor_options=has_tensor_options, + ) + + if isinstance(a, Argument): + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: str | None = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=binds, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, TensorOptionsArguments): + if faithful: + return ( + sub_argument(a.dtype) + + sub_argument(a.layout) + + sub_argument(a.device) + + sub_argument(a.pin_memory) + ) + else: + default = None + # Enforced by NativeFunction.__post_init__ + assert "options" not in cpp_no_default_args + if all(x.default == "None" for x in a.all()): + default = "{}" + elif a.dtype.default == "long": + default = "at::kLong" # TODO: this is wrong + return [ + Binding( + nctype=NamedCType("options", BaseCType(tensorOptionsT)), + name="options", + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) + else: + assert_never(a) + + +def arguments( + arguments: Arguments, + *, + faithful: bool, + symint: bool = False, + method: bool, + cpp_no_default_args: set[str], +) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) + else: + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r + for a in args + for r in argument( + a, + faithful=faithful, + symint=symint, + method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args, + ) + ] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/dispatcher.py b/.venv/lib/python3.11/site-packages/torchgen/api/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..103e6cf429907d1577c3d9caca6f3e28de9e129a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/dispatcher.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import itertools +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import ArgName, Binding, CType, NamedCType +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never, concatMap + + +# This file describes the translation of JIT schema to the dispatcher +# API, the *unboxed* calling convention by which invocations through +# the dispatcher are made. Historically, the dispatcher API matched +# the C++ API, but with the establishment of the boxed API, we've +# made changes to the dispatcher API to so that the unboxed API +# better aligns with the boxed API. The dispatcher API hooks heavily +# into our template based boxing/unboxing machinery, so changes +# to this convention will usually need template updates too. +# +# Prominent characteristics of the dispatcher API: +# +# - dtype, layout, device and pin_memory are represented as separate +# arguments. +# + + +def name(func: FunctionSchema) -> str: + return cpp.name(func) + + +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + # This is a faux amis. If it makes sense in the future to add + # more special cases here, or invert things so cpp.argument_type + # calls this, or just completely inline the function, please do + # it. + return cpp.argumenttype_type( + t, + mutable=mutable, + binds=binds, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + + +def argument_type( + a: Argument, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + return argumenttype_type( + a.type, + mutable=a.is_write, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ) + + +def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: + # At present, there is no difference. But there could be! + return cpp.returns_type(rs, symint=symint) + + +def jit_arguments(func: FunctionSchema) -> list[Argument]: + def to_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Argument]: + if isinstance(a, Argument): + return [a] + elif isinstance(a, SelfArgument): + return [a.argument] + elif isinstance(a, TensorOptionsArguments): + return [a.dtype, a.layout, a.device, a.pin_memory] + else: + assert_never(a) + + return list( + concatMap( + to_argument, + itertools.chain( + func.arguments.positional, func.arguments.kwarg_only, func.arguments.out + ), + ) + ) + + +def argument( + a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True +) -> Binding: + return Binding( + nctype=argument_type( + a, + binds=a.name, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ), + name=a.name, + argument=a, + ) + + +def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: + return [argument(a, symint=symint) for a in jit_arguments(func)] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/native.py b/.venv/lib/python3.11/site-packages/torchgen/api/native.py new file mode 100644 index 0000000000000000000000000000000000000000..a00e8266b8daa7a2614e516a010cc23c497d6151 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/native.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import Sequence + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + deviceT, + layoutT, + ListCType, + MutRefCType, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + tensorT, +) +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the native functions API. +# This looks a lot like the C++ API (which makes historical sense, because the +# idea was you wrote native functions to implement functions in the C++ API), +# but over time we have evolved the C++ API without actually changing our +# native:: kernels. The intention is to make native API and dispatcher API +# line up as closely as possible, since this results in the least overhead +# (no translation is needed from dispatcher API to native API). +# +# NB: this is symint aware, you will get the non-SymInt variant for some +# dispatch entries and SymInt for others. + + +def name(func: FunctionSchema) -> str: + name = str(func.name.name) + # TODO: delete this! + if func.is_out_fn(): + name += "_out" + if func.name.overload_name: + name += f"_{func.name.overload_name}" + return name + + +def argumenttype_type( + t: Type, *, mutable: bool, binds: ArgName, symint: bool +) -> NamedCType: + if str(t) == "Tensor?": + tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(tensor_type)) + else: + return NamedCType(binds, ConstRefCType(tensor_type)) + elif str(t) == "Tensor?[]": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elif str(t) == "Scalar": + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif str(t) == "Scalar?": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) + + +def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: + return cpp.returns_type(rs, symint=symint) + + +def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) + + +def argument( + a: Argument | SelfArgument | TensorOptionsArguments, + *, + is_out: bool, + symint: bool, +) -> list[Binding]: + # Ideally, we NEVER default native functions. However, there are a number + # of functions that call native:: directly and rely on the defaulting + # existing. So for BC, we generate defaults for non-out variants (but not + # for out variants, where it is impossible to generate an appropriate + # default) + should_default = not is_out + if isinstance(a, Argument): + default: str | None = None + if should_default and a.default is not None: + default = cpp.default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=a.name, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + # Erase SelfArgument from the distinction + return argument(a.argument, is_out=is_out, symint=symint) + elif isinstance(a, TensorOptionsArguments): + default = None + if should_default: + default = "{}" + # TODO: Not sure why the arguments assigned here are for + # TensorOptionsArguments and not the constituent pieces. It seems + # to matter + return [ + Binding( + nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), + name="dtype", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), + name="layout", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), + name="device", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), + name="pin_memory", + default=default, + argument=a, + ), + ] + else: + assert_never(a) + + +def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(func.arguments.non_out) + args.extend(func.arguments.out) + return [ + r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) + ] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/structured.py b/.venv/lib/python3.11/site-packages/torchgen/api/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..93a72eb2b4a5c119ee8f60ce04f0517fe862b4d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/structured.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + ArrayRefCType, + BaseCType, + Binding, + ConstRefCType, + dimnameListT, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalarT, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunctionsGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the structured functions API. +# This is similar to native API, but a number of historical problems with native +# API have been fixed. + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# NB: For now, mutable doesn't do anything; but it could if we make +# some more nominal types +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: + # If it's a value type, do the value type translation + # NB: structured kernels ALWAYS have symint off, since they involve actual + # kernels that require real ints. The one exception is the + # CompositeExplicitAutograd and the meta function (which could + # hypothetically be SymInt), but for simplicity we plan for these to just + # be handled in Python + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(optionalTensorRefT)) + elif t.elem == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(optionalScalarRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): + return NamedCType(binds, BaseCType(iOptTensorListRefT)) + # TODO: delete these special cases; see torchgen.api.cpp--these + # must be changed in tandem, but there are problems; see + # https://github.com/pytorch/pytorch/pull/51485 + elif str(t.elem) == "int": + return NamedCType(binds, BaseCType(intArrayRefT)) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + + +# returns_type intentionally omitted, because structured kernels never "return"; +# instead, they always indirectly report their outputs (in the case of a meta +# function, by calling set_output; in the case of an impl function, by writing +# directly into the provided out argument). + + +# Structured kernels are never defaulted +def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: + if isinstance(a, Argument): + return [ + Binding( + nctype=argument_type(a, binds=a.name), + name=a.name, + default=None, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + return argument(a.argument) + elif isinstance(a, TensorOptionsArguments): + raise AssertionError("structured kernels don't support TensorOptions yet") + else: + assert_never(a) + + +def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + + if g.out.precomputed: + # A list of parameters for the impl function with + # certain parameters replaced with precomputed counterparts + # as specified in native_functions.yaml. + non_out_args_replaced: list[ + Argument | TensorOptionsArguments | SelfArgument + ] = [] + for a in g.out.func.arguments.non_out: + if isinstance(a, Argument) and a.name in g.out.precomputed.replace: + # If a is in precompute.replace, append the parameters + # that should replace it onto non_out_args_replaced. + non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) + else: + # If not, push a as it is. + non_out_args_replaced.append(a) + + args.extend(non_out_args_replaced) + # g.out.precomputed.add is the list of parameters that are added + # without replacement after the non out args and just before the out args + args.extend(g.out.precomputed.add) + else: + args.extend(g.out.func.arguments.non_out) + + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] + + +def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.functional.func.arguments.non_out) + return [r for arg in args for r in argument(arg)] + + +def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/translate.py b/.venv/lib/python3.11/site-packages/torchgen/api/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..761fb3c7c2b98707bd9b9f79a8a5842fc7ce11a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/translate.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +from typing import NoReturn, Sequence + +from torchgen.api.types import ( + ArrayRefCType, + BaseCType, + Binding, + boolT, + ConstRefCType, + deviceT, + Expr, + intArrayRefT, + iOptTensorListRefT, + layoutT, + ListCType, + longT, + memoryFormatT, + MutRefCType, + NamedCType, + opmath_t, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalSymIntArrayRefT, + optionalTensorRefT, + scalar_t, + scalarT, + scalarTypeT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorOptionsT, + tensorT, + VectorCType, +) + + +# This file implements a small program synthesis engine that implements +# conversions between one API to another. +# +# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType +# represents a C++ type, plus semantic information about what it represents. +# For example, consider the argument "bool pin_memory"; its normal C++ type is +# "bool", but its C++ semantic type also keeps track that this represents a +# "pin_memory"; you can't just use a random other boolean in a context where you +# need a "pin_memory"! +# +# The translator takes a list of needed NamedCTypes, and then figures out how +# to construct expressions with these NamedCTypes from the given bindings. Many +# of these expressions are trivial (I need a Tensor other; there's a Tensor +# other scope); others are more nontrivial and may require packing/unpacking. +# Some examples of non-trivial action: +# +# - Need the "dtype" binding? Well, maybe "dtype" isn't available +# in the context, instead, "options" is, and you need to extract +# it from there. (Gather) +# +# - Need the "context" binding? Well, maybe "context" isn't available +# in the context, and you need to construct it from "dtype", "device", +# etc. (Scatter) +# +# - Need the "memory_format" binding? Well, actually, it's available +# from both "memory_format" and "options", so you had better make sure +# they are consistent. (Join) + +options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) + +out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) + +longVec_ctype = VectorCType(BaseCType(longT)) +longSymVec_ctype = VectorCType(BaseCType(SymIntT)) +optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) +optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) +optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) + + +class UnsatError(RuntimeError): + pass + + +# Given a set of in-scope bindings and a set of target bindings, synthesize +# a list of expressions that uses only the in-scope bindings (bindings) that +# have all of the types of goals. You may want to use this function if +# you're generating code for a function like: +# +# void f({args}) { +# g({exprs}); // g is a different API +# } +# +# and you need to generate "exprs". +# +# Typically, a list of Bindings is convenient to get (you usually call something +# like arguments() to get them); but technically you only need less information: +# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for +# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing +# something more complicated, e.g., tracking the set of bindings in a context, +# you may find using these smaller types more convenient. +def translate( + bindings: Sequence[Expr | Binding], + goals: Sequence[NamedCType | Binding], + *, + method: bool = False, + allow_expensive_conversions: bool = False, +) -> list[Expr]: + binding_exprs: list[Expr] = [] + for b in bindings: + if isinstance(b, Binding): + binding_exprs.append( + Expr( + expr=b.name, + type=b.nctype, + ) + ) + else: + binding_exprs.append(b) + + goal_ctypes: list[NamedCType] = [] + for g in goals: + if isinstance(g, Binding): + goal_ctypes.append(g.nctype) + else: + goal_ctypes.append(g) + + # Add all the bindings to the context + ctx: dict[NamedCType, str] = {} + for b in binding_exprs: + ctx[b.type] = b.expr + + # While we're at it, do some simple forward inference, looking through + # constructors. + # + # NB: When should you do forward inference versus backward inference? + # The general idea: + # + # - Backward inference WHEN the goal gets smaller + # - Forward inference WHEN the hypothesis gets smaller + # + # This helps ensure termination: backward inference starts with a goal + # and tries to make it simpler and simpler until it's trivial; if the + # goal can grow in size, we blow up to a really huge goal size. + # Similarly, with forward inference we take hypotheses and decompose + # them into simpler hypotheses; if hypotheses could expand in size, + # we also have potential nontermination. (In the code below, forward + # inference is only ever carried out at a single step, but you could + # imagine repeated application of forward inference being profitable.) + # + # A good starting point in the literature for exploring more about proof + # search are these lecture notes + # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf + # + # TODO: My kingdom for a pattern matcher + # https://www.python.org/dev/peps/pep-0634/ + # + # TODO: This could get us in recomputation trouble if b.expr is nontrivial. + # Fix this by implementing some sort of sharing so that if multiple + # goals share the same expression, we only compute it once. This seems + # to matter in practice as compiler is often unwilling to CSE nontrivial + # expressions like scalar.to() + t = b.type + if ( + isinstance(t, ConstRefCType) + and isinstance(t.elem, OptionalCType) + and isinstance(t.elem.elem, BaseCType) + and str(t.elem.elem.type) == "at::Tensor" + ): + ctx[ + NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) + ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" + + if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): + ctx[ + NamedCType(t.name, BaseCType(optionalTensorRefT)) + ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" + + if t.type == ConstRefCType(BaseCType(scalarT)): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()" + + if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): + ctx[ + NamedCType(t.name, BaseCType(optionalScalarRefT)) + ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" + + if t.type == BaseCType(scalar_t): + ctx[ + NamedCType(t.name, BaseCType(opmath_t)) + ] = f"static_cast({b.expr})" + + # [Note: IOptTensorListRef] + if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): + ctx[ + NamedCType(t.name, BaseCType(iOptTensorListRefT)) + ] = f"at::IOptTensorListRef({b.expr})" + + # Add implicit bindings if the generated code is inside a Tensor method + if method: + ctx[ + NamedCType("self", MutRefCType(BaseCType(tensorT))) + ] = "const_cast(*this)" + ctx[ + NamedCType("self", ConstRefCType(BaseCType(tensorT))) + ] = "const_cast(*this)" + # This is better! Byte-for-byte compat + # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" + + def unsat(goal: NamedCType) -> NoReturn: + ctx_desc = "\n".join( + f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() + ) + raise UnsatError( + f""" +Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". +When I failed, the following bindings were available in the context: + +{ctx_desc} + +This probably means there is a missing rule in the rules of torchgen.api.translate. +Check this module for more information. +""" + ) + + # A shitty backtracking search implementation. It's shitty because it + # does backtracking via stack (bad idea!) and for the most part tries to + # avoid backtracking. In particular, if + # direct=True, we won't try to do any fancy synthesis, just trivial + # conversions (e.g., "T a" is OK for "const T& a"). So all of the + # existing rules in this function simply try to solve immediately, + # and bail if things don't work out. + def solve(goal: NamedCType, *, direct: bool) -> str: + def direct_solve(goal: NamedCType) -> str: + return solve(goal, direct=True) + + if goal in ctx: + # Trivial + return ctx[goal] + + # const & is satisfied with mutable & + if isinstance(goal.type, ConstRefCType): + try: + # WARNING: not strictly decreasing; be careful not + # to add a direct conversion that goes satisfies + # mutable& with const& + return solve( + NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct + ) + except UnsatError: + pass + + # mutable & is satisfied with value + if isinstance(goal.type, MutRefCType): + try: + return solve(NamedCType(goal.name, goal.type.elem), direct=direct) + except UnsatError: + pass + + # TODO: These are referentially equal, shouldn't have to do this; + # ensuring we don't use type synonym IntArrayRef in codegen would + # help + if goal.type == ArrayRefCType(BaseCType(longT)): + return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) + + if direct: + unsat(goal) + + # For now, all of these rules are mutually exclusive. + if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): + memory_format = direct_solve( + NamedCType( + SpecialArgName.possibly_redundant_memory_format, + OptionalCType(BaseCType(memoryFormatT)), + ) + ) + # No need to join "memory_format" and "options" if the target API takes "options" directly. + # Otherwise it will cause the redundant memory_format error. + if options_ctype in goal_ctypes: + return memory_format + try: + options = direct_solve(options_ctype) + return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" + except UnsatError: + return memory_format + elif goal == NamedCType("options", BaseCType(tensorOptionsT)): + dtype = direct_solve( + NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) + ) + pin_memory = direct_solve( + NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) + ) + device = direct_solve( + NamedCType("device", OptionalCType(BaseCType(deviceT))) + ) + layout = direct_solve( + NamedCType("layout", OptionalCType(BaseCType(layoutT))) + ) + return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" + + elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): + try: + options = direct_solve(options_ctype) + return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.scalar_type()" + + elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): + try: + options = direct_solve(options_ctype) + return f"{options}.layout_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.layout()" + + elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): + try: + options = direct_solve(options_ctype) + return f"{options}.device_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.device()" + + elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): + try: + options = direct_solve(options_ctype) + return f"{options}.pinned_memory_opt()" + except UnsatError: + # If we're calling a factory op from its out= variant, + # We don't actually care about the value of pin_memory. + out_tensor = direct_solve(out_tensor_ctype) + return "::std::nullopt" + + # We can always do translations from value types to reference types, like vector -> IntArrayRef + elif goal.type == BaseCType(intArrayRefT): + try: + return direct_solve(NamedCType(goal.name, longVec_ctype)) + except UnsatError: + # We can also go SymIntArrayRef -> IntArrayRef + symIntArrayRef_type = direct_solve( + NamedCType(goal.name, BaseCType(symIntArrayRefT)) + ) + return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" + elif goal.type == BaseCType(symIntArrayRefT): + try: + r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) + return f"c10::fromIntArrayRefSlow({r})" + except UnsatError: + return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + elif goal.type == BaseCType(SymIntT): + return direct_solve(NamedCType(goal.name, BaseCType(longT))) + elif goal.type == OptionalCType(BaseCType(SymIntT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(longT))) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(longT): + symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) + return f"{symInt_type}.guard_int(__FILE__, __LINE__)" + elif goal.type == OptionalCType(BaseCType(longT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) + ) + return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" + elif goal.type == BaseCType(optionalIntArrayRefT): + try: + return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) + except UnsatError: + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalSymIntArrayRefT): + # TODO: You might also want to solve this from longSymVec_ctype or + # an optional version of it + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalScalarRefT): + return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) + elif goal.type == BaseCType(optionalTensorRefT): + return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) + + # Note [translation from C++ reference to value types] + # The below cases are all for when we have an argument with a reference type, + # and a corresponding goal with a value type. + # These are needed when we populate the inputs to a lambda capture and we need + # to guarantee the lifetime of each captured argument. + # We guard it with an explicit kwarg because converting to a value type is expensive + # (O(n)) to convert from IntArrayRef to vector), + # so the caller of translate() should be explicit that they need it. + if allow_expensive_conversions: + if goal.type == VectorCType(BaseCType(longT)): + intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) + argname = direct_solve(intArrayRef_ctype) + return f"{argname}.vec()" + if goal.type == VectorCType(BaseCType(SymIntT)): + symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) + argname = direct_solve(symIntArrayRef_ctype) + return f"{argname}.vec()" + elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): + optionalIntArrayRef_ctype = NamedCType( + goal.name, BaseCType(optionalIntArrayRefT) + ) + argname = direct_solve(optionalIntArrayRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalScalarRef_ctype = NamedCType( + goal.name, BaseCType(optionalScalarRefT) + ) + argname = direct_solve(optionalScalarRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalTensorRef_ctype = NamedCType( + goal.name, BaseCType(optionalTensorRefT) + ) + argname = direct_solve(optionalTensorRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + # Technically, we also need to handle cases of C++ containers holding reference types. + # But there currently aren't any ops that require lambda capture codegen + # With arguments like ::std::vector. + # If that changes, we'll have to add the translation here. + + # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. + # We could probably generalize this to non-tensor types too. + if goal.type == MutRefCType(BaseCType(tensorT)): + const_ref_tensor_ctype = NamedCType( + goal.name, ConstRefCType(BaseCType(tensorT)) + ) + argname = direct_solve(const_ref_tensor_ctype) + return f"const_cast({argname})" + + unsat(goal) + + return [Expr(solve(g, direct=False), g) for g in goal_ctypes] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e98bb8df493f2375b514e6c6aeb897cebe8ec7d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/types/__init__.py @@ -0,0 +1,5 @@ +from torchgen.api.types.types import * +from torchgen.api.types.types_base import * + + +from torchgen.api.types.signatures import * # usort: skip diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8004ae2e9b999a1484ba70e180885d5c8738fda5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/signatures.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/signatures.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8409ba436b5b6dca8972097d82d093f2756ab832 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/signatures.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48f326e87cab5ff85fb511940c95e1e151fd78d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba0d1d7fe3d135e4301025ea5cf2b9917557583 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/api/types/__pycache__/types_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/signatures.py b/.venv/lib/python3.11/site-packages/torchgen/api/types/signatures.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d85ca6e2fe88e3a7047b2f3b1c887f5e583846 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/types/signatures.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterator, Sequence, TYPE_CHECKING + +from torchgen.api.types.types_base import Binding, CType, Expr + + +if TYPE_CHECKING: + from torchgen.model import ( + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + ) + + +@dataclass(frozen=True) +class CppSignature: + """ + A CppSignature represents a single overload in the C++ API. For + any given function schema, there may be multiple CppSignatures + corresponding to it, based on how we desugar to C++. See also + CppSignatureGroup. + """ + + # The schema this signature is derived from + func: FunctionSchema + + # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? + method: bool + + # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API + # (i.e. with a potential TensorOptions argument and out arguments in the front) + faithful: bool + + # Is this a symint C++ signature. For BC reasons, functions that take + # SymInts still present as int64_t in C++, and the SymInt variant is + # offered at a different overload name + # + # NB: If a function RETURNS a SymInt, this is ALWAYS false + symint: bool + + # The set of C++ arguments which should not have defaults applied to them + cpp_no_default_args: set[str] + + # Is this a fallback C++ binding? Fallback bindings are enabled by + # manual_cpp_binding: True and are alternate, non-public API that + # lets manual C++ binding implementors access the binding that would + # have been automatically generated + fallback_binding: bool = False + + # Return the unpacked argument structure of this signature, + # discarding information about which arguments are semantically + # related to each other. + def arguments(self) -> Sequence[Binding]: + return cpp.arguments( + self.func.arguments, + faithful=self.faithful, + symint=self.symint, + method=self.method, + cpp_no_default_args=self.cpp_no_default_args, + ) + + def name(self, *, suppress_symint_suffix: bool = False) -> str: + n = cpp.name( + self.func, + faithful_name_for_out_overloads=self.faithful, + symint_overload=False if suppress_symint_suffix else self.symint, + ) + if self.fallback_binding: + n = f"__dispatch_{n}" + return n + + # Render the C++ declaration for this signature + def decl( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + suppress_symint_suffix: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.decl() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix) + return f"{returns_type} {name}({cpp_args_str})" + + # Render the C++ definition for this signature, not including + # the body (with curly braces) + def defn( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name() + return f"{returns_type} {name}({cpp_args_str})" + + def ptr_type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})" + + +# Represents group of all CppSignatures associated with a +# FunctionSchema. Right now, that's the regular, user-visible +# signature, as well as a "faithful" signature which doesn't +# have grouping. +@dataclass(frozen=True) +class CppSignatureGroup: + func: FunctionSchema + signature: CppSignature + faithful_signature: CppSignature | None + symint_signature: CppSignature | None + symint_faithful_signature: CppSignature | None + + def most_faithful_signature(self) -> CppSignature: + if self.faithful_signature: + return self.faithful_signature + else: + return self.signature + + def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]: + yield self.signature + if self.faithful_signature: + yield self.faithful_signature + if symint: + if self.symint_signature: + yield self.symint_signature + if self.symint_faithful_signature: + yield self.symint_faithful_signature + + @staticmethod + def from_native_function( + f: NativeFunction, *, method: bool, fallback_binding: bool = False + ) -> CppSignatureGroup: + func = f.func + + def make_sig(*, faithful: bool, symint: bool) -> CppSignature: + return CppSignature( + func=func, + faithful=faithful, + symint=symint, + method=method, + fallback_binding=fallback_binding, + cpp_no_default_args=f.cpp_no_default_args, + ) + + def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]: + faithful_signature: CppSignature | None = None + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = make_sig(faithful=True, symint=symint) + signature = make_sig(faithful=False, symint=symint) + return signature, faithful_signature + + signature, faithful_signature = make_sigs(symint=False) + symint_signature: CppSignature | None = None + symint_faithful_signature: CppSignature | None = None + if func.has_symint(): + symint_signature, symint_faithful_signature = make_sigs(symint=True) + + return CppSignatureGroup( + func=func, + signature=signature, + faithful_signature=faithful_signature, + symint_signature=symint_signature, + symint_faithful_signature=symint_faithful_signature, + ) + + +@dataclass(frozen=True) +class DispatcherSignature: + # The schema this signature is derived from + func: FunctionSchema + + # Allows you to prepend an arbitrary prefix to the signature name. + # This is useful for parts of the codegen that generate wrappers around kernels, + # and need to avoid naming collisions. + prefix: str = "" + + symint: bool = True + + def arguments(self) -> list[Binding]: + return dispatcher.arguments(self.func, symint=self.symint) + + def name(self) -> str: + return self.prefix + dispatcher.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def defn( + self, name: str | None = None, *, is_redispatching_fn: bool = False + ) -> str: + args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + args = ["c10::DispatchKeySet dispatchKeySet"] + args + args_str = ", ".join(args) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def exprs(self) -> list[Expr]: + return [Expr(a.name, a.nctype) for a in self.arguments()] + + def returns_type(self) -> CType: + return dispatcher.returns_type(self.func.returns, symint=self.symint) + + def ptr_type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" + + @staticmethod + def from_schema( + func: FunctionSchema, *, prefix: str = "", symint: bool = True + ) -> DispatcherSignature: + return DispatcherSignature(func, prefix, symint) + + +@dataclass(frozen=True) +class NativeSignature: + # The schema this signature is derived from + func: FunctionSchema + + symint: bool + + prefix: str = "" + + def name(self) -> str: + return self.prefix + native.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def ptr_type(self) -> str: + # don't include defaults in type signature! + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" + + def arguments(self) -> list[Binding]: + return native.arguments(self.func, symint=self.symint) + + def returns_type(self) -> CType: + return native.returns_type(self.func.returns, symint=self.symint) + + def dispatcher_exprs(self) -> list[Expr]: + return translate.translate( + self.arguments(), dispatcher.arguments(self.func), method=False + ) + + +@dataclass(frozen=True) +class ViewInverseSignature: + g: NativeFunctionsViewGroup + + def name(self) -> str: + return functionalization.reverse_name(self.g.view, include_namespace=False) + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + decls = [ + a.decl() + for a in functionalization.inner_arguments( + self.g.view.func, is_reverse=True + ) + ] + return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" + + +@dataclass(frozen=True) +class FunctionalizationLambda: + g: NativeFunctionsViewGroup + + # are we generating the forward lambda or the reverse lambda? + is_reverse: bool + + def captures(self) -> list[Expr]: + # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments + # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, + # and plumb it into the lambda. + outer_ctx = dispatcher.arguments(self.g.view.func) + [ + functionalization.reapply_views_binding, + functionalization.inverse_return_mode_binding, + ] + capture_bindings = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + # allow_expensive_conversions is set because we want to convert + # some reference types (IntArrayRef) to value types (vector). + capture_exprs = translate.translate( + outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True + ) + return capture_exprs + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + capture_str = ", ".join( + f"{val.type.name} = {val.expr}" for val in self.captures() + ) + decls = [ + a.decl() + for a in functionalization.outer_arguments(is_reverse=self.is_reverse) + ] + return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" + + def inner_call(self, *, reapply_views: bool | None = None) -> str: + inner_call_name = functionalization.name( + self.g, + is_reverse=self.is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) + capture_ctx = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + full_ctx = arg_ctx + capture_ctx + + assert self.g.view_copy is not None + call_bindings = functionalization.inner_arguments( + self.g.view_copy.func, is_reverse=self.is_reverse + ) + maybe_index = functionalization.inner_call_index(self.g.view_copy.func) + call_exprs = [ + e.expr for e in translate.translate(full_ctx, call_bindings, method=False) + ] + if not self.is_reverse and maybe_index is not None: + return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' + else: + return f'{inner_call_name}({", ".join(call_exprs)});' + + @staticmethod + def from_func( + g: NativeFunctionsViewGroup, *, is_reverse: bool + ) -> FunctionalizationLambda: + return FunctionalizationLambda(g, is_reverse) + + +@dataclass(frozen=True) +class StructuredImplSignature: + g: NativeFunctionsGroup + name: str + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"TORCH_IMPL_FUNC({self.name})({args_str})" + + def arguments(self) -> list[Binding]: + return structured.impl_arguments(self.g) + + +# Helper functions + + +def kernel_signature( + f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" +) -> NativeSignature | DispatcherSignature: + # Note [External Backends Follow Dispatcher API] + # Kernel signatures for in-tree backends follow the "native" API, + # while kernels for out-of-tree backends follow the dispatcher API. + # See the comments in `native.py` for details, but historically there have been + # some small differences in schema convention between them and the Dispatcher API. + # Any differences that require translating between the two will results in a runtime cost, + # so we'd like to keep the differences as small as possible. + # With external backends, we'd like to enforce that they write their kernels with schemas + # that match the Dispatcher API directly, if they can. + meta = backend_index.get_kernel(f) + symint = meta is not None and meta.supports_symint() + if symint: + assert ( + f.func.has_symint() + ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + if backend_index.external: + return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) + else: + return NativeSignature(f.func, prefix=prefix, symint=symint) + + +# Functions only, no types +from torchgen.api import ( + cpp, + dispatcher, + functionalization, + native, + structured, + translate, +) diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/types.py b/.venv/lib/python3.11/site-packages/torchgen/api/types/types.py new file mode 100644 index 0000000000000000000000000000000000000000..30e027a631200029e01f337b96c77013193bfd4f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/types/types.py @@ -0,0 +1,191 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from torchgen.api.types.types_base import ( + BaseCppType, + BaseCType, + boolT, + byteT, + charT, + CType, + doubleT, + floatT, + int32T, + longT, + shortT, +) +from torchgen.model import BaseTy, ScalarType + + +TENSOR_LIST_LIKE_CTYPES = [ + "at::TensorList", + "const c10::List<::std::optional> &", + "const at::ITensorListRef &", +] + + +halfT = BaseCppType("at", "Half") +complexHalfT = BaseCppType( + "c10", "complex" +) # stuffing template param here is an abuse +complexFloatT = BaseCppType("c10", "complex") +complexDoubleT = BaseCppType("c10", "complex") +bfloat16T = BaseCppType("at", "BFloat16") +float8_e5m2T = BaseCppType("at", "Float8_e5m2") +float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") +float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") +float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") +stringT = BaseCppType("c10", "string_view") +generatorT = BaseCppType("at", "Generator") +scalarTypeT = BaseCppType("at", "ScalarType") +tensorT = BaseCppType("at", "Tensor") +optionalTensorRefT = BaseCppType("at", "OptionalTensorRef") +tensorListT = BaseCppType("at", "TensorList") +iTensorListRefT = BaseCppType("at", "ITensorListRef") +iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef") +dimnameT = BaseCppType("at", "Dimname") +dimnameListT = BaseCppType("at", "DimnameList") +dimVectorT = BaseCppType("at", "DimVector") +layoutT = BaseCppType("at", "Layout") +deviceT = BaseCppType("at", "Device") +deviceIndexT = BaseCppType("at", "DeviceIndex") +scalarT = BaseCppType("at", "Scalar") +optionalScalarRefT = BaseCppType("at", "OptionalScalarRef") +memoryFormatT = BaseCppType("at", "MemoryFormat") +qschemeT = BaseCppType("at", "QScheme") +storageT = BaseCppType("at", "Storage") +streamT = BaseCppType("at", "Stream") +intArrayRefT = BaseCppType("at", "IntArrayRef") +optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef") +optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef") +tensorOptionsT = BaseCppType("at", "TensorOptions") +typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize") +tensorGeometryT = BaseCppType("at", "TensorGeometry") +SymIntT = BaseCppType("c10", "SymInt") +symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") + +# Types representing template parameters. Technically, we probably shouldn't +# represent them this way in codegen, but it was pretty convenient. +scalar_t = BaseCppType("", "scalar_t") +opmath_t = BaseCppType("", "opmath_t") + +ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { + ScalarType.Byte: byteT, + ScalarType.Char: charT, + ScalarType.Short: shortT, + ScalarType.Int: int32T, + ScalarType.Long: longT, + ScalarType.Half: halfT, + ScalarType.Float: floatT, + ScalarType.Double: doubleT, + ScalarType.ComplexHalf: complexHalfT, + ScalarType.ComplexFloat: complexFloatT, + ScalarType.ComplexDouble: complexDoubleT, + ScalarType.Bool: boolT, + ScalarType.Float8_e5m2: float8_e5m2T, + ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, + ScalarType.Float8_e4m3fn: float8_e4m3fnT, + ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, +} + +BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { + BaseTy.int: longT, + BaseTy.float: doubleT, + BaseTy.bool: boolT, + BaseTy.str: stringT, + BaseTy.Generator: generatorT, + BaseTy.ScalarType: scalarTypeT, + BaseTy.Tensor: tensorT, + BaseTy.Dimname: dimnameT, + BaseTy.DimVector: dimVectorT, + BaseTy.Layout: layoutT, + BaseTy.Device: deviceT, + BaseTy.DeviceIndex: deviceIndexT, + BaseTy.Scalar: scalarT, + BaseTy.MemoryFormat: memoryFormatT, + BaseTy.QScheme: qschemeT, + BaseTy.Storage: storageT, + BaseTy.Stream: streamT, + BaseTy.SymInt: SymIntT, +} + +# CTypes encode C++ type structure as needed for translation. + + +@dataclass(frozen=True) +class OptionalCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::optional<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return OptionalCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ListCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"c10::List<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"c10::List<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return ListCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"at::ArrayRef<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return ArrayRefCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class VectorizedCType(CType): + # This template is explicitly specialized, so the only valid + # elems are those we have specializations for (e.g., float, double, ...) + # scalar_t is also a common argument here (when we are codegen in + # a templated context) + elem: BaseCType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return f"at::vec::Vectorized<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + def remove_const_ref(self) -> CType: + return self diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/types/types_base.py b/.venv/lib/python3.11/site-packages/torchgen/api/types/types_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e031b79485e057769302149369500cdb3df4c1e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/types/types_base.py @@ -0,0 +1,276 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import auto, Enum +from typing import TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from torchgen.model import Argument, SelfArgument, TensorOptionsArguments + + +# An ArgName is just the str name of the argument in schema; +# but in some special circumstances, we may add a little extra +# context. The Enum SpecialArgName covers all of these cases; +# grep for their construction sites to see when they can occur. + + +class SpecialArgName(Enum): + possibly_redundant_memory_format = auto() + + +ArgName = Union[str, SpecialArgName] + + +# This class shouldn't be created directly; instead, use/create one of the singletons below. +@dataclass(frozen=True) +class BaseCppType: + ns: str | None + name: str + + def __str__(self) -> str: + if self.ns is None or self.ns == "": + return self.name + return f"{self.ns}::{self.name}" + + +# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen. +# Templated types get their own dataclass, mainly to make namespace parsing easier. +byteT = BaseCppType("", "uint8_t") +charT = BaseCppType("", "int8_t") +shortT = BaseCppType("", "int16_t") +# It would be more symmetric for this to be called intT, but it easy to mix +# this up with JIT int (which is int64_t in C++), so we intentionally don't +# define intT to make it obvious when you've stuffed it up +int32T = BaseCppType("", "int32_t") +longT = BaseCppType("", "int64_t") +doubleT = BaseCppType("", "double") +floatT = BaseCppType("", "float") +boolT = BaseCppType("", "bool") +voidT = BaseCppType("", "void") + + +class CType(ABC): + @abstractmethod + def cpp_type(self, *, strip_ref: bool = False) -> str: + raise NotImplementedError + + @abstractmethod + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + @abstractmethod + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class BaseCType(CType): + type: BaseCppType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return str(self.type) + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def cpp_type_registration_declarations(self) -> str: + return str(self.type).replace("at::", "") + + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class ConstRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"const {self.elem.cpp_type()} &" + + def cpp_type_registration_declarations(self) -> str: + return f"const {self.elem.cpp_type_registration_declarations()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +@dataclass(frozen=True) +class VectorCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::vector<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return VectorCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayCType(CType): + elem: CType + size: int + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::array<{self.elem.cpp_type()},{self.size}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>" + + def remove_const_ref(self) -> CType: + return ArrayCType(self.elem.remove_const_ref(), self.size) + + +@dataclass(frozen=True) +class TupleCType(CType): + elems: list[CType] + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>' + + def cpp_type_registration_declarations(self) -> str: + return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>' + + def remove_const_ref(self) -> CType: + return TupleCType([e.remove_const_ref() for e in self.elems]) + + +@dataclass(frozen=True) +class MutRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"{self.elem.cpp_type()} &" + + def cpp_type_registration_declarations(self) -> str: + return f"{self.elem.cpp_type_registration_declarations()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus +# semantic information about what it represents. For example, consider the +# argument "bool pin_memory"; its normal C++ type is "bool", but its C++ +# semantic type also keeps track that this represents a "pin_memory"; you can't +# just use a random other boolean in a context where you need a "pin_memory"! +# + + +@dataclass(frozen=True) +class NamedCType: + name: ArgName + type: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return self.type.cpp_type(strip_ref=strip_ref) + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def cpp_type_registration_declarations(self) -> str: + return self.type.cpp_type_registration_declarations() + + def remove_const_ref(self) -> NamedCType: + return NamedCType(self.name, self.type.remove_const_ref()) + + def with_name(self, name: str) -> NamedCType: + return NamedCType(name, self.type) + + +# A binding represents any C++ binding site for a formal parameter. +# We don't distinguish between binding sites for different APIs; +# instead, all of the important distinctions are encoded in CType, +# which you can use to figure out if a given Binding is appropriate +# for use in another context. (See torchgen.api.translate) + + +@dataclass(frozen=True) +class Binding: + name: str + nctype: NamedCType + argument: Argument | TensorOptionsArguments | SelfArgument + # TODO: maybe don't represent default here + default: str | None = None + + def rename(self, name: str) -> Binding: + return Binding( + name=name, + nctype=self.nctype, + argument=self.argument, + default=self.default, + ) + + @property + def type(self) -> str: + return self.nctype.cpp_type() + + def no_default(self) -> Binding: + return Binding( + name=self.name, + nctype=self.nctype, + default=None, + argument=self.argument, + ) + + def decl(self, *, func_ptr_cast: bool = False) -> str: + mb_default = "" + if self.default is not None: + mb_default = f"={self.default}" + + # casting only needs to know the type + if func_ptr_cast: + return f"{self.type}" + else: + return f"{self.type} {self.name}{mb_default}" + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def decl_registration_declarations(self) -> str: + type_s = self.nctype.cpp_type_registration_declarations() + mb_default = "" + if self.default is not None: + mb_default = f"={self.default}" + return f"{type_s} {self.name}{mb_default}" + + def defn(self) -> str: + return f"{self.type} {self.name}" + + def with_name(self, name: str) -> Binding: + return Binding( + name=name, nctype=self.nctype, argument=self.argument, default=self.default + ) + + +# An Expr is a C++ expression. It has a C++ string representing its syntax, +# as well as a CType saying what it provides. + + +@dataclass(frozen=True) +class Expr: + expr: str + type: NamedCType diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/ufunc.py b/.venv/lib/python3.11/site-packages/torchgen/api/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..17adcccecab563b6a4003215c778a00d5e1399c4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/ufunc.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torchgen.api.types as api_types +from torchgen.api import cpp, structured +from torchgen.api.types import ( + ArgName, + BaseCppType, + BaseCType, + Binding, + ConstRefCType, + CType, + NamedCType, + scalarT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + FunctionSchema, + NativeFunctionsGroup, + Type, +) + + +def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: + assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" + return f"ufunc_{func.name.name}_{dispatch_key}" + + +def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: + return schema_kernel_name(g.out.func, dispatch_key) + + +# Tensors are omitted (as they are stored in TensorIterator), everything else is +# passed along (technically, we can pass tensors along too, it just wastes +# argument registers) +# +# NB: used for CPU only +def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: + # Dispatch stubs are always plain ints + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif t == BaseType(BaseTy.Tensor): + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def opmath_type(scalar_t: BaseCppType) -> BaseCppType: + if scalar_t == api_types.scalar_t: + return api_types.opmath_t + raise NotImplementedError + + +# NB: Tensors in constructor are stored in opmath_t, not scalar_t +# because Tensor in constructor = its a scalar tensor partially applied = +# it can be higher precision and we want to compute in that higher precision +# +# NB: CUDA only +def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Only Tensors ever get passed directly to operator() +# +# NB: CUDA only +# (Actually, this works for CPU too) +def ufunctor_apply_type( + t: Type, *, binds: ArgName, scalar_t: BaseCppType +) -> NamedCType: + if t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(scalar_t)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# The actual ufunc template function the user writes. Everything here +# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t +# in CPU +def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, compute_t) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, compute_t) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunc_argument(a: Argument, compute_t: CType) -> Binding: + return Binding( + nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), + name=a.name, + default=None, + argument=a, + ) + + +@dataclass(frozen=True) +class UfunctorBindings: + ctor: list[Binding] + apply: list[Binding] + + +# ufunctors are a CUDA-only concept representing functors that take some of +# their arguments on a host-side constructor, and the rest in the device-side +# apply. E.g., +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers +# to the operator() definition +def ufunctor_arguments( + g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType +) -> UfunctorBindings: + ctor = [] + apply = [] + for a in g.functional.func.arguments.flat_non_out: + if a.type.is_tensor_like(): + if scalar_tensor_idx == 0: + # put it in the ctor anyway + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + scalar_tensor_idx = None + else: + if scalar_tensor_idx is not None: + scalar_tensor_idx -= 1 + apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) + else: + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + assert scalar_tensor_idx is None + return UfunctorBindings(ctor=ctor, apply=apply) + + +# ufuncs are the inner loop template functions that you wrote in ufunc/add.h +# which do the actual computation in question. E.g., +# +# template +# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { +# return self + alpha * other; +# } +# +# In this file, we refer to T as compute_t which is bound by caller +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: + return [ + ufunc_argument(a, compute_t=compute_t) + for a in g.functional.func.arguments.flat_non_out + ] + + +# Stubs are the DispatchStub trampolines that CPU kernels use to get to their +# vectorized versions. E.g., +# +# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); +def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: + # stubs drop all tensor arguments (they are implicit in the TensorIterator + # argument and keep everything else) + return [ + r + for a in g.out.func.arguments.flat_non_out + if not a.type.is_tensor_like() + for r in structured.argument(a) + ] diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/unboxing.py b/.venv/lib/python3.11/site-packages/torchgen/api/unboxing.py new file mode 100644 index 0000000000000000000000000000000000000000..1e649b7517889d284bf13fe8d0bd737e4e81f5f5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/unboxing.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignatureGroup, CType +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Type, +) + + +# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the +# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is +# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the +# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register +# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase. +# +# Here's an example on how the codegen works: +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# - Argument Conversion +# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type. +# - int[] size +# ```cpp +# const c10::List size_list_in = (std::move(peek(stack, 0, 7))).toList(); +# +# std::vector size_vec; +# for (c10::IValue size_elem: size_list_in) { +# int64_t size_base = size_elem.to(); +# size_vec.push_back(size_base); +# } +# at::ArrayRef size_list_out(size_vec); +# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack. +# Will be passed to unboxed kernel. +# ``` +# - Dimname[]? names +# ```cpp +# ::std::optional names_opt = (std::move(peek(stack, 1, 7))).toOptional(); +# ::std::optional> names_opt_out; +# if (names_opt.has_value()) { +# ~~~~~~~~~~~ <-- Unwrapping optional shell +# const c10::IValue names_opt_in = names_opt.value(); +# const c10::List names_list_in = names_opt_in.toList(); +# +# std::vector names_vec; +# for (c10::IValue names_elem: names_list_in) { +# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one. +# at::Dimname names_base = names_elem.to(); +# names_vec.push_back(names_base); +# } +# at::ArrayRef names_list_out(names_vec); +# +# names_opt_out = ::std::optional>(names_list_out); +# } else { +# names_opt_out = ::std::optional>(); +# } +# ``` +# - ScalarType? dtype (similarly for the rest of the arguments) +# ```cpp +# ::std::optional dtype_opt = (std::move(peek(stack, 2, 7))).toOptional(); +# ::std::optional dtype_opt_out; +# if (dtype_opt.has_value()) { +# const c10::IValue dtype_opt_in = dtype_opt.value(); +# at::ScalarType dtype_base = dtype_opt_in.to(); +# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it +# directly using ".to()" API. +# dtype_opt_out = ::std::optional(dtype_base); +# } else { +# dtype_opt_out = ::std::optional(); +# } +# ``` +# +# - Unboxed Kernel Call +# ```cpp +# auto result_ = torch::empty( +# size_list_out, +# names_opt_out, +# options, +# memory_format_opt_out +# ); +# ``` +# +# - Push Result Back to Stack +# ```cpp +# drop(stack, 7); +# pack(stack, std::move(result_)); +# ``` +connector = "\n\t" + + +# Return unboxing function name for a NativeFunction +def name(f: NativeFunction) -> str: + return f.func.name.unambiguous_name() + + +# Convert all the arguments in a NativeFunction to C++ code +def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: + # we need the 'self' argument so method needs to be False + args = ( + CppSignatureGroup.from_native_function(f, method=False) + .most_faithful_signature() + .arguments() + ) + code_list = [ + f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" + for i in range(len(args)) + ] + [""] + binding_list = [] + for arg in args: + # expecting only Argument + if not isinstance(arg.argument, Argument): + raise Exception( # noqa: TRY002 + f"Unexpected argument type, expecting `Argument` but got {arg}" + ) + argument: Argument = arg.argument + unboxed_name, _, code, decl = argumenttype_ivalue_convert( + argument.type, + argument.name, + mutable=argument.is_write, + ) + code_list.extend(decl) + code_list.extend(code) + binding_list.append(arg.with_name(unboxed_name)) + return binding_list, code_list + + +# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: +# (1) the C++ code necessary to unbox the argument +# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType +def argumenttype_ivalue_convert( + t: Type, arg_name: str, *, mutable: bool = False +) -> tuple[str, CType, list[str], list[str]]: + # Unboxing is for mobile, which doesn't care about SymInts + ctype = cpp.argumenttype_type( + t=t, mutable=mutable, binds=arg_name, symint=False + ).type + + if isinstance(t, BaseType): + out_name = f"{arg_name}_base" + code, decl = _gen_code_base_type( + arg_name=arg_name, out_name=out_name, ctype=ctype + ) + elif isinstance(t, OptionalType): + out_name = f"{arg_name}_opt_out" + code, decl = _gen_code_optional_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + elif isinstance(t, ListType): + out_name = f"{arg_name}_list_out" + code, decl = _gen_code_list_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + else: + raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002 + return out_name, ctype, code, decl + + +def _gen_code_base_type( + arg_name: str, out_name: str, ctype: CType +) -> tuple[list[str], list[str]]: + return [ + f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + ], [] + + +def _gen_code_optional_type( + arg_name: str, out_name: str, t: OptionalType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_opt_in" + res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) + return ( + f""" +auto {arg_name}_opt = {arg_name}.toOptional(); +{ctype.cpp_type(strip_ref=True)} {out_name}; +if ({arg_name}_opt.has_value()) {{ + const c10::IValue {in_name} = {arg_name}_opt.value(); + {connector.join(res_code)} + {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name}); +}} else {{ + {out_name} = {ctype.cpp_type(strip_ref=True)}(); +}} + """.split( + "\n" + ), + decl, + ) + + +def _gen_code_list_type( + arg_name: str, out_name: str, t: ListType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_list_in" + elem_name = f"{arg_name}_elem" + code = [f"const c10::List {in_name} = {arg_name}.toList();"] + res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name) + # handle list type with size, e.g., bool[4] + if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size: + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); + """.split( + "\n" + ) + ) + # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> + elif isinstance(t.elem, OptionalType): + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name}; +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {out_name}.push_back({res_name}); +}} + """.split( + "\n" + ) + ) + else: + # use ArrayRef as default. + vec_name = arg_name + "_vec" + # need to bring vector instantiation out of scope so that ArrayRef has valid data + decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};") + code.extend( + f""" +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {vec_name}.push_back({res_name}); +}} +{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); + """.split( + "\n" + ) + ) + return code, decl