koichi12 commited on
Commit
80c179b
·
verified ·
1 Parent(s): 344d5fe

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/INSTALLER +1 -0
  3. .venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/METADATA +295 -0
  4. .venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/RECORD +10 -0
  5. .venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/WHEEL +4 -0
  6. .venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/licenses/LICENSE +21 -0
  7. .venv/lib/python3.11/site-packages/ray/__init__.py +294 -0
  8. .venv/lib/python3.11/site-packages/ray/_raylet.pxd +176 -0
  9. .venv/lib/python3.11/site-packages/ray/_raylet.pyi +11 -0
  10. .venv/lib/python3.11/site-packages/ray/_version.py +6 -0
  11. .venv/lib/python3.11/site-packages/ray/actor.py +1790 -0
  12. .venv/lib/python3.11/site-packages/ray/client_builder.py +379 -0
  13. .venv/lib/python3.11/site-packages/ray/cluster_utils.py +415 -0
  14. .venv/lib/python3.11/site-packages/ray/cross_language.py +137 -0
  15. .venv/lib/python3.11/site-packages/ray/exceptions.py +933 -0
  16. .venv/lib/python3.11/site-packages/ray/job_config.py +249 -0
  17. .venv/lib/python3.11/site-packages/ray/nightly-wheels.yaml +11 -0
  18. .venv/lib/python3.11/site-packages/ray/py.typed +0 -0
  19. .venv/lib/python3.11/site-packages/ray/remote_function.py +515 -0
  20. .venv/lib/python3.11/site-packages/ray/runtime_context.py +564 -0
  21. .venv/lib/python3.11/site-packages/ray/setup-dev.py +157 -0
  22. .venv/lib/python3.11/site-packages/ray/types.py +14 -0
  23. .venv/lib/python3.11/site-packages/ray/workflow/__init__.py +55 -0
  24. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/debug_utils.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_access.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_context.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_executor.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_dag.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_storage.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/workflow/api.py +869 -0
  33. .venv/lib/python3.11/site-packages/ray/workflow/common.py +199 -0
  34. .venv/lib/python3.11/site-packages/ray/workflow/debug_utils.py +40 -0
  35. .venv/lib/python3.11/site-packages/ray/workflow/event_listener.py +70 -0
  36. .venv/lib/python3.11/site-packages/ray/workflow/exceptions.py +57 -0
  37. .venv/lib/python3.11/site-packages/ray/workflow/http_event_provider.py +272 -0
  38. .venv/lib/python3.11/site-packages/ray/workflow/serialization.py +235 -0
  39. .venv/lib/python3.11/site-packages/ray/workflow/serialization_context.py +112 -0
  40. .venv/lib/python3.11/site-packages/ray/workflow/task_executor.py +163 -0
  41. .venv/lib/python3.11/site-packages/ray/workflow/workflow_access.py +379 -0
  42. .venv/lib/python3.11/site-packages/ray/workflow/workflow_context.py +123 -0
  43. .venv/lib/python3.11/site-packages/ray/workflow/workflow_executor.py +433 -0
  44. .venv/lib/python3.11/site-packages/ray/workflow/workflow_state.py +251 -0
  45. .venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_dag.py +205 -0
  46. .venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_storage.py +71 -0
  47. .venv/lib/python3.11/site-packages/ray/workflow/workflow_storage.py +880 -0
  48. .venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc +3 -0
  49. .venv/lib/python3.11/site-packages/torchgen/api/__init__.py +0 -0
  50. .venv/lib/python3.11/site-packages/torchgen/api/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -150,3 +150,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
150
  .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
151
  .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
152
  .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
 
 
150
  .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
151
  .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
152
  .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
153
+ .venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/METADATA ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.3
2
+ Name: annotated-types
3
+ Version: 0.7.0
4
+ Summary: Reusable constraint types to use with typing.Annotated
5
+ Project-URL: Homepage, https://github.com/annotated-types/annotated-types
6
+ Project-URL: Source, https://github.com/annotated-types/annotated-types
7
+ Project-URL: Changelog, https://github.com/annotated-types/annotated-types/releases
8
+ Author-email: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>, Samuel Colvin <s@muelcolvin.com>, Zac Hatfield-Dodds <zac@zhd.dev>
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Environment :: Console
12
+ Classifier: Environment :: MacOS X
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Information Technology
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: POSIX :: Linux
17
+ Classifier: Operating System :: Unix
18
+ Classifier: Programming Language :: Python :: 3 :: Only
19
+ Classifier: Programming Language :: Python :: 3.8
20
+ Classifier: Programming Language :: Python :: 3.9
21
+ Classifier: Programming Language :: Python :: 3.10
22
+ Classifier: Programming Language :: Python :: 3.11
23
+ Classifier: Programming Language :: Python :: 3.12
24
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
25
+ Classifier: Typing :: Typed
26
+ Requires-Python: >=3.8
27
+ Requires-Dist: typing-extensions>=4.0.0; python_version < '3.9'
28
+ Description-Content-Type: text/markdown
29
+
30
+ # annotated-types
31
+
32
+ [![CI](https://github.com/annotated-types/annotated-types/workflows/CI/badge.svg?event=push)](https://github.com/annotated-types/annotated-types/actions?query=event%3Apush+branch%3Amain+workflow%3ACI)
33
+ [![pypi](https://img.shields.io/pypi/v/annotated-types.svg)](https://pypi.python.org/pypi/annotated-types)
34
+ [![versions](https://img.shields.io/pypi/pyversions/annotated-types.svg)](https://github.com/annotated-types/annotated-types)
35
+ [![license](https://img.shields.io/github/license/annotated-types/annotated-types.svg)](https://github.com/annotated-types/annotated-types/blob/main/LICENSE)
36
+
37
+ [PEP-593](https://peps.python.org/pep-0593/) added `typing.Annotated` as a way of
38
+ adding context-specific metadata to existing types, and specifies that
39
+ `Annotated[T, x]` _should_ be treated as `T` by any tool or library without special
40
+ logic for `x`.
41
+
42
+ This package provides metadata objects which can be used to represent common
43
+ constraints such as upper and lower bounds on scalar values and collection sizes,
44
+ a `Predicate` marker for runtime checks, and
45
+ descriptions of how we intend these metadata to be interpreted. In some cases,
46
+ we also note alternative representations which do not require this package.
47
+
48
+ ## Install
49
+
50
+ ```bash
51
+ pip install annotated-types
52
+ ```
53
+
54
+ ## Examples
55
+
56
+ ```python
57
+ from typing import Annotated
58
+ from annotated_types import Gt, Len, Predicate
59
+
60
+ class MyClass:
61
+ age: Annotated[int, Gt(18)] # Valid: 19, 20, ...
62
+ # Invalid: 17, 18, "19", 19.0, ...
63
+ factors: list[Annotated[int, Predicate(is_prime)]] # Valid: 2, 3, 5, 7, 11, ...
64
+ # Invalid: 4, 8, -2, 5.0, "prime", ...
65
+
66
+ my_list: Annotated[list[int], Len(0, 10)] # Valid: [], [10, 20, 30, 40, 50]
67
+ # Invalid: (1, 2), ["abc"], [0] * 20
68
+ ```
69
+
70
+ ## Documentation
71
+
72
+ _While `annotated-types` avoids runtime checks for performance, users should not
73
+ construct invalid combinations such as `MultipleOf("non-numeric")` or `Annotated[int, Len(3)]`.
74
+ Downstream implementors may choose to raise an error, emit a warning, silently ignore
75
+ a metadata item, etc., if the metadata objects described below are used with an
76
+ incompatible type - or for any other reason!_
77
+
78
+ ### Gt, Ge, Lt, Le
79
+
80
+ Express inclusive and/or exclusive bounds on orderable values - which may be numbers,
81
+ dates, times, strings, sets, etc. Note that the boundary value need not be of the
82
+ same type that was annotated, so long as they can be compared: `Annotated[int, Gt(1.5)]`
83
+ is fine, for example, and implies that the value is an integer x such that `x > 1.5`.
84
+
85
+ We suggest that implementors may also interpret `functools.partial(operator.le, 1.5)`
86
+ as being equivalent to `Gt(1.5)`, for users who wish to avoid a runtime dependency on
87
+ the `annotated-types` package.
88
+
89
+ To be explicit, these types have the following meanings:
90
+
91
+ * `Gt(x)` - value must be "Greater Than" `x` - equivalent to exclusive minimum
92
+ * `Ge(x)` - value must be "Greater than or Equal" to `x` - equivalent to inclusive minimum
93
+ * `Lt(x)` - value must be "Less Than" `x` - equivalent to exclusive maximum
94
+ * `Le(x)` - value must be "Less than or Equal" to `x` - equivalent to inclusive maximum
95
+
96
+ ### Interval
97
+
98
+ `Interval(gt, ge, lt, le)` allows you to specify an upper and lower bound with a single
99
+ metadata object. `None` attributes should be ignored, and non-`None` attributes
100
+ treated as per the single bounds above.
101
+
102
+ ### MultipleOf
103
+
104
+ `MultipleOf(multiple_of=x)` might be interpreted in two ways:
105
+
106
+ 1. Python semantics, implying `value % multiple_of == 0`, or
107
+ 2. [JSONschema semantics](https://json-schema.org/draft/2020-12/json-schema-validation.html#rfc.section.6.2.1),
108
+ where `int(value / multiple_of) == value / multiple_of`.
109
+
110
+ We encourage users to be aware of these two common interpretations and their
111
+ distinct behaviours, especially since very large or non-integer numbers make
112
+ it easy to cause silent data corruption due to floating-point imprecision.
113
+
114
+ We encourage libraries to carefully document which interpretation they implement.
115
+
116
+ ### MinLen, MaxLen, Len
117
+
118
+ `Len()` implies that `min_length <= len(value) <= max_length` - lower and upper bounds are inclusive.
119
+
120
+ As well as `Len()` which can optionally include upper and lower bounds, we also
121
+ provide `MinLen(x)` and `MaxLen(y)` which are equivalent to `Len(min_length=x)`
122
+ and `Len(max_length=y)` respectively.
123
+
124
+ `Len`, `MinLen`, and `MaxLen` may be used with any type which supports `len(value)`.
125
+
126
+ Examples of usage:
127
+
128
+ * `Annotated[list, MaxLen(10)]` (or `Annotated[list, Len(max_length=10))`) - list must have a length of 10 or less
129
+ * `Annotated[str, MaxLen(10)]` - string must have a length of 10 or less
130
+ * `Annotated[list, MinLen(3))` (or `Annotated[list, Len(min_length=3))`) - list must have a length of 3 or more
131
+ * `Annotated[list, Len(4, 6)]` - list must have a length of 4, 5, or 6
132
+ * `Annotated[list, Len(8, 8)]` - list must have a length of exactly 8
133
+
134
+ #### Changed in v0.4.0
135
+
136
+ * `min_inclusive` has been renamed to `min_length`, no change in meaning
137
+ * `max_exclusive` has been renamed to `max_length`, upper bound is now **inclusive** instead of **exclusive**
138
+ * The recommendation that slices are interpreted as `Len` has been removed due to ambiguity and different semantic
139
+ meaning of the upper bound in slices vs. `Len`
140
+
141
+ See [issue #23](https://github.com/annotated-types/annotated-types/issues/23) for discussion.
142
+
143
+ ### Timezone
144
+
145
+ `Timezone` can be used with a `datetime` or a `time` to express which timezones
146
+ are allowed. `Annotated[datetime, Timezone(None)]` must be a naive datetime.
147
+ `Timezone[...]` ([literal ellipsis](https://docs.python.org/3/library/constants.html#Ellipsis))
148
+ expresses that any timezone-aware datetime is allowed. You may also pass a specific
149
+ timezone string or [`tzinfo`](https://docs.python.org/3/library/datetime.html#tzinfo-objects)
150
+ object such as `Timezone(timezone.utc)` or `Timezone("Africa/Abidjan")` to express that you only
151
+ allow a specific timezone, though we note that this is often a symptom of fragile design.
152
+
153
+ #### Changed in v0.x.x
154
+
155
+ * `Timezone` accepts [`tzinfo`](https://docs.python.org/3/library/datetime.html#tzinfo-objects) objects instead of
156
+ `timezone`, extending compatibility to [`zoneinfo`](https://docs.python.org/3/library/zoneinfo.html) and third party libraries.
157
+
158
+ ### Unit
159
+
160
+ `Unit(unit: str)` expresses that the annotated numeric value is the magnitude of
161
+ a quantity with the specified unit. For example, `Annotated[float, Unit("m/s")]`
162
+ would be a float representing a velocity in meters per second.
163
+
164
+ Please note that `annotated_types` itself makes no attempt to parse or validate
165
+ the unit string in any way. That is left entirely to downstream libraries,
166
+ such as [`pint`](https://pint.readthedocs.io) or
167
+ [`astropy.units`](https://docs.astropy.org/en/stable/units/).
168
+
169
+ An example of how a library might use this metadata:
170
+
171
+ ```python
172
+ from annotated_types import Unit
173
+ from typing import Annotated, TypeVar, Callable, Any, get_origin, get_args
174
+
175
+ # given a type annotated with a unit:
176
+ Meters = Annotated[float, Unit("m")]
177
+
178
+
179
+ # you can cast the annotation to a specific unit type with any
180
+ # callable that accepts a string and returns the desired type
181
+ T = TypeVar("T")
182
+ def cast_unit(tp: Any, unit_cls: Callable[[str], T]) -> T | None:
183
+ if get_origin(tp) is Annotated:
184
+ for arg in get_args(tp):
185
+ if isinstance(arg, Unit):
186
+ return unit_cls(arg.unit)
187
+ return None
188
+
189
+
190
+ # using `pint`
191
+ import pint
192
+ pint_unit = cast_unit(Meters, pint.Unit)
193
+
194
+
195
+ # using `astropy.units`
196
+ import astropy.units as u
197
+ astropy_unit = cast_unit(Meters, u.Unit)
198
+ ```
199
+
200
+ ### Predicate
201
+
202
+ `Predicate(func: Callable)` expresses that `func(value)` is truthy for valid values.
203
+ Users should prefer the statically inspectable metadata above, but if you need
204
+ the full power and flexibility of arbitrary runtime predicates... here it is.
205
+
206
+ For some common constraints, we provide generic types:
207
+
208
+ * `IsLower = Annotated[T, Predicate(str.islower)]`
209
+ * `IsUpper = Annotated[T, Predicate(str.isupper)]`
210
+ * `IsDigit = Annotated[T, Predicate(str.isdigit)]`
211
+ * `IsFinite = Annotated[T, Predicate(math.isfinite)]`
212
+ * `IsNotFinite = Annotated[T, Predicate(Not(math.isfinite))]`
213
+ * `IsNan = Annotated[T, Predicate(math.isnan)]`
214
+ * `IsNotNan = Annotated[T, Predicate(Not(math.isnan))]`
215
+ * `IsInfinite = Annotated[T, Predicate(math.isinf)]`
216
+ * `IsNotInfinite = Annotated[T, Predicate(Not(math.isinf))]`
217
+
218
+ so that you can write e.g. `x: IsFinite[float] = 2.0` instead of the longer
219
+ (but exactly equivalent) `x: Annotated[float, Predicate(math.isfinite)] = 2.0`.
220
+
221
+ Some libraries might have special logic to handle known or understandable predicates,
222
+ for example by checking for `str.isdigit` and using its presence to both call custom
223
+ logic to enforce digit-only strings, and customise some generated external schema.
224
+ Users are therefore encouraged to avoid indirection like `lambda s: s.lower()`, in
225
+ favor of introspectable methods such as `str.lower` or `re.compile("pattern").search`.
226
+
227
+ To enable basic negation of commonly used predicates like `math.isnan` without introducing introspection that makes it impossible for implementers to introspect the predicate we provide a `Not` wrapper that simply negates the predicate in an introspectable manner. Several of the predicates listed above are created in this manner.
228
+
229
+ We do not specify what behaviour should be expected for predicates that raise
230
+ an exception. For example `Annotated[int, Predicate(str.isdigit)]` might silently
231
+ skip invalid constraints, or statically raise an error; or it might try calling it
232
+ and then propagate or discard the resulting
233
+ `TypeError: descriptor 'isdigit' for 'str' objects doesn't apply to a 'int' object`
234
+ exception. We encourage libraries to document the behaviour they choose.
235
+
236
+ ### Doc
237
+
238
+ `doc()` can be used to add documentation information in `Annotated`, for function and method parameters, variables, class attributes, return types, and any place where `Annotated` can be used.
239
+
240
+ It expects a value that can be statically analyzed, as the main use case is for static analysis, editors, documentation generators, and similar tools.
241
+
242
+ It returns a `DocInfo` class with a single attribute `documentation` containing the value passed to `doc()`.
243
+
244
+ This is the early adopter's alternative form of the [`typing-doc` proposal](https://github.com/tiangolo/fastapi/blob/typing-doc/typing_doc.md).
245
+
246
+ ### Integrating downstream types with `GroupedMetadata`
247
+
248
+ Implementers may choose to provide a convenience wrapper that groups multiple pieces of metadata.
249
+ This can help reduce verbosity and cognitive overhead for users.
250
+ For example, an implementer like Pydantic might provide a `Field` or `Meta` type that accepts keyword arguments and transforms these into low-level metadata:
251
+
252
+ ```python
253
+ from dataclasses import dataclass
254
+ from typing import Iterator
255
+ from annotated_types import GroupedMetadata, Ge
256
+
257
+ @dataclass
258
+ class Field(GroupedMetadata):
259
+ ge: int | None = None
260
+ description: str | None = None
261
+
262
+ def __iter__(self) -> Iterator[object]:
263
+ # Iterating over a GroupedMetadata object should yield annotated-types
264
+ # constraint metadata objects which describe it as fully as possible,
265
+ # and may include other unknown objects too.
266
+ if self.ge is not None:
267
+ yield Ge(self.ge)
268
+ if self.description is not None:
269
+ yield Description(self.description)
270
+ ```
271
+
272
+ Libraries consuming annotated-types constraints should check for `GroupedMetadata` and unpack it by iterating over the object and treating the results as if they had been "unpacked" in the `Annotated` type. The same logic should be applied to the [PEP 646 `Unpack` type](https://peps.python.org/pep-0646/), so that `Annotated[T, Field(...)]`, `Annotated[T, Unpack[Field(...)]]` and `Annotated[T, *Field(...)]` are all treated consistently.
273
+
274
+ Libraries consuming annotated-types should also ignore any metadata they do not recongize that came from unpacking a `GroupedMetadata`, just like they ignore unrecognized metadata in `Annotated` itself.
275
+
276
+ Our own `annotated_types.Interval` class is a `GroupedMetadata` which unpacks itself into `Gt`, `Lt`, etc., so this is not an abstract concern. Similarly, `annotated_types.Len` is a `GroupedMetadata` which unpacks itself into `MinLen` (optionally) and `MaxLen`.
277
+
278
+ ### Consuming metadata
279
+
280
+ We intend to not be prescriptive as to _how_ the metadata and constraints are used, but as an example of how one might parse constraints from types annotations see our [implementation in `test_main.py`](https://github.com/annotated-types/annotated-types/blob/f59cf6d1b5255a0fe359b93896759a180bec30ae/tests/test_main.py#L94-L103).
281
+
282
+ It is up to the implementer to determine how this metadata is used.
283
+ You could use the metadata for runtime type checking, for generating schemas or to generate example data, amongst other use cases.
284
+
285
+ ## Design & History
286
+
287
+ This package was designed at the PyCon 2022 sprints by the maintainers of Pydantic
288
+ and Hypothesis, with the goal of making it as easy as possible for end-users to
289
+ provide more informative annotations for use by runtime libraries.
290
+
291
+ It is deliberately minimal, and following PEP-593 allows considerable downstream
292
+ discretion in what (if anything!) they choose to support. Nonetheless, we expect
293
+ that staying simple and covering _only_ the most common use-cases will give users
294
+ and maintainers the best experience we can. If you'd like more constraints for your
295
+ types - follow our lead, by defining them and documenting them downstream!
.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/RECORD ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated_types-0.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ annotated_types-0.7.0.dist-info/METADATA,sha256=7ltqxksJJ0wCYFGBNIQCWTlWQGeAH0hRFdnK3CB895E,15046
3
+ annotated_types-0.7.0.dist-info/RECORD,,
4
+ annotated_types-0.7.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
5
+ annotated_types-0.7.0.dist-info/licenses/LICENSE,sha256=_hBJiEsaDZNCkB6I4H8ykl0ksxIdmXK2poBfuYJLCV0,1083
6
+ annotated_types/__init__.py,sha256=RynLsRKUEGI0KimXydlD1fZEfEzWwDo0Uon3zOKhG1Q,13819
7
+ annotated_types/__pycache__/__init__.cpython-311.pyc,,
8
+ annotated_types/__pycache__/test_cases.cpython-311.pyc,,
9
+ annotated_types/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ annotated_types/test_cases.py,sha256=zHFX6EpcMbGJ8FzBYDbO56bPwx_DYIVSKbZM-4B3_lg,6421
.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.24.2
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
.venv/lib/python3.11/site-packages/annotated_types-0.7.0.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2022 the contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
.venv/lib/python3.11/site-packages/ray/__init__.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: skip_file
2
+ from ray._private import log # isort: skip # noqa: F401
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ log.generate_logging_config()
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def _configure_system():
12
+ import os
13
+ import platform
14
+ import sys
15
+
16
+ """Wraps system configuration to avoid 'leaking' variables into ray."""
17
+
18
+ # Sanity check pickle5 if it has been installed.
19
+ if "pickle5" in sys.modules:
20
+ if sys.version_info >= (3, 8):
21
+ logger.warning(
22
+ "Package pickle5 becomes unnecessary in Python 3.8 and above. "
23
+ "Its presence may confuse libraries including Ray. "
24
+ "Please uninstall the package."
25
+ )
26
+
27
+ import importlib.metadata
28
+
29
+ try:
30
+ version_str = importlib.metadata.version("pickle5")
31
+ version = tuple(int(n) for n in version_str.split("."))
32
+ if version < (0, 0, 10):
33
+ logger.warning(
34
+ "Although not used by Ray, a version of pickle5 that leaks memory "
35
+ "is found in the environment. Please run 'pip install pickle5 -U' "
36
+ "to upgrade."
37
+ )
38
+ except importlib.metadata.PackageNotFoundError:
39
+ logger.warning(
40
+ "You are using the 'pickle5' module, but "
41
+ "the exact version is unknown (possibly carried as "
42
+ "an internal component by another module). Please "
43
+ "make sure you are using pickle5 >= 0.0.10 because "
44
+ "previous versions may leak memory."
45
+ )
46
+
47
+ # Importing psutil & setproctitle. Must be before ray._raylet is
48
+ # initialized.
49
+ thirdparty_files = os.path.join(
50
+ os.path.abspath(os.path.dirname(__file__)), "thirdparty_files"
51
+ )
52
+ sys.path.insert(0, thirdparty_files)
53
+
54
+ if (
55
+ platform.system() == "Linux"
56
+ and "Microsoft".lower() in platform.release().lower()
57
+ ):
58
+ from ray._private import compat # noqa: E402
59
+
60
+ compat.patch_psutil()
61
+
62
+ # Expose ray ABI symbols which may be dependent by other shared
63
+ # libraries such as _streaming.so. See BUILD.bazel:_raylet
64
+ python_shared_lib_suffix = ".so" if sys.platform != "win32" else ".pyd"
65
+ so_path = os.path.join(
66
+ os.path.dirname(__file__), "_raylet" + python_shared_lib_suffix
67
+ )
68
+ if os.path.exists(so_path):
69
+ import ctypes
70
+ from ctypes import CDLL
71
+
72
+ CDLL(so_path, ctypes.RTLD_GLOBAL)
73
+
74
+
75
+ _configure_system()
76
+ # Delete configuration function.
77
+ del _configure_system
78
+
79
+ from ray import _version # noqa: E402
80
+
81
+ __commit__ = _version.commit
82
+ __version__ = _version.version
83
+
84
+ import ray._raylet # noqa: E402
85
+
86
+ from ray._raylet import ( # noqa: E402,F401
87
+ ActorClassID,
88
+ ActorID,
89
+ NodeID,
90
+ Config as _Config,
91
+ JobID,
92
+ WorkerID,
93
+ FunctionID,
94
+ ObjectID,
95
+ ObjectRef,
96
+ ObjectRefGenerator,
97
+ DynamicObjectRefGenerator,
98
+ TaskID,
99
+ UniqueID,
100
+ Language,
101
+ PlacementGroupID,
102
+ ClusterID,
103
+ )
104
+
105
+ _config = _Config()
106
+
107
+ from ray._private.state import ( # noqa: E402,F401
108
+ nodes,
109
+ timeline,
110
+ cluster_resources,
111
+ available_resources,
112
+ )
113
+ from ray._private.worker import ( # noqa: E402,F401
114
+ LOCAL_MODE,
115
+ SCRIPT_MODE,
116
+ WORKER_MODE,
117
+ RESTORE_WORKER_MODE,
118
+ SPILL_WORKER_MODE,
119
+ cancel,
120
+ get,
121
+ get_actor,
122
+ get_gpu_ids,
123
+ init,
124
+ is_initialized,
125
+ put,
126
+ kill,
127
+ remote,
128
+ shutdown,
129
+ wait,
130
+ )
131
+
132
+ from ray._private.ray_logging.logging_config import LoggingConfig # noqa: E402
133
+
134
+ # We import ray.actor because some code is run in actor.py which initializes
135
+ # some functions in the worker.
136
+ import ray.actor # noqa: E402,F401
137
+ from ray.actor import method # noqa: E402,F401
138
+
139
+ # TODO(qwang): We should remove this exporting in Ray2.0.
140
+ from ray.cross_language import java_function, java_actor_class # noqa: E402,F401
141
+ from ray.runtime_context import get_runtime_context # noqa: E402,F401
142
+ from ray import internal # noqa: E402,F401
143
+ from ray import util # noqa: E402,F401
144
+ from ray import _private # noqa: E402,F401
145
+
146
+ # We import ClientBuilder so that modules can inherit from `ray.ClientBuilder`.
147
+ from ray.client_builder import client, ClientBuilder # noqa: E402,F401
148
+
149
+
150
+ class _DeprecationWrapper:
151
+ def __init__(self, name, real_worker):
152
+ self._name = name
153
+ self._real_worker = real_worker
154
+ self._warned = set()
155
+
156
+ def __getattr__(self, attr):
157
+ value = getattr(self._real_worker, attr)
158
+ if attr not in self._warned:
159
+ self._warned.add(attr)
160
+ logger.warning(
161
+ f"DeprecationWarning: `ray.{self._name}.{attr}` is a private "
162
+ "attribute and access will be removed in a future Ray version."
163
+ )
164
+ return value
165
+
166
+
167
+ # TODO(ekl) remove this entirely after 3rd party libraries are all migrated.
168
+ worker = _DeprecationWrapper("worker", ray._private.worker)
169
+ ray_constants = _DeprecationWrapper("ray_constants", ray._private.ray_constants)
170
+ serialization = _DeprecationWrapper("serialization", ray._private.serialization)
171
+ state = _DeprecationWrapper("state", ray._private.state)
172
+
173
+
174
+ # Pulic Ray APIs
175
+ __all__ = [
176
+ "__version__",
177
+ "_config",
178
+ "get_runtime_context",
179
+ "autoscaler",
180
+ "available_resources",
181
+ "cancel",
182
+ "client",
183
+ "ClientBuilder",
184
+ "cluster_resources",
185
+ "get",
186
+ "get_actor",
187
+ "get_gpu_ids",
188
+ "init",
189
+ "is_initialized",
190
+ "java_actor_class",
191
+ "java_function",
192
+ "cpp_function",
193
+ "kill",
194
+ "Language",
195
+ "method",
196
+ "nodes",
197
+ "put",
198
+ "remote",
199
+ "shutdown",
200
+ "show_in_dashboard",
201
+ "timeline",
202
+ "wait",
203
+ "LOCAL_MODE",
204
+ "SCRIPT_MODE",
205
+ "WORKER_MODE",
206
+ "LoggingConfig",
207
+ ]
208
+
209
+ # Public APIs that should automatically trigger ray.init().
210
+ AUTO_INIT_APIS = {
211
+ "cancel",
212
+ "get",
213
+ "get_actor",
214
+ "get_gpu_ids",
215
+ "kill",
216
+ "put",
217
+ "wait",
218
+ "get_runtime_context",
219
+ }
220
+
221
+ # Public APIs that should not automatically trigger ray.init().
222
+ NON_AUTO_INIT_APIS = {
223
+ "ClientBuilder",
224
+ "LOCAL_MODE",
225
+ "Language",
226
+ "SCRIPT_MODE",
227
+ "WORKER_MODE",
228
+ "__version__",
229
+ "_config",
230
+ "autoscaler",
231
+ "available_resources",
232
+ "client",
233
+ "cluster_resources",
234
+ "cpp_function",
235
+ "init",
236
+ "is_initialized",
237
+ "java_actor_class",
238
+ "java_function",
239
+ "method",
240
+ "nodes",
241
+ "remote",
242
+ "show_in_dashboard",
243
+ "shutdown",
244
+ "timeline",
245
+ "LoggingConfig",
246
+ }
247
+
248
+ assert set(__all__) == AUTO_INIT_APIS | NON_AUTO_INIT_APIS
249
+ from ray._private.auto_init_hook import wrap_auto_init_for_all_apis # noqa: E402
250
+
251
+ wrap_auto_init_for_all_apis(AUTO_INIT_APIS)
252
+ del wrap_auto_init_for_all_apis
253
+
254
+ # Subpackages
255
+ __all__ += [
256
+ "actor",
257
+ "autoscaler",
258
+ "data",
259
+ "internal",
260
+ "util",
261
+ "widgets",
262
+ "workflow",
263
+ ]
264
+
265
+ # ID types
266
+ __all__ += [
267
+ "ActorClassID",
268
+ "ActorID",
269
+ "NodeID",
270
+ "JobID",
271
+ "WorkerID",
272
+ "FunctionID",
273
+ "ObjectID",
274
+ "ObjectRef",
275
+ "ObjectRefGenerator",
276
+ "DynamicObjectRefGenerator",
277
+ "TaskID",
278
+ "UniqueID",
279
+ "PlacementGroupID",
280
+ ]
281
+
282
+
283
+ # Delay importing of expensive, isolated subpackages.
284
+ def __getattr__(name: str):
285
+ import importlib
286
+
287
+ if name in ["data", "workflow", "autoscaler"]:
288
+ return importlib.import_module("." + name, __name__)
289
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
290
+
291
+
292
+ del os
293
+ del logging
294
+ del sys
.venv/lib/python3.11/site-packages/ray/_raylet.pxd ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cython: profile=False
2
+ # distutils: language = c++
3
+ # cython: embedsignature = True
4
+ # cython: language_level = 3
5
+
6
+ from cpython.pystate cimport PyThreadState_Get
7
+
8
+ from libc.stdint cimport (
9
+ int64_t,
10
+ )
11
+ from libcpp cimport bool as c_bool
12
+ from libcpp.string cimport string as c_string
13
+ from libcpp.vector cimport vector as c_vector
14
+ from libcpp.unordered_map cimport unordered_map
15
+ from libcpp.memory cimport (
16
+ shared_ptr,
17
+ unique_ptr
18
+ )
19
+ from libcpp.pair cimport pair as c_pair
20
+ from libcpp.utility cimport pair
21
+ from ray.includes.optional cimport (
22
+ optional,
23
+ nullopt,
24
+ make_optional,
25
+ )
26
+
27
+ from ray.includes.common cimport (
28
+ CBuffer,
29
+ CRayObject,
30
+ CAddress,
31
+ CConcurrencyGroup,
32
+ CSchedulingStrategy,
33
+ CLabelMatchExpressions,
34
+ )
35
+ from ray.includes.libcoreworker cimport (
36
+ ActorHandleSharedPtr,
37
+ CActorHandle,
38
+ CFiberEvent,
39
+ )
40
+
41
+ from ray.includes.unique_ids cimport (
42
+ CObjectID,
43
+ CActorID,
44
+ CTaskID,
45
+ )
46
+ from ray.includes.function_descriptor cimport (
47
+ CFunctionDescriptor,
48
+ )
49
+
50
+ cdef extern from *:
51
+ """
52
+ #if __OPTIMIZE__ && __OPTIMIZE__ == 1
53
+ #undef __OPTIMIZE__
54
+ int __OPTIMIZE__ = 1;
55
+ #define __OPTIMIZE__ 1
56
+ #elif defined(BAZEL_OPT)
57
+ // For compilers that don't define __OPTIMIZE__
58
+ int __OPTIMIZE__ = 1;
59
+ #else
60
+ int __OPTIMIZE__ = 0;
61
+ #endif
62
+ """
63
+ int __OPTIMIZE__
64
+
65
+ cdef extern from "Python.h":
66
+ # Note(simon): This is used to configure asyncio actor stack size.
67
+ # Cython made PyThreadState an opaque types. Saying that if the user wants
68
+ # specific attributes, they can be declared manually.
69
+
70
+ # You can find the cpython definition in Include/cpython/pystate.h#L59
71
+ ctypedef struct CPyThreadState "PyThreadState":
72
+ int recursion_limit
73
+ int recursion_remaining
74
+
75
+ # From Include/ceveal.h#67
76
+ int Py_GetRecursionLimit()
77
+ void Py_SetRecursionLimit(int)
78
+
79
+ cdef class Buffer:
80
+ cdef:
81
+ shared_ptr[CBuffer] buffer
82
+ Py_ssize_t shape
83
+ Py_ssize_t strides
84
+
85
+ @staticmethod
86
+ cdef make(const shared_ptr[CBuffer]& buffer)
87
+
88
+ cdef class BaseID:
89
+ # To avoid the error of "Python int too large to convert to C ssize_t",
90
+ # here `cdef size_t` is required.
91
+ cdef size_t hash(self)
92
+
93
+ cdef class ObjectRef(BaseID):
94
+ cdef:
95
+ CObjectID data
96
+ c_string owner_addr
97
+ # Flag indicating whether or not this object ref was added to the set
98
+ # of active IDs in the core worker so we know whether we should clean
99
+ # it up.
100
+ c_bool in_core_worker
101
+ c_string call_site_data
102
+
103
+ cdef CObjectID native(self)
104
+
105
+ cdef class ActorID(BaseID):
106
+ cdef CActorID data
107
+
108
+ cdef CActorID native(self)
109
+
110
+ cdef size_t hash(self)
111
+
112
+
113
+ cdef class CoreWorker:
114
+ cdef:
115
+ c_bool is_driver
116
+ object async_thread
117
+ object async_event_loop
118
+ object plasma_event_handler
119
+ object job_config
120
+ object current_runtime_env
121
+ c_bool is_local_mode
122
+
123
+ object cgname_to_eventloop_dict
124
+ object eventloop_for_default_cg
125
+ object thread_for_default_cg
126
+ object fd_to_cgname_dict
127
+ object _task_id_to_future_lock
128
+ dict _task_id_to_future
129
+ object event_loop_executor
130
+
131
+ cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
132
+ size_t data_size, ObjectRef object_ref,
133
+ c_vector[CObjectID] contained_ids,
134
+ CObjectID *c_object_id, shared_ptr[CBuffer] *data,
135
+ c_bool created_by_worker,
136
+ owner_address=*,
137
+ c_bool inline_small_object=*,
138
+ c_bool is_experimental_channel=*)
139
+ cdef unique_ptr[CAddress] _convert_python_address(self, address=*)
140
+ cdef store_task_output(
141
+ self, serialized_object,
142
+ const CObjectID &return_id,
143
+ const CObjectID &generator_id,
144
+ size_t data_size, shared_ptr[CBuffer] &metadata, const c_vector[CObjectID]
145
+ &contained_id, const CAddress &caller_address,
146
+ int64_t *task_output_inlined_bytes,
147
+ shared_ptr[CRayObject] *return_ptr)
148
+ cdef store_task_outputs(
149
+ self,
150
+ worker, outputs,
151
+ const CAddress &caller_address,
152
+ c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns,
153
+ CObjectID ref_generator_id=*)
154
+ cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle,
155
+ c_bool weak_ref)
156
+ cdef c_function_descriptors_to_python(
157
+ self, const c_vector[CFunctionDescriptor] &c_function_descriptors)
158
+ cdef initialize_eventloops_for_actor_concurrency_group(
159
+ self, const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups)
160
+ cdef python_scheduling_strategy_to_c(
161
+ self, python_scheduling_strategy,
162
+ CSchedulingStrategy *c_scheduling_strategy)
163
+ cdef python_label_match_expressions_to_c(
164
+ self, python_expressions,
165
+ CLabelMatchExpressions *c_expressions)
166
+ cdef CObjectID allocate_dynamic_return_id_for_generator(
167
+ self,
168
+ const CAddress &owner_address,
169
+ const CTaskID &task_id,
170
+ return_size,
171
+ generator_index,
172
+ is_async_actor)
173
+
174
+ cdef class FunctionDescriptor:
175
+ cdef:
176
+ CFunctionDescriptor descriptor
.venv/lib/python3.11/site-packages/ray/_raylet.pyi ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Awaitable, TypeVar
2
+
3
+ R = TypeVar("R")
4
+
5
+
6
+ class ObjectRef(Awaitable[R]): # type: ignore
7
+ pass
8
+
9
+
10
+ class ObjectID(Awaitable[R]): # type: ignore
11
+ pass
.venv/lib/python3.11/site-packages/ray/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Replaced with the current commit when building the wheels.
2
+ commit = "637116a090c052d061af5ba3bef8a467c8c3fc25"
3
+ version = "2.42.0"
4
+
5
+ if __name__ == "__main__":
6
+ print("%s %s" % (version, commit))
.venv/lib/python3.11/site-packages/ray/actor.py ADDED
@@ -0,0 +1,1790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import weakref
4
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ import ray._private.ray_constants as ray_constants
7
+ import ray._private.signature as signature
8
+ import ray._private.worker
9
+ import ray._raylet
10
+ from ray import ActorClassID, Language, cross_language
11
+ from ray._private import ray_option_utils
12
+ from ray._private.async_compat import has_async_methods
13
+ from ray._private.auto_init_hook import wrap_auto_init
14
+ from ray._private.client_mode_hook import (
15
+ client_mode_convert_actor,
16
+ client_mode_hook,
17
+ client_mode_should_convert,
18
+ )
19
+ from ray._private.inspect_util import (
20
+ is_class_method,
21
+ is_function_or_method,
22
+ is_static_method,
23
+ )
24
+ from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group
25
+ from ray._private.utils import get_runtime_env_info, parse_runtime_env
26
+ from ray._raylet import (
27
+ STREAMING_GENERATOR_RETURN,
28
+ ObjectRefGenerator,
29
+ PythonFunctionDescriptor,
30
+ raise_sys_exit_with_custom_error_message,
31
+ )
32
+ from ray.exceptions import AsyncioActorExit
33
+ from ray.util.annotations import DeveloperAPI, PublicAPI
34
+ from ray.util.placement_group import _configure_placement_group_based_on_context
35
+ from ray.util.scheduling_strategies import (
36
+ PlacementGroupSchedulingStrategy,
37
+ SchedulingStrategyT,
38
+ )
39
+ from ray.util.tracing.tracing_helper import (
40
+ _inject_tracing_into_class,
41
+ _tracing_actor_creation,
42
+ _tracing_actor_method_invocation,
43
+ )
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Hook to call with (actor, resources, strategy) on each local actor creation.
48
+ _actor_launch_hook = None
49
+
50
+
51
+ @PublicAPI
52
+ @client_mode_hook
53
+ def method(*args, **kwargs):
54
+ """Annotate an actor method.
55
+
56
+ .. code-block:: python
57
+
58
+ @ray.remote
59
+ class Foo:
60
+ @ray.method(num_returns=2)
61
+ def bar(self):
62
+ return 1, 2
63
+
64
+ f = Foo.remote()
65
+
66
+ _, _ = f.bar.remote()
67
+
68
+ Args:
69
+ num_returns: The number of object refs that should be returned by
70
+ invocations of this actor method.
71
+ """
72
+ valid_kwargs = [
73
+ "num_returns",
74
+ "concurrency_group",
75
+ "max_task_retries",
76
+ "retry_exceptions",
77
+ "_generator_backpressure_num_objects",
78
+ "enable_task_events",
79
+ ]
80
+ error_string = (
81
+ "The @ray.method decorator must be applied using at least one of "
82
+ f"the arguments in the list {valid_kwargs}, for example "
83
+ "'@ray.method(num_returns=2)'."
84
+ )
85
+ assert len(args) == 0 and len(kwargs) > 0, error_string
86
+ for key in kwargs:
87
+ key_error_string = (
88
+ f"Unexpected keyword argument to @ray.method: '{key}'. The "
89
+ f"supported keyword arguments are {valid_kwargs}"
90
+ )
91
+ assert key in valid_kwargs, key_error_string
92
+
93
+ def annotate_method(method):
94
+ if "num_returns" in kwargs:
95
+ method.__ray_num_returns__ = kwargs["num_returns"]
96
+ if "max_task_retries" in kwargs:
97
+ method.__ray_max_task_retries__ = kwargs["max_task_retries"]
98
+ if "retry_exceptions" in kwargs:
99
+ method.__ray_retry_exceptions__ = kwargs["retry_exceptions"]
100
+ if "concurrency_group" in kwargs:
101
+ method.__ray_concurrency_group__ = kwargs["concurrency_group"]
102
+ if "_generator_backpressure_num_objects" in kwargs:
103
+ method.__ray_generator_backpressure_num_objects__ = kwargs[
104
+ "_generator_backpressure_num_objects"
105
+ ]
106
+ if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None:
107
+ method.__ray_enable_task_events__ = kwargs["enable_task_events"]
108
+ return method
109
+
110
+ return annotate_method
111
+
112
+
113
+ # Create objects to wrap method invocations. This is done so that we can
114
+ # invoke methods with actor.method.remote() instead of actor.method().
115
+ @PublicAPI
116
+ class ActorMethod:
117
+ """A class used to invoke an actor method.
118
+
119
+ Note: This class only keeps a weak ref to the actor, unless it has been
120
+ passed to a remote function. This avoids delays in GC of the actor.
121
+
122
+ Attributes:
123
+ _actor_ref: A weakref handle to the actor.
124
+ _method_name: The name of the actor method.
125
+ _num_returns: The default number of return values that the method
126
+ invocation should return. If None is given, it uses
127
+ DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS for a normal actor task
128
+ and "streaming" for a generator task (when `is_generator` is True).
129
+ _max_task_retries: Number of retries on method failure.
130
+ _retry_exceptions: Boolean of whether you want to retry all user-raised
131
+ exceptions, or a list of allowlist exceptions to retry.
132
+ _is_generator: True if a given method is a Python generator.
133
+ _generator_backpressure_num_objects: Generator-only config.
134
+ If a number of unconsumed objects reach this threshold,
135
+ a actor task stop pausing.
136
+ enable_task_events: True if task events is enabled, i.e., task events from
137
+ the actor should be reported. Defaults to True.
138
+ _signature: The signature of the actor method. It is None only when cross
139
+ language feature is used.
140
+ _decorator: An optional decorator that should be applied to the actor
141
+ method invocation (as opposed to the actor method execution) before
142
+ invoking the method. The decorator must return a function that
143
+ takes in two arguments ("args" and "kwargs"). In most cases, it
144
+ should call the function that was passed into the decorator and
145
+ return the resulting ObjectRefs. For an example, see
146
+ "test_decorated_method" in "python/ray/tests/test_actor.py".
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ actor,
152
+ method_name,
153
+ num_returns: Optional[Union[int, Literal["streaming"]]],
154
+ max_task_retries: int,
155
+ retry_exceptions: Union[bool, list, tuple],
156
+ is_generator: bool,
157
+ generator_backpressure_num_objects: int,
158
+ enable_task_events: bool,
159
+ decorator=None,
160
+ signature: Optional[List[inspect.Parameter]] = None,
161
+ hardref=False,
162
+ ):
163
+ self._actor_ref = weakref.ref(actor)
164
+ self._method_name = method_name
165
+ self._num_returns = num_returns
166
+
167
+ # Default case.
168
+ if self._num_returns is None:
169
+ if is_generator:
170
+ self._num_returns = "streaming"
171
+ else:
172
+ self._num_returns = ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS
173
+
174
+ self._max_task_retries = max_task_retries
175
+ self._retry_exceptions = retry_exceptions
176
+ self._is_generator = is_generator
177
+ self._generator_backpressure_num_objects = generator_backpressure_num_objects
178
+ self._enable_task_events = enable_task_events
179
+ self._signature = signature
180
+ # This is a decorator that is used to wrap the function invocation (as
181
+ # opposed to the function execution). The decorator must return a
182
+ # function that takes in two arguments ("args" and "kwargs"). In most
183
+ # cases, it should call the function that was passed into the decorator
184
+ # and return the resulting ObjectRefs.
185
+ self._decorator = decorator
186
+
187
+ # Acquire a hard ref to the actor, this is useful mainly when passing
188
+ # actor method handles to remote functions.
189
+ if hardref:
190
+ self._actor_hard_ref = actor
191
+ else:
192
+ self._actor_hard_ref = None
193
+
194
+ def __call__(self, *args, **kwargs):
195
+ raise TypeError(
196
+ "Actor methods cannot be called directly. Instead "
197
+ f"of running 'object.{self._method_name}()', try "
198
+ f"'object.{self._method_name}.remote()'."
199
+ )
200
+
201
+ @DeveloperAPI
202
+ def bind(self, *args, **kwargs):
203
+ return self._bind(args, kwargs)
204
+
205
+ def remote(self, *args, **kwargs):
206
+ return self._remote(args, kwargs)
207
+
208
+ def options(self, **options):
209
+ """Convenience method for executing an actor method call with options.
210
+
211
+ Same arguments as func._remote(), but returns a wrapped function
212
+ that a non-underscore .remote() can be called on.
213
+
214
+ Examples:
215
+ # The following two calls are equivalent.
216
+ >>> actor.my_method._remote(args=[x, y], name="foo", num_returns=2)
217
+ >>> actor.my_method.options(name="foo", num_returns=2).remote(x, y)
218
+ """
219
+
220
+ func_cls = self
221
+
222
+ class FuncWrapper:
223
+ def remote(self, *args, **kwargs):
224
+ return func_cls._remote(args=args, kwargs=kwargs, **options)
225
+
226
+ @DeveloperAPI
227
+ def bind(self, *args, **kwargs):
228
+ return func_cls._bind(args=args, kwargs=kwargs, **options)
229
+
230
+ return FuncWrapper()
231
+
232
+ @wrap_auto_init
233
+ @_tracing_actor_method_invocation
234
+ def _bind(
235
+ self,
236
+ args=None,
237
+ kwargs=None,
238
+ name="",
239
+ num_returns=None,
240
+ concurrency_group=None,
241
+ _generator_backpressure_num_objects=None,
242
+ ) -> Union["ray.dag.ClassMethodNode", Tuple["ray.dag.ClassMethodNode", ...]]:
243
+ from ray.dag.class_node import (
244
+ BIND_INDEX_KEY,
245
+ IS_CLASS_METHOD_OUTPUT_KEY,
246
+ PARENT_CLASS_NODE_KEY,
247
+ PREV_CLASS_METHOD_CALL_KEY,
248
+ ClassMethodNode,
249
+ )
250
+
251
+ # TODO(sang): unify option passing
252
+ options = {
253
+ "name": name,
254
+ "num_returns": num_returns,
255
+ "concurrency_group": concurrency_group,
256
+ "_generator_backpressure_num_objects": _generator_backpressure_num_objects,
257
+ }
258
+
259
+ actor = self._actor_ref()
260
+ if actor is None:
261
+ # Ref is GC'ed. It happens when the actor handle is GC'ed
262
+ # when bind is called.
263
+ raise RuntimeError("Lost reference to actor")
264
+
265
+ other_args_to_resolve = {
266
+ PARENT_CLASS_NODE_KEY: actor,
267
+ PREV_CLASS_METHOD_CALL_KEY: None,
268
+ BIND_INDEX_KEY: actor._ray_dag_bind_index,
269
+ }
270
+ actor._ray_dag_bind_index += 1
271
+
272
+ assert (
273
+ self._signature is not None
274
+ ), "self._signature should be set for .bind API."
275
+ try:
276
+ signature.validate_args(self._signature, args, kwargs)
277
+ except TypeError as e:
278
+ signature_copy = self._signature.copy()
279
+ if len(signature_copy) > 0 and signature_copy[-1].name == "_ray_trace_ctx":
280
+ # Remove the trace context arg for readability.
281
+ signature_copy.pop(-1)
282
+ signature_copy = inspect.Signature(parameters=signature_copy)
283
+ raise TypeError(
284
+ f"{str(e)}. The function `{self._method_name}` has a signature "
285
+ f"`{signature_copy}`, but the given arguments to `bind` doesn't "
286
+ f"match. args: {args}. kwargs: {kwargs}."
287
+ ) from None
288
+
289
+ node = ClassMethodNode(
290
+ self._method_name,
291
+ args,
292
+ kwargs,
293
+ options,
294
+ other_args_to_resolve=other_args_to_resolve,
295
+ )
296
+
297
+ if node.num_returns > 1:
298
+ output_nodes: List[ClassMethodNode] = []
299
+ for i in range(node.num_returns):
300
+ output_node = ClassMethodNode(
301
+ f"return_idx_{i}",
302
+ (node, i),
303
+ dict(),
304
+ dict(),
305
+ {IS_CLASS_METHOD_OUTPUT_KEY: True, PARENT_CLASS_NODE_KEY: actor},
306
+ )
307
+ output_nodes.append(output_node)
308
+ return tuple(output_nodes)
309
+ else:
310
+ return node
311
+
312
+ @wrap_auto_init
313
+ @_tracing_actor_method_invocation
314
+ def _remote(
315
+ self,
316
+ args=None,
317
+ kwargs=None,
318
+ name="",
319
+ num_returns=None,
320
+ max_task_retries=None,
321
+ retry_exceptions=None,
322
+ concurrency_group=None,
323
+ _generator_backpressure_num_objects=None,
324
+ enable_task_events=None,
325
+ ):
326
+ if num_returns is None:
327
+ num_returns = self._num_returns
328
+ if max_task_retries is None:
329
+ max_task_retries = self._max_task_retries
330
+ if max_task_retries is None:
331
+ max_task_retries = 0
332
+ if retry_exceptions is None:
333
+ retry_exceptions = self._retry_exceptions
334
+ if enable_task_events is None:
335
+ enable_task_events = self._enable_task_events
336
+ if _generator_backpressure_num_objects is None:
337
+ _generator_backpressure_num_objects = (
338
+ self._generator_backpressure_num_objects
339
+ )
340
+
341
+ def invocation(args, kwargs):
342
+ actor = self._actor_hard_ref or self._actor_ref()
343
+
344
+ if actor is None:
345
+ raise RuntimeError("Lost reference to actor")
346
+
347
+ return actor._actor_method_call(
348
+ self._method_name,
349
+ args=args,
350
+ kwargs=kwargs,
351
+ name=name,
352
+ num_returns=num_returns,
353
+ max_task_retries=max_task_retries,
354
+ retry_exceptions=retry_exceptions,
355
+ concurrency_group_name=concurrency_group,
356
+ generator_backpressure_num_objects=(
357
+ _generator_backpressure_num_objects
358
+ ),
359
+ enable_task_events=enable_task_events,
360
+ )
361
+
362
+ # Apply the decorator if there is one.
363
+ if self._decorator is not None:
364
+ invocation = self._decorator(invocation)
365
+
366
+ return invocation(args, kwargs)
367
+
368
+ def __getstate__(self):
369
+ return {
370
+ "actor": self._actor_ref(),
371
+ "method_name": self._method_name,
372
+ "num_returns": self._num_returns,
373
+ "max_task_retries": self._max_task_retries,
374
+ "retry_exceptions": self._retry_exceptions,
375
+ "decorator": self._decorator,
376
+ "is_generator": self._is_generator,
377
+ "generator_backpressure_num_objects": self._generator_backpressure_num_objects, # noqa
378
+ "enable_task_events": self._enable_task_events,
379
+ }
380
+
381
+ def __setstate__(self, state):
382
+ self.__init__(
383
+ state["actor"],
384
+ state["method_name"],
385
+ state["num_returns"],
386
+ state["max_task_retries"],
387
+ state["retry_exceptions"],
388
+ state["is_generator"],
389
+ state["generator_backpressure_num_objects"],
390
+ state["enable_task_events"],
391
+ state["decorator"],
392
+ hardref=True,
393
+ )
394
+
395
+
396
+ class _ActorClassMethodMetadata(object):
397
+ """Metadata for all methods in an actor class. This data can be cached.
398
+
399
+ Attributes:
400
+ methods: The actor methods.
401
+ decorators: Optional decorators that should be applied to the
402
+ method invocation function before invoking the actor methods. These
403
+ can be set by attaching the attribute
404
+ "__ray_invocation_decorator__" to the actor method.
405
+ signatures: The signatures of the methods.
406
+ num_returns: The default number of return values for
407
+ each actor method.
408
+ max_task_retries: Number of retries on method failure.
409
+ retry_exceptions: Boolean of whether you want to retry all user-raised
410
+ exceptions, or a list of allowlist exceptions to retry, for each method.
411
+ enable_task_events: True if tracing is enabled, i.e., task events from
412
+ the actor should be reported. Defaults to True.
413
+ """
414
+
415
+ _cache = {} # This cache will be cleared in ray._private.worker.disconnect()
416
+
417
+ def __init__(self):
418
+ class_name = type(self).__name__
419
+ raise TypeError(
420
+ f"{class_name} can not be constructed directly, "
421
+ f"instead of running '{class_name}()', "
422
+ f"try '{class_name}.create()'"
423
+ )
424
+
425
+ @classmethod
426
+ def reset_cache(cls):
427
+ cls._cache.clear()
428
+
429
+ @classmethod
430
+ def create(cls, modified_class, actor_creation_function_descriptor):
431
+ # Try to create an instance from cache.
432
+ cached_meta = cls._cache.get(actor_creation_function_descriptor)
433
+ if cached_meta is not None:
434
+ return cached_meta
435
+
436
+ # Create an instance without __init__ called.
437
+ self = cls.__new__(cls)
438
+
439
+ actor_methods = inspect.getmembers(modified_class, is_function_or_method)
440
+ self.methods = dict(actor_methods)
441
+
442
+ # Extract the signatures of each of the methods. This will be used
443
+ # to catch some errors if the methods are called with inappropriate
444
+ # arguments.
445
+ self.decorators = {}
446
+ self.signatures = {}
447
+ self.num_returns = {}
448
+ self.max_task_retries = {}
449
+ self.retry_exceptions = {}
450
+ self.method_is_generator = {}
451
+ self.enable_task_events = {}
452
+ self.generator_backpressure_num_objects = {}
453
+ self.concurrency_group_for_methods = {}
454
+
455
+ for method_name, method in actor_methods:
456
+ # Whether or not this method requires binding of its first
457
+ # argument. For class and static methods, we do not want to bind
458
+ # the first argument, but we do for instance methods
459
+ method = inspect.unwrap(method)
460
+ is_bound = is_class_method(method) or is_static_method(
461
+ modified_class, method_name
462
+ )
463
+
464
+ # Print a warning message if the method signature is not
465
+ # supported. We don't raise an exception because if the actor
466
+ # inherits from a class that has a method whose signature we
467
+ # don't support, there may not be much the user can do about it.
468
+ self.signatures[method_name] = signature.extract_signature(
469
+ method, ignore_first=not is_bound
470
+ )
471
+ # Set the default number of return values for this method.
472
+ if hasattr(method, "__ray_num_returns__"):
473
+ self.num_returns[method_name] = method.__ray_num_returns__
474
+ else:
475
+ self.num_returns[method_name] = None
476
+
477
+ # Only contains entries from `@ray.method(max_task_retries=...)`
478
+ # Ray may not populate the others with max_task_retries here because you may
479
+ # have set in `actor.method.options(max_task_retries=...)`. So Ray always
480
+ # stores max_task_retries both from the method and from the actor, and
481
+ # favors the former.
482
+ if hasattr(method, "__ray_max_task_retries__"):
483
+ self.max_task_retries[method_name] = method.__ray_max_task_retries__
484
+
485
+ if hasattr(method, "__ray_retry_exceptions__"):
486
+ self.retry_exceptions[method_name] = method.__ray_retry_exceptions__
487
+
488
+ if hasattr(method, "__ray_invocation_decorator__"):
489
+ self.decorators[method_name] = method.__ray_invocation_decorator__
490
+
491
+ if hasattr(method, "__ray_concurrency_group__"):
492
+ self.concurrency_group_for_methods[
493
+ method_name
494
+ ] = method.__ray_concurrency_group__
495
+
496
+ if hasattr(method, "__ray_enable_task_events__"):
497
+ self.enable_task_events[method_name] = method.__ray_enable_task_events__
498
+
499
+ is_generator = inspect.isgeneratorfunction(
500
+ method
501
+ ) or inspect.isasyncgenfunction(method)
502
+ self.method_is_generator[method_name] = is_generator
503
+
504
+ if hasattr(method, "__ray_generator_backpressure_num_objects__"):
505
+ self.generator_backpressure_num_objects[
506
+ method_name
507
+ ] = method.__ray_generator_backpressure_num_objects__
508
+
509
+ # Update cache.
510
+ cls._cache[actor_creation_function_descriptor] = self
511
+ return self
512
+
513
+
514
+ class _ActorClassMetadata:
515
+ """Metadata for an actor class.
516
+
517
+ Attributes:
518
+ language: The actor language, e.g. Python, Java.
519
+ modified_class: The original class that was decorated (with some
520
+ additional methods added like __ray_terminate__).
521
+ actor_creation_function_descriptor: The function descriptor for
522
+ the actor creation task.
523
+ class_id: The ID of this actor class.
524
+ class_name: The name of this class.
525
+ num_cpus: The default number of CPUs required by the actor creation
526
+ task.
527
+ num_gpus: The default number of GPUs required by the actor creation
528
+ task.
529
+ memory: The heap memory quota for this actor.
530
+ resources: The default resources required by the actor creation task.
531
+ accelerator_type: The specified type of accelerator required for the
532
+ node on which this actor runs.
533
+ See :ref:`accelerator types <accelerator_types>`.
534
+ runtime_env: The runtime environment for this actor.
535
+ scheduling_strategy: Strategy about how to schedule this actor.
536
+ last_export_cluster_and_job: A pair of the last exported cluster
537
+ and job to help us to know whether this function was exported.
538
+ This is an imperfect mechanism used to determine if we need to
539
+ export the remote function again. It is imperfect in the sense that
540
+ the actor class definition could be exported multiple times by
541
+ different workers.
542
+ method_meta: The actor method metadata.
543
+ """
544
+
545
+ def __init__(
546
+ self,
547
+ language,
548
+ modified_class,
549
+ actor_creation_function_descriptor,
550
+ class_id,
551
+ max_restarts,
552
+ max_task_retries,
553
+ num_cpus,
554
+ num_gpus,
555
+ memory,
556
+ object_store_memory,
557
+ resources,
558
+ accelerator_type,
559
+ runtime_env,
560
+ concurrency_groups,
561
+ scheduling_strategy: SchedulingStrategyT,
562
+ ):
563
+ self.language = language
564
+ self.modified_class = modified_class
565
+ self.actor_creation_function_descriptor = actor_creation_function_descriptor
566
+ self.class_name = actor_creation_function_descriptor.class_name
567
+ self.is_cross_language = language != Language.PYTHON
568
+ self.class_id = class_id
569
+ self.max_restarts = max_restarts
570
+ self.max_task_retries = max_task_retries
571
+ self.num_cpus = num_cpus
572
+ self.num_gpus = num_gpus
573
+ self.memory = memory
574
+ self.object_store_memory = object_store_memory
575
+ self.resources = resources
576
+ self.accelerator_type = accelerator_type
577
+ self.runtime_env = runtime_env
578
+ self.concurrency_groups = concurrency_groups
579
+ self.scheduling_strategy = scheduling_strategy
580
+ self.last_export_cluster_and_job = None
581
+ self.method_meta = _ActorClassMethodMetadata.create(
582
+ modified_class, actor_creation_function_descriptor
583
+ )
584
+
585
+
586
+ @PublicAPI
587
+ class ActorClassInheritanceException(TypeError):
588
+ pass
589
+
590
+
591
+ def _process_option_dict(actor_options):
592
+ _filled_options = {}
593
+ arg_names = set(inspect.getfullargspec(_ActorClassMetadata.__init__)[0])
594
+ for k, v in ray_option_utils.actor_options.items():
595
+ if k in arg_names:
596
+ _filled_options[k] = actor_options.get(k, v.default_value)
597
+ _filled_options["runtime_env"] = parse_runtime_env(_filled_options["runtime_env"])
598
+ return _filled_options
599
+
600
+
601
+ @PublicAPI
602
+ class ActorClass:
603
+ """An actor class.
604
+
605
+ This is a decorated class. It can be used to create actors.
606
+
607
+ Attributes:
608
+ __ray_metadata__: Contains metadata for the actor.
609
+ """
610
+
611
+ def __init__(cls, name, bases, attr):
612
+ """Prevents users from directly inheriting from an ActorClass.
613
+
614
+ This will be called when a class is defined with an ActorClass object
615
+ as one of its base classes. To intentionally construct an ActorClass,
616
+ use the '_ray_from_modified_class' classmethod.
617
+
618
+ Raises:
619
+ ActorClassInheritanceException: When ActorClass is inherited.
620
+ AssertionError: If ActorClassInheritanceException is not raised i.e.,
621
+ conditions for raising it are not met in any
622
+ iteration of the loop.
623
+ TypeError: In all other cases.
624
+ """
625
+ for base in bases:
626
+ if isinstance(base, ActorClass):
627
+ raise ActorClassInheritanceException(
628
+ f"Attempted to define subclass '{name}' of actor "
629
+ f"class '{base.__ray_metadata__.class_name}'. "
630
+ "Inheriting from actor classes is "
631
+ "not currently supported. You can instead "
632
+ "inherit from a non-actor base class and make "
633
+ "the derived class an actor class (with "
634
+ "@ray.remote)."
635
+ )
636
+
637
+ # This shouldn't be reached because one of the base classes must be
638
+ # an actor class if this was meant to be subclassed.
639
+ assert False, (
640
+ "ActorClass.__init__ should not be called. Please use "
641
+ "the @ray.remote decorator instead."
642
+ )
643
+
644
+ def __call__(self, *args, **kwargs):
645
+ """Prevents users from directly instantiating an ActorClass.
646
+
647
+ This will be called instead of __init__ when 'ActorClass()' is executed
648
+ because an is an object rather than a metaobject. To properly
649
+ instantiated a remote actor, use 'ActorClass.remote()'.
650
+
651
+ Raises:
652
+ Exception: Always.
653
+ """
654
+ raise TypeError(
655
+ "Actors cannot be instantiated directly. "
656
+ f"Instead of '{self.__ray_metadata__.class_name}()', "
657
+ f"use '{self.__ray_metadata__.class_name}.remote()'."
658
+ )
659
+
660
+ @classmethod
661
+ def _ray_from_modified_class(
662
+ cls,
663
+ modified_class,
664
+ class_id,
665
+ actor_options,
666
+ ):
667
+ for attribute in [
668
+ "remote",
669
+ "_remote",
670
+ "_ray_from_modified_class",
671
+ "_ray_from_function_descriptor",
672
+ ]:
673
+ if hasattr(modified_class, attribute):
674
+ logger.warning(
675
+ "Creating an actor from class "
676
+ f"{modified_class.__name__} overwrites "
677
+ f"attribute {attribute} of that class"
678
+ )
679
+
680
+ # Make sure the actor class we are constructing inherits from the
681
+ # original class so it retains all class properties.
682
+ class DerivedActorClass(cls, modified_class):
683
+ def __init__(self, *args, **kwargs):
684
+ try:
685
+ cls.__init__(self, *args, **kwargs)
686
+ except Exception as e:
687
+ # Delegate call to modified_class.__init__ only
688
+ # if the exception raised by cls.__init__ is
689
+ # TypeError and not ActorClassInheritanceException(TypeError).
690
+ # In all other cases proceed with raise e.
691
+ if isinstance(e, TypeError) and not isinstance(
692
+ e, ActorClassInheritanceException
693
+ ):
694
+ modified_class.__init__(self, *args, **kwargs)
695
+ else:
696
+ raise e
697
+
698
+ name = f"ActorClass({modified_class.__name__})"
699
+ DerivedActorClass.__module__ = modified_class.__module__
700
+ DerivedActorClass.__name__ = name
701
+ DerivedActorClass.__qualname__ = name
702
+ # Construct the base object.
703
+ self = DerivedActorClass.__new__(DerivedActorClass)
704
+ # Actor creation function descriptor.
705
+ actor_creation_function_descriptor = PythonFunctionDescriptor.from_class(
706
+ modified_class.__ray_actor_class__
707
+ )
708
+
709
+ self.__ray_metadata__ = _ActorClassMetadata(
710
+ Language.PYTHON,
711
+ modified_class,
712
+ actor_creation_function_descriptor,
713
+ class_id,
714
+ **_process_option_dict(actor_options),
715
+ )
716
+ self._default_options = actor_options
717
+ if "runtime_env" in self._default_options:
718
+ self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
719
+
720
+ return self
721
+
722
+ @classmethod
723
+ def _ray_from_function_descriptor(
724
+ cls,
725
+ language,
726
+ actor_creation_function_descriptor,
727
+ actor_options,
728
+ ):
729
+ self = ActorClass.__new__(ActorClass)
730
+ self.__ray_metadata__ = _ActorClassMetadata(
731
+ language,
732
+ None,
733
+ actor_creation_function_descriptor,
734
+ None,
735
+ **_process_option_dict(actor_options),
736
+ )
737
+ self._default_options = actor_options
738
+ if "runtime_env" in self._default_options:
739
+ self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
740
+ return self
741
+
742
+ def remote(self, *args, **kwargs):
743
+ """Create an actor.
744
+
745
+ Args:
746
+ args: These arguments are forwarded directly to the actor
747
+ constructor.
748
+ kwargs: These arguments are forwarded directly to the actor
749
+ constructor.
750
+
751
+ Returns:
752
+ A handle to the newly created actor.
753
+ """
754
+ return self._remote(args=args, kwargs=kwargs, **self._default_options)
755
+
756
+ def options(self, **actor_options):
757
+ """Configures and overrides the actor instantiation parameters.
758
+
759
+ The arguments are the same as those that can be passed
760
+ to :obj:`ray.remote`.
761
+
762
+ Args:
763
+ num_cpus: The quantity of CPU cores to reserve
764
+ for this task or for the lifetime of the actor.
765
+ num_gpus: The quantity of GPUs to reserve
766
+ for this task or for the lifetime of the actor.
767
+ resources (Dict[str, float]): The quantity of various custom resources
768
+ to reserve for this task or for the lifetime of the actor.
769
+ This is a dictionary mapping strings (resource names) to floats.
770
+ accelerator_type: If specified, requires that the task or actor run
771
+ on a node with the specified type of accelerator.
772
+ See :ref:`accelerator types <accelerator_types>`.
773
+ memory: The heap memory request in bytes for this task/actor,
774
+ rounded down to the nearest integer.
775
+ object_store_memory: The object store memory request for actors only.
776
+ max_restarts: This specifies the maximum
777
+ number of times that the actor should be restarted when it dies
778
+ unexpectedly. The minimum valid value is 0 (default),
779
+ which indicates that the actor doesn't need to be restarted.
780
+ A value of -1 indicates that an actor should be restarted
781
+ indefinitely.
782
+ max_task_retries: How many times to
783
+ retry an actor task if the task fails due to a runtime error,
784
+ e.g., the actor has died. If set to -1, the system will
785
+ retry the failed task until the task succeeds, or the actor
786
+ has reached its max_restarts limit. If set to `n > 0`, the
787
+ system will retry the failed task up to n times, after which the
788
+ task will throw a `RayActorError` exception upon :obj:`ray.get`.
789
+ Note that Python exceptions may trigger retries *only if*
790
+ `retry_exceptions` is set for the method, in that case when
791
+ `max_task_retries` runs out the task will rethrow the exception from
792
+ the task. You can override this number with the method's
793
+ `max_task_retries` option in `@ray.method` decorator or in `.option()`.
794
+ max_pending_calls: Set the max number of pending calls
795
+ allowed on the actor handle. When this value is exceeded,
796
+ PendingCallsLimitExceeded will be raised for further tasks.
797
+ Note that this limit is counted per handle. -1 means that the
798
+ number of pending calls is unlimited.
799
+ max_concurrency: The max number of concurrent calls to allow for
800
+ this actor. This only works with direct actor calls. The max
801
+ concurrency defaults to 1 for threaded execution, and 1000 for
802
+ asyncio execution. Note that the execution order is not
803
+ guaranteed when max_concurrency > 1.
804
+ name: The globally unique name for the actor, which can be used
805
+ to retrieve the actor via ray.get_actor(name) as long as the
806
+ actor is still alive.
807
+ namespace: Override the namespace to use for the actor. By default,
808
+ actors are created in an anonymous namespace. The actor can
809
+ be retrieved via ray.get_actor(name=name, namespace=namespace).
810
+ lifetime: Either `None`, which defaults to the actor will fate
811
+ share with its creator and will be deleted once its refcount
812
+ drops to zero, or "detached", which means the actor will live
813
+ as a global object independent of the creator.
814
+ runtime_env (Dict[str, Any]): Specifies the runtime environment for
815
+ this actor or task and its children. See
816
+ :ref:`runtime-environments` for detailed documentation.
817
+ scheduling_strategy: Strategy about how to
818
+ schedule a remote function or actor. Possible values are
819
+ None: ray will figure out the scheduling strategy to use, it
820
+ will either be the PlacementGroupSchedulingStrategy using parent's
821
+ placement group if parent has one and has
822
+ placement_group_capture_child_tasks set to true,
823
+ or "DEFAULT";
824
+ "DEFAULT": default hybrid scheduling;
825
+ "SPREAD": best effort spread scheduling;
826
+ `PlacementGroupSchedulingStrategy`:
827
+ placement group based scheduling;
828
+ `NodeAffinitySchedulingStrategy`:
829
+ node id based affinity scheduling.
830
+ _metadata: Extended options for Ray libraries. For example,
831
+ _metadata={"workflows.io/options": <workflow options>} for
832
+ Ray workflows.
833
+ enable_task_events: True if tracing is enabled, i.e., task events from
834
+ the actor should be reported. Defaults to True.
835
+
836
+ Examples:
837
+
838
+ .. code-block:: python
839
+
840
+ @ray.remote(num_cpus=2, resources={"CustomResource": 1})
841
+ class Foo:
842
+ def method(self):
843
+ return 1
844
+ # Class Bar will require 1 cpu instead of 2.
845
+ # It will also require no custom resources.
846
+ Bar = Foo.options(num_cpus=1, resources=None)
847
+ """
848
+
849
+ actor_cls = self
850
+
851
+ # override original options
852
+ default_options = self._default_options.copy()
853
+ # "concurrency_groups" could not be used in ".options()",
854
+ # we should remove it before merging options from '@ray.remote'.
855
+ default_options.pop("concurrency_groups", None)
856
+ updated_options = ray_option_utils.update_options(
857
+ default_options, actor_options
858
+ )
859
+ ray_option_utils.validate_actor_options(updated_options, in_options=True)
860
+
861
+ # only update runtime_env when ".options()" specifies new runtime_env
862
+ if "runtime_env" in actor_options:
863
+ updated_options["runtime_env"] = parse_runtime_env(
864
+ updated_options["runtime_env"]
865
+ )
866
+
867
+ class ActorOptionWrapper:
868
+ def remote(self, *args, **kwargs):
869
+ return actor_cls._remote(args=args, kwargs=kwargs, **updated_options)
870
+
871
+ @DeveloperAPI
872
+ def bind(self, *args, **kwargs):
873
+ """
874
+ For Ray DAG building that creates static graph from decorated
875
+ class or functions.
876
+ """
877
+ from ray.dag.class_node import ClassNode
878
+
879
+ return ClassNode(
880
+ actor_cls.__ray_metadata__.modified_class,
881
+ args,
882
+ kwargs,
883
+ updated_options,
884
+ )
885
+
886
+ return ActorOptionWrapper()
887
+
888
+ @wrap_auto_init
889
+ @_tracing_actor_creation
890
+ def _remote(self, args=None, kwargs=None, **actor_options):
891
+ """Create an actor.
892
+
893
+ This method allows more flexibility than the remote method because
894
+ resource requirements can be specified and override the defaults in the
895
+ decorator.
896
+
897
+ Args:
898
+ args: The arguments to forward to the actor constructor.
899
+ kwargs: The keyword arguments to forward to the actor constructor.
900
+ num_cpus: The number of CPUs required by the actor creation task.
901
+ num_gpus: The number of GPUs required by the actor creation task.
902
+ memory: Restrict the heap memory usage of this actor.
903
+ resources: The custom resources required by the actor creation
904
+ task.
905
+ max_concurrency: The max number of concurrent calls to allow for
906
+ this actor. This only works with direct actor calls. The max
907
+ concurrency defaults to 1 for threaded execution, and 1000 for
908
+ asyncio execution. Note that the execution order is not
909
+ guaranteed when max_concurrency > 1.
910
+ name: The globally unique name for the actor, which can be used
911
+ to retrieve the actor via ray.get_actor(name) as long as the
912
+ actor is still alive.
913
+ namespace: Override the namespace to use for the actor. By default,
914
+ actors are created in an anonymous namespace. The actor can
915
+ be retrieved via ray.get_actor(name=name, namespace=namespace).
916
+ lifetime: Either `None`, which defaults to the actor will fate
917
+ share with its creator and will be deleted once its refcount
918
+ drops to zero, or "detached", which means the actor will live
919
+ as a global object independent of the creator.
920
+ placement_group: (This has been deprecated, please use
921
+ `PlacementGroupSchedulingStrategy` scheduling_strategy)
922
+ the placement group this actor belongs to,
923
+ or None if it doesn't belong to any group. Setting to "default"
924
+ autodetects the placement group based on the current setting of
925
+ placement_group_capture_child_tasks.
926
+ placement_group_bundle_index: (This has been deprecated, please use
927
+ `PlacementGroupSchedulingStrategy` scheduling_strategy)
928
+ the index of the bundle
929
+ if the actor belongs to a placement group, which may be -1 to
930
+ specify any available bundle.
931
+ placement_group_capture_child_tasks: (This has been deprecated,
932
+ please use `PlacementGroupSchedulingStrategy`
933
+ scheduling_strategy)
934
+ Whether or not children tasks
935
+ of this actor should implicitly use the same placement group
936
+ as its parent. It is False by default.
937
+ runtime_env (Dict[str, Any]): Specifies the runtime environment for
938
+ this actor or task and its children (see
939
+ :ref:`runtime-environments` for details).
940
+ max_pending_calls: Set the max number of pending calls
941
+ allowed on the actor handle. When this value is exceeded,
942
+ PendingCallsLimitExceeded will be raised for further tasks.
943
+ Note that this limit is counted per handle. -1 means that the
944
+ number of pending calls is unlimited.
945
+ scheduling_strategy: Strategy about how to schedule this actor.
946
+ enable_task_events: True if tracing is enabled, i.e., task events from
947
+ the actor should be reported. Defaults to True.
948
+ _labels: The key-value labels of the actor.
949
+
950
+ Returns:
951
+ A handle to the newly created actor.
952
+ """
953
+ name = actor_options.get("name")
954
+ namespace = actor_options.get("namespace")
955
+ if name is not None:
956
+ if not isinstance(name, str):
957
+ raise TypeError(f"name must be None or a string, got: '{type(name)}'.")
958
+ elif name == "":
959
+ raise ValueError("Actor name cannot be an empty string.")
960
+ if namespace is not None:
961
+ ray._private.utils.validate_namespace(namespace)
962
+
963
+ # Handle the get-or-create case.
964
+ if actor_options.get("get_if_exists"):
965
+ try:
966
+ return ray.get_actor(name, namespace=namespace)
967
+ except ValueError:
968
+ # Attempt to create it (may race with other attempts).
969
+ updated_options = actor_options.copy()
970
+ updated_options["get_if_exists"] = False # prevent infinite loop
971
+ try:
972
+ return self._remote(args, kwargs, **updated_options)
973
+ except ValueError:
974
+ # We lost the creation race, ignore.
975
+ pass
976
+ return ray.get_actor(name, namespace=namespace)
977
+
978
+ # We pop the "concurrency_groups" coming from "@ray.remote" here. We no longer
979
+ # need it in "_remote()".
980
+ actor_options.pop("concurrency_groups", None)
981
+
982
+ if args is None:
983
+ args = []
984
+ if kwargs is None:
985
+ kwargs = {}
986
+ meta = self.__ray_metadata__
987
+ is_asyncio = has_async_methods(meta.modified_class)
988
+
989
+ if actor_options.get("max_concurrency") is None:
990
+ actor_options["max_concurrency"] = (
991
+ ray_constants.DEFAULT_MAX_CONCURRENCY_ASYNC
992
+ if is_asyncio
993
+ else ray_constants.DEFAULT_MAX_CONCURRENCY_THREADED
994
+ )
995
+
996
+ if client_mode_should_convert():
997
+ return client_mode_convert_actor(self, args, kwargs, **actor_options)
998
+
999
+ # fill actor required options
1000
+ for k, v in ray_option_utils.actor_options.items():
1001
+ actor_options[k] = actor_options.get(k, v.default_value)
1002
+ # "concurrency_groups" already takes effects and should not apply again.
1003
+ # Remove the default value here.
1004
+ actor_options.pop("concurrency_groups", None)
1005
+
1006
+ # TODO(suquark): cleanup these fields
1007
+ max_concurrency = actor_options["max_concurrency"]
1008
+ lifetime = actor_options["lifetime"]
1009
+ runtime_env = actor_options["runtime_env"]
1010
+ placement_group = actor_options["placement_group"]
1011
+ placement_group_bundle_index = actor_options["placement_group_bundle_index"]
1012
+ placement_group_capture_child_tasks = actor_options[
1013
+ "placement_group_capture_child_tasks"
1014
+ ]
1015
+ scheduling_strategy = actor_options["scheduling_strategy"]
1016
+ max_restarts = actor_options["max_restarts"]
1017
+ max_task_retries = actor_options["max_task_retries"]
1018
+ max_pending_calls = actor_options["max_pending_calls"]
1019
+
1020
+ # Override enable_task_events to default for actor if not specified (i.e. None)
1021
+ enable_task_events = actor_options.get("enable_task_events")
1022
+
1023
+ if scheduling_strategy is None or not isinstance(
1024
+ scheduling_strategy, PlacementGroupSchedulingStrategy
1025
+ ):
1026
+ _warn_if_using_deprecated_placement_group(actor_options, 3)
1027
+
1028
+ worker = ray._private.worker.global_worker
1029
+ worker.check_connected()
1030
+
1031
+ # Check whether the name is already taken.
1032
+ # TODO(edoakes): this check has a race condition because two drivers
1033
+ # could pass the check and then create the same named actor. We should
1034
+ # instead check this when we create the actor, but that's currently an
1035
+ # async call.
1036
+ if name is not None:
1037
+ try:
1038
+ ray.get_actor(name, namespace=namespace)
1039
+ except ValueError: # Name is not taken.
1040
+ pass
1041
+ else:
1042
+ raise ValueError(
1043
+ f"The name {name} (namespace={namespace}) is already "
1044
+ "taken. Please use "
1045
+ "a different name or get the existing actor using "
1046
+ f"ray.get_actor('{name}', namespace='{namespace}')"
1047
+ )
1048
+
1049
+ if lifetime is None:
1050
+ detached = None
1051
+ elif lifetime == "detached":
1052
+ detached = True
1053
+ elif lifetime == "non_detached":
1054
+ detached = False
1055
+ else:
1056
+ raise ValueError(
1057
+ "actor `lifetime` argument must be one of 'detached', "
1058
+ "'non_detached' and 'None'."
1059
+ )
1060
+
1061
+ # LOCAL_MODE cannot handle cross_language
1062
+ if worker.mode == ray.LOCAL_MODE:
1063
+ assert (
1064
+ not meta.is_cross_language
1065
+ ), "Cross language ActorClass cannot be executed locally."
1066
+
1067
+ # Export the actor.
1068
+ if not meta.is_cross_language and (
1069
+ meta.last_export_cluster_and_job != worker.current_cluster_and_job
1070
+ ):
1071
+ # If this actor class was not exported in this cluster and job,
1072
+ # we need to export this function again, because current GCS
1073
+ # doesn't have it.
1074
+
1075
+ # After serialize / deserialize modified class, the __module__
1076
+ # of modified class will be ray.cloudpickle.cloudpickle.
1077
+ # So, here pass actor_creation_function_descriptor to make
1078
+ # sure export actor class correct.
1079
+ worker.function_actor_manager.export_actor_class(
1080
+ meta.modified_class,
1081
+ meta.actor_creation_function_descriptor,
1082
+ meta.method_meta.methods.keys(),
1083
+ )
1084
+ meta.last_export_cluster_and_job = worker.current_cluster_and_job
1085
+
1086
+ resources = ray._private.utils.resources_from_ray_options(actor_options)
1087
+ # Set the actor's default resources if not already set. First three
1088
+ # conditions are to check that no resources were specified in the
1089
+ # decorator. Last three conditions are to check that no resources were
1090
+ # specified when _remote() was called.
1091
+ # TODO(suquark): In the original code, memory is not considered as resources,
1092
+ # when deciding the default CPUs. It is strange, but we keep the original
1093
+ # semantics in case that it breaks user applications & tests.
1094
+ if not set(resources.keys()).difference({"memory", "object_store_memory"}):
1095
+ # In the default case, actors acquire no resources for
1096
+ # their lifetime, and actor methods will require 1 CPU.
1097
+ resources.setdefault("CPU", ray_constants.DEFAULT_ACTOR_CREATION_CPU_SIMPLE)
1098
+ actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SIMPLE
1099
+ else:
1100
+ # If any resources are specified (here or in decorator), then
1101
+ # all resources are acquired for the actor's lifetime and no
1102
+ # resources are associated with methods.
1103
+ resources.setdefault(
1104
+ "CPU", ray_constants.DEFAULT_ACTOR_CREATION_CPU_SPECIFIED
1105
+ )
1106
+ actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED
1107
+
1108
+ # If the actor methods require CPU resources, then set the required
1109
+ # placement resources. If actor_placement_resources is empty, then
1110
+ # the required placement resources will be the same as resources.
1111
+ actor_placement_resources = {}
1112
+ assert actor_method_cpu in [0, 1]
1113
+ if actor_method_cpu == 1:
1114
+ actor_placement_resources = resources.copy()
1115
+ actor_placement_resources["CPU"] += 1
1116
+ if meta.is_cross_language:
1117
+ creation_args = cross_language._format_args(worker, args, kwargs)
1118
+ else:
1119
+ function_signature = meta.method_meta.signatures["__init__"]
1120
+ creation_args = signature.flatten_args(function_signature, args, kwargs)
1121
+
1122
+ if scheduling_strategy is None or isinstance(
1123
+ scheduling_strategy, PlacementGroupSchedulingStrategy
1124
+ ):
1125
+ # TODO(jjyao) Clean this up once the
1126
+ # placement_group option is removed.
1127
+ # We should also consider pushing this logic down to c++
1128
+ # so that it can be reused by all languages.
1129
+ if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy):
1130
+ placement_group = scheduling_strategy.placement_group
1131
+ placement_group_bundle_index = (
1132
+ scheduling_strategy.placement_group_bundle_index
1133
+ )
1134
+ placement_group_capture_child_tasks = (
1135
+ scheduling_strategy.placement_group_capture_child_tasks
1136
+ )
1137
+
1138
+ if placement_group_capture_child_tasks is None:
1139
+ placement_group_capture_child_tasks = (
1140
+ worker.should_capture_child_tasks_in_placement_group
1141
+ )
1142
+ placement_group = _configure_placement_group_based_on_context(
1143
+ placement_group_capture_child_tasks,
1144
+ placement_group_bundle_index,
1145
+ resources,
1146
+ actor_placement_resources,
1147
+ meta.class_name,
1148
+ placement_group=placement_group,
1149
+ )
1150
+ if not placement_group.is_empty:
1151
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
1152
+ placement_group,
1153
+ placement_group_bundle_index,
1154
+ placement_group_capture_child_tasks,
1155
+ )
1156
+ else:
1157
+ scheduling_strategy = "DEFAULT"
1158
+
1159
+ serialized_runtime_env_info = None
1160
+ if runtime_env is not None:
1161
+ serialized_runtime_env_info = get_runtime_env_info(
1162
+ runtime_env,
1163
+ is_job_runtime_env=False,
1164
+ serialize=True,
1165
+ )
1166
+
1167
+ concurrency_groups_dict = {}
1168
+ if meta.concurrency_groups is None:
1169
+ meta.concurrency_groups = []
1170
+ for cg_name in meta.concurrency_groups:
1171
+ concurrency_groups_dict[cg_name] = {
1172
+ "name": cg_name,
1173
+ "max_concurrency": meta.concurrency_groups[cg_name],
1174
+ "function_descriptors": [],
1175
+ }
1176
+
1177
+ # Update methods
1178
+ for method_name in meta.method_meta.concurrency_group_for_methods:
1179
+ cg_name = meta.method_meta.concurrency_group_for_methods[method_name]
1180
+ assert cg_name in concurrency_groups_dict
1181
+
1182
+ module_name = meta.actor_creation_function_descriptor.module_name
1183
+ class_name = meta.actor_creation_function_descriptor.class_name
1184
+ concurrency_groups_dict[cg_name]["function_descriptors"].append(
1185
+ PythonFunctionDescriptor(module_name, method_name, class_name)
1186
+ )
1187
+
1188
+ # Update the creation descriptor based on number of arguments
1189
+ if meta.is_cross_language:
1190
+ func_name = "<init>"
1191
+ if meta.language == Language.CPP:
1192
+ func_name = meta.actor_creation_function_descriptor.function_name
1193
+ meta.actor_creation_function_descriptor = (
1194
+ cross_language._get_function_descriptor_for_actor_method(
1195
+ meta.language,
1196
+ meta.actor_creation_function_descriptor,
1197
+ func_name,
1198
+ str(len(args) + len(kwargs)),
1199
+ )
1200
+ )
1201
+
1202
+ actor_id = worker.core_worker.create_actor(
1203
+ meta.language,
1204
+ meta.actor_creation_function_descriptor,
1205
+ creation_args,
1206
+ max_restarts,
1207
+ max_task_retries,
1208
+ resources,
1209
+ actor_placement_resources,
1210
+ max_concurrency,
1211
+ detached,
1212
+ name if name is not None else "",
1213
+ namespace if namespace is not None else "",
1214
+ is_asyncio,
1215
+ # Store actor_method_cpu in actor handle's extension data.
1216
+ extension_data=str(actor_method_cpu),
1217
+ serialized_runtime_env_info=serialized_runtime_env_info or "{}",
1218
+ concurrency_groups_dict=concurrency_groups_dict or dict(),
1219
+ max_pending_calls=max_pending_calls,
1220
+ scheduling_strategy=scheduling_strategy,
1221
+ enable_task_events=enable_task_events,
1222
+ labels=actor_options.get("_labels"),
1223
+ )
1224
+
1225
+ if _actor_launch_hook:
1226
+ _actor_launch_hook(
1227
+ meta.actor_creation_function_descriptor, resources, scheduling_strategy
1228
+ )
1229
+
1230
+ actor_handle = ActorHandle(
1231
+ meta.language,
1232
+ actor_id,
1233
+ max_task_retries,
1234
+ enable_task_events,
1235
+ meta.method_meta.method_is_generator,
1236
+ meta.method_meta.decorators,
1237
+ meta.method_meta.signatures,
1238
+ meta.method_meta.num_returns,
1239
+ meta.method_meta.max_task_retries,
1240
+ meta.method_meta.retry_exceptions,
1241
+ meta.method_meta.generator_backpressure_num_objects,
1242
+ meta.method_meta.enable_task_events,
1243
+ actor_method_cpu,
1244
+ meta.actor_creation_function_descriptor,
1245
+ worker.current_cluster_and_job,
1246
+ original_handle=True,
1247
+ )
1248
+
1249
+ return actor_handle
1250
+
1251
+ @DeveloperAPI
1252
+ def bind(self, *args, **kwargs):
1253
+ """
1254
+ For Ray DAG building that creates static graph from decorated
1255
+ class or functions.
1256
+ """
1257
+ from ray.dag.class_node import ClassNode
1258
+
1259
+ return ClassNode(
1260
+ self.__ray_metadata__.modified_class, args, kwargs, self._default_options
1261
+ )
1262
+
1263
+
1264
+ @PublicAPI
1265
+ class ActorHandle:
1266
+ """A handle to an actor.
1267
+
1268
+ The fields in this class are prefixed with _ray_ to hide them from the user
1269
+ and to avoid collision with actor method names.
1270
+
1271
+ An ActorHandle can be created in three ways. First, by calling .remote() on
1272
+ an ActorClass. Second, by passing an actor handle into a task (forking the
1273
+ ActorHandle). Third, by directly serializing the ActorHandle (e.g., with
1274
+ cloudpickle).
1275
+
1276
+ Attributes:
1277
+ _ray_actor_language: The actor language.
1278
+ _ray_actor_id: Actor ID.
1279
+ _ray_enable_task_events: The default value of whether task events is
1280
+ enabled, i.e., task events from the actor should be reported.
1281
+ _ray_method_is_generator: Map of method name -> if it is a generator
1282
+ method.
1283
+ _ray_method_decorators: Optional decorators for the function
1284
+ invocation. This can be used to change the behavior on the
1285
+ invocation side, whereas a regular decorator can be used to change
1286
+ the behavior on the execution side.
1287
+ _ray_method_signatures: The signatures of the actor methods.
1288
+ _ray_method_max_task_retries: Max number of retries on method failure.
1289
+ _ray_method_num_returns: The default number of return values for
1290
+ each method.
1291
+ _ray_method_retry_exceptions: The default value of boolean of whether you want
1292
+ to retry all user-raised exceptions, or a list of allowlist exceptions to
1293
+ retry.
1294
+ _ray_method_generator_backpressure_num_objects: Generator-only
1295
+ config. The max number of objects to generate before it
1296
+ starts pausing a generator.
1297
+ _ray_method_enable_task_events: The value of whether task
1298
+ tracing is enabled for the actor methods. This overrides the
1299
+ actor's default value (`_ray_enable_task_events`).
1300
+ _ray_actor_method_cpus: The number of CPUs required by actor methods.
1301
+ _ray_original_handle: True if this is the original actor handle for a
1302
+ given actor. If this is true, then the actor will be destroyed when
1303
+ this handle goes out of scope.
1304
+ _ray_weak_ref: True means that this handle does not count towards the
1305
+ distributed ref count for the actor, i.e. the actor may be GCed
1306
+ while this handle is still in scope. This is set to True if the
1307
+ handle was created by getting an actor by name or by getting the
1308
+ self handle. It is set to False if this is the original handle or
1309
+ if it was created by passing the original handle through task args
1310
+ and returns.
1311
+ _ray_is_cross_language: Whether this actor is cross language.
1312
+ _ray_actor_creation_function_descriptor: The function descriptor
1313
+ of the actor creation task.
1314
+ """
1315
+
1316
+ def __init__(
1317
+ self,
1318
+ language,
1319
+ actor_id,
1320
+ max_task_retries: Optional[int],
1321
+ enable_task_events: bool,
1322
+ method_is_generator: Dict[str, bool],
1323
+ method_decorators,
1324
+ method_signatures,
1325
+ method_num_returns: Dict[str, Union[int, Literal["streaming"]]],
1326
+ method_max_task_retries: Dict[str, int],
1327
+ method_retry_exceptions: Dict[str, Union[bool, list, tuple]],
1328
+ method_generator_backpressure_num_objects: Dict[str, int],
1329
+ method_enable_task_events: Dict[str, bool],
1330
+ actor_method_cpus: int,
1331
+ actor_creation_function_descriptor,
1332
+ cluster_and_job,
1333
+ original_handle=False,
1334
+ weak_ref: bool = False,
1335
+ ):
1336
+ self._ray_actor_language = language
1337
+ self._ray_actor_id = actor_id
1338
+ self._ray_max_task_retries = max_task_retries
1339
+ self._ray_original_handle = original_handle
1340
+ self._ray_weak_ref = weak_ref
1341
+ self._ray_enable_task_events = enable_task_events
1342
+
1343
+ self._ray_method_is_generator = method_is_generator
1344
+ self._ray_method_decorators = method_decorators
1345
+ self._ray_method_signatures = method_signatures
1346
+ self._ray_method_num_returns = method_num_returns
1347
+ self._ray_method_max_task_retries = method_max_task_retries
1348
+ self._ray_method_retry_exceptions = method_retry_exceptions
1349
+ self._ray_method_generator_backpressure_num_objects = (
1350
+ method_generator_backpressure_num_objects
1351
+ )
1352
+ self._ray_method_enable_task_events = method_enable_task_events
1353
+ self._ray_actor_method_cpus = actor_method_cpus
1354
+ self._ray_cluster_and_job = cluster_and_job
1355
+ self._ray_is_cross_language = language != Language.PYTHON
1356
+ self._ray_actor_creation_function_descriptor = (
1357
+ actor_creation_function_descriptor
1358
+ )
1359
+ self._ray_function_descriptor = {}
1360
+ # This is incremented each time `bind()` is called on an actor handle
1361
+ # (in Ray DAGs), therefore capturing the bind order of the actor methods.
1362
+ # TODO: this does not work properly if the caller has two copies of the
1363
+ # same actor handle, and needs to be fixed.
1364
+ self._ray_dag_bind_index = 0
1365
+
1366
+ if not self._ray_is_cross_language:
1367
+ assert isinstance(
1368
+ actor_creation_function_descriptor, PythonFunctionDescriptor
1369
+ )
1370
+ module_name = actor_creation_function_descriptor.module_name
1371
+ class_name = actor_creation_function_descriptor.class_name
1372
+ for method_name in self._ray_method_signatures.keys():
1373
+ function_descriptor = PythonFunctionDescriptor(
1374
+ module_name, method_name, class_name
1375
+ )
1376
+ self._ray_function_descriptor[method_name] = function_descriptor
1377
+ method = ActorMethod(
1378
+ self,
1379
+ method_name,
1380
+ self._ray_method_num_returns[method_name],
1381
+ self._ray_method_max_task_retries.get(
1382
+ method_name, self._ray_max_task_retries
1383
+ )
1384
+ or 0, # never None
1385
+ self._ray_method_retry_exceptions.get(method_name),
1386
+ self._ray_method_is_generator[method_name],
1387
+ self._ray_method_generator_backpressure_num_objects.get(
1388
+ method_name
1389
+ ), # noqa
1390
+ self._ray_method_enable_task_events.get(
1391
+ method_name,
1392
+ self._ray_enable_task_events, # Use actor's default value
1393
+ ),
1394
+ decorator=self._ray_method_decorators.get(method_name),
1395
+ signature=self._ray_method_signatures[method_name],
1396
+ )
1397
+ setattr(self, method_name, method)
1398
+
1399
+ def __del__(self):
1400
+ # Weak references don't count towards the distributed ref count, so no
1401
+ # need to decrement the ref count.
1402
+ if self._ray_weak_ref:
1403
+ return
1404
+
1405
+ try:
1406
+ # Mark that this actor handle has gone out of scope. Once all actor
1407
+ # handles are out of scope, the actor will exit.
1408
+ if ray._private.worker:
1409
+ worker = ray._private.worker.global_worker
1410
+ if worker.connected and hasattr(worker, "core_worker"):
1411
+ worker.core_worker.remove_actor_handle_reference(self._ray_actor_id)
1412
+ except AttributeError:
1413
+ # Suppress the attribtue error which is caused by
1414
+ # python destruction ordering issue.
1415
+ # It only happen when python exits.
1416
+ pass
1417
+
1418
+ def _actor_method_call(
1419
+ self,
1420
+ method_name: str,
1421
+ args: List[Any] = None,
1422
+ kwargs: Dict[str, Any] = None,
1423
+ name: str = "",
1424
+ num_returns: Optional[Union[int, Literal["streaming"]]] = None,
1425
+ max_task_retries: int = None,
1426
+ retry_exceptions: Union[bool, list, tuple] = None,
1427
+ concurrency_group_name: Optional[str] = None,
1428
+ generator_backpressure_num_objects: Optional[int] = None,
1429
+ enable_task_events: Optional[bool] = None,
1430
+ ):
1431
+ """Method execution stub for an actor handle.
1432
+
1433
+ This is the function that executes when
1434
+ `actor.method_name.remote(*args, **kwargs)` is called. Instead of
1435
+ executing locally, the method is packaged as a task and scheduled
1436
+ to the remote actor instance.
1437
+
1438
+ Args:
1439
+ method_name: The name of the actor method to execute.
1440
+ args: A list of arguments for the actor method.
1441
+ kwargs: A dictionary of keyword arguments for the actor method.
1442
+ name: The name to give the actor method call task.
1443
+ num_returns: The number of return values for the method.
1444
+ max_task_retries: Number of retries when method fails.
1445
+ retry_exceptions: Boolean of whether you want to retry all user-raised
1446
+ exceptions, or a list of allowlist exceptions to retry.
1447
+ enable_task_events: True if tracing is enabled, i.e., task events from
1448
+ the actor should be reported.
1449
+
1450
+ Returns:
1451
+ object_refs: A list of object refs returned by the remote actor
1452
+ method.
1453
+ """
1454
+ worker = ray._private.worker.global_worker
1455
+
1456
+ args = args or []
1457
+ kwargs = kwargs or {}
1458
+ if self._ray_is_cross_language:
1459
+ list_args = cross_language._format_args(worker, args, kwargs)
1460
+ function_descriptor = cross_language._get_function_descriptor_for_actor_method( # noqa: E501
1461
+ self._ray_actor_language,
1462
+ self._ray_actor_creation_function_descriptor,
1463
+ method_name,
1464
+ # The signature for xlang should be "{length_of_arguments}" to handle
1465
+ # overloaded methods.
1466
+ signature=str(len(args) + len(kwargs)),
1467
+ )
1468
+ else:
1469
+ function_signature = self._ray_method_signatures[method_name]
1470
+
1471
+ if not args and not kwargs and not function_signature:
1472
+ list_args = []
1473
+ else:
1474
+ list_args = signature.flatten_args(function_signature, args, kwargs)
1475
+ function_descriptor = self._ray_function_descriptor[method_name]
1476
+
1477
+ if worker.mode == ray.LOCAL_MODE:
1478
+ assert (
1479
+ not self._ray_is_cross_language
1480
+ ), "Cross language remote actor method cannot be executed locally."
1481
+
1482
+ if num_returns == "dynamic":
1483
+ num_returns = -1
1484
+ elif num_returns == "streaming":
1485
+ # TODO(sang): This is a temporary private API.
1486
+ # Remove it when we migrate to the streaming generator.
1487
+ num_returns = ray._raylet.STREAMING_GENERATOR_RETURN
1488
+
1489
+ retry_exception_allowlist = None
1490
+ if retry_exceptions is None:
1491
+ retry_exceptions = False
1492
+ elif isinstance(retry_exceptions, (list, tuple)):
1493
+ retry_exception_allowlist = tuple(retry_exceptions)
1494
+ retry_exceptions = True
1495
+ assert isinstance(
1496
+ retry_exceptions, bool
1497
+ ), "retry_exceptions can either be \
1498
+ boolean or list/tuple of exception types."
1499
+
1500
+ if generator_backpressure_num_objects is None:
1501
+ generator_backpressure_num_objects = -1
1502
+
1503
+ object_refs = worker.core_worker.submit_actor_task(
1504
+ self._ray_actor_language,
1505
+ self._ray_actor_id,
1506
+ function_descriptor,
1507
+ list_args,
1508
+ name,
1509
+ num_returns,
1510
+ max_task_retries,
1511
+ retry_exceptions,
1512
+ retry_exception_allowlist,
1513
+ self._ray_actor_method_cpus,
1514
+ concurrency_group_name if concurrency_group_name is not None else b"",
1515
+ generator_backpressure_num_objects,
1516
+ enable_task_events,
1517
+ )
1518
+
1519
+ if num_returns == STREAMING_GENERATOR_RETURN:
1520
+ # Streaming generator will return a single ref
1521
+ # that is for the generator task.
1522
+ assert len(object_refs) == 1
1523
+ generator_ref = object_refs[0]
1524
+ return ObjectRefGenerator(generator_ref, worker)
1525
+ if len(object_refs) == 1:
1526
+ object_refs = object_refs[0]
1527
+ elif len(object_refs) == 0:
1528
+ object_refs = None
1529
+
1530
+ return object_refs
1531
+
1532
+ def __getattr__(self, item):
1533
+ if not self._ray_is_cross_language:
1534
+ raise AttributeError(
1535
+ f"'{type(self).__name__}' object has " f"no attribute '{item}'"
1536
+ )
1537
+ if item in ["__ray_terminate__"]:
1538
+
1539
+ class FakeActorMethod(object):
1540
+ def __call__(self, *args, **kwargs):
1541
+ raise TypeError(
1542
+ "Actor methods cannot be called directly. Instead "
1543
+ "of running 'object.{}()', try 'object.{}.remote()'.".format(
1544
+ item, item
1545
+ )
1546
+ )
1547
+
1548
+ def remote(self, *args, **kwargs):
1549
+ logger.warning(
1550
+ f"Actor method {item} is not supported by cross language."
1551
+ )
1552
+
1553
+ return FakeActorMethod()
1554
+
1555
+ return ActorMethod(
1556
+ self, # actor
1557
+ item, # method_name
1558
+ ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS,
1559
+ 0, # max_task_retries
1560
+ False, # retry_exceptions
1561
+ False, # is_generator
1562
+ self._ray_method_generator_backpressure_num_objects.get(item, -1),
1563
+ self._ray_enable_task_events, # enable_task_events
1564
+ # Currently, cross-lang actor method not support decorator
1565
+ decorator=None,
1566
+ signature=None,
1567
+ )
1568
+
1569
+ # Make tab completion work.
1570
+ def __dir__(self):
1571
+ return self._ray_method_signatures.keys()
1572
+
1573
+ def __repr__(self):
1574
+ return (
1575
+ "Actor("
1576
+ f"{self._ray_actor_creation_function_descriptor.class_name}, "
1577
+ f"{self._actor_id.hex()})"
1578
+ )
1579
+
1580
+ def __hash__(self):
1581
+ return hash(self._actor_id)
1582
+
1583
+ def __eq__(self, __value):
1584
+ return hash(self) == hash(__value)
1585
+
1586
+ @property
1587
+ def _actor_id(self):
1588
+ return self._ray_actor_id
1589
+
1590
+ def _get_local_state(self):
1591
+ """Get the local actor state.
1592
+
1593
+ NOTE: this method only returns accurate actor state
1594
+ after a first actor method call is made against
1595
+ this actor handle due to https://github.com/ray-project/ray/pull/24600.
1596
+
1597
+ Returns:
1598
+ ActorTableData.ActorState or None if the state is unknown.
1599
+ """
1600
+ worker = ray._private.worker.global_worker
1601
+ worker.check_connected()
1602
+ return worker.core_worker.get_local_actor_state(self._ray_actor_id)
1603
+
1604
+ def _serialization_helper(self):
1605
+ """This is defined in order to make pickling work.
1606
+
1607
+ Returns:
1608
+ A dictionary of the information needed to reconstruct the object.
1609
+ """
1610
+ worker = ray._private.worker.global_worker
1611
+ worker.check_connected()
1612
+
1613
+ if hasattr(worker, "core_worker"):
1614
+ # Non-local mode
1615
+ state = worker.core_worker.serialize_actor_handle(self._ray_actor_id)
1616
+ else:
1617
+ # Local mode
1618
+ state = (
1619
+ {
1620
+ "actor_language": self._ray_actor_language,
1621
+ "actor_id": self._ray_actor_id,
1622
+ "max_task_retries": self._ray_max_task_retries,
1623
+ "enable_task_events": self._enable_task_events,
1624
+ "method_is_generator": self._ray_method_is_generator,
1625
+ "method_decorators": self._ray_method_decorators,
1626
+ "method_signatures": self._ray_method_signatures,
1627
+ "method_num_returns": self._ray_method_num_returns,
1628
+ "method_max_task_retries": self._ray_method_max_task_retries,
1629
+ "method_retry_exceptions": self._ray_method_retry_exceptions,
1630
+ "method_generator_backpressure_num_objects": (
1631
+ self._ray_method_generator_backpressure_num_objects
1632
+ ),
1633
+ "method_enable_task_events": self._ray_method_enable_task_events,
1634
+ "actor_method_cpus": self._ray_actor_method_cpus,
1635
+ "actor_creation_function_descriptor": self._ray_actor_creation_function_descriptor, # noqa: E501
1636
+ },
1637
+ None,
1638
+ )
1639
+
1640
+ return (*state, self._ray_weak_ref)
1641
+
1642
+ @classmethod
1643
+ def _deserialization_helper(cls, state, weak_ref: bool, outer_object_ref=None):
1644
+ """This is defined in order to make pickling work.
1645
+
1646
+ Args:
1647
+ state: The serialized state of the actor handle.
1648
+ outer_object_ref: The ObjectRef that the serialized actor handle
1649
+ was contained in, if any. This is used for counting references
1650
+ to the actor handle.
1651
+ weak_ref: Whether this was serialized from an actor handle with a
1652
+ weak ref to the actor.
1653
+
1654
+ """
1655
+ worker = ray._private.worker.global_worker
1656
+ worker.check_connected()
1657
+
1658
+ if hasattr(worker, "core_worker"):
1659
+ # Non-local mode
1660
+ return worker.core_worker.deserialize_and_register_actor_handle(
1661
+ state,
1662
+ outer_object_ref,
1663
+ weak_ref,
1664
+ )
1665
+ else:
1666
+ # Local mode
1667
+ assert worker.current_cluster_and_job == state["current_cluster_and_job"]
1668
+ return cls(
1669
+ # TODO(swang): Accessing the worker's current task ID is not
1670
+ # thread-safe.
1671
+ state["actor_language"],
1672
+ state["actor_id"],
1673
+ state["max_task_retries"],
1674
+ state["enable_task_events"],
1675
+ state["method_is_generator"],
1676
+ state["method_decorators"],
1677
+ state["method_signatures"],
1678
+ state["method_num_returns"],
1679
+ state["method_max_task_retries"],
1680
+ state["method_retry_exceptions"],
1681
+ state["method_generator_backpressure_num_objects"],
1682
+ state["method_enable_task_events"],
1683
+ state["actor_method_cpus"],
1684
+ state["actor_creation_function_descriptor"],
1685
+ state["current_cluster_and_job"],
1686
+ )
1687
+
1688
+ def __reduce__(self):
1689
+ """This code path is used by pickling but not by Ray forking."""
1690
+ (serialized, _, weak_ref) = self._serialization_helper()
1691
+ # There is no outer object ref when the actor handle is
1692
+ # deserialized out-of-band using pickle.
1693
+ return ActorHandle._deserialization_helper, (serialized, weak_ref, None)
1694
+
1695
+
1696
+ def _modify_class(cls):
1697
+ # cls has been modified.
1698
+ if hasattr(cls, "__ray_actor_class__"):
1699
+ return cls
1700
+
1701
+ # Give an error if cls is an old-style class.
1702
+ if not issubclass(cls, object):
1703
+ raise TypeError(
1704
+ "The @ray.remote decorator cannot be applied to old-style "
1705
+ "classes. In Python 2, you must declare the class with "
1706
+ "'class ClassName(object):' instead of 'class ClassName:'."
1707
+ )
1708
+
1709
+ # Modify the class to have additional default methods.
1710
+ class Class(cls):
1711
+ __ray_actor_class__ = cls # The original actor class
1712
+
1713
+ def __ray_ready__(self):
1714
+ return True
1715
+
1716
+ def __ray_call__(self, fn, *args, **kwargs):
1717
+ return fn(self, *args, **kwargs)
1718
+
1719
+ def __ray_terminate__(self):
1720
+ worker = ray._private.worker.global_worker
1721
+ if worker.mode != ray.LOCAL_MODE:
1722
+ ray.actor.exit_actor()
1723
+
1724
+ Class.__module__ = cls.__module__
1725
+ Class.__name__ = cls.__name__
1726
+
1727
+ if not is_function_or_method(getattr(Class, "__init__", None)):
1728
+ # Add __init__ if it does not exist.
1729
+ # Actor creation will be executed with __init__ together.
1730
+
1731
+ # Assign an __init__ function will avoid many checks later on.
1732
+ def __init__(self):
1733
+ pass
1734
+
1735
+ Class.__init__ = __init__
1736
+
1737
+ return Class
1738
+
1739
+
1740
+ def _make_actor(cls, actor_options):
1741
+ Class = _modify_class(cls)
1742
+ _inject_tracing_into_class(Class)
1743
+
1744
+ if "max_restarts" in actor_options:
1745
+ if actor_options["max_restarts"] != -1: # -1 represents infinite restart
1746
+ # Make sure we don't pass too big of an int to C++, causing
1747
+ # an overflow.
1748
+ actor_options["max_restarts"] = min(
1749
+ actor_options["max_restarts"], ray_constants.MAX_INT64_VALUE
1750
+ )
1751
+
1752
+ return ActorClass._ray_from_modified_class(
1753
+ Class,
1754
+ ActorClassID.from_random(),
1755
+ actor_options,
1756
+ )
1757
+
1758
+
1759
+ @PublicAPI
1760
+ def exit_actor():
1761
+ """Intentionally exit the current actor.
1762
+
1763
+ This API can be used only inside an actor. Use ray.kill
1764
+ API if you'd like to kill an actor using actor handle.
1765
+
1766
+ When the API is called, the actor raises an exception and exits.
1767
+ Any queued methods will fail. Any ``atexit``
1768
+ handlers installed in the actor will be run.
1769
+
1770
+ Raises:
1771
+ TypeError: An exception is raised if this is a driver or this
1772
+ worker is not an actor.
1773
+ """
1774
+ worker = ray._private.worker.global_worker
1775
+ if worker.mode == ray.WORKER_MODE and not worker.actor_id.is_nil():
1776
+ # In asyncio actor mode, we can't raise SystemExit because it will just
1777
+ # quit the asycnio event loop thread, not the main thread. Instead, we
1778
+ # raise a custom error to the main thread to tell it to exit.
1779
+ if worker.core_worker.current_actor_is_asyncio():
1780
+ raise AsyncioActorExit()
1781
+
1782
+ # Set a flag to indicate this is an intentional actor exit. This
1783
+ # reduces log verbosity.
1784
+ raise_sys_exit_with_custom_error_message("exit_actor() is called.")
1785
+ else:
1786
+ raise TypeError(
1787
+ "exit_actor API is called on a non-actor worker, "
1788
+ f"{worker.mode}. Call this API inside an actor methods"
1789
+ "if you'd like to exit the actor gracefully."
1790
+ )
.venv/lib/python3.11/site-packages/ray/client_builder.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, Optional, Tuple
10
+
11
+ import ray.util.client_connect
12
+ from ray._private.ray_constants import (
13
+ RAY_ADDRESS_ENVIRONMENT_VARIABLE,
14
+ RAY_NAMESPACE_ENVIRONMENT_VARIABLE,
15
+ RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE,
16
+ )
17
+ from ray._private.utils import check_ray_client_dependencies_installed, split_address
18
+ from ray._private.worker import BaseContext
19
+ from ray._private.worker import init as ray_driver_init
20
+ from ray.job_config import JobConfig
21
+ from ray.util.annotations import Deprecated, PublicAPI
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ CLIENT_DOCS_URL = (
26
+ "https://docs.ray.io/en/latest/cluster/running-applications/"
27
+ "job-submission/ray-client.html"
28
+ )
29
+
30
+
31
+ @dataclass
32
+ @PublicAPI
33
+ class ClientContext(BaseContext):
34
+ """
35
+ Basic context manager for a ClientBuilder connection.
36
+
37
+ `protocol_version` is no longer used.
38
+ """
39
+
40
+ dashboard_url: Optional[str]
41
+ python_version: str
42
+ ray_version: str
43
+ ray_commit: str
44
+ _num_clients: int
45
+ _context_to_restore: Optional[ray.util.client.RayAPIStub]
46
+ protocol_version: Optional[str] = None # Deprecated
47
+
48
+ def __enter__(self) -> "ClientContext":
49
+ self._swap_context()
50
+ return self
51
+
52
+ def __exit__(self, *exc) -> None:
53
+ self._disconnect_with_context(False)
54
+ self._swap_context()
55
+
56
+ def disconnect(self) -> None:
57
+ self._swap_context()
58
+ self._disconnect_with_context(True)
59
+ self._swap_context()
60
+
61
+ def _swap_context(self):
62
+ if self._context_to_restore is not None:
63
+ self._context_to_restore = ray.util.client.ray.set_context(
64
+ self._context_to_restore
65
+ )
66
+
67
+ def _disconnect_with_context(self, force_disconnect: bool) -> None:
68
+ """
69
+ Disconnect Ray. If it's a ray client and created with `allow_multiple`,
70
+ it will do nothing. For other cases this either disconnects from the
71
+ remote Client Server or shuts the current driver down.
72
+ """
73
+ if ray.util.client.ray.is_connected():
74
+ if ray.util.client.ray.is_default() or force_disconnect:
75
+ # This is the only client connection
76
+ ray.util.client_connect.disconnect()
77
+ elif ray._private.worker.global_worker.node is None:
78
+ # Already disconnected.
79
+ return
80
+ elif ray._private.worker.global_worker.node.is_head():
81
+ logger.debug(
82
+ "The current Ray Cluster is scoped to this process. "
83
+ "Disconnecting is not possible as it will shutdown the "
84
+ "cluster."
85
+ )
86
+ else:
87
+ # This is only a driver connected to an existing cluster.
88
+ ray.shutdown()
89
+
90
+
91
+ @Deprecated
92
+ class ClientBuilder:
93
+ """
94
+ Builder for a Ray Client connection. This class can be subclassed by
95
+ custom builder classes to modify connection behavior to include additional
96
+ features or altered semantics. One example is the ``_LocalClientBuilder``.
97
+ """
98
+
99
+ def __init__(self, address: Optional[str]) -> None:
100
+ if not check_ray_client_dependencies_installed():
101
+ raise ValueError(
102
+ "Ray Client requires pip package `ray[client]`. "
103
+ "If you installed the minimal Ray (e.g. `pip install ray`), "
104
+ "please reinstall by executing `pip install ray[client]`."
105
+ )
106
+ self.address = address
107
+ self._job_config = JobConfig()
108
+ self._remote_init_kwargs = {}
109
+ # Whether to allow connections to multiple clusters"
110
+ # " (allow_multiple=True).
111
+ self._allow_multiple_connections = False
112
+ self._credentials = None
113
+ self._metadata = None
114
+ # Set to False if ClientBuilder is being constructed by internal
115
+ # methods
116
+ self._deprecation_warn_enabled = True
117
+
118
+ def env(self, env: Dict[str, Any]) -> "ClientBuilder":
119
+ """
120
+ Set an environment for the session.
121
+ Args:
122
+ env (Dict[st, Any]): A runtime environment to use for this
123
+ connection. See :ref:`runtime-environments` for what values are
124
+ accepted in this dict.
125
+ """
126
+ self._job_config.set_runtime_env(env)
127
+ return self
128
+
129
+ def namespace(self, namespace: str) -> "ClientBuilder":
130
+ """
131
+ Sets the namespace for the session.
132
+ Args:
133
+ namespace: Namespace to use.
134
+ """
135
+ self._job_config.set_ray_namespace(namespace)
136
+ return self
137
+
138
+ def connect(self) -> ClientContext:
139
+ """
140
+ Begin a connection to the address passed in via ray.client(...).
141
+
142
+ Returns:
143
+ ClientInfo: Dataclass with information about the setting. This
144
+ includes the server's version of Python & Ray as well as the
145
+ dashboard_url.
146
+ """
147
+ if self._deprecation_warn_enabled:
148
+ self._client_deprecation_warn()
149
+ # Fill runtime env/namespace from environment if not already set.
150
+ # Should be done *after* the deprecation warning, since warning will
151
+ # check if those values are already set.
152
+ self._fill_defaults_from_env()
153
+
154
+ # If it has already connected to the cluster with allow_multiple=True,
155
+ # connect to the default one is not allowed.
156
+ # But if it has connected to the default one, connect to other clients
157
+ # with allow_multiple=True is allowed
158
+ default_cli_connected = ray.util.client.ray.is_connected()
159
+ has_cli_connected = ray.util.client.num_connected_contexts() > 0
160
+ if (
161
+ not self._allow_multiple_connections
162
+ and not default_cli_connected
163
+ and has_cli_connected
164
+ ):
165
+ raise ValueError(
166
+ "The client has already connected to the cluster "
167
+ "with allow_multiple=True. Please set allow_multiple=True"
168
+ " to proceed"
169
+ )
170
+
171
+ old_ray_cxt = None
172
+ if self._allow_multiple_connections:
173
+ old_ray_cxt = ray.util.client.ray.set_context(None)
174
+
175
+ client_info_dict = ray.util.client_connect.connect(
176
+ self.address,
177
+ job_config=self._job_config,
178
+ _credentials=self._credentials,
179
+ ray_init_kwargs=self._remote_init_kwargs,
180
+ metadata=self._metadata,
181
+ )
182
+
183
+ dashboard_url = ray.util.client.ray._get_dashboard_url()
184
+
185
+ cxt = ClientContext(
186
+ dashboard_url=dashboard_url,
187
+ python_version=client_info_dict["python_version"],
188
+ ray_version=client_info_dict["ray_version"],
189
+ ray_commit=client_info_dict["ray_commit"],
190
+ _num_clients=client_info_dict["num_clients"],
191
+ _context_to_restore=ray.util.client.ray.get_context(),
192
+ )
193
+ if self._allow_multiple_connections:
194
+ ray.util.client.ray.set_context(old_ray_cxt)
195
+ return cxt
196
+
197
+ def _fill_defaults_from_env(self):
198
+ # Check environment variables for default values
199
+ namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE)
200
+ if namespace_env_var and self._job_config.ray_namespace is None:
201
+ self.namespace(namespace_env_var)
202
+
203
+ runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE)
204
+ if runtime_env_var and self._job_config.runtime_env is None:
205
+ self.env(json.loads(runtime_env_var))
206
+
207
+ def _init_args(self, **kwargs) -> "ClientBuilder":
208
+ """
209
+ When a client builder is constructed through ray.init, for example
210
+ `ray.init(ray://..., namespace=...)`, all of the
211
+ arguments passed into ray.init with non-default values are passed
212
+ again into this method. Custom client builders can override this method
213
+ to do their own handling/validation of arguments.
214
+ """
215
+ # Use namespace and runtime_env from ray.init call
216
+ if kwargs.get("namespace") is not None:
217
+ self.namespace(kwargs["namespace"])
218
+ del kwargs["namespace"]
219
+ if kwargs.get("runtime_env") is not None:
220
+ self.env(kwargs["runtime_env"])
221
+ del kwargs["runtime_env"]
222
+
223
+ if kwargs.get("allow_multiple") is True:
224
+ self._allow_multiple_connections = True
225
+ del kwargs["allow_multiple"]
226
+
227
+ if "_credentials" in kwargs.keys():
228
+ self._credentials = kwargs["_credentials"]
229
+ del kwargs["_credentials"]
230
+
231
+ if "_metadata" in kwargs.keys():
232
+ self._metadata = kwargs["_metadata"]
233
+ del kwargs["_metadata"]
234
+
235
+ if kwargs:
236
+ expected_sig = inspect.signature(ray_driver_init)
237
+ extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys())
238
+ if len(extra_args) > 0:
239
+ raise RuntimeError(
240
+ "Got unexpected kwargs: {}".format(", ".join(extra_args))
241
+ )
242
+ self._remote_init_kwargs = kwargs
243
+ unknown = ", ".join(kwargs)
244
+ logger.info(
245
+ "Passing the following kwargs to ray.init() "
246
+ f"on the server: {unknown}"
247
+ )
248
+ return self
249
+
250
+ def _client_deprecation_warn(self) -> None:
251
+ """
252
+ Generates a warning for user's if this ClientBuilder instance was
253
+ created directly or through ray.client, instead of relying on
254
+ internal methods (ray.init, or auto init)
255
+ """
256
+ namespace = self._job_config.ray_namespace
257
+ runtime_env = self._job_config.runtime_env
258
+ replacement_args = []
259
+ if self.address:
260
+ if isinstance(self, _LocalClientBuilder):
261
+ # Address might be set for LocalClientBuilder if ray.client()
262
+ # is called while ray_current_cluster is set
263
+ # (see _get_builder_from_address). In this case,
264
+ # leave off the ray:// so the user attaches the driver directly
265
+ replacement_args.append(f'"{self.address}"')
266
+ else:
267
+ replacement_args.append(f'"ray://{self.address}"')
268
+ if namespace:
269
+ replacement_args.append(f'namespace="{namespace}"')
270
+ if runtime_env:
271
+ # Use a placeholder here, since the real runtime_env would be
272
+ # difficult to read if formatted in directly
273
+ replacement_args.append("runtime_env=<your_runtime_env>")
274
+ args_str = ", ".join(replacement_args)
275
+ replacement_call = f"ray.init({args_str})"
276
+
277
+ # Note: stack level is set to 3 since we want the warning to reach the
278
+ # call to ray.client(...).connect(). The intervening frames are
279
+ # connect() -> client_deprecation_warn() -> warnings.warn()
280
+ # https://docs.python.org/3/library/warnings.html#available-functions
281
+ warnings.warn(
282
+ "Starting a connection through `ray.client` will be deprecated "
283
+ "in future ray versions in favor of `ray.init`. See the docs for "
284
+ f"more details: {CLIENT_DOCS_URL}. You can replace your call to "
285
+ "`ray.client().connect()` with the following:\n"
286
+ f" {replacement_call}\n",
287
+ DeprecationWarning,
288
+ stacklevel=3,
289
+ )
290
+
291
+
292
+ class _LocalClientBuilder(ClientBuilder):
293
+ def connect(self) -> ClientContext:
294
+ """
295
+ Begin a connection to the address passed in via ray.client(...)
296
+ """
297
+ if self._deprecation_warn_enabled:
298
+ self._client_deprecation_warn()
299
+ # Fill runtime env/namespace from environment if not already set.
300
+ # Should be done *after* the deprecation warning, since warning will
301
+ # check if those values are already set.
302
+ self._fill_defaults_from_env()
303
+
304
+ connection_dict = ray.init(address=self.address, job_config=self._job_config)
305
+ return ClientContext(
306
+ dashboard_url=connection_dict["webui_url"],
307
+ python_version="{}.{}.{}".format(
308
+ sys.version_info[0], sys.version_info[1], sys.version_info[2]
309
+ ),
310
+ ray_version=ray.__version__,
311
+ ray_commit=ray.__commit__,
312
+ _num_clients=1,
313
+ _context_to_restore=None,
314
+ )
315
+
316
+
317
+ def _split_address(address: str) -> Tuple[str, str]:
318
+ """
319
+ Splits address into a module string (scheme) and an inner_address.
320
+
321
+ If the scheme is not present, then "ray://" is prepended to the address.
322
+ """
323
+ if "://" not in address:
324
+ address = "ray://" + address
325
+ return split_address(address)
326
+
327
+
328
+ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
329
+ if address == "local":
330
+ return _LocalClientBuilder("local")
331
+ if address is None:
332
+ # NOTE: This is not placed in `Node::get_temp_dir_path`, because
333
+ # this file is accessed before the `Node` object is created.
334
+ address = ray._private.services.canonicalize_bootstrap_address(address)
335
+ return _LocalClientBuilder(address)
336
+ module_string, inner_address = _split_address(address)
337
+ try:
338
+ module = importlib.import_module(module_string)
339
+ except Exception as e:
340
+ raise RuntimeError(
341
+ f"Module: {module_string} does not exist.\n"
342
+ f"This module was parsed from Address: {address}"
343
+ ) from e
344
+ assert "ClientBuilder" in dir(
345
+ module
346
+ ), f"Module: {module_string} does not have ClientBuilder."
347
+ return module.ClientBuilder(inner_address)
348
+
349
+
350
+ @Deprecated
351
+ def client(
352
+ address: Optional[str] = None, _deprecation_warn_enabled: bool = True
353
+ ) -> ClientBuilder:
354
+ """
355
+ Creates a ClientBuilder based on the provided address. The address can be
356
+ of the following forms:
357
+
358
+ * None: Connects to or creates a local cluster and connects to it.
359
+ * ``"local"``: Creates a new cluster locally and connects to it.
360
+ * ``"IP:Port"``: Connects to a Ray Client Server at the given address.
361
+ * ``"module://inner_address"``: load module.ClientBuilder & pass
362
+ inner_address
363
+
364
+ The _deprecation_warn_enabled flag enables deprecation warnings, and is
365
+ for internal use only. Set it to False to suppress client deprecation
366
+ warnings.
367
+ """
368
+ env_address = os.environ.get(RAY_ADDRESS_ENVIRONMENT_VARIABLE)
369
+ if env_address and address is None:
370
+ logger.debug(
371
+ f"Using address ({env_address}) instead of auto-detection "
372
+ f"because {RAY_ADDRESS_ENVIRONMENT_VARIABLE} is set."
373
+ )
374
+ address = env_address
375
+
376
+ builder = _get_builder_from_address(address)
377
+ # Disable client deprecation warn when ray.client is used internally
378
+ builder._deprecation_warn_enabled = _deprecation_warn_enabled
379
+ return builder
.venv/lib/python3.11/site-packages/ray/cluster_utils.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+ import time
8
+ from typing import Dict, Optional
9
+
10
+ import yaml
11
+
12
+ import ray
13
+ import ray._private.services
14
+ from ray._private import ray_constants
15
+ from ray._private.client_mode_hook import disable_client_hook
16
+ from ray._raylet import GcsClientOptions
17
+ from ray.autoscaler._private.fake_multi_node.node_provider import FAKE_HEAD_NODE_ID
18
+ from ray.util.annotations import DeveloperAPI
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ cluster_not_supported = os.name == "nt"
23
+
24
+
25
+ @DeveloperAPI
26
+ class AutoscalingCluster:
27
+ """Create a local autoscaling cluster for testing.
28
+
29
+ See test_autoscaler_fake_multinode.py for an end-to-end example.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ head_resources: dict,
35
+ worker_node_types: dict,
36
+ autoscaler_v2: bool = False,
37
+ **config_kwargs,
38
+ ):
39
+ """Create the cluster.
40
+
41
+ Args:
42
+ head_resources: resources of the head node, including CPU.
43
+ worker_node_types: autoscaler node types config for worker nodes.
44
+ """
45
+ self._head_resources = head_resources
46
+ self._config = self._generate_config(
47
+ head_resources,
48
+ worker_node_types,
49
+ autoscaler_v2=autoscaler_v2,
50
+ **config_kwargs,
51
+ )
52
+ self._autoscaler_v2 = autoscaler_v2
53
+
54
+ def _generate_config(
55
+ self, head_resources, worker_node_types, autoscaler_v2=False, **config_kwargs
56
+ ):
57
+ base_config = yaml.safe_load(
58
+ open(
59
+ os.path.join(
60
+ os.path.dirname(ray.__file__),
61
+ "autoscaler/_private/fake_multi_node/example.yaml",
62
+ )
63
+ )
64
+ )
65
+ custom_config = copy.deepcopy(base_config)
66
+ custom_config["available_node_types"] = worker_node_types
67
+ custom_config["available_node_types"]["ray.head.default"] = {
68
+ "resources": head_resources,
69
+ "node_config": {},
70
+ "max_workers": 0,
71
+ }
72
+
73
+ # Autoscaler v2 specific configs
74
+ if autoscaler_v2:
75
+ custom_config["provider"]["launch_multiple"] = True
76
+ custom_config["provider"]["head_node_id"] = FAKE_HEAD_NODE_ID
77
+ custom_config.update(config_kwargs)
78
+ return custom_config
79
+
80
+ def start(self, _system_config=None, override_env: Optional[Dict] = None):
81
+ """Start the cluster.
82
+
83
+ After this call returns, you can connect to the cluster with
84
+ ray.init("auto").
85
+ """
86
+ subprocess.check_call(["ray", "stop", "--force"])
87
+ _, fake_config = tempfile.mkstemp()
88
+ with open(fake_config, "w") as f:
89
+ f.write(json.dumps(self._config))
90
+ cmd = [
91
+ "ray",
92
+ "start",
93
+ "--autoscaling-config={}".format(fake_config),
94
+ "--head",
95
+ ]
96
+ if "CPU" in self._head_resources:
97
+ cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU")))
98
+ if "GPU" in self._head_resources:
99
+ cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU")))
100
+ if "object_store_memory" in self._head_resources:
101
+ cmd.append(
102
+ "--object-store-memory={}".format(
103
+ self._head_resources.pop("object_store_memory")
104
+ )
105
+ )
106
+ if self._head_resources:
107
+ cmd.append("--resources='{}'".format(json.dumps(self._head_resources)))
108
+ if _system_config is not None:
109
+ cmd.append(
110
+ "--system-config={}".format(
111
+ json.dumps(_system_config, separators=(",", ":"))
112
+ )
113
+ )
114
+ env = os.environ.copy()
115
+ env.update({"AUTOSCALER_UPDATE_INTERVAL_S": "1", "RAY_FAKE_CLUSTER": "1"})
116
+ if self._autoscaler_v2:
117
+ # Set the necessary environment variables for autoscaler v2.
118
+ env.update(
119
+ {
120
+ "RAY_enable_autoscaler_v2": "1",
121
+ "RAY_CLOUD_INSTANCE_ID": FAKE_HEAD_NODE_ID,
122
+ "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID,
123
+ }
124
+ )
125
+ if override_env:
126
+ env.update(override_env)
127
+ subprocess.check_call(cmd, env=env)
128
+
129
+ def shutdown(self):
130
+ """Terminate the cluster."""
131
+ subprocess.check_call(["ray", "stop", "--force"])
132
+
133
+
134
+ @DeveloperAPI
135
+ class Cluster:
136
+ def __init__(
137
+ self,
138
+ initialize_head: bool = False,
139
+ connect: bool = False,
140
+ head_node_args: dict = None,
141
+ shutdown_at_exit: bool = True,
142
+ ):
143
+ """Initializes all services of a Ray cluster.
144
+
145
+ Args:
146
+ initialize_head: Automatically start a Ray cluster
147
+ by initializing the head node. Defaults to False.
148
+ connect: If `initialize_head=True` and `connect=True`,
149
+ ray.init will be called with the address of this cluster
150
+ passed in.
151
+ head_node_args: Arguments to be passed into
152
+ `start_ray_head` via `self.add_node`.
153
+ shutdown_at_exit: If True, registers an exit hook
154
+ for shutting down all started processes.
155
+ """
156
+ if cluster_not_supported:
157
+ logger.warning(
158
+ "Ray cluster mode is currently experimental and untested on "
159
+ "Windows. If you are using it and running into issues please "
160
+ "file a report at https://github.com/ray-project/ray/issues."
161
+ )
162
+ self.head_node = None
163
+ self.worker_nodes = set()
164
+ self.redis_address = None
165
+ self.connected = False
166
+ # Create a new global state accessor for fetching GCS table.
167
+ self.global_state = ray._private.state.GlobalState()
168
+ self._shutdown_at_exit = shutdown_at_exit
169
+ if not initialize_head and connect:
170
+ raise RuntimeError("Cannot connect to uninitialized cluster.")
171
+
172
+ if initialize_head:
173
+ head_node_args = head_node_args or {}
174
+ self.add_node(**head_node_args)
175
+ if connect:
176
+ self.connect()
177
+
178
+ @property
179
+ def gcs_address(self):
180
+ if self.head_node is None:
181
+ return None
182
+ return self.head_node.gcs_address
183
+
184
+ @property
185
+ def address(self):
186
+ return self.gcs_address
187
+
188
+ def connect(self, namespace=None):
189
+ """Connect the driver to the cluster."""
190
+ assert self.address is not None
191
+ assert not self.connected
192
+ output_info = ray.init(
193
+ namespace=namespace,
194
+ ignore_reinit_error=True,
195
+ address=self.address,
196
+ _redis_username=self.redis_username,
197
+ _redis_password=self.redis_password,
198
+ )
199
+ logger.info(output_info)
200
+ self.connected = True
201
+
202
+ def add_node(self, wait: bool = True, **node_args):
203
+ """Adds a node to the local Ray Cluster.
204
+
205
+ All nodes are by default started with the following settings:
206
+ cleanup=True,
207
+ num_cpus=1,
208
+ object_store_memory=150 * 1024 * 1024 # 150 MiB
209
+
210
+ Args:
211
+ wait: Whether to wait until the node is alive.
212
+ node_args: Keyword arguments used in `start_ray_head` and
213
+ `start_ray_node`. Overrides defaults.
214
+
215
+ Returns:
216
+ Node object of the added Ray node.
217
+ """
218
+ default_kwargs = {
219
+ "num_cpus": 1,
220
+ "num_gpus": 0,
221
+ "object_store_memory": 150 * 1024 * 1024, # 150 MiB
222
+ "min_worker_port": 0,
223
+ "max_worker_port": 0,
224
+ }
225
+ ray_params = ray._private.parameter.RayParams(**node_args)
226
+ ray_params.update_if_absent(**default_kwargs)
227
+ with disable_client_hook():
228
+ if self.head_node is None:
229
+ node = ray._private.node.Node(
230
+ ray_params,
231
+ head=True,
232
+ shutdown_at_exit=self._shutdown_at_exit,
233
+ spawn_reaper=self._shutdown_at_exit,
234
+ )
235
+ self.head_node = node
236
+ self.redis_address = self.head_node.redis_address
237
+ self.redis_username = node_args.get(
238
+ "redis_username", ray_constants.REDIS_DEFAULT_USERNAME
239
+ )
240
+ self.redis_password = node_args.get(
241
+ "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD
242
+ )
243
+ self.webui_url = self.head_node.webui_url
244
+ # Init global state accessor when creating head node.
245
+ gcs_options = GcsClientOptions.create(
246
+ node.gcs_address,
247
+ None,
248
+ allow_cluster_id_nil=True,
249
+ fetch_cluster_id_if_nil=False,
250
+ )
251
+ self.global_state._initialize_global_state(gcs_options)
252
+ # Write the Ray cluster address for convenience in unit
253
+ # testing. ray.init() and ray.init(address="auto") will connect
254
+ # to the local cluster.
255
+ ray._private.utils.write_ray_address(self.head_node.gcs_address)
256
+ else:
257
+ ray_params.update_if_absent(redis_address=self.redis_address)
258
+ ray_params.update_if_absent(gcs_address=self.gcs_address)
259
+ # We only need one log monitor per physical node.
260
+ ray_params.update_if_absent(include_log_monitor=False)
261
+ # Let grpc pick a port.
262
+ ray_params.update_if_absent(node_manager_port=0)
263
+ if "dashboard_agent_listen_port" not in node_args:
264
+ # Pick a random one to not conflict
265
+ # with the head node dashboard agent
266
+ ray_params.dashboard_agent_listen_port = None
267
+
268
+ node = ray._private.node.Node(
269
+ ray_params,
270
+ head=False,
271
+ shutdown_at_exit=self._shutdown_at_exit,
272
+ spawn_reaper=self._shutdown_at_exit,
273
+ )
274
+ self.worker_nodes.add(node)
275
+
276
+ if wait:
277
+ # Wait for the node to appear in the client table. We do this
278
+ # so that the nodes appears in the client table in the order
279
+ # that the corresponding calls to add_node were made. We do
280
+ # this because in the tests we assume that the driver is
281
+ # connected to the first node that is added.
282
+ self._wait_for_node(node)
283
+
284
+ return node
285
+
286
+ def remove_node(self, node, allow_graceful=True):
287
+ """Kills all processes associated with worker node.
288
+
289
+ Args:
290
+ node: Worker node of which all associated processes
291
+ will be removed.
292
+ """
293
+ global_node = ray._private.worker._global_node
294
+ if global_node is not None:
295
+ if node._raylet_socket_name == global_node._raylet_socket_name:
296
+ ray.shutdown()
297
+ raise ValueError(
298
+ "Removing a node that is connected to this Ray client "
299
+ "is not allowed because it will break the driver."
300
+ "You can use the get_other_node utility to avoid removing"
301
+ "a node that the Ray client is connected."
302
+ )
303
+
304
+ node.destroy_external_storage()
305
+ if self.head_node == node:
306
+ # We have to wait to prevent the raylet becomes a zombie which will prevent
307
+ # worker from exiting
308
+ self.head_node.kill_all_processes(
309
+ check_alive=False, allow_graceful=allow_graceful, wait=True
310
+ )
311
+ self.head_node = None
312
+ # TODO(rliaw): Do we need to kill all worker processes?
313
+ else:
314
+ # We have to wait to prevent the raylet becomes a zombie which will prevent
315
+ # worker from exiting
316
+ node.kill_all_processes(
317
+ check_alive=False, allow_graceful=allow_graceful, wait=True
318
+ )
319
+ self.worker_nodes.remove(node)
320
+
321
+ assert (
322
+ not node.any_processes_alive()
323
+ ), "There are zombie processes left over after killing."
324
+
325
+ def _wait_for_node(self, node, timeout: float = 30):
326
+ """Wait until this node has appeared in the client table.
327
+
328
+ Args:
329
+ node (ray._private.node.Node): The node to wait for.
330
+ timeout: The amount of time in seconds to wait before raising an
331
+ exception.
332
+
333
+ Raises:
334
+ TimeoutError: An exception is raised if the timeout expires before
335
+ the node appears in the client table.
336
+ """
337
+ ray._private.services.wait_for_node(
338
+ node.gcs_address,
339
+ node.plasma_store_socket_name,
340
+ timeout,
341
+ )
342
+
343
+ def wait_for_nodes(self, timeout: float = 30):
344
+ """Waits for correct number of nodes to be registered.
345
+
346
+ This will wait until the number of live nodes in the client table
347
+ exactly matches the number of "add_node" calls minus the number of
348
+ "remove_node" calls that have been made on this cluster. This means
349
+ that if a node dies without "remove_node" having been called, this will
350
+ raise an exception.
351
+
352
+ Args:
353
+ timeout: The number of seconds to wait for nodes to join
354
+ before failing.
355
+
356
+ Raises:
357
+ TimeoutError: An exception is raised if we time out while waiting
358
+ for nodes to join.
359
+ """
360
+ start_time = time.time()
361
+ while time.time() - start_time < timeout:
362
+ live_clients = self.global_state._live_node_ids()
363
+
364
+ expected = len(self.list_all_nodes())
365
+ if len(live_clients) == expected:
366
+ logger.debug("All nodes registered as expected.")
367
+ return
368
+ else:
369
+ logger.debug(
370
+ f"{len(live_clients)} nodes are currently registered, "
371
+ f"but we are expecting {expected}"
372
+ )
373
+ time.sleep(0.1)
374
+ raise TimeoutError("Timed out while waiting for nodes to join.")
375
+
376
+ def list_all_nodes(self):
377
+ """Lists all nodes.
378
+
379
+ TODO(rliaw): What is the desired behavior if a head node
380
+ dies before worker nodes die?
381
+
382
+ Returns:
383
+ List of all nodes, including the head node.
384
+ """
385
+ nodes = list(self.worker_nodes)
386
+ if self.head_node:
387
+ nodes = [self.head_node] + nodes
388
+ return nodes
389
+
390
+ def remaining_processes_alive(self):
391
+ """Returns a bool indicating whether all processes are alive or not.
392
+
393
+ Note that this ignores processes that have been explicitly killed,
394
+ e.g., via a command like node.kill_raylet().
395
+
396
+ Returns:
397
+ True if all processes are alive and false otherwise.
398
+ """
399
+ return all(node.remaining_processes_alive() for node in self.list_all_nodes())
400
+
401
+ def shutdown(self):
402
+ """Removes all nodes."""
403
+
404
+ # We create a list here as a copy because `remove_node`
405
+ # modifies `self.worker_nodes`.
406
+ all_nodes = list(self.worker_nodes)
407
+ for node in all_nodes:
408
+ self.remove_node(node)
409
+
410
+ if self.head_node is not None:
411
+ self.remove_node(self.head_node)
412
+ # need to reset internal kv since gcs is down
413
+ ray.experimental.internal_kv._internal_kv_reset()
414
+ # Delete the cluster address.
415
+ ray._private.utils.reset_ray_address()
.venv/lib/python3.11/site-packages/ray/cross_language.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ from ray import Language
4
+ from ray._raylet import CppFunctionDescriptor, JavaFunctionDescriptor
5
+ from ray.util.annotations import PublicAPI
6
+
7
+ __all__ = [
8
+ "java_function",
9
+ "java_actor_class",
10
+ "cpp_function",
11
+ ]
12
+
13
+
14
+ @PublicAPI(stability="beta")
15
+ def java_function(class_name: str, function_name: str):
16
+ """Define a Java function.
17
+
18
+ Args:
19
+ class_name: Java class name.
20
+ function_name: Java function name.
21
+ """
22
+ from ray.remote_function import RemoteFunction
23
+
24
+ return RemoteFunction(
25
+ Language.JAVA,
26
+ lambda *args, **kwargs: None,
27
+ JavaFunctionDescriptor(class_name, function_name, ""),
28
+ {},
29
+ )
30
+
31
+
32
+ @PublicAPI(stability="beta")
33
+ def cpp_function(function_name: str):
34
+ """Define a Cpp function.
35
+
36
+ Args:
37
+ function_name: Cpp function name.
38
+ """
39
+ from ray.remote_function import RemoteFunction
40
+
41
+ return RemoteFunction(
42
+ Language.CPP,
43
+ lambda *args, **kwargs: None,
44
+ CppFunctionDescriptor(function_name, "PYTHON"),
45
+ {},
46
+ )
47
+
48
+
49
+ @PublicAPI(stability="beta")
50
+ def java_actor_class(class_name: str):
51
+ """Define a Java actor class.
52
+
53
+ Args:
54
+ class_name: Java class name.
55
+ """
56
+ from ray.actor import ActorClass
57
+
58
+ return ActorClass._ray_from_function_descriptor(
59
+ Language.JAVA,
60
+ JavaFunctionDescriptor(class_name, "<init>", ""),
61
+ {},
62
+ )
63
+
64
+
65
+ @PublicAPI(stability="beta")
66
+ def cpp_actor_class(create_function_name: str, class_name: str):
67
+ """Define a Cpp actor class.
68
+
69
+ Args:
70
+ create_function_name: Create cpp class function name.
71
+ class_name: Cpp class name.
72
+ """
73
+ from ray.actor import ActorClass
74
+
75
+ print("create func=", create_function_name, "class_name=", class_name)
76
+ return ActorClass._ray_from_function_descriptor(
77
+ Language.CPP,
78
+ CppFunctionDescriptor(create_function_name, "PYTHON", class_name),
79
+ {},
80
+ )
81
+
82
+
83
+ def _format_args(worker, args, kwargs):
84
+ """Format args for various languages.
85
+
86
+ Args:
87
+ worker: The global worker instance.
88
+ args: The arguments for cross language.
89
+ kwargs: The keyword arguments for cross language.
90
+
91
+ Returns:
92
+ List of args and kwargs (if supported).
93
+ """
94
+ if not worker.load_code_from_local:
95
+ raise ValueError(
96
+ "Cross language feature needs --load-code-from-local to be set."
97
+ )
98
+ if kwargs:
99
+ raise TypeError(
100
+ f"Cross language remote functions does not support kwargs, "
101
+ f"kwargs:{str(kwargs)}."
102
+ )
103
+ return args
104
+
105
+
106
+ def _get_function_descriptor_for_actor_method(
107
+ language: str, actor_creation_function_descriptor, method_name: str, signature: str
108
+ ):
109
+ """Get function descriptor for cross language actor method call.
110
+
111
+ Args:
112
+ language: Target language.
113
+ actor_creation_function_descriptor:
114
+ The function signature for actor creation.
115
+ method_name: The name of actor method.
116
+ signature: The signature for the actor method. When calling Java from Python,
117
+ it should be string in the form of "{length_of_args}".
118
+
119
+ Returns:
120
+ Function descriptor for cross language actor method call.
121
+ """
122
+ if language == Language.JAVA:
123
+ return JavaFunctionDescriptor(
124
+ actor_creation_function_descriptor.class_name,
125
+ method_name,
126
+ signature,
127
+ )
128
+ elif language == Language.CPP:
129
+ return CppFunctionDescriptor(
130
+ method_name,
131
+ "PYTHON",
132
+ actor_creation_function_descriptor.class_name,
133
+ )
134
+ else:
135
+ raise NotImplementedError(
136
+ "Cross language remote actor method " f"not support language {language}"
137
+ )
.venv/lib/python3.11/site-packages/ray/exceptions.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from traceback import format_exception
5
+ from typing import Optional, Union
6
+
7
+ import colorama
8
+
9
+ import ray._private.ray_constants as ray_constants
10
+ import ray.cloudpickle as pickle
11
+ from ray._raylet import ActorID, TaskID, WorkerID
12
+ from ray.core.generated.common_pb2 import (
13
+ PYTHON,
14
+ ActorDiedErrorContext,
15
+ Address,
16
+ Language,
17
+ NodeDeathInfo,
18
+ RayException,
19
+ )
20
+ from ray.util.annotations import DeveloperAPI, PublicAPI
21
+
22
+ import setproctitle
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @PublicAPI
28
+ class RayError(Exception):
29
+ """Super class of all ray exception types."""
30
+
31
+ def to_bytes(self):
32
+ # Extract exc_info from exception object.
33
+ exc_info = (type(self), self, self.__traceback__)
34
+ formatted_exception_string = "\n".join(format_exception(*exc_info))
35
+ return RayException(
36
+ language=PYTHON,
37
+ serialized_exception=pickle.dumps(self),
38
+ formatted_exception_string=formatted_exception_string,
39
+ ).SerializeToString()
40
+
41
+ @staticmethod
42
+ def from_bytes(b):
43
+ ray_exception = RayException()
44
+ ray_exception.ParseFromString(b)
45
+ return RayError.from_ray_exception(ray_exception)
46
+
47
+ @staticmethod
48
+ def from_ray_exception(ray_exception):
49
+ if ray_exception.language == PYTHON:
50
+ try:
51
+ return pickle.loads(ray_exception.serialized_exception)
52
+ except Exception as e:
53
+ msg = "Failed to unpickle serialized exception"
54
+ raise RuntimeError(msg) from e
55
+ else:
56
+ return CrossLanguageError(ray_exception)
57
+
58
+
59
+ @PublicAPI
60
+ class CrossLanguageError(RayError):
61
+ """Raised from another language."""
62
+
63
+ def __init__(self, ray_exception):
64
+ super().__init__(
65
+ "An exception raised from {}:\n{}".format(
66
+ Language.Name(ray_exception.language),
67
+ ray_exception.formatted_exception_string,
68
+ )
69
+ )
70
+
71
+
72
+ @PublicAPI
73
+ class TaskCancelledError(RayError):
74
+ """Raised when this task is cancelled.
75
+
76
+ Args:
77
+ task_id: The TaskID of the function that was directly
78
+ cancelled.
79
+ """
80
+
81
+ def __init__(
82
+ self, task_id: Optional[TaskID] = None, error_message: Optional[str] = None
83
+ ):
84
+ self.task_id = task_id
85
+ self.error_message = error_message
86
+
87
+ def __str__(self):
88
+ msg = ""
89
+ if self.task_id:
90
+ msg = "Task: " + str(self.task_id) + " was cancelled. "
91
+ if self.error_message:
92
+ msg += self.error_message
93
+ return msg
94
+
95
+
96
+ @PublicAPI
97
+ class RayTaskError(RayError):
98
+ """Indicates that a task threw an exception during execution.
99
+
100
+ If a task throws an exception during execution, a RayTaskError is stored in
101
+ the object store for each of the task's outputs. When an object is
102
+ retrieved from the object store, the Python method that retrieved it checks
103
+ to see if the object is a RayTaskError and if it is then an exception is
104
+ thrown propagating the error message.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ function_name,
110
+ traceback_str,
111
+ cause,
112
+ proctitle=None,
113
+ pid=None,
114
+ ip=None,
115
+ actor_repr=None,
116
+ actor_id=None,
117
+ ):
118
+ """Initialize a RayTaskError."""
119
+ import ray
120
+
121
+ if proctitle:
122
+ self.proctitle = proctitle
123
+ else:
124
+ self.proctitle = setproctitle.getproctitle()
125
+ self.pid = pid or os.getpid()
126
+ self.ip = ip or ray.util.get_node_ip_address()
127
+ self.function_name = function_name
128
+ self.traceback_str = traceback_str
129
+ self.actor_repr = actor_repr
130
+ self._actor_id = actor_id
131
+ self.cause = cause
132
+
133
+ try:
134
+ pickle.dumps(cause)
135
+ except (pickle.PicklingError, TypeError) as e:
136
+ err_msg = (
137
+ "The original cause of the RayTaskError"
138
+ f" ({self.cause.__class__}) isn't serializable: {e}."
139
+ " Overwriting the cause to a RayError."
140
+ )
141
+ logger.warning(err_msg)
142
+ self.cause = RayError(err_msg)
143
+
144
+ # BaseException implements a __reduce__ method that returns
145
+ # a tuple with the type and the value of self.args.
146
+ # https://stackoverflow.com/a/49715949/2213289
147
+ self.args = (function_name, traceback_str, self.cause, proctitle, pid, ip)
148
+
149
+ assert traceback_str is not None
150
+
151
+ def make_dual_exception_instance(self) -> "RayTaskError":
152
+ """Makes a object instance that inherits from both RayTaskError and the type of
153
+ `self.cause`. Raises TypeError if the cause class can't be subclassed"""
154
+ # For normal user Exceptions, we subclass from both
155
+ # RayTaskError and the user exception. For ExceptionGroup,
156
+ # we special handle it because it has a different __new__()
157
+ # signature from Exception.
158
+ # Ref: https://docs.python.org/3/library/exceptions.html#exception-groups
159
+ if sys.version_info >= (3, 11) and isinstance(
160
+ self.cause, ExceptionGroup # noqa: F821
161
+ ):
162
+ return self._make_exceptiongroup_dual_exception_instance()
163
+ return self._make_normal_dual_exception_instance()
164
+
165
+ def _make_normal_dual_exception_instance(self) -> "RayTaskError":
166
+ cause_cls = self.cause.__class__
167
+ error_msg = str(self)
168
+
169
+ class cls(RayTaskError, cause_cls):
170
+ def __init__(self, cause):
171
+ self.cause = cause
172
+ # BaseException implements a __reduce__ method that returns
173
+ # a tuple with the type and the value of self.args.
174
+ # https://stackoverflow.com/a/49715949/2213289
175
+ self.args = (cause,)
176
+
177
+ def __getattr__(self, name):
178
+ return getattr(self.cause, name)
179
+
180
+ def __str__(self):
181
+ return error_msg
182
+
183
+ name = f"RayTaskError({cause_cls.__name__})"
184
+ cls.__name__ = name
185
+ cls.__qualname__ = name
186
+
187
+ return cls(self.cause)
188
+
189
+ def _make_exceptiongroup_dual_exception_instance(self) -> "RayTaskError":
190
+ cause_cls = self.cause.__class__
191
+ error_msg = str(self)
192
+
193
+ class cls(RayTaskError, cause_cls):
194
+ def __new__(cls, cause):
195
+ self = super().__new__(cls, cause.message, cause.exceptions)
196
+ return self
197
+
198
+ def __init__(self, cause):
199
+ self.cause = cause
200
+ # BaseException implements a __reduce__ method that returns
201
+ # a tuple with the type and the value of self.args.
202
+ # https://stackoverflow.com/a/49715949/2213289
203
+ self.args = (cause,)
204
+
205
+ def __getattr__(self, name):
206
+ return getattr(self.cause, name)
207
+
208
+ def __str__(self):
209
+ return error_msg
210
+
211
+ name = f"RayTaskError({cause_cls.__name__})"
212
+ cls.__name__ = name
213
+ cls.__qualname__ = name
214
+
215
+ return cls(self.cause)
216
+
217
+ def as_instanceof_cause(self):
218
+ """Returns an exception that's an instance of the cause's class.
219
+
220
+ The returned exception inherits from both RayTaskError and the
221
+ cause class and contains all of the attributes of the cause
222
+ exception.
223
+
224
+ If the cause class can't be subclassed, issues a warning and returns `self`.
225
+ """
226
+ cause_cls = self.cause.__class__
227
+ if issubclass(RayTaskError, cause_cls):
228
+ return self # already satisfied
229
+
230
+ try:
231
+ return self.make_dual_exception_instance()
232
+ except TypeError as e:
233
+ logger.warning(
234
+ f"User exception type {type(self.cause)} in RayTaskError can't"
235
+ " be subclassed! This exception is raised as"
236
+ " RayTaskError only. You can use `ray_task_error.cause` to"
237
+ f" access the user exception. Failure in subclassing: {e}"
238
+ )
239
+ return self
240
+
241
+ def __str__(self):
242
+ """Format a RayTaskError as a string."""
243
+ lines = self.traceback_str.strip().split("\n")
244
+ out = []
245
+ code_from_internal_file = False
246
+
247
+ # Format tracebacks.
248
+ # Python stacktrace consists of
249
+ # Traceback...: Indicate the next line will be a traceback.
250
+ # File [file_name + line number]
251
+ # code
252
+ # XError: [message]
253
+ # NOTE: For _raylet.pyx (Cython), the code is not always included.
254
+ for i, line in enumerate(lines):
255
+ # Convert traceback to the readable information.
256
+ if line.startswith("Traceback "):
257
+ traceback_line = (
258
+ f"{colorama.Fore.CYAN}"
259
+ f"{self.proctitle}()"
260
+ f"{colorama.Fore.RESET} "
261
+ f"(pid={self.pid}, ip={self.ip}"
262
+ )
263
+ if self.actor_repr:
264
+ traceback_line += (
265
+ f", actor_id={self._actor_id}, repr={self.actor_repr})"
266
+ )
267
+ else:
268
+ traceback_line += ")"
269
+ code_from_internal_file = False
270
+ out.append(traceback_line)
271
+ elif line.startswith(" File ") and (
272
+ "ray/worker.py" in line
273
+ or "ray/_private/" in line
274
+ or "ray/util/tracing/" in line
275
+ or "ray/_raylet.pyx" in line
276
+ ):
277
+ # TODO(windows)
278
+ # Process the internal file line.
279
+ # The file line always starts with 2 space and File.
280
+ # https://github.com/python/cpython/blob/0a0a135bae2692d069b18d2d590397fbe0a0d39a/Lib/traceback.py#L421 # noqa
281
+ if "ray._raylet.raise_if_dependency_failed" in line:
282
+ # It means the current task is failed
283
+ # due to the dependency failure.
284
+ # Print out an user-friendly
285
+ # message to explain that..
286
+ out.append(
287
+ " At least one of the input arguments for "
288
+ "this task could not be computed:"
289
+ )
290
+ if i + 1 < len(lines) and lines[i + 1].startswith(" "):
291
+ # If the next line is indented with 2 space,
292
+ # that means it contains internal code information.
293
+ # For example,
294
+ # File [file_name] [line]
295
+ # [code] # if the next line is indented, it is code.
296
+ # Note there there are 4 spaces in the code line.
297
+ code_from_internal_file = True
298
+ elif code_from_internal_file:
299
+ # If the current line is internal file's code,
300
+ # the next line is not code anymore.
301
+ code_from_internal_file = False
302
+ else:
303
+ out.append(line)
304
+ return "\n".join(out)
305
+
306
+
307
+ @PublicAPI
308
+ class LocalRayletDiedError(RayError):
309
+ """Indicates that the task's local raylet died."""
310
+
311
+ def __str__(self):
312
+ return "The task's local raylet died. Check raylet.out for more information."
313
+
314
+
315
+ @PublicAPI
316
+ class WorkerCrashedError(RayError):
317
+ """Indicates that the worker died unexpectedly while executing a task."""
318
+
319
+ def __str__(self):
320
+ return (
321
+ "The worker died unexpectedly while executing this task. "
322
+ "Check python-core-worker-*.log files for more information."
323
+ )
324
+
325
+
326
+ @PublicAPI
327
+ class RayActorError(RayError):
328
+ """Indicates that the actor has outages unexpectedly before finishing a task.
329
+
330
+ This exception could happen because the actor process is dead, or is unavailable for
331
+ the moment. Ray raises subclasses `ActorDiedError` and `ActorUnavailableError`
332
+ respectively.
333
+ """
334
+
335
+ BASE_ERROR_MSG = "The actor experienced an error before finishing this task."
336
+
337
+ def __init__(
338
+ self,
339
+ actor_id: str = None,
340
+ error_msg: str = BASE_ERROR_MSG,
341
+ actor_init_failed: bool = False,
342
+ preempted: bool = False,
343
+ ):
344
+ #: The actor ID in hex string.
345
+ self.actor_id = actor_id
346
+ #: Whether the actor failed in the middle of __init__.
347
+ self.error_msg = error_msg
348
+ #: The full error message.
349
+ self._actor_init_failed = actor_init_failed
350
+ #: Whether the actor died because the node was preempted.
351
+ self._preempted = preempted
352
+
353
+ def __str__(self) -> str:
354
+ return self.error_msg
355
+
356
+ @property
357
+ def preempted(self) -> bool:
358
+ return self._preempted
359
+
360
+ @property
361
+ def actor_init_failed(self) -> bool:
362
+ return self._actor_init_failed
363
+
364
+
365
+ @DeveloperAPI
366
+ class ActorDiedError(RayActorError):
367
+ """Indicates that the actor died unexpectedly before finishing a task.
368
+
369
+ This exception could happen either because the actor process dies while
370
+ executing a task, or because a task is submitted to a dead actor.
371
+
372
+ Args:
373
+ cause: The cause of the actor error. `RayTaskError` type means
374
+ the actor has died because of an exception within `__init__`.
375
+ `ActorDiedErrorContext` means the actor has died because of
376
+ an unexpected system error. None means the cause isn't known.
377
+ Theoretically, this shouldn't happen,
378
+ but it's there as a safety check.
379
+ """
380
+
381
+ BASE_ERROR_MSG = "The actor died unexpectedly before finishing this task."
382
+
383
+ def __init__(
384
+ self, cause: Optional[Union[RayTaskError, ActorDiedErrorContext]] = None
385
+ ):
386
+ """
387
+ Construct a RayActorError by building the arguments.
388
+ """
389
+
390
+ actor_id = None
391
+ error_msg = ActorDiedError.BASE_ERROR_MSG
392
+ actor_init_failed = False
393
+ preempted = False
394
+
395
+ if not cause:
396
+ # Use the defaults above.
397
+ pass
398
+ elif isinstance(cause, RayTaskError):
399
+ actor_init_failed = True
400
+ actor_id = cause._actor_id
401
+ error_msg = (
402
+ "The actor died because of an error"
403
+ " raised in its creation task, "
404
+ f"{cause.__str__()}"
405
+ )
406
+ else:
407
+ # Inidicating system-level actor failures.
408
+ assert isinstance(cause, ActorDiedErrorContext)
409
+ error_msg_lines = [ActorDiedError.BASE_ERROR_MSG]
410
+ error_msg_lines.append(f"\tclass_name: {cause.class_name}")
411
+ error_msg_lines.append(f"\tactor_id: {ActorID(cause.actor_id).hex()}")
412
+ # Below items are optional fields.
413
+ if cause.pid != 0:
414
+ error_msg_lines.append(f"\tpid: {cause.pid}")
415
+ if cause.name != "":
416
+ error_msg_lines.append(f"\tname: {cause.name}")
417
+ if cause.ray_namespace != "":
418
+ error_msg_lines.append(f"\tnamespace: {cause.ray_namespace}")
419
+ if cause.node_ip_address != "":
420
+ error_msg_lines.append(f"\tip: {cause.node_ip_address}")
421
+ error_msg_lines.append(cause.error_message)
422
+ if cause.never_started:
423
+ error_msg_lines.append(
424
+ "The actor never ran - it was cancelled before it started running."
425
+ )
426
+ if (
427
+ cause.node_death_info
428
+ and cause.node_death_info.reason
429
+ == NodeDeathInfo.AUTOSCALER_DRAIN_PREEMPTED
430
+ ):
431
+ preempted = True
432
+ error_msg = "\n".join(error_msg_lines)
433
+ actor_id = ActorID(cause.actor_id).hex()
434
+ super().__init__(actor_id, error_msg, actor_init_failed, preempted)
435
+
436
+ @staticmethod
437
+ def from_task_error(task_error: RayTaskError):
438
+ return ActorDiedError(task_error)
439
+
440
+
441
+ @DeveloperAPI
442
+ class ActorUnavailableError(RayActorError):
443
+ """Raised when the actor is temporarily unavailable but may be available later."""
444
+
445
+ def __init__(self, error_message: str, actor_id: Optional[bytes]):
446
+ actor_id = ActorID(actor_id).hex() if actor_id is not None else None
447
+ error_msg = (
448
+ f"The actor {actor_id} is unavailable: {error_message}. The task may or may"
449
+ "not have been executed on the actor."
450
+ )
451
+ actor_init_failed = False
452
+ preempted = False
453
+
454
+ super().__init__(actor_id, error_msg, actor_init_failed, preempted)
455
+
456
+
457
+ @PublicAPI
458
+ class RaySystemError(RayError):
459
+ """Indicates that Ray encountered a system error.
460
+
461
+ This exception can be thrown when the raylet is killed.
462
+ """
463
+
464
+ def __init__(self, client_exc, traceback_str=None):
465
+ self.client_exc = client_exc
466
+ self.traceback_str = traceback_str
467
+
468
+ def __str__(self):
469
+ error_msg = f"System error: {self.client_exc}"
470
+ if self.traceback_str:
471
+ error_msg += f"\ntraceback: {self.traceback_str}"
472
+ return error_msg
473
+
474
+
475
+ @DeveloperAPI
476
+ class UserCodeException(RayError):
477
+ """Indicates that an exception occurred while executing user code.
478
+ For example, this exception can be used to wrap user code exceptions
479
+ from a remote task or actor. The `retry_exceptions` parameter will
480
+ still respect the underlying cause of this exception."""
481
+
482
+ pass
483
+
484
+
485
+ @PublicAPI
486
+ class ObjectStoreFullError(RayError):
487
+ """Indicates that the object store is full.
488
+
489
+ This is raised if the attempt to store the object fails
490
+ because the object store is full even after multiple retries.
491
+ """
492
+
493
+ def __str__(self):
494
+ return super(ObjectStoreFullError, self).__str__() + (
495
+ "\n"
496
+ "The local object store is full of objects that are still in "
497
+ "scope and cannot be evicted. Tip: Use the `ray memory` command "
498
+ "to list active objects in the cluster."
499
+ )
500
+
501
+
502
+ @PublicAPI
503
+ class OutOfDiskError(RayError):
504
+ """Indicates that the local disk is full.
505
+
506
+ This is raised if the attempt to store the object fails
507
+ because both the object store and disk are full.
508
+ """
509
+
510
+ def __str__(self):
511
+ # TODO(scv119): expose more disk usage information and link to a doc.
512
+ return super(OutOfDiskError, self).__str__() + (
513
+ "\n"
514
+ "The object cannot be created because the local object store"
515
+ " is full and the local disk's utilization is over capacity"
516
+ " (95% by default)."
517
+ "Tip: Use `df` on this node to check disk usage and "
518
+ "`ray memory` to check object store memory usage."
519
+ )
520
+
521
+
522
+ @PublicAPI
523
+ class OutOfMemoryError(RayError):
524
+ """Indicates that the node is running out of memory and is close to full.
525
+
526
+ This is raised if the node is low on memory and tasks or actors are being
527
+ evicted to free up memory.
528
+ """
529
+
530
+ # TODO: (clarng) expose the error message string here and format it with proto
531
+ def __init__(self, message):
532
+ self.message = message
533
+
534
+ def __str__(self):
535
+ return self.message
536
+
537
+
538
+ @PublicAPI
539
+ class NodeDiedError(RayError):
540
+ """Indicates that the node is either dead or unreachable."""
541
+
542
+ # TODO: (clarng) expose the error message string here and format it with proto
543
+ def __init__(self, message):
544
+ self.message = message
545
+
546
+ def __str__(self):
547
+ return self.message
548
+
549
+
550
+ @PublicAPI
551
+ class ObjectLostError(RayError):
552
+ """Indicates that the object is lost from distributed memory, due to
553
+ node failure or system error.
554
+
555
+ Args:
556
+ object_ref_hex: Hex ID of the object.
557
+ """
558
+
559
+ def __init__(self, object_ref_hex, owner_address, call_site):
560
+ self.object_ref_hex = object_ref_hex
561
+ self.owner_address = owner_address
562
+ self.call_site = call_site.replace(
563
+ ray_constants.CALL_STACK_LINE_DELIMITER, "\n "
564
+ )
565
+
566
+ def _base_str(self):
567
+ msg = f"Failed to retrieve object {self.object_ref_hex}. "
568
+ if self.call_site:
569
+ msg += f"The ObjectRef was created at: {self.call_site}"
570
+ else:
571
+ msg += (
572
+ "To see information about where this ObjectRef was created "
573
+ "in Python, set the environment variable "
574
+ "RAY_record_ref_creation_sites=1 during `ray start` and "
575
+ "`ray.init()`."
576
+ )
577
+ return msg
578
+
579
+ def __str__(self):
580
+ return (
581
+ self._base_str()
582
+ + "\n\n"
583
+ + (
584
+ f"All copies of {self.object_ref_hex} have been lost due to node "
585
+ "failure. Check cluster logs (`/tmp/ray/session_latest/logs`) for "
586
+ "more information about the failure."
587
+ )
588
+ )
589
+
590
+
591
+ @PublicAPI
592
+ class ObjectFetchTimedOutError(ObjectLostError):
593
+ """Indicates that an object fetch timed out.
594
+
595
+ Args:
596
+ object_ref_hex: Hex ID of the object.
597
+ """
598
+
599
+ def __str__(self):
600
+ return (
601
+ self._base_str()
602
+ + "\n\n"
603
+ + (
604
+ f"Fetch for object {self.object_ref_hex} timed out because no "
605
+ "locations were found for the object. This may indicate a "
606
+ "system-level bug."
607
+ )
608
+ )
609
+
610
+
611
+ @DeveloperAPI
612
+ class RpcError(RayError):
613
+ """Indicates an error in the underlying RPC system."""
614
+
615
+ def __init__(self, message, rpc_code=None):
616
+ self.message = message
617
+ self.rpc_code = rpc_code
618
+
619
+ def __str__(self):
620
+ return self.message
621
+
622
+
623
+ @DeveloperAPI
624
+ class ReferenceCountingAssertionError(ObjectLostError, AssertionError):
625
+ """Indicates that an object has been deleted while there was still a
626
+ reference to it.
627
+
628
+ Args:
629
+ object_ref_hex: Hex ID of the object.
630
+ """
631
+
632
+ def __str__(self):
633
+ return (
634
+ self._base_str()
635
+ + "\n\n"
636
+ + (
637
+ "The object has already been deleted by the reference counting "
638
+ "protocol. This should not happen."
639
+ )
640
+ )
641
+
642
+
643
+ @DeveloperAPI
644
+ class ObjectFreedError(ObjectLostError):
645
+ """Indicates that an object was manually freed by the application.
646
+
647
+ Attributes:
648
+ object_ref_hex: Hex ID of the object.
649
+ """
650
+
651
+ def __str__(self):
652
+ return (
653
+ self._base_str()
654
+ + "\n\n"
655
+ + (
656
+ "The object was manually freed using the internal `free` call. "
657
+ "Please ensure that `free` is only called once the object is no "
658
+ "longer needed."
659
+ )
660
+ )
661
+
662
+
663
+ @PublicAPI
664
+ class OwnerDiedError(ObjectLostError):
665
+ """Indicates that the owner of the object has died while there is still a
666
+ reference to the object.
667
+
668
+ Args:
669
+ object_ref_hex: Hex ID of the object.
670
+ """
671
+
672
+ def __str__(self):
673
+ log_loc = "`/tmp/ray/session_latest/logs`"
674
+ if self.owner_address:
675
+ try:
676
+ addr = Address()
677
+ addr.ParseFromString(self.owner_address)
678
+ ip_addr = addr.ip_address
679
+ worker_id = WorkerID(addr.worker_id)
680
+ log_loc = (
681
+ f"`/tmp/ray/session_latest/logs/*{worker_id.hex()}*`"
682
+ f" at IP address {ip_addr}"
683
+ )
684
+ except Exception:
685
+ # Catch all to make sure we always at least print the default
686
+ # message.
687
+ pass
688
+
689
+ return (
690
+ self._base_str()
691
+ + "\n\n"
692
+ + (
693
+ "The object's owner has exited. This is the Python "
694
+ "worker that first created the ObjectRef via `.remote()` or "
695
+ "`ray.put()`. "
696
+ f"Check cluster logs ({log_loc}) for more "
697
+ "information about the Python worker failure."
698
+ )
699
+ )
700
+
701
+
702
+ @PublicAPI
703
+ class ObjectReconstructionFailedError(ObjectLostError):
704
+ """Indicates that the object cannot be reconstructed.
705
+
706
+ Args:
707
+ object_ref_hex: Hex ID of the object.
708
+ """
709
+
710
+ def __str__(self):
711
+ return (
712
+ self._base_str()
713
+ + "\n\n"
714
+ + (
715
+ "The object cannot be reconstructed "
716
+ "because it was created by an actor, ray.put() call, or its "
717
+ "ObjectRef was created by a different worker."
718
+ )
719
+ )
720
+
721
+
722
+ @PublicAPI
723
+ class ObjectReconstructionFailedMaxAttemptsExceededError(ObjectLostError):
724
+ """Indicates that the object cannot be reconstructed because the maximum
725
+ number of task retries has been exceeded.
726
+
727
+ Args:
728
+ object_ref_hex: Hex ID of the object.
729
+ """
730
+
731
+ def __str__(self):
732
+ return (
733
+ self._base_str()
734
+ + "\n\n"
735
+ + (
736
+ "The object cannot be reconstructed "
737
+ "because the maximum number of task retries has been exceeded. "
738
+ "To prevent this error, set "
739
+ "`@ray.remote(max_retries=<num retries>)` (default 3)."
740
+ )
741
+ )
742
+
743
+
744
+ @PublicAPI
745
+ class ObjectReconstructionFailedLineageEvictedError(ObjectLostError):
746
+ """Indicates that the object cannot be reconstructed because its lineage
747
+ was evicted due to memory pressure.
748
+
749
+ Args:
750
+ object_ref_hex: Hex ID of the object.
751
+ """
752
+
753
+ def __str__(self):
754
+ return (
755
+ self._base_str()
756
+ + "\n\n"
757
+ + (
758
+ "The object cannot be reconstructed because its lineage has been "
759
+ "evicted to reduce memory pressure. "
760
+ "To prevent this error, set the environment variable "
761
+ "RAY_max_lineage_bytes=<bytes> (default 1GB) during `ray start`."
762
+ )
763
+ )
764
+
765
+
766
+ @PublicAPI
767
+ class GetTimeoutError(RayError, TimeoutError):
768
+ """Indicates that a call to the worker timed out."""
769
+
770
+ pass
771
+
772
+
773
+ @PublicAPI
774
+ class PlasmaObjectNotAvailable(RayError):
775
+ """Called when an object was not available within the given timeout."""
776
+
777
+ pass
778
+
779
+
780
+ @PublicAPI
781
+ class AsyncioActorExit(RayError):
782
+ """Raised when an asyncio actor intentionally exits via exit_actor()."""
783
+
784
+ pass
785
+
786
+
787
+ @PublicAPI
788
+ class RuntimeEnvSetupError(RayError):
789
+ """Raised when a runtime environment fails to be set up.
790
+
791
+ Args:
792
+ error_message: The error message that explains
793
+ why runtime env setup has failed.
794
+ """
795
+
796
+ def __init__(self, error_message: str = None):
797
+ self.error_message = error_message
798
+
799
+ def __str__(self):
800
+ msgs = ["Failed to set up runtime environment."]
801
+ if self.error_message:
802
+ msgs.append(self.error_message)
803
+ return "\n".join(msgs)
804
+
805
+
806
+ @PublicAPI
807
+ class TaskPlacementGroupRemoved(RayError):
808
+ """Raised when the corresponding placement group was removed."""
809
+
810
+ def __str__(self):
811
+ return "The placement group corresponding to this task has been removed."
812
+
813
+
814
+ @PublicAPI
815
+ class ActorPlacementGroupRemoved(RayError):
816
+ """Raised when the corresponding placement group was removed."""
817
+
818
+ def __str__(self):
819
+ return "The placement group corresponding to this Actor has been removed."
820
+
821
+
822
+ @PublicAPI
823
+ class PendingCallsLimitExceeded(RayError):
824
+ """Raised when the pending actor calls exceeds `max_pending_calls` option.
825
+
826
+ This exception could happen probably because the caller calls the callee
827
+ too frequently.
828
+ """
829
+
830
+ pass
831
+
832
+
833
+ @PublicAPI
834
+ class TaskUnschedulableError(RayError):
835
+ """Raised when the task cannot be scheduled.
836
+
837
+ One example is that the node specified through
838
+ NodeAffinitySchedulingStrategy is dead.
839
+ """
840
+
841
+ def __init__(self, error_message: str):
842
+ self.error_message = error_message
843
+
844
+ def __str__(self):
845
+ return f"The task is not schedulable: {self.error_message}"
846
+
847
+
848
+ @PublicAPI
849
+ class ActorUnschedulableError(RayError):
850
+ """Raised when the actor cannot be scheduled.
851
+
852
+ One example is that the node specified through
853
+ NodeAffinitySchedulingStrategy is dead.
854
+ """
855
+
856
+ def __init__(self, error_message: str):
857
+ self.error_message = error_message
858
+
859
+ def __str__(self):
860
+ return f"The actor is not schedulable: {self.error_message}"
861
+
862
+
863
+ @DeveloperAPI
864
+ class ObjectRefStreamEndOfStreamError(RayError):
865
+ """Raised by streaming generator tasks when there are no more ObjectRefs to
866
+ read.
867
+ """
868
+
869
+ pass
870
+
871
+
872
+ @DeveloperAPI
873
+ class OufOfBandObjectRefSerializationException(RayError):
874
+ """Raised when an `ray.ObjectRef` is out of band serialized by
875
+ `ray.cloudpickle`. It is an anti pattern.
876
+ """
877
+
878
+ pass
879
+
880
+
881
+ @PublicAPI(stability="alpha")
882
+ class RayChannelError(RaySystemError):
883
+ """Indicates that Ray encountered a system error related
884
+ to ray.experimental.channel.
885
+ """
886
+
887
+ pass
888
+
889
+
890
+ @PublicAPI(stability="alpha")
891
+ class RayChannelTimeoutError(RayChannelError, TimeoutError):
892
+ """Raised when the Compiled Graph channel operation times out."""
893
+
894
+ pass
895
+
896
+
897
+ @PublicAPI(stability="alpha")
898
+ class RayCgraphCapacityExceeded(RaySystemError):
899
+ """Raised when the Compiled Graph channel's buffer is at max capacity"""
900
+
901
+ pass
902
+
903
+
904
+ RAY_EXCEPTION_TYPES = [
905
+ PlasmaObjectNotAvailable,
906
+ RayError,
907
+ RayTaskError,
908
+ WorkerCrashedError,
909
+ RayActorError,
910
+ ObjectStoreFullError,
911
+ ObjectLostError,
912
+ ObjectFetchTimedOutError,
913
+ ReferenceCountingAssertionError,
914
+ ObjectReconstructionFailedError,
915
+ ObjectReconstructionFailedMaxAttemptsExceededError,
916
+ ObjectReconstructionFailedLineageEvictedError,
917
+ OwnerDiedError,
918
+ GetTimeoutError,
919
+ AsyncioActorExit,
920
+ RuntimeEnvSetupError,
921
+ TaskPlacementGroupRemoved,
922
+ ActorPlacementGroupRemoved,
923
+ PendingCallsLimitExceeded,
924
+ LocalRayletDiedError,
925
+ TaskUnschedulableError,
926
+ ActorDiedError,
927
+ ActorUnschedulableError,
928
+ ActorUnavailableError,
929
+ RayChannelError,
930
+ RayChannelTimeoutError,
931
+ OufOfBandObjectRefSerializationException,
932
+ RayCgraphCapacityExceeded,
933
+ ]
.venv/lib/python3.11/site-packages/ray/job_config.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3
+
4
+ import ray.cloudpickle as pickle
5
+ from ray._private.ray_logging.logging_config import LoggingConfig
6
+ from ray.util.annotations import PublicAPI
7
+
8
+ if TYPE_CHECKING:
9
+ from ray.runtime_env import RuntimeEnv
10
+
11
+
12
+ @PublicAPI
13
+ class JobConfig:
14
+ """A class used to store the configurations of a job.
15
+
16
+ Examples:
17
+ .. testcode::
18
+ :hide:
19
+
20
+ import ray
21
+ ray.shutdown()
22
+
23
+ .. testcode::
24
+
25
+ import ray
26
+ from ray.job_config import JobConfig
27
+
28
+ ray.init(job_config=JobConfig(default_actor_lifetime="non_detached"))
29
+
30
+ Args:
31
+ jvm_options: The jvm options for java workers of the job.
32
+ code_search_path: A list of directories or jar files that
33
+ specify the search path for user code. This will be used as
34
+ `CLASSPATH` in Java and `PYTHONPATH` in Python.
35
+ See :ref:`Ray cross-language programming <cross_language>` for more details.
36
+ runtime_env: A :ref:`runtime environment <runtime-environments>` dictionary.
37
+ metadata: An opaque metadata dictionary.
38
+ ray_namespace: A :ref:`namespace <namespaces-guide>`
39
+ is a logical grouping of jobs and named actors.
40
+ default_actor_lifetime: The default value of actor lifetime,
41
+ can be "detached" or "non_detached".
42
+ See :ref:`actor lifetimes <actor-lifetimes>` for more details.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ jvm_options: Optional[List[str]] = None,
48
+ code_search_path: Optional[List[str]] = None,
49
+ runtime_env: Optional[dict] = None,
50
+ _client_job: bool = False,
51
+ metadata: Optional[dict] = None,
52
+ ray_namespace: Optional[str] = None,
53
+ default_actor_lifetime: str = "non_detached",
54
+ _py_driver_sys_path: Optional[List[str]] = None,
55
+ ):
56
+ #: The jvm options for java workers of the job.
57
+ self.jvm_options = jvm_options or []
58
+ #: A list of directories or jar files that
59
+ #: specify the search path for user code.
60
+ self.code_search_path = code_search_path or []
61
+ # It's difficult to find the error that caused by the
62
+ # code_search_path is a string. So we assert here.
63
+ assert isinstance(self.code_search_path, (list, tuple)), (
64
+ f"The type of code search path is incorrect: " f"{type(code_search_path)}"
65
+ )
66
+ self._client_job = _client_job
67
+ #: An opaque metadata dictionary.
68
+ self.metadata = metadata or {}
69
+ #: A namespace is a logical grouping of jobs and named actors.
70
+ self.ray_namespace = ray_namespace
71
+ self.set_runtime_env(runtime_env)
72
+ self.set_default_actor_lifetime(default_actor_lifetime)
73
+ # A list of directories that specify the search path for python workers.
74
+ self._py_driver_sys_path = _py_driver_sys_path or []
75
+ # Python logging configurations that will be passed to Ray tasks/actors.
76
+ self.py_logging_config = None
77
+
78
+ def set_metadata(self, key: str, value: str) -> None:
79
+ """Add key-value pair to the metadata dictionary.
80
+
81
+ If the key already exists, the value is overwritten to the new value.
82
+
83
+ Examples:
84
+ .. testcode::
85
+
86
+ import ray
87
+ from ray.job_config import JobConfig
88
+
89
+ job_config = JobConfig()
90
+ job_config.set_metadata("submitter", "foo")
91
+
92
+ Args:
93
+ key: The key of the metadata.
94
+ value: The value of the metadata.
95
+ """
96
+ self.metadata[key] = value
97
+
98
+ def _serialize(self) -> str:
99
+ """Serialize the struct into protobuf string"""
100
+ return self._get_proto_job_config().SerializeToString()
101
+
102
+ def set_runtime_env(
103
+ self,
104
+ runtime_env: Optional[Union[Dict[str, Any], "RuntimeEnv"]],
105
+ validate: bool = False,
106
+ ) -> None:
107
+ """Modify the runtime_env of the JobConfig.
108
+
109
+ We don't validate the runtime_env by default here because it may go
110
+ through some translation before actually being passed to C++ (e.g.,
111
+ working_dir translated from a local directory to a URI).
112
+
113
+ Args:
114
+ runtime_env: A :ref:`runtime environment <runtime-environments>` dictionary.
115
+ validate: Whether to validate the runtime env.
116
+ """
117
+ self.runtime_env = runtime_env if runtime_env is not None else {}
118
+ if validate:
119
+ self.runtime_env = self._validate_runtime_env()
120
+ self._cached_pb = None
121
+
122
+ def set_py_logging_config(
123
+ self,
124
+ logging_config: Optional[LoggingConfig] = None,
125
+ ):
126
+ """Set the logging configuration for the job.
127
+
128
+ The logging configuration will be applied to the root loggers of
129
+ all Ray task and actor processes that belong to this job.
130
+
131
+ Args:
132
+ logging_config: The logging configuration to set.
133
+ """
134
+ self.py_logging_config = logging_config
135
+
136
+ def set_ray_namespace(self, ray_namespace: str) -> None:
137
+ """Set Ray :ref:`namespace <namespaces-guide>`.
138
+
139
+ Args:
140
+ ray_namespace: The namespace to set.
141
+ """
142
+
143
+ if ray_namespace != self.ray_namespace:
144
+ self.ray_namespace = ray_namespace
145
+ self._cached_pb = None
146
+
147
+ def set_default_actor_lifetime(self, default_actor_lifetime: str) -> None:
148
+ """Set the default actor lifetime, which can be "detached" or "non_detached".
149
+
150
+ See :ref:`actor lifetimes <actor-lifetimes>` for more details.
151
+
152
+ Args:
153
+ default_actor_lifetime: The default actor lifetime to set.
154
+ """
155
+ import ray.core.generated.common_pb2 as common_pb2
156
+
157
+ if default_actor_lifetime == "detached":
158
+ self._default_actor_lifetime = common_pb2.JobConfig.ActorLifetime.DETACHED
159
+ elif default_actor_lifetime == "non_detached":
160
+ self._default_actor_lifetime = (
161
+ common_pb2.JobConfig.ActorLifetime.NON_DETACHED
162
+ )
163
+ else:
164
+ raise ValueError(
165
+ "Default actor lifetime must be one of `detached`, `non_detached`"
166
+ )
167
+
168
+ def _validate_runtime_env(self):
169
+ # TODO(edoakes): this is really unfortunate, but JobConfig is imported
170
+ # all over the place so this causes circular imports. We should remove
171
+ # this dependency and pass in a validated runtime_env instead.
172
+ from ray.runtime_env import RuntimeEnv
173
+
174
+ if isinstance(self.runtime_env, RuntimeEnv):
175
+ return self.runtime_env
176
+ return RuntimeEnv(**self.runtime_env)
177
+
178
+ def _get_proto_job_config(self):
179
+ """Return the protobuf structure of JobConfig."""
180
+ # TODO(edoakes): this is really unfortunate, but JobConfig is imported
181
+ # all over the place so this causes circular imports. We should remove
182
+ # this dependency and pass in a validated runtime_env instead.
183
+ import ray.core.generated.common_pb2 as common_pb2
184
+ from ray._private.utils import get_runtime_env_info
185
+
186
+ if self._cached_pb is None:
187
+ pb = common_pb2.JobConfig()
188
+ if self.ray_namespace is None:
189
+ pb.ray_namespace = str(uuid.uuid4())
190
+ else:
191
+ pb.ray_namespace = self.ray_namespace
192
+ pb.jvm_options.extend(self.jvm_options)
193
+ pb.code_search_path.extend(self.code_search_path)
194
+ pb.py_driver_sys_path.extend(self._py_driver_sys_path)
195
+ for k, v in self.metadata.items():
196
+ pb.metadata[k] = v
197
+
198
+ parsed_env = self._validate_runtime_env()
199
+ pb.runtime_env_info.CopyFrom(
200
+ get_runtime_env_info(
201
+ parsed_env,
202
+ is_job_runtime_env=True,
203
+ serialize=False,
204
+ )
205
+ )
206
+
207
+ if self._default_actor_lifetime is not None:
208
+ pb.default_actor_lifetime = self._default_actor_lifetime
209
+ if self.py_logging_config:
210
+ pb.serialized_py_logging_config = pickle.dumps(self.py_logging_config)
211
+ self._cached_pb = pb
212
+
213
+ return self._cached_pb
214
+
215
+ def _runtime_env_has_working_dir(self):
216
+ return self._validate_runtime_env().has_working_dir()
217
+
218
+ def _get_serialized_runtime_env(self) -> str:
219
+ """Return the JSON-serialized parsed runtime env dict"""
220
+ return self._validate_runtime_env().serialize()
221
+
222
+ def _get_proto_runtime_env_config(self) -> str:
223
+ """Return the JSON-serialized parsed runtime env info"""
224
+ return self._get_proto_job_config().runtime_env_info.runtime_env_config
225
+
226
+ @classmethod
227
+ def from_json(cls, job_config_json):
228
+ """Generates a JobConfig object from json.
229
+
230
+ Examples:
231
+ .. testcode::
232
+
233
+ from ray.job_config import JobConfig
234
+
235
+ job_config = JobConfig.from_json(
236
+ {"runtime_env": {"working_dir": "uri://abc"}})
237
+
238
+ Args:
239
+ job_config_json: The job config json dictionary.
240
+ """
241
+ return cls(
242
+ jvm_options=job_config_json.get("jvm_options", None),
243
+ code_search_path=job_config_json.get("code_search_path", None),
244
+ runtime_env=job_config_json.get("runtime_env", None),
245
+ metadata=job_config_json.get("metadata", None),
246
+ ray_namespace=job_config_json.get("ray_namespace", None),
247
+ _client_job=job_config_json.get("client_job", False),
248
+ _py_driver_sys_path=job_config_json.get("py_driver_sys_path", None),
249
+ )
.venv/lib/python3.11/site-packages/ray/nightly-wheels.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ linux:
2
+ "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
3
+ "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl
4
+
5
+ darwin:
6
+ "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-macosx_10_15_x86_64.whl
7
+ "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-macosx_10_15_intel.whl
8
+
9
+ win32:
10
+ "3.8": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-win_amd64.whl
11
+ "3.7": https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-win_amd64.whl
.venv/lib/python3.11/site-packages/ray/py.typed ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/remote_function.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from functools import wraps
6
+ from threading import Lock
7
+ from typing import Optional
8
+
9
+ import ray._private.signature
10
+ from ray import Language, cross_language
11
+ from ray._private import ray_option_utils
12
+ from ray._private.auto_init_hook import wrap_auto_init
13
+ from ray._private.client_mode_hook import (
14
+ client_mode_convert_function,
15
+ client_mode_should_convert,
16
+ )
17
+ from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group
18
+ from ray._private.serialization import pickle_dumps
19
+ from ray._private.utils import get_runtime_env_info, parse_runtime_env
20
+ from ray._raylet import (
21
+ STREAMING_GENERATOR_RETURN,
22
+ ObjectRefGenerator,
23
+ PythonFunctionDescriptor,
24
+ )
25
+ from ray.util.annotations import DeveloperAPI, PublicAPI
26
+ from ray.util.placement_group import _configure_placement_group_based_on_context
27
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
28
+ from ray.util.tracing.tracing_helper import (
29
+ _inject_tracing_into_function,
30
+ _tracing_task_invocation,
31
+ )
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # Hook to call with (fn, resources, strategy) on each local task submission.
37
+ _task_launch_hook = None
38
+
39
+
40
+ @PublicAPI
41
+ class RemoteFunction:
42
+ """A remote function.
43
+
44
+ This is a decorated function. It can be used to spawn tasks.
45
+
46
+ Attributes:
47
+ _language: The target language.
48
+ _function: The original function.
49
+ _function_descriptor: The function descriptor. This is not defined
50
+ until the remote function is first invoked because that is when the
51
+ function is pickled, and the pickled function is used to compute
52
+ the function descriptor.
53
+ _function_name: The module and function name.
54
+ _num_cpus: The default number of CPUs to use for invocations of this
55
+ remote function.
56
+ _num_gpus: The default number of GPUs to use for invocations of this
57
+ remote function.
58
+ _memory: The heap memory request in bytes for this task/actor,
59
+ rounded down to the nearest integer.
60
+ _resources: The default custom resource requirements for invocations of
61
+ this remote function.
62
+ _num_returns: The default number of return values for invocations
63
+ of this remote function.
64
+ _max_calls: The number of times a worker can execute this function
65
+ before exiting.
66
+ _max_retries: The number of times this task may be retried
67
+ on worker failure.
68
+ _retry_exceptions: Whether application-level errors should be retried.
69
+ This can be a boolean or a list/tuple of exceptions that should be retried.
70
+ _runtime_env: The runtime environment for this task.
71
+ _decorator: An optional decorator that should be applied to the remote
72
+ function invocation (as opposed to the function execution) before
73
+ invoking the function. The decorator must return a function that
74
+ takes in two arguments ("args" and "kwargs"). In most cases, it
75
+ should call the function that was passed into the decorator and
76
+ return the resulting ObjectRefs. For an example, see
77
+ "test_decorated_function" in "python/ray/tests/test_basic.py".
78
+ _function_signature: The function signature.
79
+ _last_export_cluster_and_job: A pair of the last exported cluster
80
+ and job to help us to know whether this function was exported.
81
+ This is an imperfect mechanism used to determine if we need to
82
+ export the remote function again. It is imperfect in the sense that
83
+ the actor class definition could be exported multiple times by
84
+ different workers.
85
+ _scheduling_strategy: Strategy about how to schedule
86
+ this remote function.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ language,
92
+ function,
93
+ function_descriptor,
94
+ task_options,
95
+ ):
96
+ if inspect.iscoroutinefunction(function):
97
+ raise ValueError(
98
+ "'async def' should not be used for remote tasks. You can wrap the "
99
+ "async function with `asyncio.run(f())`. See more at:"
100
+ "https://docs.ray.io/en/latest/ray-core/actors/async_api.html "
101
+ )
102
+ self._default_options = task_options
103
+
104
+ # When gpu is used, set the task non-recyclable by default.
105
+ # https://github.com/ray-project/ray/issues/29624 for more context.
106
+ # Note: Ray task worker process is not being reused when nsight
107
+ # profiler is running, as nsight generate report once the process exit.
108
+ num_gpus = self._default_options.get("num_gpus") or 0
109
+ if (
110
+ num_gpus > 0 and self._default_options.get("max_calls", None) is None
111
+ ) or "nsight" in (self._default_options.get("runtime_env") or {}):
112
+ self._default_options["max_calls"] = 1
113
+
114
+ # TODO(suquark): This is a workaround for class attributes of options.
115
+ # They are being used in some other places, mostly tests. Need cleanup later.
116
+ # E.g., actors uses "__ray_metadata__" to collect options, we can so something
117
+ # similar for remote functions.
118
+ for k, v in ray_option_utils.task_options.items():
119
+ setattr(self, "_" + k, task_options.get(k, v.default_value))
120
+ self._runtime_env = parse_runtime_env(self._runtime_env)
121
+ if "runtime_env" in self._default_options:
122
+ self._default_options["runtime_env"] = self._runtime_env
123
+
124
+ # Pre-calculate runtime env info, to avoid re-calculation at `remote`
125
+ # invocation. When `remote` call has specified extra `option` field,
126
+ # runtime env will be overwritten and re-serialized.
127
+ #
128
+ # Caveat: To support dynamic runtime envs in
129
+ # `func.option(runtime_env={...}).remote()`, we recalculate the serialized
130
+ # runtime env info in the `option` call. But it's acceptable since
131
+ # pre-calculation here only happens once at `RemoteFunction` initialization.
132
+ self._serialized_base_runtime_env_info = ""
133
+ if self._runtime_env:
134
+ self._serialized_base_runtime_env_info = get_runtime_env_info(
135
+ self._runtime_env,
136
+ is_job_runtime_env=False,
137
+ serialize=True,
138
+ )
139
+
140
+ self._language = language
141
+ self._is_generator = inspect.isgeneratorfunction(function)
142
+ self._function = function
143
+ self._function_signature = None
144
+ # Guards trace injection to enforce exactly once semantics
145
+ self._inject_lock = Lock()
146
+ self._function_name = function.__module__ + "." + function.__name__
147
+ self._function_descriptor = function_descriptor
148
+ self._is_cross_language = language != Language.PYTHON
149
+ self._decorator = getattr(function, "__ray_invocation_decorator__", None)
150
+ self._last_export_cluster_and_job = None
151
+ self._uuid = uuid.uuid4()
152
+
153
+ # Override task.remote's signature and docstring
154
+ @wraps(function)
155
+ def _remote_proxy(*args, **kwargs):
156
+ return self._remote(
157
+ serialized_runtime_env_info=self._serialized_base_runtime_env_info,
158
+ args=args,
159
+ kwargs=kwargs,
160
+ **self._default_options,
161
+ )
162
+
163
+ self.remote = _remote_proxy
164
+
165
+ def __call__(self, *args, **kwargs):
166
+ raise TypeError(
167
+ "Remote functions cannot be called directly. Instead "
168
+ f"of running '{self._function_name}()', "
169
+ f"try '{self._function_name}.remote()'."
170
+ )
171
+
172
+ # Lock is not picklable
173
+ def __getstate__(self):
174
+ attrs = self.__dict__.copy()
175
+ del attrs["_inject_lock"]
176
+ return attrs
177
+
178
+ def __setstate__(self, state):
179
+ self.__dict__.update(state)
180
+ self.__dict__["_inject_lock"] = Lock()
181
+
182
+ def options(self, **task_options):
183
+ """Configures and overrides the task invocation parameters.
184
+
185
+ The arguments are the same as those that can be passed to :obj:`ray.remote`.
186
+ Overriding `max_calls` is not supported.
187
+
188
+ Args:
189
+ num_returns: It specifies the number of object refs returned by
190
+ the remote function invocation.
191
+ num_cpus: The quantity of CPU cores to reserve
192
+ for this task or for the lifetime of the actor.
193
+ num_gpus: The quantity of GPUs to reserve
194
+ for this task or for the lifetime of the actor.
195
+ resources (Dict[str, float]): The quantity of various custom resources
196
+ to reserve for this task or for the lifetime of the actor.
197
+ This is a dictionary mapping strings (resource names) to floats.
198
+ accelerator_type: If specified, requires that the task or actor run
199
+ on a node with the specified type of accelerator.
200
+ See :ref:`accelerator types <accelerator_types>`.
201
+ memory: The heap memory request in bytes for this task/actor,
202
+ rounded down to the nearest integer.
203
+ object_store_memory: The object store memory request for actors only.
204
+ max_calls: This specifies the
205
+ maximum number of times that a given worker can execute
206
+ the given remote function before it must exit
207
+ (this can be used to address memory leaks in third-party
208
+ libraries or to reclaim resources that cannot easily be
209
+ released, e.g., GPU memory that was acquired by TensorFlow).
210
+ By default this is infinite for CPU tasks and 1 for GPU tasks
211
+ (to force GPU tasks to release resources after finishing).
212
+ max_retries: This specifies the maximum number of times that the remote
213
+ function should be rerun when the worker process executing it
214
+ crashes unexpectedly. The minimum valid value is 0,
215
+ the default is 3 (default), and a value of -1 indicates
216
+ infinite retries.
217
+ runtime_env (Dict[str, Any]): Specifies the runtime environment for
218
+ this actor or task and its children. See
219
+ :ref:`runtime-environments` for detailed documentation.
220
+ retry_exceptions: This specifies whether application-level errors
221
+ should be retried up to max_retries times.
222
+ scheduling_strategy: Strategy about how to
223
+ schedule a remote function or actor. Possible values are
224
+ None: ray will figure out the scheduling strategy to use, it
225
+ will either be the PlacementGroupSchedulingStrategy using parent's
226
+ placement group if parent has one and has
227
+ placement_group_capture_child_tasks set to true,
228
+ or "DEFAULT";
229
+ "DEFAULT": default hybrid scheduling;
230
+ "SPREAD": best effort spread scheduling;
231
+ `PlacementGroupSchedulingStrategy`:
232
+ placement group based scheduling;
233
+ `NodeAffinitySchedulingStrategy`:
234
+ node id based affinity scheduling.
235
+ enable_task_events: This specifies whether to enable task events for this
236
+ task. If set to True, task events such as (task running, finished)
237
+ are emitted, and available to Ray Dashboard and State API.
238
+ See :ref:`state-api-overview-ref` for more details.
239
+ _metadata: Extended options for Ray libraries. For example,
240
+ _metadata={"workflows.io/options": <workflow options>} for
241
+ Ray workflows.
242
+ _labels: The key-value labels of a task.
243
+
244
+ Examples:
245
+
246
+ .. code-block:: python
247
+
248
+ @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
249
+ def f():
250
+ return 1, 2
251
+ # Task g will require 2 gpus instead of 1.
252
+ g = f.options(num_gpus=2)
253
+ """
254
+
255
+ func_cls = self
256
+
257
+ # override original options
258
+ default_options = self._default_options.copy()
259
+ # max_calls could not be used in ".options()", we should remove it before
260
+ # merging options from '@ray.remote'.
261
+ default_options.pop("max_calls", None)
262
+ updated_options = ray_option_utils.update_options(default_options, task_options)
263
+ ray_option_utils.validate_task_options(updated_options, in_options=True)
264
+
265
+ # Only update runtime_env and re-calculate serialized runtime env info when
266
+ # ".options()" specifies new runtime_env.
267
+ serialized_runtime_env_info = self._serialized_base_runtime_env_info
268
+ if "runtime_env" in task_options:
269
+ updated_options["runtime_env"] = parse_runtime_env(
270
+ updated_options["runtime_env"]
271
+ )
272
+ # Re-calculate runtime env info based on updated runtime env.
273
+ if updated_options["runtime_env"]:
274
+ serialized_runtime_env_info = get_runtime_env_info(
275
+ updated_options["runtime_env"],
276
+ is_job_runtime_env=False,
277
+ serialize=True,
278
+ )
279
+
280
+ class FuncWrapper:
281
+ def remote(self, *args, **kwargs):
282
+ return func_cls._remote(
283
+ args=args,
284
+ kwargs=kwargs,
285
+ serialized_runtime_env_info=serialized_runtime_env_info,
286
+ **updated_options,
287
+ )
288
+
289
+ @DeveloperAPI
290
+ def bind(self, *args, **kwargs):
291
+ """
292
+ For Ray DAG building that creates static graph from decorated
293
+ class or functions.
294
+ """
295
+ from ray.dag.function_node import FunctionNode
296
+
297
+ return FunctionNode(func_cls._function, args, kwargs, updated_options)
298
+
299
+ return FuncWrapper()
300
+
301
+ @wrap_auto_init
302
+ @_tracing_task_invocation
303
+ def _remote(
304
+ self,
305
+ args=None,
306
+ kwargs=None,
307
+ serialized_runtime_env_info: Optional[str] = None,
308
+ **task_options,
309
+ ):
310
+ """Submit the remote function for execution."""
311
+ # We pop the "max_calls" coming from "@ray.remote" here. We no longer need
312
+ # it in "_remote()".
313
+ task_options.pop("max_calls", None)
314
+ if client_mode_should_convert():
315
+ return client_mode_convert_function(self, args, kwargs, **task_options)
316
+
317
+ worker = ray._private.worker.global_worker
318
+ worker.check_connected()
319
+
320
+ # We cannot do this when the function is first defined, because we need
321
+ # ray.init() to have been called when this executes
322
+ with self._inject_lock:
323
+ if self._function_signature is None:
324
+ self._function = _inject_tracing_into_function(self._function)
325
+ self._function_signature = ray._private.signature.extract_signature(
326
+ self._function
327
+ )
328
+
329
+ # If this function was not exported in this cluster and job, we need to
330
+ # export this function again, because the current GCS doesn't have it.
331
+ if (
332
+ not self._is_cross_language
333
+ and self._last_export_cluster_and_job != worker.current_cluster_and_job
334
+ ):
335
+ self._function_descriptor = PythonFunctionDescriptor.from_function(
336
+ self._function, self._uuid
337
+ )
338
+ # There is an interesting question here. If the remote function is
339
+ # used by a subsequent driver (in the same script), should the
340
+ # second driver pickle the function again? If yes, then the remote
341
+ # function definition can differ in the second driver (e.g., if
342
+ # variables in its closure have changed). We probably want the
343
+ # behavior of the remote function in the second driver to be
344
+ # independent of whether or not the function was invoked by the
345
+ # first driver. This is an argument for repickling the function,
346
+ # which we do here.
347
+ self._pickled_function = pickle_dumps(
348
+ self._function,
349
+ f"Could not serialize the function {self._function_descriptor.repr}",
350
+ )
351
+
352
+ self._last_export_cluster_and_job = worker.current_cluster_and_job
353
+ worker.function_actor_manager.export(self)
354
+
355
+ kwargs = {} if kwargs is None else kwargs
356
+ args = [] if args is None else args
357
+
358
+ # fill task required options
359
+ for k, v in ray_option_utils.task_options.items():
360
+ if k == "max_retries":
361
+ # TODO(swang): We need to override max_retries here because the default
362
+ # value gets set at Ray import time. Ideally, we should allow setting
363
+ # default values from env vars for other options too.
364
+ v.default_value = os.environ.get(
365
+ "RAY_TASK_MAX_RETRIES", v.default_value
366
+ )
367
+ v.default_value = int(v.default_value)
368
+ task_options[k] = task_options.get(k, v.default_value)
369
+ # "max_calls" already takes effects and should not apply again.
370
+ # Remove the default value here.
371
+ task_options.pop("max_calls", None)
372
+
373
+ # TODO(suquark): cleanup these fields
374
+ name = task_options["name"]
375
+ placement_group = task_options["placement_group"]
376
+ placement_group_bundle_index = task_options["placement_group_bundle_index"]
377
+ placement_group_capture_child_tasks = task_options[
378
+ "placement_group_capture_child_tasks"
379
+ ]
380
+ scheduling_strategy = task_options["scheduling_strategy"]
381
+
382
+ num_returns = task_options["num_returns"]
383
+ if num_returns is None:
384
+ if self._is_generator:
385
+ num_returns = "streaming"
386
+ else:
387
+ num_returns = 1
388
+
389
+ if num_returns == "dynamic":
390
+ num_returns = -1
391
+ elif num_returns == "streaming":
392
+ # TODO(sang): This is a temporary private API.
393
+ # Remove it when we migrate to the streaming generator.
394
+ num_returns = ray._raylet.STREAMING_GENERATOR_RETURN
395
+ generator_backpressure_num_objects = task_options[
396
+ "_generator_backpressure_num_objects"
397
+ ]
398
+ if generator_backpressure_num_objects is None:
399
+ generator_backpressure_num_objects = -1
400
+
401
+ max_retries = task_options["max_retries"]
402
+ retry_exceptions = task_options["retry_exceptions"]
403
+ if isinstance(retry_exceptions, (list, tuple)):
404
+ retry_exception_allowlist = tuple(retry_exceptions)
405
+ retry_exceptions = True
406
+ else:
407
+ retry_exception_allowlist = None
408
+
409
+ if scheduling_strategy is None or not isinstance(
410
+ scheduling_strategy, PlacementGroupSchedulingStrategy
411
+ ):
412
+ _warn_if_using_deprecated_placement_group(task_options, 4)
413
+
414
+ resources = ray._private.utils.resources_from_ray_options(task_options)
415
+
416
+ if scheduling_strategy is None or isinstance(
417
+ scheduling_strategy, PlacementGroupSchedulingStrategy
418
+ ):
419
+ if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy):
420
+ placement_group = scheduling_strategy.placement_group
421
+ placement_group_bundle_index = (
422
+ scheduling_strategy.placement_group_bundle_index
423
+ )
424
+ placement_group_capture_child_tasks = (
425
+ scheduling_strategy.placement_group_capture_child_tasks
426
+ )
427
+
428
+ if placement_group_capture_child_tasks is None:
429
+ placement_group_capture_child_tasks = (
430
+ worker.should_capture_child_tasks_in_placement_group
431
+ )
432
+ placement_group = _configure_placement_group_based_on_context(
433
+ placement_group_capture_child_tasks,
434
+ placement_group_bundle_index,
435
+ resources,
436
+ {}, # no placement_resources for tasks
437
+ self._function_descriptor.function_name,
438
+ placement_group=placement_group,
439
+ )
440
+ if not placement_group.is_empty:
441
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
442
+ placement_group,
443
+ placement_group_bundle_index,
444
+ placement_group_capture_child_tasks,
445
+ )
446
+ else:
447
+ scheduling_strategy = "DEFAULT"
448
+
449
+ if _task_launch_hook:
450
+ _task_launch_hook(self._function_descriptor, resources, scheduling_strategy)
451
+
452
+ # Override enable_task_events to default for actor if not specified (i.e. None)
453
+ enable_task_events = task_options.get("enable_task_events")
454
+ labels = task_options.get("_labels")
455
+
456
+ def invocation(args, kwargs):
457
+ if self._is_cross_language:
458
+ list_args = cross_language._format_args(worker, args, kwargs)
459
+ elif not args and not kwargs and not self._function_signature:
460
+ list_args = []
461
+ else:
462
+ list_args = ray._private.signature.flatten_args(
463
+ self._function_signature, args, kwargs
464
+ )
465
+
466
+ if worker.mode == ray._private.worker.LOCAL_MODE:
467
+ assert (
468
+ not self._is_cross_language
469
+ ), "Cross language remote function cannot be executed locally."
470
+ object_refs = worker.core_worker.submit_task(
471
+ self._language,
472
+ self._function_descriptor,
473
+ list_args,
474
+ name if name is not None else "",
475
+ num_returns,
476
+ resources,
477
+ max_retries,
478
+ retry_exceptions,
479
+ retry_exception_allowlist,
480
+ scheduling_strategy,
481
+ worker.debugger_breakpoint,
482
+ serialized_runtime_env_info or "{}",
483
+ generator_backpressure_num_objects,
484
+ enable_task_events,
485
+ labels,
486
+ )
487
+ # Reset worker's debug context from the last "remote" command
488
+ # (which applies only to this .remote call).
489
+ worker.debugger_breakpoint = b""
490
+ if num_returns == STREAMING_GENERATOR_RETURN:
491
+ # Streaming generator will return a single ref
492
+ # that is for the generator task.
493
+ assert len(object_refs) == 1
494
+ generator_ref = object_refs[0]
495
+ return ObjectRefGenerator(generator_ref, worker)
496
+ if len(object_refs) == 1:
497
+ return object_refs[0]
498
+ elif len(object_refs) > 1:
499
+ return object_refs
500
+
501
+ if self._decorator is not None:
502
+ invocation = self._decorator(invocation)
503
+
504
+ return invocation(args, kwargs)
505
+
506
+ @DeveloperAPI
507
+ def bind(self, *args, **kwargs):
508
+ """
509
+ For Ray DAG building that creates static graph from decorated
510
+ class or functions.
511
+ """
512
+
513
+ from ray.dag.function_node import FunctionNode
514
+
515
+ return FunctionNode(self._function, args, kwargs, self._default_options)
.venv/lib/python3.11/site-packages/ray/runtime_context.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import ray._private.worker
5
+ from ray._private.client_mode_hook import client_mode_hook
6
+ from ray._private.utils import parse_pg_formatted_resources_to_original
7
+ from ray._raylet import TaskID
8
+ from ray.runtime_env import RuntimeEnv
9
+ from ray.util.annotations import Deprecated, PublicAPI
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @PublicAPI
15
+ class RuntimeContext(object):
16
+ """A class used for getting runtime context."""
17
+
18
+ def __init__(self, worker):
19
+ assert worker is not None
20
+ self.worker = worker
21
+
22
+ @Deprecated(
23
+ message="Use get_xxx_id() methods to get relevant ids instead", warning=True
24
+ )
25
+ def get(self) -> Dict[str, Any]:
26
+ """Get a dictionary of the current context.
27
+
28
+ Returns:
29
+ dict: Dictionary of the current context.
30
+ """
31
+ context = {
32
+ "job_id": self.job_id,
33
+ "node_id": self.node_id,
34
+ "namespace": self.namespace,
35
+ }
36
+ if self.worker.mode == ray._private.worker.WORKER_MODE:
37
+ if self.task_id is not None:
38
+ context["task_id"] = self.task_id
39
+ if self.actor_id is not None:
40
+ context["actor_id"] = self.actor_id
41
+
42
+ return context
43
+
44
+ @property
45
+ @Deprecated(message="Use get_job_id() instead", warning=True)
46
+ def job_id(self):
47
+ """Get current job ID for this worker or driver.
48
+
49
+ Job ID is the id of your Ray drivers that create tasks or actors.
50
+
51
+ Returns:
52
+ If called by a driver, this returns the job ID. If called in
53
+ a task, return the job ID of the associated driver.
54
+
55
+ """
56
+ job_id = self.worker.current_job_id
57
+ assert not job_id.is_nil()
58
+ return job_id
59
+
60
+ def get_job_id(self) -> str:
61
+ """Get current job ID for this worker or driver.
62
+
63
+ Job ID is the id of your Ray drivers that create tasks or actors.
64
+
65
+ Returns:
66
+ If called by a driver, this returns the job ID. If called in
67
+ a task, return the job ID of the associated driver. The
68
+ job ID will be hex format.
69
+
70
+ Raises:
71
+ AssertionError: If not called in a driver or worker. Generally,
72
+ this means that ray.init() was not called.
73
+ """
74
+ assert ray.is_initialized(), (
75
+ "Job ID is not available because " "Ray has not been initialized."
76
+ )
77
+ job_id = self.worker.current_job_id
78
+ return job_id.hex()
79
+
80
+ @property
81
+ @Deprecated(message="Use get_node_id() instead", warning=True)
82
+ def node_id(self):
83
+ """Get current node ID for this worker or driver.
84
+
85
+ Node ID is the id of a node that your driver, task, or actor runs.
86
+
87
+ Returns:
88
+ A node id for this worker or driver.
89
+ """
90
+ node_id = self.worker.current_node_id
91
+ assert not node_id.is_nil()
92
+ return node_id
93
+
94
+ def get_node_id(self) -> str:
95
+ """Get current node ID for this worker or driver.
96
+
97
+ Node ID is the id of a node that your driver, task, or actor runs.
98
+ The ID will be in hex format.
99
+
100
+ Returns:
101
+ A node id in hex format for this worker or driver.
102
+
103
+ Raises:
104
+ AssertionError: If not called in a driver or worker. Generally,
105
+ this means that ray.init() was not called.
106
+ """
107
+ assert ray.is_initialized(), (
108
+ "Node ID is not available because " "Ray has not been initialized."
109
+ )
110
+ node_id = self.worker.current_node_id
111
+ return node_id.hex()
112
+
113
+ def get_worker_id(self) -> str:
114
+ """Get current worker ID for this worker or driver process.
115
+
116
+ Returns:
117
+ A worker id in hex format for this worker or driver process.
118
+ """
119
+ assert (
120
+ ray.is_initialized()
121
+ ), "Worker ID is not available because Ray has not been initialized."
122
+ return self.worker.worker_id.hex()
123
+
124
+ @property
125
+ @Deprecated(message="Use get_task_id() instead", warning=True)
126
+ def task_id(self):
127
+ """Get current task ID for this worker.
128
+
129
+ Task ID is the id of a Ray task.
130
+ This shouldn't be used in a driver process.
131
+
132
+ Example:
133
+
134
+ .. testcode::
135
+
136
+ import ray
137
+
138
+ @ray.remote
139
+ class Actor:
140
+ def ready(self):
141
+ return True
142
+
143
+ @ray.remote
144
+ def f():
145
+ return True
146
+
147
+ # All the below code generates different task ids.
148
+ # Task ids are available for actor creation.
149
+ a = Actor.remote()
150
+ # Task ids are available for actor tasks.
151
+ a.ready.remote()
152
+ # Task ids are available for normal tasks.
153
+ f.remote()
154
+
155
+ Returns:
156
+ The current worker's task id. None if there's no task id.
157
+ """
158
+ # only worker mode has task_id
159
+ assert (
160
+ self.worker.mode == ray._private.worker.WORKER_MODE
161
+ ), f"This method is only available when the process is a\
162
+ worker. Current mode: {self.worker.mode}"
163
+
164
+ task_id = self._get_current_task_id()
165
+ return task_id if not task_id.is_nil() else None
166
+
167
+ def get_task_id(self) -> Optional[str]:
168
+ """Get current task ID for this worker.
169
+
170
+ Task ID is the id of a Ray task. The ID will be in hex format.
171
+ This shouldn't be used in a driver process.
172
+
173
+ Example:
174
+
175
+ .. testcode::
176
+
177
+ import ray
178
+
179
+ @ray.remote
180
+ class Actor:
181
+ def get_task_id(self):
182
+ return ray.get_runtime_context().get_task_id()
183
+
184
+ @ray.remote
185
+ def get_task_id():
186
+ return ray.get_runtime_context().get_task_id()
187
+
188
+ # All the below code generates different task ids.
189
+ a = Actor.remote()
190
+ # Task ids are available for actor tasks.
191
+ print(ray.get(a.get_task_id.remote()))
192
+ # Task ids are available for normal tasks.
193
+ print(ray.get(get_task_id.remote()))
194
+
195
+ .. testoutput::
196
+ :options: +MOCK
197
+
198
+ 16310a0f0a45af5c2746a0e6efb235c0962896a201000000
199
+ c2668a65bda616c1ffffffffffffffffffffffff01000000
200
+
201
+ Returns:
202
+ The current worker's task id in hex. None if there's no task id.
203
+ """
204
+ # only worker mode has task_id
205
+ if self.worker.mode != ray._private.worker.WORKER_MODE:
206
+ logger.warning(
207
+ "This method is only available when the process is a "
208
+ f"worker. Current mode: {self.worker.mode}"
209
+ )
210
+ return None
211
+ task_id = self._get_current_task_id()
212
+ return task_id.hex() if not task_id.is_nil() else None
213
+
214
+ def _get_current_task_id(self) -> TaskID:
215
+ return self.worker.current_task_id
216
+
217
+ def get_task_name(self) -> Optional[str]:
218
+ """Get current task name for this worker.
219
+
220
+ Task name by default is the task's funciton call string. It can also be
221
+ specified in options when triggering a task.
222
+
223
+ Example:
224
+
225
+ .. testcode::
226
+
227
+ import ray
228
+
229
+ @ray.remote
230
+ class Actor:
231
+ def get_task_name(self):
232
+ return ray.get_runtime_context().get_task_name()
233
+
234
+ @ray.remote
235
+ class AsyncActor:
236
+ async def get_task_name(self):
237
+ return ray.get_runtime_context().get_task_name()
238
+
239
+ @ray.remote
240
+ def get_task_name():
241
+ return ray.get_runtime_context().get_task_name()
242
+
243
+ a = Actor.remote()
244
+ b = AsyncActor.remote()
245
+ # Task names are available for actor tasks.
246
+ print(ray.get(a.get_task_name.remote()))
247
+ # Task names are avaiable for async actor tasks.
248
+ print(ray.get(b.get_task_name.remote()))
249
+ # Task names are available for normal tasks.
250
+ # Get default task name
251
+ print(ray.get(get_task_name.remote()))
252
+ # Get specified task name
253
+ print(ray.get(get_task_name.options(name="task_name").remote()))
254
+
255
+ .. testoutput::
256
+ :options: +MOCK
257
+
258
+ Actor.get_task_name
259
+ AsyncActor.get_task_name
260
+ get_task_name
261
+ task_nams
262
+
263
+ Returns:
264
+ The current worker's task name
265
+ """
266
+ # only worker mode has task_name
267
+ if self.worker.mode != ray._private.worker.WORKER_MODE:
268
+ logger.warning(
269
+ "This method is only available when the process is a "
270
+ f"worker. Current mode: {self.worker.mode}"
271
+ )
272
+ return None
273
+ return self.worker.current_task_name
274
+
275
+ def get_task_function_name(self) -> Optional[str]:
276
+ """Get current task function name string for this worker.
277
+
278
+ Example:
279
+
280
+ .. testcode::
281
+
282
+ import ray
283
+
284
+ @ray.remote
285
+ class Actor:
286
+ def get_task_function_name(self):
287
+ return ray.get_runtime_context().get_task_function_name()
288
+
289
+ @ray.remote
290
+ class AsyncActor:
291
+ async def get_task_function_name(self):
292
+ return ray.get_runtime_context().get_task_function_name()
293
+
294
+ @ray.remote
295
+ def get_task_function_name():
296
+ return ray.get_runtime_context().get_task_function_name()
297
+
298
+ a = Actor.remote()
299
+ b = AsyncActor.remote()
300
+ # Task functions are available for actor tasks.
301
+ print(ray.get(a.get_task_function_name.remote()))
302
+ # Task functions are available for async actor tasks.
303
+ print(ray.get(b.get_task_function_name.remote()))
304
+ # Task functions are available for normal tasks.
305
+ print(ray.get(get_task_function_name.remote()))
306
+
307
+ .. testoutput::
308
+ :options: +MOCK
309
+
310
+ [python modual name].Actor.get_task_function_name
311
+ [python modual name].AsyncActor.get_task_function_name
312
+ [python modual name].get_task_function_name
313
+
314
+ Returns:
315
+ The current worker's task function call string
316
+ """
317
+ # only worker mode has task_function_name
318
+ if self.worker.mode != ray._private.worker.WORKER_MODE:
319
+ logger.warning(
320
+ "This method is only available when the process is a "
321
+ f"worker. Current mode: {self.worker.mode}"
322
+ )
323
+ return None
324
+ return self.worker.current_task_function_name
325
+
326
+ @property
327
+ @Deprecated(message="Use get_actor_id() instead", warning=True)
328
+ def actor_id(self):
329
+ """Get the current actor ID in this worker.
330
+
331
+ ID of the actor of the current process.
332
+ This shouldn't be used in a driver process.
333
+
334
+ Returns:
335
+ The current actor id in this worker. None if there's no actor id.
336
+ """
337
+ # only worker mode has actor_id
338
+ assert (
339
+ self.worker.mode == ray._private.worker.WORKER_MODE
340
+ ), f"This method is only available when the process is a\
341
+ worker. Current mode: {self.worker.mode}"
342
+ actor_id = self.worker.actor_id
343
+ return actor_id if not actor_id.is_nil() else None
344
+
345
+ def get_actor_id(self) -> Optional[str]:
346
+ """Get the current actor ID in this worker.
347
+
348
+ ID of the actor of the current process.
349
+ This shouldn't be used in a driver process.
350
+ The ID will be in hex format.
351
+
352
+ Returns:
353
+ The current actor id in hex format in this worker. None if there's no
354
+ actor id.
355
+ """
356
+ # only worker mode has actor_id
357
+ if self.worker.mode != ray._private.worker.WORKER_MODE:
358
+ logger.debug(
359
+ "This method is only available when the process is a "
360
+ f"worker. Current mode: {self.worker.mode}"
361
+ )
362
+ return None
363
+ actor_id = self.worker.actor_id
364
+ return actor_id.hex() if not actor_id.is_nil() else None
365
+
366
+ def get_actor_name(self) -> Optional[str]:
367
+ """Get the current actor name of this worker.
368
+
369
+ This shouldn't be used in a driver process.
370
+ The name is in string format.
371
+
372
+ Returns:
373
+ The current actor name of this worker.
374
+ If a current worker is an actor, and
375
+ if actor name doesn't exist, it returns an empty string.
376
+ If a current worker is not an actor, it returns None.
377
+ """
378
+ # only worker mode has actor_id
379
+ if self.worker.mode != ray._private.worker.WORKER_MODE:
380
+ logger.warning(
381
+ "This method is only available when the process is a "
382
+ f"worker. Current mode: {self.worker.mode}"
383
+ )
384
+ return None
385
+ actor_id = self.worker.actor_id
386
+ return self.worker.actor_name if not actor_id.is_nil() else None
387
+
388
+ @property
389
+ def namespace(self):
390
+ """Get the current namespace of this worker.
391
+
392
+ Returns:
393
+ The current namespace of this worker.
394
+ """
395
+ return self.worker.namespace
396
+
397
+ @property
398
+ def was_current_actor_reconstructed(self):
399
+ """Check whether this actor has been restarted.
400
+
401
+ Returns:
402
+ Whether this actor has been ever restarted.
403
+ """
404
+ assert (
405
+ not self.actor_id.is_nil()
406
+ ), "This method should't be called inside Ray tasks."
407
+ actor_info = ray._private.state.actors(self.actor_id.hex())
408
+ return actor_info and actor_info["NumRestarts"] != 0
409
+
410
+ @property
411
+ @Deprecated(message="Use get_placement_group_id() instead", warning=True)
412
+ def current_placement_group_id(self):
413
+ """Get the current Placement group ID of this worker.
414
+
415
+ Returns:
416
+ The current placement group id of this worker.
417
+ """
418
+ return self.worker.placement_group_id
419
+
420
+ def get_placement_group_id(self) -> Optional[str]:
421
+ """Get the current Placement group ID of this worker.
422
+
423
+ Returns:
424
+ The current placement group id in hex format of this worker.
425
+ """
426
+ pg_id = self.worker.placement_group_id
427
+ return pg_id.hex() if not pg_id.is_nil() else None
428
+
429
+ @property
430
+ def should_capture_child_tasks_in_placement_group(self):
431
+ """Get if the current task should capture parent's placement group.
432
+
433
+ This returns True if it is called inside a driver.
434
+
435
+ Returns:
436
+ Return True if the current task should implicitly
437
+ capture the parent placement group.
438
+ """
439
+ return self.worker.should_capture_child_tasks_in_placement_group
440
+
441
+ def get_assigned_resources(self):
442
+ """Get the assigned resources to this worker.
443
+
444
+ By default for tasks, this will return {"CPU": 1}.
445
+ By default for actors, this will return {}. This is because
446
+ actors do not have CPUs assigned to them by default.
447
+
448
+ Returns:
449
+ A dictionary mapping the name of a resource to a float, where
450
+ the float represents the amount of that resource reserved
451
+ for this worker.
452
+ """
453
+ assert (
454
+ self.worker.mode == ray._private.worker.WORKER_MODE
455
+ ), f"This method is only available when the process is a\
456
+ worker. Current mode: {self.worker.mode}"
457
+ self.worker.check_connected()
458
+ resource_id_map = self.worker.core_worker.resource_ids()
459
+ resource_map = {
460
+ res: sum(amt for _, amt in mapping)
461
+ for res, mapping in resource_id_map.items()
462
+ }
463
+ result = parse_pg_formatted_resources_to_original(resource_map)
464
+ return result
465
+
466
+ def get_runtime_env_string(self):
467
+ """Get the runtime env string used for the current driver or worker.
468
+
469
+ Returns:
470
+ The runtime env string currently using by this worker.
471
+ """
472
+ return self.worker.runtime_env
473
+
474
+ @property
475
+ def runtime_env(self):
476
+ """Get the runtime env used for the current driver or worker.
477
+
478
+ Returns:
479
+ The runtime env currently using by this worker. The type of
480
+ return value is ray.runtime_env.RuntimeEnv.
481
+ """
482
+
483
+ return RuntimeEnv.deserialize(self.get_runtime_env_string())
484
+
485
+ @property
486
+ def current_actor(self):
487
+ """Get the current actor handle of this actor itself.
488
+
489
+ Returns:
490
+ The handle of current actor.
491
+ """
492
+ worker = self.worker
493
+ worker.check_connected()
494
+ actor_id = worker.actor_id
495
+ if actor_id.is_nil():
496
+ raise RuntimeError("This method is only available in an actor.")
497
+
498
+ return worker.core_worker.get_actor_handle(actor_id)
499
+
500
+ @property
501
+ def gcs_address(self):
502
+ """Get the GCS address of the ray cluster.
503
+ Returns:
504
+ The GCS address of the cluster.
505
+ """
506
+ self.worker.check_connected()
507
+ return self.worker.gcs_client.address
508
+
509
+ @Deprecated(message="Use get_accelerator_ids() instead", warning=True)
510
+ def get_resource_ids(self) -> Dict[str, List[str]]:
511
+ return self.get_accelerator_ids()
512
+
513
+ def get_accelerator_ids(self) -> Dict[str, List[str]]:
514
+ """
515
+ Get the current worker's visible accelerator ids.
516
+
517
+ Returns:
518
+ A dictionary keyed by the accelerator resource name. The values are a list
519
+ of ids `{'GPU': ['0', '1'], 'neuron_cores': ['0', '1'],
520
+ 'TPU': ['0', '1']}`.
521
+ """
522
+ worker = self.worker
523
+ worker.check_connected()
524
+ ids_dict: Dict[str, List[str]] = {}
525
+ for (
526
+ accelerator_resource_name
527
+ ) in ray._private.accelerators.get_all_accelerator_resource_names():
528
+ accelerator_ids = worker.get_accelerator_ids_for_accelerator_resource(
529
+ accelerator_resource_name,
530
+ f"^{accelerator_resource_name}_group_[0-9A-Za-z]+$",
531
+ )
532
+ ids_dict[accelerator_resource_name] = [str(id) for id in accelerator_ids]
533
+ return ids_dict
534
+
535
+
536
+ _runtime_context = None
537
+
538
+
539
+ @PublicAPI
540
+ @client_mode_hook
541
+ def get_runtime_context() -> RuntimeContext:
542
+ """Get the runtime context of the current driver/worker.
543
+
544
+ The obtained runtime context can be used to get the metadata
545
+ of the current task and actor.
546
+
547
+ Example:
548
+
549
+ .. testcode::
550
+
551
+ import ray
552
+ # Get the job id.
553
+ ray.get_runtime_context().get_job_id()
554
+ # Get the actor id.
555
+ ray.get_runtime_context().get_actor_id()
556
+ # Get the task id.
557
+ ray.get_runtime_context().get_task_id()
558
+
559
+ """
560
+ global _runtime_context
561
+ if _runtime_context is None:
562
+ _runtime_context = RuntimeContext(ray._private.worker.global_worker)
563
+
564
+ return _runtime_context
.venv/lib/python3.11/site-packages/ray/setup-dev.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # ruff: noqa: E402
3
+ """This script allows you to develop Ray Python code without needing to compile
4
+ Ray.
5
+ See https://docs.ray.io/en/master/development.html#building-ray-python-only"""
6
+
7
+ import os
8
+ import sys
9
+
10
+ # types.py can conflict with stdlib's types.py in some python versions,
11
+ # see https://github.com/python/cpython/issues/101210.
12
+ # To avoid import errors, we move the current working dir to the end of sys.path.
13
+ this_dir = os.path.dirname(__file__)
14
+ if this_dir in sys.path:
15
+ cur = sys.path.remove(this_dir)
16
+ sys.path.append(this_dir)
17
+
18
+ import argparse
19
+ import click
20
+ import shutil
21
+ import subprocess
22
+
23
+ import ray
24
+
25
+
26
+ def do_link(package, force=False, skip_list=None, local_path=None):
27
+ if skip_list and package in skip_list:
28
+ print(f"Skip creating symbolic link for {package}")
29
+ return
30
+ package_home = os.path.abspath(os.path.join(ray.__file__, f"../{package}"))
31
+ # Infer local_path automatically.
32
+ if local_path is None:
33
+ local_path = f"../{package}"
34
+ local_home = os.path.abspath(os.path.join(__file__, local_path))
35
+ # If installed package dir does not exist, continue either way. We'll
36
+ # remove it/create a link from there anyways.
37
+ if not os.path.isdir(package_home) and not os.path.isfile(package_home):
38
+ print(f"{package_home} does not exist. Continuing to link.")
39
+ # Make sure the path we are linking to does exist.
40
+ assert os.path.exists(local_home), local_home
41
+ # Confirm with user.
42
+ if not force and not click.confirm(
43
+ f"This will replace:\n {package_home}\nwith " f"a symlink to:\n {local_home}",
44
+ default=True,
45
+ ):
46
+ return
47
+
48
+ # Windows: Create directory junction.
49
+ if os.name == "nt":
50
+ try:
51
+ shutil.rmtree(package_home)
52
+ except FileNotFoundError:
53
+ pass
54
+ except OSError:
55
+ os.remove(package_home)
56
+
57
+ # create symlink for directory or file
58
+ if os.path.isdir(local_home):
59
+ subprocess.check_call(
60
+ ["mklink", "/J", package_home, local_home], shell=True
61
+ )
62
+ elif os.path.isfile(local_home):
63
+ subprocess.check_call(
64
+ ["mklink", "/H", package_home, local_home], shell=True
65
+ )
66
+ else:
67
+ print(f"{local_home} is neither directory nor file. Link failed.")
68
+
69
+ # Posix: Use `ln -s` to create softlink.
70
+ else:
71
+ sudo = []
72
+ if not os.access(os.path.dirname(package_home), os.W_OK):
73
+ print("You don't have write permission " f"to {package_home}, using sudo:")
74
+ sudo = ["sudo"]
75
+ print(f"Creating symbolic link from \n {local_home} to \n {package_home}")
76
+
77
+ # Preserve ray/serve/generated
78
+ if package == "serve":
79
+ # Copy generated folder to a temp dir
80
+ generated_folder = os.path.join(package_home, "generated")
81
+ temp_dir = "/tmp/ray/_serve/"
82
+ if not os.path.exists(temp_dir):
83
+ os.makedirs(temp_dir)
84
+ subprocess.check_call(["cp", "-r", generated_folder, temp_dir])
85
+
86
+ subprocess.check_call(sudo + ["rm", "-rf", package_home])
87
+ subprocess.check_call(sudo + ["ln", "-s", local_home, package_home])
88
+
89
+ # Move generated folder to local_home
90
+ if package == "serve":
91
+ tmp_generated_folder = os.path.join(temp_dir, "generated")
92
+ package_generated_folder = os.path.join(package_home, "generated")
93
+ subprocess.check_call(
94
+ ["mv", tmp_generated_folder, package_generated_folder]
95
+ )
96
+
97
+
98
+ if __name__ == "__main__":
99
+ parser = argparse.ArgumentParser(
100
+ formatter_class=argparse.RawDescriptionHelpFormatter, description="Setup dev."
101
+ )
102
+ parser.add_argument(
103
+ "--yes", "-y", action="store_true", help="Don't ask for confirmation."
104
+ )
105
+ parser.add_argument(
106
+ "--skip",
107
+ "-s",
108
+ nargs="*",
109
+ help="List of folders to skip linking to facilitate workspace dev",
110
+ required=False,
111
+ )
112
+ parser.add_argument(
113
+ "--extras",
114
+ "-e",
115
+ nargs="*",
116
+ help="List of extra folders to link to facilitate workspace dev",
117
+ required=False,
118
+ )
119
+
120
+ args = parser.parse_args()
121
+ if not args.yes:
122
+ print("NOTE: Use '-y' to override all python files without confirmation.")
123
+
124
+ do_link("rllib", force=args.yes, skip_list=args.skip, local_path="../../../rllib")
125
+ do_link("air", force=args.yes, skip_list=args.skip)
126
+ do_link("tune", force=args.yes, skip_list=args.skip)
127
+ do_link("train", force=args.yes, skip_list=args.skip)
128
+ do_link("autoscaler", force=args.yes, skip_list=args.skip)
129
+ do_link("cloudpickle", force=args.yes, skip_list=args.skip)
130
+ do_link("data", force=args.yes, skip_list=args.skip)
131
+ do_link("scripts", force=args.yes, skip_list=args.skip)
132
+ do_link("internal", force=args.yes, skip_list=args.skip)
133
+ do_link("tests", force=args.yes, skip_list=args.skip)
134
+ do_link("experimental", force=args.yes, skip_list=args.skip)
135
+ do_link("util", force=args.yes, skip_list=args.skip)
136
+ do_link("workflow", force=args.yes, skip_list=args.skip)
137
+ do_link("serve", force=args.yes, skip_list=args.skip)
138
+ do_link("dag", force=args.yes, skip_list=args.skip)
139
+ do_link("widgets", force=args.yes, skip_list=args.skip)
140
+ do_link("cluster_utils.py", force=args.yes, skip_list=args.skip)
141
+ do_link("_private", force=args.yes, skip_list=args.skip)
142
+ do_link("dashboard", force=args.yes, skip_list=args.skip)
143
+
144
+ if args.extras is not None:
145
+ for package in args.extras:
146
+ do_link(package, force=args.yes, skip_list=args.skip)
147
+
148
+ print(
149
+ "Created links.\n\nIf you run into issues initializing Ray, please "
150
+ "ensure that your local repo and the installed Ray are in sync "
151
+ "(pip install -U the latest wheels at "
152
+ "https://docs.ray.io/en/master/installation.html, "
153
+ "and ensure you are up-to-date on the master branch on git).\n\n"
154
+ "Note that you may need to delete the package symlinks when pip "
155
+ "installing new Ray versions to prevent pip from overwriting files "
156
+ "in your git repo."
157
+ )
.venv/lib/python3.11/site-packages/ray/types.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generic, TypeVar
2
+
3
+ from ray.util.annotations import PublicAPI
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ # TODO(ekl) this is a dummy generic ref type for documentation purposes only.
9
+ # We should try to make the Cython ray.ObjectRef properly generic.
10
+ # NOTE(sang): Looks like using Generic in Cython is not currently possible.
11
+ # We should update Cython > 3.0 for this.
12
+ @PublicAPI
13
+ class ObjectRef(Generic[T]):
14
+ pass
.venv/lib/python3.11/site-packages/ray/workflow/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.workflow.api import (
2
+ init,
3
+ run,
4
+ run_async,
5
+ resume,
6
+ resume_all,
7
+ resume_async,
8
+ cancel,
9
+ list_all,
10
+ delete,
11
+ get_output,
12
+ get_output_async,
13
+ get_status,
14
+ get_metadata,
15
+ sleep,
16
+ wait_for_event,
17
+ continuation,
18
+ options,
19
+ )
20
+ from ray.workflow.exceptions import (
21
+ WorkflowError,
22
+ WorkflowExecutionError,
23
+ WorkflowCancellationError,
24
+ )
25
+ from ray.workflow.common import WorkflowStatus
26
+ from ray.workflow.event_listener import EventListener
27
+
28
+ globals().update(WorkflowStatus.__members__)
29
+
30
+
31
+ __all__ = [
32
+ "init",
33
+ "run",
34
+ "run_async",
35
+ "resume",
36
+ "resume_async",
37
+ "resume_all",
38
+ "cancel",
39
+ "list_all",
40
+ "delete",
41
+ "get_output",
42
+ "get_output_async",
43
+ "get_status",
44
+ "get_metadata",
45
+ "sleep",
46
+ "wait_for_event",
47
+ "options",
48
+ "continuation",
49
+ # events
50
+ "EventListener",
51
+ # exceptions
52
+ "WorkflowError",
53
+ "WorkflowExecutionError",
54
+ "WorkflowCancellationError",
55
+ ]
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/debug_utils.cpython-311.pyc ADDED
Binary file (3.09 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_access.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_context.cpython-311.pyc ADDED
Binary file (5.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_executor.cpython-311.pyc ADDED
Binary file (21.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_dag.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_state_from_storage.cpython-311.pyc ADDED
Binary file (3.64 kB). View file
 
.venv/lib/python3.11/site-packages/ray/workflow/api.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import tempfile
4
+ from typing import Dict, Set, List, Tuple, Union, Optional, Any
5
+ import time
6
+ import uuid
7
+ from pathlib import Path
8
+
9
+ import ray
10
+ from ray.dag import DAGNode
11
+ from ray.dag.input_node import DAGInputData
12
+ from ray.remote_function import RemoteFunction
13
+
14
+ # avoid collision with arguments & APIs
15
+
16
+ from ray.workflow.common import (
17
+ WorkflowStatus,
18
+ Event,
19
+ asyncio_run,
20
+ validate_user_metadata,
21
+ )
22
+ from ray.workflow import serialization, workflow_access, workflow_context
23
+ from ray.workflow.event_listener import EventListener, EventListenerType, TimerListener
24
+ from ray.workflow.workflow_storage import WorkflowStorage
25
+ from ray.workflow.workflow_state_from_dag import workflow_state_from_dag
26
+
27
+ from ray.util.annotations import PublicAPI
28
+ from ray._private.usage import usage_lib
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @PublicAPI(stability="alpha")
34
+ def init(
35
+ *,
36
+ max_running_workflows: Optional[int] = None,
37
+ max_pending_workflows: Optional[int] = None,
38
+ ) -> None:
39
+ """Initialize workflow.
40
+
41
+ If Ray is not initialized, we will initialize Ray and
42
+ use ``/tmp/ray/workflow_data`` as the default storage.
43
+
44
+ Args:
45
+ max_running_workflows: The maximum number of concurrently running workflows.
46
+ Use -1 as infinity. 'None' means preserving previous setting or initialize
47
+ the setting with infinity.
48
+ max_pending_workflows: The maximum number of queued workflows.
49
+ Use -1 as infinity. 'None' means preserving previous setting or initialize
50
+ the setting with infinity.
51
+ """
52
+ usage_lib.record_library_usage("workflow")
53
+
54
+ if max_running_workflows is not None:
55
+ if not isinstance(max_running_workflows, int):
56
+ raise TypeError("'max_running_workflows' must be None or an integer.")
57
+ if max_running_workflows < -1 or max_running_workflows == 0:
58
+ raise ValueError(
59
+ "'max_running_workflows' must be a positive integer "
60
+ "or use -1 as infinity."
61
+ )
62
+ if max_pending_workflows is not None:
63
+ if not isinstance(max_pending_workflows, int):
64
+ raise TypeError("'max_pending_workflows' must be None or an integer.")
65
+ if max_pending_workflows < -1:
66
+ raise ValueError(
67
+ "'max_pending_workflows' must be a non-negative integer "
68
+ "or use -1 as infinity."
69
+ )
70
+
71
+ if not ray.is_initialized():
72
+ # We should use get_temp_dir_path, but for ray client, we don't
73
+ # have this one. We need a flag to tell whether it's a client
74
+ # or a driver to use the right dir.
75
+ # For now, just use $TMP/ray/workflow_data
76
+ workflow_dir = Path(tempfile.gettempdir()) / "ray" / "workflow_data"
77
+ ray.init(storage=workflow_dir.as_uri())
78
+ workflow_access.init_management_actor(max_running_workflows, max_pending_workflows)
79
+ serialization.init_manager()
80
+
81
+
82
+ def _ensure_workflow_initialized() -> None:
83
+ # NOTE: Trying to get the actor has a side effect: it initializes Ray with
84
+ # default arguments. This is different in "init()": it assigns a temporary
85
+ # storage. This is why we need to check "ray.is_initialized()" first.
86
+ if not ray.is_initialized():
87
+ init()
88
+ else:
89
+ try:
90
+ workflow_access.get_management_actor()
91
+ except ValueError:
92
+ init()
93
+
94
+
95
+ def client_mode_wrap(func):
96
+ """Wraps a function called during client mode for execution as a remote task.
97
+
98
+ Adopted from "ray._private.client_mode_hook.client_mode_wrap". Some changes are made
99
+ (e.g., init the workflow instead of init Ray; the latter does not specify a storage
100
+ during Ray init and will result in workflow failures).
101
+ """
102
+
103
+ @functools.wraps(func)
104
+ def wrapper(*args, **kwargs):
105
+ from ray._private.client_mode_hook import client_mode_should_convert
106
+ from ray._private.auto_init_hook import enable_auto_connect
107
+
108
+ if enable_auto_connect:
109
+ _ensure_workflow_initialized()
110
+
111
+ # `is_client_mode_enabled_by_default` is used for testing with
112
+ # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode.
113
+ if client_mode_should_convert():
114
+ f = ray.remote(num_cpus=0)(func)
115
+ ref = f.remote(*args, **kwargs)
116
+ return ray.get(ref)
117
+ return func(*args, **kwargs)
118
+
119
+ return wrapper
120
+
121
+
122
+ @PublicAPI(stability="alpha")
123
+ def run(
124
+ dag: DAGNode,
125
+ *args,
126
+ workflow_id: Optional[str] = None,
127
+ metadata: Optional[Dict[str, Any]] = None,
128
+ **kwargs,
129
+ ) -> Any:
130
+ """Run a workflow.
131
+
132
+ If the workflow with the given id already exists, it will be resumed.
133
+
134
+ Examples:
135
+ .. testcode::
136
+
137
+ import ray
138
+ from ray import workflow
139
+
140
+ @ray.remote
141
+ def book_flight(origin: str, dest: str):
142
+ return f"Flight: {origin}->{dest}"
143
+
144
+ @ray.remote
145
+ def book_hotel(location: str):
146
+ return f"Hotel: {location}"
147
+
148
+ @ray.remote
149
+ def finalize_trip(bookings: List[Any]):
150
+ return ' | '.join(ray.get(bookings))
151
+
152
+ flight1 = book_flight.bind("OAK", "SAN")
153
+ flight2 = book_flight.bind("SAN", "OAK")
154
+ hotel = book_hotel.bind("SAN")
155
+ trip = finalize_trip.bind([flight1, flight2, hotel])
156
+ print(workflow.run(trip))
157
+
158
+ .. testoutput::
159
+
160
+ Flight: OAK->SAN | Flight: SAN->OAK | Hotel: SAN
161
+
162
+ Args:
163
+ workflow_id: A unique identifier that can be used to resume the
164
+ workflow. If not specified, a random id will be generated.
165
+ metadata: The metadata to add to the workflow. It has to be able
166
+ to serialize to json.
167
+
168
+ Returns:
169
+ The running result.
170
+ """
171
+ return ray.get(
172
+ run_async(dag, *args, workflow_id=workflow_id, metadata=metadata, **kwargs)
173
+ )
174
+
175
+
176
+ @PublicAPI(stability="alpha")
177
+ def run_async(
178
+ dag: DAGNode,
179
+ *args,
180
+ workflow_id: Optional[str] = None,
181
+ metadata: Optional[Dict[str, Any]] = None,
182
+ **kwargs,
183
+ ) -> ray.ObjectRef:
184
+ """Run a workflow asynchronously.
185
+
186
+ If the workflow with the given id already exists, it will be resumed.
187
+
188
+ Args:
189
+ workflow_id: A unique identifier that can be used to resume the
190
+ workflow. If not specified, a random id will be generated.
191
+ metadata: The metadata to add to the workflow. It has to be able
192
+ to serialize to json.
193
+
194
+ Returns:
195
+ The running result as ray.ObjectRef.
196
+
197
+ """
198
+ _ensure_workflow_initialized()
199
+ if not isinstance(dag, DAGNode):
200
+ raise TypeError("Input should be a DAG.")
201
+ input_data = DAGInputData(*args, **kwargs)
202
+ validate_user_metadata(metadata)
203
+ metadata = metadata or {}
204
+
205
+ if workflow_id is None:
206
+ # Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds}
207
+ workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}"
208
+
209
+ workflow_manager = workflow_access.get_management_actor()
210
+ if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)):
211
+ raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.")
212
+
213
+ state = workflow_state_from_dag(dag, input_data, workflow_id)
214
+ logger.info(f'Workflow job created. [id="{workflow_id}"].')
215
+ context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id)
216
+ with workflow_context.workflow_task_context(context):
217
+ # checkpoint the workflow
218
+ @client_mode_wrap
219
+ def _try_checkpoint_workflow(workflow_state) -> bool:
220
+ ws = WorkflowStorage(workflow_id)
221
+ ws.save_workflow_user_metadata(metadata)
222
+ try:
223
+ ws.get_entrypoint_task_id()
224
+ return True
225
+ except Exception:
226
+ # The workflow does not exist. We must checkpoint entry workflow.
227
+ ws.save_workflow_execution_state("", workflow_state)
228
+ return False
229
+
230
+ wf_exists = _try_checkpoint_workflow(state)
231
+ if wf_exists:
232
+ return resume_async(workflow_id)
233
+ ray.get(
234
+ workflow_manager.submit_workflow.remote(
235
+ workflow_id, state, ignore_existing=False
236
+ )
237
+ )
238
+ job_id = ray.get_runtime_context().get_job_id()
239
+ return workflow_manager.execute_workflow.remote(job_id, context)
240
+
241
+
242
+ @PublicAPI(stability="alpha")
243
+ def resume(workflow_id: str) -> Any:
244
+ """Resume a workflow.
245
+
246
+ Resume a workflow and retrieve its output. If the workflow was incomplete,
247
+ it will be re-executed from its checkpointed outputs. If the workflow was
248
+ complete, returns the result immediately.
249
+
250
+ Examples:
251
+ .. testcode::
252
+
253
+ from ray import workflow
254
+
255
+ @ray.remote
256
+ def start_trip():
257
+ return 3
258
+
259
+ trip = start_trip.bind()
260
+ res1 = workflow.run_async(trip, workflow_id="trip1")
261
+ res2 = workflow.resume("trip1")
262
+ assert ray.get(res1) == res2
263
+
264
+ Args:
265
+ workflow_id: The id of the workflow to resume.
266
+
267
+ Returns:
268
+ The output of the workflow.
269
+ """
270
+ return ray.get(resume_async(workflow_id))
271
+
272
+
273
+ @PublicAPI(stability="alpha")
274
+ def resume_async(workflow_id: str) -> ray.ObjectRef:
275
+ """Resume a workflow asynchronously.
276
+
277
+ Resume a workflow and retrieve its output. If the workflow was incomplete,
278
+ it will be re-executed from its checkpointed outputs. If the workflow was
279
+ complete, returns the result immediately.
280
+
281
+ Examples:
282
+ .. testcode::
283
+
284
+ from ray import workflow
285
+
286
+ @ray.remote
287
+ def start_trip():
288
+ return 3
289
+
290
+ trip = start_trip.bind()
291
+ res1 = workflow.run_async(trip, workflow_id="trip1")
292
+ res2 = workflow.resume_async("trip1")
293
+ assert ray.get(res1) == ray.get(res2)
294
+
295
+ Args:
296
+ workflow_id: The id of the workflow to resume.
297
+
298
+ Returns:
299
+ An object reference that can be used to retrieve the workflow result.
300
+ """
301
+ _ensure_workflow_initialized()
302
+ logger.info(f'Resuming workflow [id="{workflow_id}"].')
303
+ workflow_manager = workflow_access.get_management_actor()
304
+ if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)):
305
+ raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.")
306
+ # NOTE: It is important to 'ray.get' the returned output. This
307
+ # ensures caller of 'run()' holds the reference to the workflow
308
+ # result. Otherwise if the actor removes the reference of the
309
+ # workflow output, the caller may fail to resolve the result.
310
+ job_id = ray.get_runtime_context().get_job_id()
311
+
312
+ context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id)
313
+ ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context))
314
+ result = workflow_manager.execute_workflow.remote(job_id, context)
315
+ logger.info(f"Workflow job {workflow_id} resumed.")
316
+ return result
317
+
318
+
319
+ @PublicAPI(stability="alpha")
320
+ def get_output(workflow_id: str, *, task_id: Optional[str] = None) -> Any:
321
+ """Get the output of a running workflow.
322
+
323
+ Args:
324
+ workflow_id: The workflow to get the output of.
325
+ task_id: If set, fetch the specific task instead of the output of the
326
+ workflow.
327
+
328
+ Examples:
329
+ .. testcode::
330
+
331
+ from ray import workflow
332
+
333
+ @ray.remote
334
+ def start_trip():
335
+ return 1
336
+
337
+ trip = start_trip.options(**workflow.options(task_id="trip")).bind()
338
+ res1 = workflow.run_async(trip, workflow_id="trip1")
339
+ # you could "get_output()" in another machine
340
+ res2 = workflow.get_output("trip1")
341
+ assert ray.get(res1) == res2
342
+ task_output = workflow.get_output_async("trip1", task_id="trip")
343
+ assert ray.get(task_output) == ray.get(res1)
344
+
345
+ Returns:
346
+ The output of the workflow task.
347
+ """
348
+ return ray.get(get_output_async(workflow_id, task_id=task_id))
349
+
350
+
351
+ @PublicAPI(stability="alpha")
352
+ @client_mode_wrap
353
+ def get_output_async(
354
+ workflow_id: str, *, task_id: Optional[str] = None
355
+ ) -> ray.ObjectRef:
356
+ """Get the output of a running workflow asynchronously.
357
+
358
+ Args:
359
+ workflow_id: The workflow to get the output of.
360
+ task_id: If set, fetch the specific task output instead of the output
361
+ of the workflow.
362
+
363
+ Returns:
364
+ An object reference that can be used to retrieve the workflow task result.
365
+ """
366
+ _ensure_workflow_initialized()
367
+ try:
368
+ workflow_manager = workflow_access.get_management_actor()
369
+ except ValueError as e:
370
+ raise ValueError(
371
+ "Failed to connect to the workflow management "
372
+ "actor. The workflow could have already failed. You can use "
373
+ "workflow.resume() or workflow.resume_async() to resume the "
374
+ "workflow."
375
+ ) from e
376
+ return workflow_manager.get_output.remote(workflow_id, task_id)
377
+
378
+
379
+ @PublicAPI(stability="alpha")
380
+ @client_mode_wrap
381
+ def list_all(
382
+ status_filter: Optional[
383
+ Union[Union[WorkflowStatus, str], Set[Union[WorkflowStatus, str]]]
384
+ ] = None
385
+ ) -> List[Tuple[str, WorkflowStatus]]:
386
+ """List all workflows matching a given status filter. When returning "RESUMEABLE"
387
+ workflows, the workflows that was running ranks before the workflow that was pending
388
+ in the result list.
389
+
390
+ Args:
391
+ status_filter: If given, only returns workflow with that status. This can
392
+ be a single status or set of statuses. The string form of the
393
+ status is also acceptable, i.e.,
394
+ "RUNNING"/"FAILED"/"SUCCESSFUL"/"CANCELED"/"RESUMABLE"/"PENDING".
395
+
396
+ Examples:
397
+ .. testcode::
398
+
399
+ from ray import workflow
400
+
401
+ @ray.remote
402
+ def long_running_job():
403
+ import time
404
+ time.sleep(2)
405
+
406
+ workflow_task = long_running_job.bind()
407
+ wf = workflow.run_async(workflow_task,
408
+ workflow_id="long_running_job")
409
+ jobs = workflow.list_all(workflow.RUNNING)
410
+ assert jobs == [ ("long_running_job", workflow.RUNNING) ]
411
+ ray.get(wf)
412
+ jobs = workflow.list_all({workflow.RUNNING})
413
+ assert jobs == []
414
+
415
+ Returns:
416
+ A list of tuple with workflow id and workflow status
417
+ """
418
+ _ensure_workflow_initialized()
419
+ if isinstance(status_filter, str):
420
+ status_filter = set({WorkflowStatus(status_filter)})
421
+ elif isinstance(status_filter, WorkflowStatus):
422
+ status_filter = set({status_filter})
423
+ elif isinstance(status_filter, set):
424
+ if all(isinstance(s, str) for s in status_filter):
425
+ status_filter = {WorkflowStatus(s) for s in status_filter}
426
+ elif not all(isinstance(s, WorkflowStatus) for s in status_filter):
427
+ raise TypeError(
428
+ "status_filter contains element which is not"
429
+ " a type of `WorkflowStatus or str`."
430
+ f" {status_filter}"
431
+ )
432
+ elif status_filter is None:
433
+ status_filter = set(WorkflowStatus)
434
+ status_filter.discard(WorkflowStatus.NONE)
435
+ else:
436
+ raise TypeError(
437
+ "status_filter must be WorkflowStatus or a set of WorkflowStatus."
438
+ )
439
+
440
+ try:
441
+ workflow_manager = workflow_access.get_management_actor()
442
+ except ValueError:
443
+ workflow_manager = None
444
+
445
+ if workflow_manager is None:
446
+ non_terminating_workflows = {}
447
+ else:
448
+ non_terminating_workflows = ray.get(
449
+ workflow_manager.list_non_terminating_workflows.remote()
450
+ )
451
+
452
+ ret = []
453
+ if set(non_terminating_workflows.keys()).issuperset(status_filter):
454
+ for status, workflows in non_terminating_workflows.items():
455
+ if status in status_filter:
456
+ for w in workflows:
457
+ ret.append((w, status))
458
+ return ret
459
+
460
+ ret = []
461
+ # Here we don't have workflow id, so use empty one instead
462
+ store = WorkflowStorage("")
463
+ modified_status_filter = status_filter.copy()
464
+ # Here we have to add non-terminating status to the status filter, because some
465
+ # "RESUMABLE" workflows are converted from non-terminating workflows below.
466
+ # This is the tricky part: the status "RESUMABLE" neither come from
467
+ # the workflow management actor nor the storage. It is the status where
468
+ # the storage says it is non-terminating but the workflow management actor
469
+ # is not running it. This usually happened when there was a sudden crash
470
+ # of the whole Ray runtime or the workflow management actor
471
+ # (due to cluster etc.). So we includes non terminating status in the storage
472
+ # filter to get "RESUMABLE" candidates.
473
+ modified_status_filter.update(WorkflowStatus.non_terminating_status())
474
+ status_from_storage = store.list_workflow(modified_status_filter)
475
+ non_terminating_workflows = {
476
+ k: set(v) for k, v in non_terminating_workflows.items()
477
+ }
478
+ resume_running = []
479
+ resume_pending = []
480
+ for (k, s) in status_from_storage:
481
+ if s in non_terminating_workflows and k not in non_terminating_workflows[s]:
482
+ if s == WorkflowStatus.RUNNING:
483
+ resume_running.append(k)
484
+ elif s == WorkflowStatus.PENDING:
485
+ resume_pending.append(k)
486
+ else:
487
+ assert False, "This line of code should not be reachable."
488
+ continue
489
+ if s in status_filter:
490
+ ret.append((k, s))
491
+ if WorkflowStatus.RESUMABLE in status_filter:
492
+ # The running workflows ranks before the pending workflows.
493
+ for w in resume_running:
494
+ ret.append((w, WorkflowStatus.RESUMABLE))
495
+ for w in resume_pending:
496
+ ret.append((w, WorkflowStatus.RESUMABLE))
497
+ return ret
498
+
499
+
500
+ @PublicAPI(stability="alpha")
501
+ @client_mode_wrap
502
+ def resume_all(include_failed: bool = False) -> List[Tuple[str, ray.ObjectRef]]:
503
+ """Resume all resumable workflow jobs.
504
+
505
+ This can be used after cluster restart to resume all tasks.
506
+
507
+ Args:
508
+ include_failed: Whether to resume FAILED workflows.
509
+
510
+ Examples:
511
+ .. testcode::
512
+
513
+ from ray import workflow
514
+
515
+ @ray.remote
516
+ def failed_job():
517
+ raise ValueError()
518
+
519
+ workflow_task = failed_job.bind()
520
+ output = workflow.run_async(
521
+ workflow_task, workflow_id="failed_job")
522
+ try:
523
+ ray.get(output)
524
+ except Exception:
525
+ print("JobFailed")
526
+
527
+ assert workflow.get_status("failed_job") == workflow.FAILED
528
+ print(workflow.resume_all(include_failed=True))
529
+
530
+ .. testoutput::
531
+
532
+ JobFailed
533
+ [('failed_job', ObjectRef(...))]
534
+
535
+ Returns:
536
+ A list of (workflow_id, returned_obj_ref) resumed.
537
+ """
538
+ _ensure_workflow_initialized()
539
+ filter_set = {WorkflowStatus.RESUMABLE}
540
+ if include_failed:
541
+ filter_set.add(WorkflowStatus.FAILED)
542
+ all_failed = list_all(filter_set)
543
+
544
+ try:
545
+ workflow_manager = workflow_access.get_management_actor()
546
+ except Exception as e:
547
+ raise RuntimeError("Failed to get management actor") from e
548
+
549
+ job_id = ray.get_runtime_context().get_job_id()
550
+ reconstructed_workflows = []
551
+ for wid, _ in all_failed:
552
+ context = workflow_context.WorkflowTaskContext(workflow_id=wid)
553
+ # TODO(suquark): This is not very efficient, but it makes sure
554
+ # running workflows has higher priority when getting reconstructed.
555
+ try:
556
+ ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context))
557
+ except Exception as e:
558
+ # TODO(suquark): Here some workflows got resumed successfully but some
559
+ # failed and the user has no idea about this, which is very wired.
560
+ # Maybe we should raise an exception here instead?
561
+ logger.error(f"Failed to resume workflow {context.workflow_id}", exc_info=e)
562
+ raise
563
+ reconstructed_workflows.append(context)
564
+
565
+ results = []
566
+ for context in reconstructed_workflows:
567
+ results.append(
568
+ (
569
+ context.workflow_id,
570
+ workflow_manager.execute_workflow.remote(job_id, context),
571
+ )
572
+ )
573
+ return results
574
+
575
+
576
+ @PublicAPI(stability="alpha")
577
+ def get_status(workflow_id: str) -> WorkflowStatus:
578
+ """Get the status for a given workflow.
579
+
580
+ Args:
581
+ workflow_id: The workflow to query.
582
+
583
+ Examples:
584
+ .. testcode::
585
+
586
+ from ray import workflow
587
+
588
+ @ray.remote
589
+ def trip():
590
+ pass
591
+
592
+ workflow_task = trip.bind()
593
+ output = workflow.run(workflow_task, workflow_id="local_trip")
594
+ assert workflow.SUCCESSFUL == workflow.get_status("local_trip")
595
+
596
+ Returns:
597
+ The status of that workflow
598
+ """
599
+ _ensure_workflow_initialized()
600
+ if not isinstance(workflow_id, str):
601
+ raise TypeError("workflow_id has to be a string type.")
602
+ workflow_manager = workflow_access.get_management_actor()
603
+ return ray.get(workflow_manager.get_workflow_status.remote(workflow_id))
604
+
605
+
606
+ @PublicAPI(stability="alpha")
607
+ def wait_for_event(
608
+ event_listener_type: EventListenerType, *args, **kwargs
609
+ ) -> "DAGNode[Event]":
610
+ if not issubclass(event_listener_type, EventListener):
611
+ raise TypeError(
612
+ f"Event listener type is {event_listener_type.__name__}"
613
+ ", which is not a subclass of workflow.EventListener"
614
+ )
615
+
616
+ @ray.remote
617
+ def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event:
618
+ event_listener = event_listener_type()
619
+ return asyncio_run(event_listener.poll_for_event(*args, **kwargs))
620
+
621
+ @ray.remote
622
+ def message_committed(
623
+ event_listener_type: EventListenerType, event: Event
624
+ ) -> Event:
625
+ event_listener = event_listener_type()
626
+ asyncio_run(event_listener.event_checkpointed(event))
627
+ return event
628
+
629
+ return message_committed.bind(
630
+ event_listener_type, get_message.bind(event_listener_type, *args, **kwargs)
631
+ )
632
+
633
+
634
+ @PublicAPI(stability="alpha")
635
+ def sleep(duration: float) -> "DAGNode[Event]":
636
+ """
637
+ A workfow that resolves after sleeping for a given duration.
638
+ """
639
+
640
+ @ray.remote
641
+ def end_time():
642
+ return time.time() + duration
643
+
644
+ return wait_for_event(TimerListener, end_time.bind())
645
+
646
+
647
+ @PublicAPI(stability="alpha")
648
+ @client_mode_wrap
649
+ def get_metadata(workflow_id: str, task_id: Optional[str] = None) -> Dict[str, Any]:
650
+ """Get the metadata of the workflow.
651
+
652
+ This will return a dict of metadata of either the workflow (
653
+ if only workflow_id is given) or a specific workflow task (if
654
+ both workflow_id and task id are given). Exception will be
655
+ raised if the given workflow id or task id does not exist.
656
+
657
+ If only workflow id is given, this will return metadata on
658
+ workflow level, which includes running status, workflow-level
659
+ user metadata and workflow-level running stats (e.g. the
660
+ start time and end time of the workflow).
661
+
662
+ If both workflow id and task id are given, this will return
663
+ metadata on workflow task level, which includes task inputs,
664
+ task-level user metadata and task-level running stats (e.g.
665
+ the start time and end time of the task).
666
+
667
+
668
+ Args:
669
+ workflow_id: The workflow to get the metadata of.
670
+ task_id: If set, fetch the metadata of the specific task instead of
671
+ the metadata of the workflow.
672
+
673
+ Examples:
674
+ .. testcode::
675
+
676
+ from ray import workflow
677
+
678
+ @ray.remote
679
+ def trip():
680
+ pass
681
+
682
+ workflow_task = trip.options(
683
+ **workflow.options(task_id="trip", metadata={"k1": "v1"})).bind()
684
+ workflow.run(workflow_task,
685
+ workflow_id="trip1", metadata={"k2": "v2"})
686
+ workflow_metadata = workflow.get_metadata("trip1")
687
+ print(workflow_metadata)
688
+
689
+ task_metadata = workflow.get_metadata("trip1", "trip")
690
+ print(task_metadata)
691
+
692
+ .. testoutput::
693
+
694
+ {'status': 'SUCCESSFUL', 'user_metadata': {'k2': 'v2'}, 'stats': {'start_time': ..., 'end_time': ...}}
695
+ {'task_id': 'trip', 'task_options': {'task_type': 'FUNCTION', 'max_retries': 3, 'catch_exceptions': False, 'retry_exceptions': False, 'checkpoint': True, 'ray_options': {'_metadata': {'workflow.io/options': {'task_id': 'trip', 'metadata': {'k1': 'v1'}}}}}, 'user_metadata': {'k1': 'v1'}, 'workflow_refs': [], 'stats': {'start_time': ..., 'end_time': ...}}
696
+
697
+ Returns:
698
+ A dictionary containing the metadata of the workflow.
699
+
700
+ Raises:
701
+ ValueError: if given workflow or workflow task does not exist.
702
+ """ # noqa: E501
703
+ _ensure_workflow_initialized()
704
+ store = WorkflowStorage(workflow_id)
705
+ if task_id is None:
706
+ return store.load_workflow_metadata()
707
+ else:
708
+ return store.load_task_metadata(task_id)
709
+
710
+
711
+ @PublicAPI(stability="alpha")
712
+ def cancel(workflow_id: str) -> None:
713
+ """Cancel a workflow. Workflow checkpoints will still be saved in storage. To
714
+ clean up saved checkpoints, see `workflow.delete()`.
715
+
716
+ Args:
717
+ workflow_id: The workflow to cancel.
718
+
719
+ Examples:
720
+ .. testcode::
721
+
722
+ from ray import workflow
723
+
724
+ @ray.remote
725
+ def some_job():
726
+ return 1
727
+
728
+ workflow_task = some_job.bind()
729
+ workflow.run(workflow_task, workflow_id="some_job")
730
+ workflow.cancel(workflow_id="some_job")
731
+ assert workflow.get_status("some_job") == workflow.CANCELED
732
+
733
+ Returns:
734
+ None
735
+
736
+ """
737
+ _ensure_workflow_initialized()
738
+ if not isinstance(workflow_id, str):
739
+ raise TypeError("workflow_id has to be a string type.")
740
+ workflow_manager = workflow_access.get_management_actor()
741
+ ray.get(workflow_manager.cancel_workflow.remote(workflow_id))
742
+
743
+
744
+ @PublicAPI(stability="alpha")
745
+ def delete(workflow_id: str) -> None:
746
+ """Delete a workflow, its checkpoints, and other information it may have
747
+ persisted to storage. To stop a running workflow, see
748
+ `workflow.cancel()`.
749
+
750
+ Args:
751
+ workflow_id: The workflow to delete.
752
+
753
+ Raises:
754
+ WorkflowStillActiveError: The workflow is still active.
755
+ WorkflowNotFoundError: The workflow does not exist.
756
+
757
+ Examples:
758
+ .. testcode::
759
+
760
+ from ray import workflow
761
+
762
+ @ray.remote
763
+ def some_job():
764
+ pass
765
+
766
+ workflow_task = some_job.bind()
767
+ workflow.run(workflow_task, workflow_id="some_job")
768
+ workflow.delete(workflow_id="some_job")
769
+ """
770
+ _ensure_workflow_initialized()
771
+ workflow_manager = workflow_access.get_management_actor()
772
+ ray.get(workflow_manager.delete_workflow.remote(workflow_id))
773
+
774
+
775
+ @PublicAPI(stability="alpha")
776
+ def continuation(dag_node: "DAGNode") -> Union["DAGNode", Any]:
777
+ """Converts a DAG into a continuation.
778
+
779
+ The result depends on the context. If it is inside a workflow, it
780
+ returns a workflow; otherwise it executes and get the result of
781
+ the DAG.
782
+
783
+ Args:
784
+ dag_node: The DAG to be converted.
785
+ """
786
+ from ray.workflow.workflow_context import in_workflow_execution
787
+
788
+ if not isinstance(dag_node, DAGNode):
789
+ raise TypeError("Input should be a DAG.")
790
+
791
+ if in_workflow_execution():
792
+ return dag_node
793
+ return ray.get(dag_node.execute())
794
+
795
+
796
+ @PublicAPI(stability="alpha")
797
+ class options:
798
+ """This class serves both as a decorator and options for workflow.
799
+
800
+ Examples:
801
+
802
+ .. testcode::
803
+
804
+ import ray
805
+ from ray import workflow
806
+
807
+ # specify workflow options with a decorator
808
+ @workflow.options(catch_exceptions=True)
809
+ @ray.remote
810
+ def foo():
811
+ return 1
812
+
813
+ # specify workflow options in ".options"
814
+ foo_new = foo.options(**workflow.options(catch_exceptions=False))
815
+ """
816
+
817
+ def __init__(self, **workflow_options: Dict[str, Any]):
818
+ # TODO(suquark): More rigid arguments check like @ray.remote arguments. This is
819
+ # fairly complex, but we should enable it later.
820
+ valid_options = {
821
+ "task_id",
822
+ "metadata",
823
+ "catch_exceptions",
824
+ "checkpoint",
825
+ }
826
+ invalid_keywords = set(workflow_options.keys()) - valid_options
827
+ if invalid_keywords:
828
+ raise ValueError(
829
+ f"Invalid option keywords {invalid_keywords} for workflow tasks. "
830
+ f"Valid ones are {valid_options}."
831
+ )
832
+ from ray.workflow.common import WORKFLOW_OPTIONS
833
+
834
+ validate_user_metadata(workflow_options.get("metadata"))
835
+
836
+ self.options = {"_metadata": {WORKFLOW_OPTIONS: workflow_options}}
837
+
838
+ def keys(self):
839
+ return ("_metadata",)
840
+
841
+ def __getitem__(self, key):
842
+ return self.options[key]
843
+
844
+ def __call__(self, f: RemoteFunction) -> RemoteFunction:
845
+ if not isinstance(f, RemoteFunction):
846
+ raise ValueError("Only apply 'workflow.options' to Ray remote functions.")
847
+ f._default_options.update(self.options)
848
+ return f
849
+
850
+
851
+ __all__ = (
852
+ "init",
853
+ "run",
854
+ "run_async",
855
+ "resume",
856
+ "resume_async",
857
+ "resume_all",
858
+ "cancel",
859
+ "list_all",
860
+ "delete",
861
+ "get_output",
862
+ "get_output_async",
863
+ "get_status",
864
+ "get_metadata",
865
+ "sleep",
866
+ "wait_for_event",
867
+ "options",
868
+ "continuation",
869
+ )
.venv/lib/python3.11/site-packages/ray/workflow/common.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+
4
+ from ray import cloudpickle
5
+ from enum import Enum, unique
6
+ import hashlib
7
+ from typing import Dict, Optional, Any, Tuple
8
+
9
+ from dataclasses import dataclass
10
+
11
+ import ray
12
+ from ray import ObjectRef
13
+ from ray._private.utils import get_or_create_event_loop
14
+ from ray.util.annotations import PublicAPI
15
+
16
+ # Alias types
17
+ Event = Any
18
+ TaskID = str
19
+ WorkflowOutputType = ObjectRef
20
+
21
+ MANAGEMENT_ACTOR_NAMESPACE = "workflow"
22
+ MANAGEMENT_ACTOR_NAME = "WorkflowManagementActor"
23
+ HTTP_EVENT_PROVIDER_NAME = "WorkflowHttpEventProvider"
24
+ STORAGE_ACTOR_NAME = "StorageManagementActor"
25
+ WORKFLOW_OPTIONS = "workflow.io/options"
26
+
27
+
28
+ def asyncio_run(coro):
29
+ return get_or_create_event_loop().run_until_complete(coro)
30
+
31
+
32
+ def validate_user_metadata(metadata):
33
+ if metadata is not None:
34
+ if not isinstance(metadata, dict):
35
+ raise ValueError("metadata must be a dict.")
36
+ try:
37
+ json.dumps(metadata)
38
+ except TypeError as e:
39
+ raise ValueError(
40
+ "metadata must be JSON serializable, instead, "
41
+ "we got 'TypeError: {}'".format(e)
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class WorkflowRef:
47
+ """This class represents a reference of a workflow output.
48
+
49
+ A reference means the workflow has already been executed,
50
+ and we have both the workflow task ID and the object ref to it
51
+ living outputs.
52
+
53
+ This could be used when you want to return a running workflow
54
+ from a workflow task. For example, the remaining workflows
55
+ returned by 'workflow.wait' contains a static ref to these
56
+ pending workflows.
57
+ """
58
+
59
+ # The ID of the task that produces the output of the workflow.
60
+ task_id: TaskID
61
+ # The ObjectRef of the output. If it is "None", then the output has been
62
+ # saved in the storage, and we need to check the workflow management actor
63
+ # for the object ref.
64
+ ref: Optional[ObjectRef] = None
65
+
66
+ @classmethod
67
+ def from_output(cls, task_id: str, output: Any):
68
+ """Create static ref from given output."""
69
+ if not isinstance(output, cls):
70
+ if not isinstance(output, ray.ObjectRef):
71
+ output = ray.put(output)
72
+ output = cls(task_id=task_id, ref=output)
73
+ return output
74
+
75
+ def __hash__(self):
76
+ return hash(self.task_id)
77
+
78
+
79
+ @PublicAPI(stability="alpha")
80
+ @unique
81
+ class WorkflowStatus(str, Enum):
82
+ # No status is set for this workflow.
83
+ NONE = "NONE"
84
+ # There is at least a remote task running in ray cluster
85
+ RUNNING = "RUNNING"
86
+ # It got canceled and can't be resumed later.
87
+ CANCELED = "CANCELED"
88
+ # The workflow runs successfully.
89
+ SUCCESSFUL = "SUCCESSFUL"
90
+ # The workflow failed with an application error.
91
+ # It can be resumed.
92
+ FAILED = "FAILED"
93
+ # The workflow failed with a system error, i.e., ray shutdown.
94
+ # It can be resumed.
95
+ RESUMABLE = "RESUMABLE"
96
+ # The workflow is queued and waited to be executed.
97
+ PENDING = "PENDING"
98
+
99
+ @classmethod
100
+ def non_terminating_status(cls) -> "Tuple[WorkflowStatus, ...]":
101
+ return cls.RUNNING, cls.PENDING
102
+
103
+
104
+ @unique
105
+ class TaskType(str, Enum):
106
+ """All task types."""
107
+
108
+ FUNCTION = "FUNCTION"
109
+ WAIT = "WAIT"
110
+
111
+
112
+ CheckpointModeType = bool
113
+
114
+
115
+ @unique
116
+ class CheckpointMode(Enum):
117
+ """All checkpoint modes."""
118
+
119
+ # Keep the checkpoint of the workflow task.
120
+ SYNC = True
121
+ # Skip the checkpoint of the workflow task.
122
+ SKIP = False
123
+
124
+
125
+ @ray.remote
126
+ def _hash(obj: Any) -> bytes:
127
+ m = hashlib.sha256()
128
+ m.update(cloudpickle.dumps(obj))
129
+ return m.digest()
130
+
131
+
132
+ @ray.remote
133
+ def calculate_identifier(obj: Any) -> str:
134
+ """Calculate a url-safe identifier for an object."""
135
+
136
+ # Task 1: Serialize the object.
137
+ # Task 2: Calculate its sha256 hash.
138
+ # Task 3: Get the url safe, base64 representation of it.
139
+
140
+ # TODO (Alex): Ideally we should use the existing ObjectRef serializer to
141
+ # avoid duplicate serialization passes and support nested object refs.
142
+ m = hashlib.sha256()
143
+ m.update(cloudpickle.dumps(obj))
144
+ hash = m.digest()
145
+ encoded = base64.urlsafe_b64encode(hash).decode("ascii")
146
+ return encoded
147
+
148
+
149
+ @dataclass
150
+ class WorkflowTaskRuntimeOptions:
151
+ """Options that will affect a workflow task at runtime."""
152
+
153
+ # Type of the task.
154
+ task_type: "TaskType"
155
+ # Whether the user want to handle the exception manually.
156
+ catch_exceptions: bool
157
+ # Whether application-level errors should be retried.
158
+ retry_exceptions: bool
159
+ # The num of retry for application exceptions & system failures.
160
+ max_retries: int
161
+ # Checkpoint mode.
162
+ checkpoint: CheckpointModeType
163
+ # ray_remote options
164
+ ray_options: Dict[str, Any]
165
+
166
+ def to_dict(self) -> Dict[str, Any]:
167
+ return {
168
+ "task_type": self.task_type,
169
+ "max_retries": self.max_retries,
170
+ "catch_exceptions": self.catch_exceptions,
171
+ "retry_exceptions": self.retry_exceptions,
172
+ "checkpoint": self.checkpoint,
173
+ "ray_options": self.ray_options,
174
+ }
175
+
176
+ @classmethod
177
+ def from_dict(cls, value: Dict[str, Any]):
178
+ return cls(
179
+ task_type=TaskType[value["task_type"]],
180
+ max_retries=value["max_retries"],
181
+ catch_exceptions=value["catch_exceptions"],
182
+ retry_exceptions=value["retry_exceptions"],
183
+ checkpoint=value["checkpoint"],
184
+ ray_options=value["ray_options"],
185
+ )
186
+
187
+
188
+ @dataclass
189
+ class WorkflowExecutionMetadata:
190
+ """Dataclass for the metadata of the workflow execution."""
191
+
192
+ # True if the workflow task returns a workflow DAG.
193
+ is_output_workflow: bool = False
194
+
195
+
196
+ @dataclass
197
+ class WorkflowMetaData:
198
+ # The current status of the workflow
199
+ status: WorkflowStatus
.venv/lib/python3.11/site-packages/ray/workflow/debug_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for debugging purpose."""
2
+ import ray
3
+ from ray.dag import DAGNode, DAGInputData
4
+
5
+ from ray.workflow.common import asyncio_run
6
+ from ray.workflow.workflow_executor import WorkflowExecutor
7
+ from ray.workflow.workflow_context import workflow_task_context, WorkflowTaskContext
8
+ from ray.workflow.workflow_storage import get_workflow_storage
9
+
10
+
11
+ def execute_workflow_local(dag: DAGNode, workflow_id: str, *args, **kwargs):
12
+ """Execute the workflow locally."""
13
+ from ray.workflow.workflow_state_from_dag import workflow_state_from_dag
14
+
15
+ job_id = ray.get_runtime_context().get_job_id()
16
+ context = WorkflowTaskContext(workflow_id=workflow_id)
17
+ with workflow_task_context(context):
18
+ wf_store = get_workflow_storage()
19
+ state = workflow_state_from_dag(
20
+ dag, DAGInputData(*args, **kwargs), workflow_id=workflow_id
21
+ )
22
+ executor = WorkflowExecutor(state)
23
+ fut = executor.get_task_output_async(state.output_task_id)
24
+ asyncio_run(executor.run_until_complete(job_id, context, wf_store))
25
+ return asyncio_run(fut)
26
+
27
+
28
+ def resume_workflow_local(workflow_id: str):
29
+ """Resume the workflow locally."""
30
+ from ray.workflow.workflow_state_from_storage import workflow_state_from_storage
31
+
32
+ job_id = ray.get_runtime_context().get_job_id()
33
+ context = WorkflowTaskContext(workflow_id=workflow_id)
34
+ with workflow_task_context(context):
35
+ wf_store = get_workflow_storage()
36
+ state = workflow_state_from_storage(workflow_id, None)
37
+ executor = WorkflowExecutor(state)
38
+ fut = executor.get_task_output_async(state.output_task_id)
39
+ asyncio_run(executor.run_until_complete(job_id, context, wf_store))
40
+ return asyncio_run(fut)
.venv/lib/python3.11/site-packages/ray/workflow/event_listener.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from ray.util.annotations import PublicAPI
3
+ from ray.workflow.common import Event
4
+ import time
5
+ from typing import Callable
6
+
7
+ EventListenerType = Callable[[], "EventListener"]
8
+
9
+
10
+ @PublicAPI(stability="alpha")
11
+ class EventListener:
12
+ """Defining a custom event listener. Event listeners provide an efficient way
13
+ to listen for a custom event.
14
+
15
+ Event listeners should be stateless. They will be instantiated from a
16
+ coordinator actor.
17
+
18
+ Example definition
19
+ ==================
20
+
21
+ ```
22
+ class CustomEventListener:
23
+
24
+ def __init__(self):
25
+ self.event_provider = ...
26
+
27
+ async def poll_for_event(self, topic, partition):
28
+ return await self.event_provider.poll(topic, partition)
29
+
30
+ async def event_checkpointed(self, event: Event):
31
+ self.event_provider.commit(event.offset)
32
+ ```
33
+
34
+ Example Usage
35
+ =============
36
+ .. testcode::
37
+ :skipif: True
38
+
39
+ from ray import workflow
40
+ CustomEventListener = ...
41
+ event_task = workflow.wait_for_event(
42
+ CustomEventListener, "topic1", "partition2")
43
+ handle_event = ...
44
+ workflow.run(handle_event.task(event_task))
45
+
46
+ """
47
+
48
+ def __init__(self):
49
+ """Optional constructor. Only the constructor with now arguments will be
50
+ called."""
51
+ pass
52
+
53
+ async def poll_for_event(self, *args, **kwargs) -> Event:
54
+ """Should return only when the event is received."""
55
+ raise NotImplementedError
56
+
57
+ async def event_checkpointed(self, event: Event) -> None:
58
+ """Optional. Called after an event has been checkpointed and a transaction can
59
+ be safely committed."""
60
+ pass
61
+
62
+
63
+ @PublicAPI(stability="alpha")
64
+ class TimerListener(EventListener):
65
+ """
66
+ A listener that produces an event at a given timestamp.
67
+ """
68
+
69
+ async def poll_for_event(self, timestamp):
70
+ await asyncio.sleep(timestamp - time.time())
.venv/lib/python3.11/site-packages/ray/workflow/exceptions.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.util.annotations import PublicAPI
2
+ from ray.workflow.common import TaskID
3
+
4
+
5
+ @PublicAPI(stability="alpha")
6
+ class WorkflowError(Exception):
7
+ """Workflow error base class."""
8
+
9
+
10
+ @PublicAPI(stability="alpha")
11
+ class WorkflowExecutionError(WorkflowError):
12
+ def __init__(self, workflow_id: str):
13
+ self.message = f"Workflow[id={workflow_id}] failed during execution."
14
+ super().__init__(self.message)
15
+
16
+
17
+ @PublicAPI(stability="alpha")
18
+ class WorkflowCancellationError(WorkflowError):
19
+ def __init__(self, workflow_id: str):
20
+ self.message = f"Workflow[id={workflow_id}] is cancelled during execution."
21
+ super().__init__(self.message)
22
+
23
+
24
+ @PublicAPI(stability="alpha")
25
+ class WorkflowNotResumableError(WorkflowError):
26
+ """Raise the exception when we cannot resume from a workflow."""
27
+
28
+ def __init__(self, workflow_id: str):
29
+ self.message = f"Workflow[id={workflow_id}] is not resumable."
30
+ super().__init__(self.message)
31
+
32
+
33
+ @PublicAPI(stability="alpha")
34
+ class WorkflowTaskNotRecoverableError(WorkflowNotResumableError):
35
+ """Raise the exception when we find a workflow task cannot be recovered
36
+ using the checkpointed inputs."""
37
+
38
+ def __init__(self, task_id: TaskID):
39
+ self.message = f"Workflow task[id={task_id}] is not recoverable"
40
+ super(WorkflowError, self).__init__(self.message)
41
+
42
+
43
+ @PublicAPI(stability="alpha")
44
+ class WorkflowNotFoundError(WorkflowError):
45
+ def __init__(self, workflow_id: str):
46
+ self.message = f"Workflow[id={workflow_id}] was referenced but doesn't exist."
47
+ super().__init__(self.message)
48
+
49
+
50
+ @PublicAPI(stability="alpha")
51
+ class WorkflowStillActiveError(WorkflowError):
52
+ def __init__(self, operation: str, workflow_id: str):
53
+ self.message = (
54
+ f"{operation} couldn't be completed because "
55
+ f"Workflow[id={workflow_id}] is still running or pending."
56
+ )
57
+ super().__init__(self.message)
.venv/lib/python3.11/site-packages/ray/workflow/http_event_provider.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.responses import JSONResponse
5
+
6
+ import ray
7
+ from ray import serve
8
+ from ray.workflow import common, workflow_context, workflow_access
9
+ from ray.workflow.event_listener import EventListener
10
+ from ray.workflow.common import Event
11
+
12
+
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class WorkflowEventHandleError(Exception):
19
+ """Raise when event processing failed"""
20
+
21
+ def __init__(self, workflow_id: str, what_happened: str):
22
+ self.message = (
23
+ f"Workflow[id={workflow_id}] HTTP event handle failed: {what_happened}"
24
+ )
25
+ super().__init__(self.message)
26
+
27
+
28
+ app = FastAPI()
29
+
30
+
31
+ @serve.deployment(num_replicas=1)
32
+ @serve.ingress(app)
33
+ class HTTPEventProvider:
34
+ """HTTPEventProvider is defined to be a Ray Serve deployment with route_prefix='/event',
35
+ which will receive external events via an HTTP endpoint. It supports FastAPI,
36
+ e.g. post. It responds to both poll_for_event() and event_checkpointed() from
37
+ an HTTPListener instance.
38
+
39
+ HTTPListener class is designed to work with the current workflow.wait_for_event()
40
+ implementation, where an HTTPListener instance will be initiated by the
41
+ get_message() and message_committed() of the workflow.wait_for_event().
42
+
43
+ HTTPEventProvider requires an event to arrive after HTTPListner registers
44
+ its event_key. If an event arrived before the registration, it returns HTTP
45
+ error code 404 with the error "workflow_id and event_key need to be registered
46
+ to receive event. Please make sure they are registered before resending."
47
+
48
+ Example definition
49
+ ==================
50
+
51
+ ```
52
+ class HTTPEventProvider:
53
+
54
+ def __init__(self):
55
+
56
+ @app.post("/send_event/{workflow_id}")
57
+ async def send_event(self, workflow_id: str, req: Request):
58
+ Receive an external event message and acknowledge if it was processed
59
+ by the workflow
60
+ async def get_event_payload(self, workflow_id, event_key):
61
+ Internal method used by HTTPListner to subscribe to an event matched by
62
+ workflow_id and event_key
63
+ async def report_checkpointed(self, workflow_id, event, confirmation):
64
+ Internal method used by HTTPListner to confirm the received event has been
65
+ checkpointed by workflow
66
+ ```
67
+
68
+ Example Usage
69
+ =============
70
+ .. testcode::
71
+ :skipif: True
72
+
73
+ from ray.workflow.http_event_provider import HTTPEventProvider, HTTPListener
74
+ ray.init(address='auto', namespace='serve')
75
+ serve.start(detached=True)
76
+ event_node = workflow.wait_for_event(
77
+ HTTPListener, event_key='')
78
+ handle_event = ...
79
+ workflow.run_aync(handle_event.bind(event_node))
80
+
81
+ On a separate python process, it sends an event to the HTTPEventProvider:
82
+
83
+ .. testcode::
84
+ :skipif: True
85
+
86
+ import requests
87
+ resp = requests.post('http://127.0.0.1:8000/event/send_event/{workflow_id}',
88
+ json={'event_key':'my_key','event_payload':'testMessage'})
89
+
90
+ """
91
+
92
+ def __init__(self):
93
+ """Maintain two data structures to track pending events and confirmations
94
+ event_key_payload: for each registered workflow_id and event_key,
95
+ keep the Future to be set after an event is received.
96
+ event_checkpoint_pending: for each received event_key, keep its Future
97
+ after checkpointing is confirmed so HTTP 200 can be returned.
98
+ """
99
+ self.event_key_payload: Dict[str, Dict[str, asyncio.Future]] = {}
100
+ self.event_checkpoint_pending: Dict[str, asyncio.Future] = {}
101
+
102
+ @app.post("/send_event/{workflow_id}")
103
+ async def send_event(self, workflow_id: str, req: Request) -> JSONResponse:
104
+ """Receive an external event message and acknowledge if it was processed
105
+ by the workflow
106
+ Args:
107
+ workflow_id: the workflow that this event is submitted for
108
+ req: the JSON formatted request that contains two string fields: '
109
+ event_key' and 'event_payload'
110
+ 'event_key' uniquely identifies a node in the receiving workflow;
111
+ 'event_payload' refers to the event's content
112
+ Example:
113
+ JSON formatted request {"event_key":"node_event","event_payload":"approved"}
114
+ Returns:
115
+ if the event was received and processed, HTTP response status 200
116
+ if the event was not expected or the workflow_id did not exist, HTTP
117
+ response status 404
118
+ if the event was received but failed at checkpointing, HTTP response 500
119
+
120
+ """
121
+ req_json = await req.json()
122
+ try:
123
+ event_key = req_json["event_key"]
124
+ event_payload = req_json["event_payload"]
125
+ except KeyError as e:
126
+ return JSONResponse(
127
+ status_code=404,
128
+ content={
129
+ "error": {
130
+ "code": 404,
131
+ "message": f"{e} field is not found in the request JSON",
132
+ }
133
+ },
134
+ )
135
+ try:
136
+ self.event_key_payload[workflow_id][event_key].set_result(
137
+ (event_key, event_payload)
138
+ )
139
+ except KeyError:
140
+ return JSONResponse(
141
+ status_code=404,
142
+ content={
143
+ "error": {
144
+ "code": 404,
145
+ "message": "workflow_id and event_key need to be registered "
146
+ "to receive event. Please make sure they are "
147
+ "registered before resending.",
148
+ }
149
+ },
150
+ )
151
+
152
+ self.event_checkpoint_pending[event_key] = asyncio.Future()
153
+ confirmed = await self.event_checkpoint_pending[event_key]
154
+ self.event_checkpoint_pending.pop(event_key)
155
+ if confirmed:
156
+ return JSONResponse(status_code=200, content={})
157
+ return JSONResponse(
158
+ status_code=500,
159
+ content={"error": {"code": 500, "message": "event processing failed"}},
160
+ )
161
+
162
+ async def get_event_payload(self, workflow_id: str, event_key: str) -> Event:
163
+ """Internal method used by HTTPListener to subscribe to an event matched
164
+ by workflow_id and event_key"""
165
+ if workflow_id not in self.event_key_payload:
166
+ self.event_key_payload[workflow_id] = {}
167
+
168
+ if event_key in self.event_key_payload[workflow_id]:
169
+ raise WorkflowEventHandleError(
170
+ workflow_id, f"The same {event_key} is used to get payload again."
171
+ )
172
+
173
+ self.event_key_payload[workflow_id][event_key] = asyncio.Future()
174
+ return await self.event_key_payload[workflow_id][event_key]
175
+
176
+ async def report_checkpointed(
177
+ self, workflow_id: str, event_key: str, confirmation: bool
178
+ ) -> str:
179
+ """Internal method used by HTTPListner to confirm the received event has
180
+ been checkpointed by workflow"""
181
+ try:
182
+ self.event_checkpoint_pending[event_key].set_result(confirmation)
183
+ except KeyError:
184
+ logger.error(
185
+ f"{event_key} cannot be found to acknowledge request. "
186
+ f"The event provider may have been restarted."
187
+ )
188
+ raise WorkflowEventHandleError(
189
+ workflow_id, f"{event_key} cannot be found to acknowledge request."
190
+ )
191
+ return "OK"
192
+
193
+
194
+ class HTTPListener(EventListener):
195
+ """HTTPLister is defined to work with the HTTPEventProvider. It implements two
196
+ APIs, poll_for_event() and event_checkpointed(). An instance of HTTPListener will
197
+ be started by the get_message() of the workflow.wait_for_event() to listen for
198
+ an event from the HTTPEventProvider instance (a Ray Serve deployment). Another
199
+ instance of HTTPListener will be started by the message_committed() of the
200
+ workflow.wait_for_event() to confirm that the event has been checkpointed.
201
+
202
+
203
+ Example definition
204
+ ==================
205
+
206
+ ```
207
+ class HTTPListener:
208
+
209
+ def __init__(self):
210
+
211
+ async def poll_for_event(self, event_key) -> Event:
212
+
213
+ async def event_checkpointed(self, event) -> None:
214
+
215
+ ```
216
+
217
+ Example Usage
218
+ =============
219
+
220
+ .. testcode::
221
+
222
+ import tempfile
223
+ from ray import workflow
224
+ from ray.workflow.http_event_provider import HTTPListener
225
+
226
+ temp_dir = tempfile.TemporaryDirectory()
227
+ ray.init(storage=f"file://{temp_dir.name}")
228
+
229
+ serve.start(detached=True)
230
+ event_node = workflow.wait_for_event(HTTPListener, event_key='')
231
+
232
+ @ray.remote
233
+ def handle_event(arg):
234
+ return arg
235
+
236
+ workflow.run_async(handle_event.bind(event_node), workflow_id="http_listener")
237
+ """
238
+
239
+ def __init__(self):
240
+ super().__init__()
241
+ try:
242
+ self.handle = ray.serve.get_app_handle(common.HTTP_EVENT_PROVIDER_NAME)
243
+ except ray.serve.exceptions.RayServeException:
244
+ mgr = workflow_access.get_management_actor()
245
+ ray.get(mgr.create_http_event_provider.remote())
246
+ self.handle = ray.serve.get_app_handle(common.HTTP_EVENT_PROVIDER_NAME)
247
+
248
+ async def poll_for_event(self, event_key: str = None) -> Event:
249
+ """workflow.wait_for_event calls this method to subscribe to the
250
+ HTTPEventProvider and return the received external event
251
+ Args:
252
+ event_key: a unique identifier to the receiving node in a workflow;
253
+ if missing, default to current workflow task id
254
+ Returns:
255
+ tuple(event_key, event_payload)
256
+ """
257
+ workflow_id = workflow_context.get_current_workflow_id()
258
+ if event_key is None:
259
+ event_key = workflow_context.get_current_task_id()
260
+
261
+ event_key_payload = await self.handle.get_event_payload.remote(
262
+ workflow_id, event_key
263
+ )
264
+ return event_key_payload
265
+
266
+ async def event_checkpointed(self, event: Event) -> None:
267
+ """workflow.wait_for_event calls this method after the event has
268
+ been checkpointed and a transaction can be safely committed."""
269
+ (event_key, _) = event
270
+ await self.handle.report_checkpointed.remote(
271
+ workflow_context.get_current_workflow_id(), event_key, True
272
+ )
.venv/lib/python3.11/site-packages/ray/workflow/serialization.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from dataclasses import dataclass
3
+ import logging
4
+ import os
5
+
6
+ import ray
7
+ from ray import cloudpickle
8
+ from ray.types import ObjectRef
9
+ from ray.workflow import common, workflow_storage
10
+ from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING
11
+
12
+ from collections import ChainMap
13
+ import io
14
+
15
+ if TYPE_CHECKING:
16
+ from ray.actor import ActorHandle
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def init_manager() -> None:
22
+ get_or_create_manager(warn_on_creation=False)
23
+
24
+
25
+ def get_or_create_manager(warn_on_creation: bool = True) -> "ActorHandle":
26
+ """Get or create the storage manager."""
27
+ # TODO(suquark): We should not get the actor everytime. We also need to
28
+ # resume the actor if it failed. Using a global variable to cache the
29
+ # actor seems not enough to resume the actor, because there is no
30
+ # aliveness detection for an actor.
31
+ try:
32
+ return ray.get_actor(
33
+ common.STORAGE_ACTOR_NAME, namespace=common.MANAGEMENT_ACTOR_NAMESPACE
34
+ )
35
+ except ValueError:
36
+ if warn_on_creation:
37
+ logger.warning(
38
+ "Cannot access workflow serialization manager. It "
39
+ "could be because "
40
+ "the workflow manager exited unexpectedly. A new "
41
+ "workflow manager is being created. "
42
+ )
43
+ handle = Manager.options(
44
+ name=common.STORAGE_ACTOR_NAME,
45
+ namespace=common.MANAGEMENT_ACTOR_NAMESPACE,
46
+ lifetime="detached",
47
+ ).remote()
48
+ ray.get(handle.ping.remote())
49
+ return handle
50
+
51
+
52
+ @dataclass
53
+ class Upload:
54
+ identifier_ref: ObjectRef[str]
55
+ upload_task: ObjectRef[None]
56
+
57
+
58
+ @ray.remote(num_cpus=0)
59
+ class Manager:
60
+ """
61
+ Responsible for deduping the serialization/upload of object references.
62
+ """
63
+
64
+ def __init__(self):
65
+ self._uploads: Dict[ray.ObjectRef, Upload] = {}
66
+ self._num_uploads = 0
67
+
68
+ def ping(self) -> None:
69
+ """
70
+ Trivial function to ensure actor creation is successful.
71
+ """
72
+ return None
73
+
74
+ async def save_objectref(
75
+ self, ref_tuple: Tuple[ray.ObjectRef], workflow_id: "str"
76
+ ) -> Tuple[List[str], ray.ObjectRef]:
77
+ """Serialize and upload an object reference exactly once.
78
+
79
+ Args:
80
+ ref_tuple: A 1-element tuple which wraps the reference.
81
+
82
+ Returns:
83
+ A pair. The first element is the paths the ref will be uploaded to.
84
+ The second is an object reference to the upload task.
85
+ """
86
+ (ref,) = ref_tuple
87
+ # Use the hex as the key to avoid holding a reference to the object.
88
+ key = (ref.hex(), workflow_id)
89
+
90
+ if key not in self._uploads:
91
+ # TODO(Alex): We should probably eventually free these refs.
92
+ identifier_ref = common.calculate_identifier.remote(ref)
93
+ upload_task = _put_helper.remote(identifier_ref, ref, workflow_id)
94
+ self._uploads[key] = Upload(
95
+ identifier_ref=identifier_ref, upload_task=upload_task
96
+ )
97
+ self._num_uploads += 1
98
+
99
+ info = self._uploads[key]
100
+ identifer = await info.identifier_ref
101
+ key = _obj_id_to_key(identifer)
102
+ return key, info.upload_task
103
+
104
+ async def export_stats(self) -> Dict[str, Any]:
105
+ return {"num_uploads": self._num_uploads}
106
+
107
+
108
+ OBJECTS_DIR = "objects"
109
+
110
+
111
+ def _obj_id_to_key(object_id: str) -> str:
112
+ return os.path.join(OBJECTS_DIR, object_id)
113
+
114
+
115
+ @ray.remote(num_cpus=0)
116
+ def _put_helper(identifier: str, obj: Any, workflow_id: str) -> None:
117
+ # TODO (Alex): This check isn't sufficient, it only works for directly
118
+ # nested object refs.
119
+ if isinstance(obj, ray.ObjectRef):
120
+ raise NotImplementedError(
121
+ "Workflow does not support checkpointing nested object references yet."
122
+ )
123
+ key = _obj_id_to_key(identifier)
124
+
125
+ dump_to_storage(
126
+ key,
127
+ obj,
128
+ workflow_id,
129
+ workflow_storage.WorkflowStorage(workflow_id),
130
+ update_existing=False,
131
+ )
132
+
133
+
134
+ def _reduce_objectref(
135
+ workflow_id: str,
136
+ obj_ref: ObjectRef,
137
+ tasks: List[ObjectRef],
138
+ ):
139
+ manager = get_or_create_manager()
140
+ key, task = ray.get(manager.save_objectref.remote((obj_ref,), workflow_id))
141
+
142
+ assert task
143
+ tasks.append(task)
144
+
145
+ return _load_object_ref, (key, workflow_id)
146
+
147
+
148
+ def dump_to_storage(
149
+ key: str,
150
+ obj: Any,
151
+ workflow_id: str,
152
+ storage: "workflow_storage.WorkflowStorage",
153
+ update_existing=True,
154
+ ) -> None:
155
+ """Serializes and puts arbitrary object, handling references. The object will
156
+ be uploaded at `paths`. Any object references will be uploaded to their
157
+ global, remote storage.
158
+
159
+ Args:
160
+ key: The key of the object.
161
+ obj: The object to serialize. If it contains object references, those
162
+ will be serialized too.
163
+ workflow_id: The workflow id.
164
+ storage: The storage to use. If obj contains object references,
165
+ `storage.put` will be called on them individually.
166
+ update_existing: If False, the object will not be uploaded if the path
167
+ exists.
168
+ """
169
+ if not update_existing:
170
+ if storage._exists(key):
171
+ return
172
+
173
+ tasks = []
174
+
175
+ # NOTE: Cloudpickle doesn't support private dispatch tables, so we extend
176
+ # the cloudpickler instead to avoid changing cloudpickle's global dispatch
177
+ # table which is shared with `ray.put`. See
178
+ # https://github.com/cloudpipe/cloudpickle/issues/437
179
+ class ObjectRefPickler(cloudpickle.CloudPickler):
180
+ _object_ref_reducer = {
181
+ ray.ObjectRef: lambda ref: _reduce_objectref(workflow_id, ref, tasks)
182
+ }
183
+ dispatch_table = ChainMap(
184
+ _object_ref_reducer, cloudpickle.CloudPickler.dispatch_table
185
+ )
186
+ dispatch = dispatch_table
187
+
188
+ ray.get(tasks)
189
+
190
+ # TODO(Alex): We should be able to do this without the extra buffer.
191
+ with io.BytesIO() as f:
192
+ pickler = ObjectRefPickler(f)
193
+ pickler.dump(obj)
194
+ f.seek(0)
195
+ # use the underlying storage to avoid cyclic calls of "dump_to_storage"
196
+ storage._storage.put(key, f.read())
197
+
198
+
199
+ @ray.remote
200
+ def _load_ref_helper(key: str, workflow_id: str):
201
+ # TODO(Alex): We should stream the data directly into `cloudpickle.load`.
202
+ storage = workflow_storage.WorkflowStorage(workflow_id)
203
+ return storage._get(key)
204
+
205
+
206
+ # TODO (Alex): We should use weakrefs here instead requiring a context manager.
207
+ _object_cache: Optional[Dict[str, ray.ObjectRef]] = None
208
+
209
+
210
+ def _load_object_ref(key: str, workflow_id: str) -> ray.ObjectRef:
211
+ global _object_cache
212
+ if _object_cache is None:
213
+ return _load_ref_helper.remote(key, workflow_id)
214
+
215
+ if _object_cache is None:
216
+ return _load_ref_helper.remote(key, workflow_id)
217
+
218
+ if key not in _object_cache:
219
+ _object_cache[key] = _load_ref_helper.remote(key, workflow_id)
220
+
221
+ return _object_cache[key]
222
+
223
+
224
+ @contextlib.contextmanager
225
+ def objectref_cache() -> Generator:
226
+ """A reentrant caching context for object refs."""
227
+ global _object_cache
228
+ clear_cache = _object_cache is None
229
+ if clear_cache:
230
+ _object_cache = {}
231
+ try:
232
+ yield
233
+ finally:
234
+ if clear_cache:
235
+ _object_cache = None
.venv/lib/python3.11/site-packages/ray/workflow/serialization_context.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from typing import List, Any, Dict
3
+
4
+ from ray.util.serialization import register_serializer, deregister_serializer
5
+ from ray.workflow.common import WorkflowRef
6
+
7
+
8
+ def _resolve_workflow_refs(index: int) -> Any:
9
+ raise ValueError("There is no context for resolving workflow refs.")
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def workflow_args_serialization_context(workflow_refs: List[WorkflowRef]) -> None:
14
+ """
15
+ This serialization context reduces workflow input arguments to three
16
+ parts:
17
+
18
+ 1. A workflow input placeholder. It is an object without 'Workflow' and
19
+ 'ObjectRef' object. They are replaced with integer indices. During
20
+ deserialization, we can refill the placeholder with a list of
21
+ 'Workflow' and a list of 'ObjectRef'. This provides us great
22
+ flexibility, for example, during recovery we can plug an alternative
23
+ list of 'Workflow' and 'ObjectRef', since we lose the original ones.
24
+ 2. A list of 'Workflow'. There is no duplication in it.
25
+ 3. A list of 'ObjectRef'. There is no duplication in it.
26
+
27
+ We do not allow duplication because in the arguments duplicated workflows
28
+ and object refs are shared by reference. So when deserialized, we also
29
+ want them to be shared by reference. See
30
+ "tests/test_object_deref.py:deref_shared" as an example.
31
+
32
+ The deduplication works like this:
33
+ Inputs: [A B A B C C A]
34
+ Output List: [A B C]
35
+ Index in placeholder: [0 1 0 1 2 2 0]
36
+
37
+ Args:
38
+ workflow_refs: Output list of workflows or references to workflows.
39
+ """
40
+ deduplicator: Dict[WorkflowRef, int] = {}
41
+
42
+ def serializer(w):
43
+ if w in deduplicator:
44
+ return deduplicator[w]
45
+ if isinstance(w, WorkflowRef):
46
+ # The ref should be resolved by the workflow management actor
47
+ # when treated as the input of a workflow, so we remove the ref here.
48
+ w.ref = None
49
+ i = len(workflow_refs)
50
+ workflow_refs.append(w)
51
+ deduplicator[w] = i
52
+ return i
53
+
54
+ register_serializer(
55
+ WorkflowRef,
56
+ serializer=serializer,
57
+ deserializer=_resolve_workflow_refs,
58
+ )
59
+
60
+ try:
61
+ yield
62
+ finally:
63
+ # we do not want to serialize Workflow objects in other places.
64
+ deregister_serializer(WorkflowRef)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def workflow_args_resolving_context(workflow_ref_mapping: List[Any]) -> None:
69
+ """
70
+ This context resolves workflows and object refs inside workflow
71
+ arguments into correct values.
72
+
73
+ Args:
74
+ workflow_ref_mapping: List of workflow refs.
75
+ """
76
+ global _resolve_workflow_refs
77
+ _resolve_workflow_refs_bak = _resolve_workflow_refs
78
+ _resolve_workflow_refs = workflow_ref_mapping.__getitem__
79
+
80
+ try:
81
+ yield
82
+ finally:
83
+ _resolve_workflow_refs = _resolve_workflow_refs_bak
84
+
85
+
86
+ class _KeepWorkflowRefs:
87
+ def __init__(self, index: int):
88
+ self._index = index
89
+
90
+ def __reduce__(self):
91
+ return _resolve_workflow_refs, (self._index,)
92
+
93
+
94
+ @contextlib.contextmanager
95
+ def workflow_args_keeping_context() -> None:
96
+ """
97
+ This context only read workflow arguments. Workflows inside
98
+ are untouched and can be serialized again properly.
99
+ """
100
+ global _resolve_workflow_refs
101
+ _resolve_workflow_refs_bak = _resolve_workflow_refs
102
+
103
+ # we must capture the old functions to prevent self-referencing.
104
+ def _keep_workflow_refs(index: int):
105
+ return _KeepWorkflowRefs(index)
106
+
107
+ _resolve_workflow_refs = _keep_workflow_refs
108
+
109
+ try:
110
+ yield
111
+ finally:
112
+ _resolve_workflow_refs = _resolve_workflow_refs_bak
.venv/lib/python3.11/site-packages/ray/workflow/task_executor.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ import logging
4
+ from typing import List, Tuple, Any, Dict, Callable, TYPE_CHECKING
5
+ import ray
6
+ from ray import ObjectRef
7
+ from ray._private import signature
8
+
9
+ from ray.dag import DAGNode
10
+ from ray.workflow import workflow_context
11
+ from ray.workflow.workflow_context import get_task_status_info
12
+ from ray.workflow import serialization_context
13
+ from ray.workflow import workflow_storage
14
+
15
+ from ray.workflow.common import (
16
+ WorkflowStatus,
17
+ WorkflowExecutionMetadata,
18
+ TaskType,
19
+ TaskID,
20
+ WorkflowRef,
21
+ CheckpointMode,
22
+ )
23
+ from ray.workflow.workflow_state import WorkflowExecutionState
24
+ from ray.workflow.workflow_state_from_dag import workflow_state_from_dag
25
+
26
+ if TYPE_CHECKING:
27
+ from ray.workflow.common import (
28
+ WorkflowTaskRuntimeOptions,
29
+ )
30
+ from ray.workflow.workflow_context import WorkflowTaskContext
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def get_task_executor(task_options: "WorkflowTaskRuntimeOptions"):
37
+ if task_options.task_type == TaskType.FUNCTION:
38
+ # prevent automatic lineage reconstruction
39
+ task_options.ray_options["max_retries"] = 0
40
+ # prevent retrying exception by Ray
41
+ task_options.ray_options["retry_exceptions"] = False
42
+ executor = _workflow_task_executor_remote.options(
43
+ **task_options.ray_options
44
+ ).remote
45
+ else:
46
+ raise ValueError(f"Invalid task type {task_options.task_type}")
47
+ return executor
48
+
49
+
50
+ def _workflow_task_executor(
51
+ func: Callable,
52
+ context: "WorkflowTaskContext",
53
+ task_id: "TaskID",
54
+ baked_inputs: "_BakedWorkflowInputs",
55
+ runtime_options: "WorkflowTaskRuntimeOptions",
56
+ ) -> Tuple[Any, Any]:
57
+ """Executor function for workflow task.
58
+
59
+ Args:
60
+ task_id: ID of the task.
61
+ func: The workflow task function.
62
+ baked_inputs: The processed inputs for the task.
63
+ context: Workflow task context. Used to access correct storage etc.
64
+ runtime_options: Parameters for workflow task execution.
65
+
66
+ Returns:
67
+ Workflow task output.
68
+ """
69
+ with workflow_context.workflow_task_context(context):
70
+ store = workflow_storage.get_workflow_storage()
71
+ # Part 1: resolve inputs
72
+ args, kwargs = baked_inputs.resolve(store)
73
+
74
+ # Part 2: execute the task
75
+ try:
76
+ store.save_task_prerun_metadata(task_id, {"start_time": time.time()})
77
+ with workflow_context.workflow_execution():
78
+ logger.info(f"{get_task_status_info(WorkflowStatus.RUNNING)}")
79
+ output = func(*args, **kwargs)
80
+ store.save_task_postrun_metadata(task_id, {"end_time": time.time()})
81
+ except Exception as e:
82
+ # Always checkpoint the exception.
83
+ store.save_task_output(task_id, None, exception=e)
84
+ raise e
85
+
86
+ if isinstance(output, DAGNode):
87
+ output = workflow_state_from_dag(output, None, context.workflow_id)
88
+ execution_metadata = WorkflowExecutionMetadata(is_output_workflow=True)
89
+ else:
90
+ execution_metadata = WorkflowExecutionMetadata()
91
+ if runtime_options.catch_exceptions:
92
+ output = (output, None)
93
+
94
+ # Part 3: save outputs
95
+ # TODO(suquark): Validate checkpoint options before commit the task.
96
+ if CheckpointMode(runtime_options.checkpoint) == CheckpointMode.SYNC:
97
+ if isinstance(output, WorkflowExecutionState):
98
+ store.save_workflow_execution_state(task_id, output)
99
+ else:
100
+ store.save_task_output(task_id, output, exception=None)
101
+ return execution_metadata, output
102
+
103
+
104
+ @ray.remote(num_returns=2)
105
+ def _workflow_task_executor_remote(
106
+ func: Callable,
107
+ context: "WorkflowTaskContext",
108
+ job_id: str,
109
+ task_id: "TaskID",
110
+ baked_inputs: "_BakedWorkflowInputs",
111
+ runtime_options: "WorkflowTaskRuntimeOptions",
112
+ ) -> Any:
113
+ """The remote version of '_workflow_task_executor'."""
114
+ with workflow_context.workflow_logging_context(job_id):
115
+ return _workflow_task_executor(
116
+ func, context, task_id, baked_inputs, runtime_options
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class _BakedWorkflowInputs:
122
+ """This class stores pre-processed inputs for workflow task execution.
123
+ Especially, all input workflows to the workflow task will be scheduled,
124
+ and their outputs (ObjectRefs) replace the original workflows."""
125
+
126
+ args: "ObjectRef"
127
+ workflow_refs: "List[WorkflowRef]"
128
+
129
+ def resolve(self, store: workflow_storage.WorkflowStorage) -> Tuple[List, Dict]:
130
+ """
131
+ This function resolves the inputs for the code inside
132
+ a workflow task (works on the callee side). For outputs from other
133
+ workflows, we resolve them into object instances inplace.
134
+
135
+ For each ObjectRef argument, the function returns both the ObjectRef
136
+ and the object instance. If the ObjectRef is a chain of nested
137
+ ObjectRefs, then we resolve it recursively until we get the
138
+ object instance, and we return the *direct* ObjectRef of the
139
+ instance. This function does not resolve ObjectRef
140
+ inside another object (e.g. list of ObjectRefs) to give users some
141
+ flexibility.
142
+
143
+ Returns:
144
+ Instances of arguments.
145
+ """
146
+ workflow_ref_mapping = []
147
+ for r in self.workflow_refs:
148
+ if r.ref is None:
149
+ workflow_ref_mapping.append(store.load_task_output(r.task_id))
150
+ else:
151
+ workflow_ref_mapping.append(r.ref)
152
+
153
+ with serialization_context.workflow_args_resolving_context(
154
+ workflow_ref_mapping
155
+ ):
156
+ # reconstruct input arguments under correct serialization context
157
+ flattened_args: List[Any] = ray.get(self.args)
158
+
159
+ # dereference arguments like Ray remote functions
160
+ flattened_args = [
161
+ ray.get(a) if isinstance(a, ObjectRef) else a for a in flattened_args
162
+ ]
163
+ return signature.recover_args(flattened_args)
.venv/lib/python3.11/site-packages/ray/workflow/workflow_access.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import queue
4
+ from typing import Dict, List, Set, Optional, TYPE_CHECKING
5
+
6
+ import ray
7
+
8
+ from ray.workflow import common
9
+ from ray.workflow.common import WorkflowStatus, TaskID
10
+ from ray.workflow import workflow_state_from_storage
11
+ from ray.workflow import workflow_context
12
+ from ray.workflow import workflow_storage
13
+ from ray.workflow.exceptions import (
14
+ WorkflowCancellationError,
15
+ WorkflowNotFoundError,
16
+ WorkflowNotResumableError,
17
+ WorkflowStillActiveError,
18
+ )
19
+ from ray.workflow.workflow_executor import WorkflowExecutor
20
+ from ray.workflow.workflow_state import WorkflowExecutionState
21
+ from ray.workflow.workflow_context import WorkflowTaskContext
22
+
23
+ if TYPE_CHECKING:
24
+ from ray.actor import ActorHandle
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class SelfResolvingObject:
30
+ def __init__(self, x):
31
+ self.x = x
32
+
33
+ def __reduce__(self):
34
+ return ray.get, (self.x,)
35
+
36
+
37
+ @ray.remote(num_cpus=0)
38
+ def load_task_output_from_storage(workflow_id: str, task_id: Optional[TaskID]):
39
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
40
+ tid = wf_store.inspect_output(task_id)
41
+ if tid is not None:
42
+ return wf_store.load_task_output(tid)
43
+ # TODO(suquark): Unify the error from "workflow.get_output" & "workflow.run_async".
44
+ # Currently they could be different, because "workflow.get_output" could
45
+ # get the output from a stopped workflow, it does not may sense to raise
46
+ # "WorkflowExecutionError" as the workflow is not running.
47
+ if task_id is not None:
48
+ raise ValueError(
49
+ f"Cannot load output from task id '{task_id}' in workflow '{workflow_id}'"
50
+ )
51
+ else:
52
+ raise ValueError(f"Cannot load output from workflow '{workflow_id}'")
53
+
54
+
55
+ @ray.remote(num_cpus=0)
56
+ def resume_workflow_task(
57
+ job_id: str,
58
+ workflow_id: str,
59
+ task_id: Optional[TaskID] = None,
60
+ ) -> WorkflowExecutionState:
61
+ """Resume a task of a workflow.
62
+
63
+ Args:
64
+ job_id: The ID of the job that submits the workflow execution. The ID
65
+ is used to identify the submitter of the workflow.
66
+ workflow_id: The ID of the workflow job. The ID is used to identify
67
+ the workflow.
68
+ task_id: The task to resume in the workflow.
69
+
70
+ Raises:
71
+ WorkflowNotResumableException: fail to resume the workflow.
72
+
73
+ Returns:
74
+ The execution result of the workflow, represented by Ray ObjectRef.
75
+ """
76
+ with workflow_context.workflow_logging_context(job_id):
77
+ try:
78
+ return workflow_state_from_storage.workflow_state_from_storage(
79
+ workflow_id, task_id
80
+ )
81
+ except Exception as e:
82
+ raise WorkflowNotResumableError(workflow_id) from e
83
+
84
+
85
+ # TODO(suquark): we may use an actor pool in the future if too much
86
+ # concurrent workflow access blocks the actor.
87
+ @ray.remote(num_cpus=0)
88
+ class WorkflowManagementActor:
89
+ """Keep the ownership and manage the workflow output."""
90
+
91
+ def __init__(self, max_running_workflows: int, max_pending_workflows: int):
92
+ self._workflow_executors: Dict[str, WorkflowExecutor] = {}
93
+
94
+ self._max_running_workflows: int = max_running_workflows
95
+ self._max_pending_workflows: int = max_pending_workflows
96
+
97
+ # 0 means infinite for queue
98
+ self._workflow_queue = queue.Queue(
99
+ max_pending_workflows if max_pending_workflows != -1 else 0
100
+ )
101
+
102
+ self._running_workflows: Set[str] = set()
103
+ self._queued_workflows: Dict[str, asyncio.Future] = {}
104
+ # TODO(suquark): We do not cleanup "_executed_workflows" because we need to
105
+ # know if users are running the same workflow again long after a workflow
106
+ # completes. One possible alternative solution is to check the workflow
107
+ # status in the storage.
108
+ self._executed_workflows: Set[str] = set()
109
+
110
+ def validate_init_options(
111
+ self, max_running_workflows: Optional[int], max_pending_workflows: Optional[int]
112
+ ):
113
+ if (
114
+ max_running_workflows is not None
115
+ and max_running_workflows != self._max_running_workflows
116
+ ) or (
117
+ max_pending_workflows is not None
118
+ and max_pending_workflows != self._max_pending_workflows
119
+ ):
120
+ raise ValueError(
121
+ "The workflow init is called again but the init options"
122
+ "does not match the original ones. Original options: "
123
+ f"max_running_workflows={self._max_running_workflows} "
124
+ f"max_pending_workflows={self._max_pending_workflows}; "
125
+ f"New options: max_running_workflows={max_running_workflows} "
126
+ f"max_pending_workflows={max_pending_workflows}."
127
+ )
128
+
129
+ def gen_task_id(self, workflow_id: str, task_name: str) -> str:
130
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
131
+ idx = wf_store.gen_task_id(task_name)
132
+ if idx == 0:
133
+ return task_name
134
+ else:
135
+ return f"{task_name}_{idx}"
136
+
137
+ def submit_workflow(
138
+ self,
139
+ workflow_id: str,
140
+ state: WorkflowExecutionState,
141
+ ignore_existing: bool = False,
142
+ ):
143
+ """Submit workflow. A submitted workflow can be executed later.
144
+
145
+ Args:
146
+ workflow_id: ID of the workflow.
147
+ state: The initial state of the workflow.
148
+ ignore_existing: Ignore existing executed workflows.
149
+ """
150
+ if workflow_id in self._workflow_executors:
151
+ raise RuntimeError(f"Workflow[id={workflow_id}] is being executed.")
152
+ if workflow_id in self._executed_workflows and not ignore_existing:
153
+ raise RuntimeError(f"Workflow[id={workflow_id}] has been executed.")
154
+
155
+ if state.output_task_id is None:
156
+ raise ValueError(
157
+ "No root DAG specified that generates output for the workflow."
158
+ )
159
+
160
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
161
+ if (
162
+ self._max_running_workflows != -1
163
+ and len(self._running_workflows) >= self._max_running_workflows
164
+ ):
165
+ try:
166
+ self._workflow_queue.put_nowait(workflow_id)
167
+ self._queued_workflows[workflow_id] = asyncio.Future()
168
+ wf_store.update_workflow_status(WorkflowStatus.PENDING)
169
+ except queue.Full:
170
+ # override with our error message
171
+ raise queue.Full("Workflow queue has been full") from None
172
+ else:
173
+ self._running_workflows.add(workflow_id)
174
+ wf_store.update_workflow_status(WorkflowStatus.RUNNING)
175
+ # initialize executor
176
+ self._workflow_executors[workflow_id] = WorkflowExecutor(state)
177
+
178
+ async def reconstruct_workflow(
179
+ self, job_id: str, context: WorkflowTaskContext
180
+ ) -> None:
181
+ """Reconstruct a (failed) workflow and submit it."""
182
+ state = await resume_workflow_task.remote(job_id, context.workflow_id)
183
+ self.submit_workflow(context.workflow_id, state, ignore_existing=True)
184
+
185
+ async def execute_workflow(
186
+ self,
187
+ job_id: str,
188
+ context: WorkflowTaskContext,
189
+ ) -> ray.ObjectRef:
190
+ """Execute a submitted workflow.
191
+
192
+ Args:
193
+ job_id: The ID of the job for logging.
194
+ context: The execution context.
195
+ Returns:
196
+ An object ref that represent the result.
197
+ """
198
+ workflow_id = context.workflow_id
199
+ if workflow_id not in self._workflow_executors:
200
+ raise RuntimeError(f"Workflow '{workflow_id}' has not been submitted.")
201
+
202
+ pending_fut = self._queued_workflows.get(workflow_id)
203
+ if pending_fut is not None:
204
+ await pending_fut # wait until this workflow is ready to go
205
+
206
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
207
+ executor = self._workflow_executors[workflow_id]
208
+ try:
209
+ await executor.run_until_complete(job_id, context, wf_store)
210
+ return await self.get_output(workflow_id, executor.output_task_id)
211
+ finally:
212
+ self._workflow_executors.pop(workflow_id)
213
+ self._running_workflows.remove(workflow_id)
214
+ self._executed_workflows.add(workflow_id)
215
+ if not self._workflow_queue.empty():
216
+ # schedule another workflow from the pending queue
217
+ next_workflow_id = self._workflow_queue.get_nowait()
218
+ self._running_workflows.add(next_workflow_id)
219
+ fut = self._queued_workflows.pop(next_workflow_id)
220
+ fut.set_result(None)
221
+
222
+ async def cancel_workflow(self, workflow_id: str) -> None:
223
+ """Cancel workflow execution."""
224
+ if workflow_id in self._workflow_executors:
225
+ executor = self._workflow_executors[workflow_id]
226
+ fut = executor.get_task_output_async(executor.output_task_id)
227
+ executor.cancel()
228
+ try:
229
+ # Wait until cancelled, otherwise workflow status may not
230
+ # get updated after "workflow.cancel()" is called.
231
+ await fut
232
+ except WorkflowCancellationError:
233
+ pass
234
+ else:
235
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
236
+ wf_store.update_workflow_status(WorkflowStatus.CANCELED)
237
+
238
+ def get_workflow_status(self, workflow_id: str) -> WorkflowStatus:
239
+ """Get the status of the workflow."""
240
+ if workflow_id in self._workflow_executors:
241
+ if workflow_id in self._queued_workflows:
242
+ return WorkflowStatus.PENDING
243
+ return WorkflowStatus.RUNNING
244
+ store = workflow_storage.get_workflow_storage(workflow_id)
245
+ status = store.load_workflow_status()
246
+ if status == WorkflowStatus.NONE:
247
+ raise WorkflowNotFoundError(workflow_id)
248
+ elif status in WorkflowStatus.non_terminating_status():
249
+ return WorkflowStatus.RESUMABLE
250
+ return status
251
+
252
+ def is_workflow_non_terminating(self, workflow_id: str) -> bool:
253
+ """True if the workflow is still running or pending."""
254
+ return workflow_id in self._workflow_executors
255
+
256
+ def list_non_terminating_workflows(self) -> Dict[WorkflowStatus, List[str]]:
257
+ """List workflows whose status are not of terminated status."""
258
+ result = {WorkflowStatus.RUNNING: [], WorkflowStatus.PENDING: []}
259
+ for wf in self._workflow_executors.keys():
260
+ if wf in self._running_workflows:
261
+ result[WorkflowStatus.RUNNING].append(wf)
262
+ else:
263
+ result[WorkflowStatus.PENDING].append(wf)
264
+ return result
265
+
266
+ async def get_output(
267
+ self, workflow_id: str, task_id: Optional[TaskID]
268
+ ) -> ray.ObjectRef:
269
+ """Get the output of a running workflow.
270
+
271
+ Args:
272
+ workflow_id: The ID of a workflow job.
273
+ task_id: If set, fetch the specific task output instead of the output
274
+ of the workflow.
275
+
276
+ Returns:
277
+ An object reference that can be used to retrieve the workflow result.
278
+ """
279
+ ref = None
280
+ if self.is_workflow_non_terminating(workflow_id):
281
+ executor = self._workflow_executors[workflow_id]
282
+ if task_id is None:
283
+ task_id = executor.output_task_id
284
+ workflow_ref = await executor.get_task_output_async(task_id)
285
+ task_id, ref = workflow_ref.task_id, workflow_ref.ref
286
+ if ref is None:
287
+ wf_store = workflow_storage.WorkflowStorage(workflow_id)
288
+ tid = wf_store.inspect_output(task_id)
289
+ if tid is not None:
290
+ ref = load_task_output_from_storage.remote(workflow_id, task_id)
291
+ elif task_id is not None:
292
+ raise ValueError(
293
+ f"Cannot load output from task id '{task_id}' in workflow "
294
+ f"'{workflow_id}'"
295
+ )
296
+ else:
297
+ raise ValueError(f"Cannot load output from workflow '{workflow_id}'")
298
+ return SelfResolvingObject(ref)
299
+
300
+ def delete_workflow(self, workflow_id: str) -> None:
301
+ """Delete a workflow, its checkpoints, and other information it may have
302
+ persisted to storage.
303
+
304
+ Args:
305
+ workflow_id: The workflow to delete.
306
+
307
+ Raises:
308
+ WorkflowStillActiveError: The workflow is still active.
309
+ WorkflowNotFoundError: The workflow does not exist.
310
+ """
311
+ if self.is_workflow_non_terminating(workflow_id):
312
+ raise WorkflowStillActiveError("DELETE", workflow_id)
313
+ wf_storage = workflow_storage.WorkflowStorage(workflow_id)
314
+ wf_storage.delete_workflow()
315
+ self._executed_workflows.discard(workflow_id)
316
+
317
+ def create_http_event_provider(self) -> None:
318
+ """Deploy an HTTPEventProvider as a Serve deployment with
319
+ name = common.HTTP_EVENT_PROVIDER_NAME, if one doesn't exist
320
+ """
321
+ ray.serve.start(detached=True)
322
+ provider_exists = (
323
+ common.HTTP_EVENT_PROVIDER_NAME in ray.serve.status().applications
324
+ )
325
+ if not provider_exists:
326
+ from ray.workflow.http_event_provider import HTTPEventProvider
327
+
328
+ ray.serve.run(
329
+ HTTPEventProvider.bind(),
330
+ name=common.HTTP_EVENT_PROVIDER_NAME,
331
+ route_prefix="/event",
332
+ )
333
+
334
+ def ready(self) -> None:
335
+ """A no-op to make sure the actor is ready."""
336
+
337
+
338
+ def init_management_actor(
339
+ max_running_workflows: Optional[int], max_pending_workflows: Optional[int]
340
+ ) -> None:
341
+ """Initialize WorkflowManagementActor.
342
+
343
+ Args:
344
+ max_running_workflows: The maximum number of concurrently running workflows.
345
+ Use -1 as infinity. Use 'None' for keeping the original value if the actor
346
+ exists, or it is equivalent to infinity if the actor does not exist.
347
+ max_pending_workflows: The maximum number of queued workflows.
348
+ Use -1 as infinity. Use 'None' for keeping the original value if the actor
349
+ exists, or it is equivalent to infinity if the actor does not exist.
350
+ """
351
+ try:
352
+ actor = get_management_actor()
353
+ # Check if max_running_workflows/max_pending_workflows
354
+ # matches the previous settings.
355
+ ray.get(
356
+ actor.validate_init_options.remote(
357
+ max_running_workflows, max_pending_workflows
358
+ )
359
+ )
360
+ except ValueError:
361
+ logger.info("Initializing workflow manager...")
362
+ if max_running_workflows is None:
363
+ max_running_workflows = -1
364
+ if max_pending_workflows is None:
365
+ max_pending_workflows = -1
366
+ # the actor does not exist
367
+ actor = WorkflowManagementActor.options(
368
+ name=common.MANAGEMENT_ACTOR_NAME,
369
+ namespace=common.MANAGEMENT_ACTOR_NAMESPACE,
370
+ lifetime="detached",
371
+ ).remote(max_running_workflows, max_pending_workflows)
372
+ # No-op to ensure the actor is created before the driver exits.
373
+ ray.get(actor.ready.remote())
374
+
375
+
376
+ def get_management_actor() -> "ActorHandle":
377
+ return ray.get_actor(
378
+ common.MANAGEMENT_ACTOR_NAME, namespace=common.MANAGEMENT_ACTOR_NAMESPACE
379
+ )
.venv/lib/python3.11/site-packages/ray/workflow/workflow_context.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import ray
7
+ from ray._private.ray_logging import configure_log_file, get_worker_log_file_name
8
+ from ray.workflow.common import CheckpointModeType, WorkflowStatus
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class WorkflowTaskContext:
15
+ """
16
+ The structure for saving workflow task context. The context provides
17
+ critical info (e.g. where to checkpoint, which is its parent task)
18
+ for the task to execute correctly.
19
+ """
20
+
21
+ # ID of the workflow.
22
+ workflow_id: Optional[str] = None
23
+ # ID of the current task.
24
+ task_id: str = ""
25
+ # ID of the task that creates the current task.
26
+ creator_task_id: str = ""
27
+ # The checkpoint context of parent workflow tasks.
28
+ checkpoint: CheckpointModeType = True
29
+ # The context of catching exceptions.
30
+ catch_exceptions: bool = False
31
+
32
+
33
+ _context: Optional[WorkflowTaskContext] = None
34
+
35
+
36
+ @contextmanager
37
+ def workflow_task_context(context) -> None:
38
+ """Initialize the workflow task context.
39
+
40
+ Args:
41
+ context: The new context.
42
+ """
43
+ global _context
44
+ original_context = _context
45
+ try:
46
+ _context = context
47
+ yield
48
+ finally:
49
+ _context = original_context
50
+
51
+
52
+ def get_workflow_task_context() -> Optional[WorkflowTaskContext]:
53
+ return _context
54
+
55
+
56
+ def get_current_task_id() -> str:
57
+ """Get the current workflow task ID. Empty means we are in
58
+ the workflow job driver."""
59
+ return get_workflow_task_context().task_id
60
+
61
+
62
+ def get_current_workflow_id() -> str:
63
+ assert _context is not None
64
+ return _context.workflow_id
65
+
66
+
67
+ def get_name() -> str:
68
+ return f"{get_current_workflow_id()}@{get_current_task_id()}"
69
+
70
+
71
+ def get_task_status_info(status: WorkflowStatus) -> str:
72
+ assert _context is not None
73
+ return f"Task status [{status.value}]\t[{get_name()}]"
74
+
75
+
76
+ _in_workflow_execution = False
77
+
78
+
79
+ @contextmanager
80
+ def workflow_execution() -> None:
81
+ """Scope for workflow task execution."""
82
+ global _in_workflow_execution
83
+ try:
84
+ _in_workflow_execution = True
85
+ yield
86
+ finally:
87
+ _in_workflow_execution = False
88
+
89
+
90
+ def in_workflow_execution() -> bool:
91
+ """Whether we are in workflow task execution."""
92
+ global _in_workflow_execution
93
+ return _in_workflow_execution
94
+
95
+
96
+ @contextmanager
97
+ def workflow_logging_context(job_id) -> None:
98
+ """Initialize the workflow logging context.
99
+
100
+ Workflow executions are running as remote functions from
101
+ WorkflowManagementActor. Without logging redirection, workflow
102
+ inner execution logs will be pushed to the driver that initially
103
+ created WorkflowManagementActor rather than the driver that
104
+ actually submits the current workflow execution.
105
+ We use this conext manager to re-configure the log files to send
106
+ the logs to the correct driver, and to restore the log files once
107
+ the execution is done.
108
+
109
+ Args:
110
+ job_id: The ID of the job that submits the workflow execution.
111
+ """
112
+ node = ray._private.worker._global_node
113
+ original_out_file, original_err_file = node.get_log_file_handles(
114
+ get_worker_log_file_name("WORKER")
115
+ )
116
+ out_file, err_file = node.get_log_file_handles(
117
+ get_worker_log_file_name("WORKER", job_id)
118
+ )
119
+ try:
120
+ configure_log_file(out_file, err_file)
121
+ yield
122
+ finally:
123
+ configure_log_file(original_out_file, original_err_file)
.venv/lib/python3.11/site-packages/ray/workflow/workflow_executor.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Iterator, Optional, Tuple, TYPE_CHECKING
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ from collections import defaultdict
7
+
8
+ import ray
9
+ from ray.exceptions import RayTaskError, RayError
10
+
11
+ from ray.workflow.common import (
12
+ WorkflowRef,
13
+ WorkflowExecutionMetadata,
14
+ WorkflowStatus,
15
+ TaskID,
16
+ )
17
+ from ray.workflow.exceptions import WorkflowCancellationError, WorkflowExecutionError
18
+ from ray.workflow.task_executor import get_task_executor, _BakedWorkflowInputs
19
+ from ray.workflow.workflow_state import (
20
+ WorkflowExecutionState,
21
+ TaskExecutionMetadata,
22
+ Task,
23
+ )
24
+
25
+ if TYPE_CHECKING:
26
+ from ray.workflow.workflow_context import WorkflowTaskContext
27
+ from ray.workflow.workflow_storage import WorkflowStorage
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class WorkflowExecutor:
33
+ def __init__(
34
+ self,
35
+ state: WorkflowExecutionState,
36
+ ):
37
+ """The core logic of executing a workflow.
38
+
39
+ This class is responsible for:
40
+
41
+ - Dependency resolving.
42
+ - Task scheduling.
43
+ - Reference counting.
44
+ - Garbage collection.
45
+ - Continuation handling and scheduling.
46
+ - Error handling.
47
+ - Responding callbacks.
48
+
49
+ It borrows some design of event loop in asyncio,
50
+ e.g., 'run_until_complete'.
51
+
52
+ Args:
53
+ state: The initial state of the workflow.
54
+ """
55
+ self._state = state
56
+ self._completion_queue = asyncio.Queue()
57
+ self._task_done_callbacks: Dict[TaskID, List[asyncio.Future]] = defaultdict(
58
+ list
59
+ )
60
+
61
+ def is_running(self) -> bool:
62
+ """The state is running, if there are tasks to be run or running tasks."""
63
+ return bool(self._state.frontier_to_run or self._state.running_frontier)
64
+
65
+ def get_state(self) -> WorkflowExecutionState:
66
+ return self._state
67
+
68
+ @property
69
+ def output_task_id(self) -> TaskID:
70
+ return self._state.output_task_id
71
+
72
+ async def run_until_complete(
73
+ self, job_id: str, context: "WorkflowTaskContext", wf_store: "WorkflowStorage"
74
+ ):
75
+ """Drive the state util it completes.
76
+
77
+ Args:
78
+ job_id: The Ray JobID for logging properly.
79
+ context: The context of workflow execution.
80
+ wf_store: The store for the workflow.
81
+
82
+ # TODO(suquark): move job_id inside context
83
+ """
84
+ workflow_id = context.workflow_id
85
+ wf_store.update_workflow_status(WorkflowStatus.RUNNING)
86
+ logger.info(f"Workflow job [id={workflow_id}] started.")
87
+
88
+ self._state.construct_scheduling_plan(self._state.output_task_id)
89
+ self._state.init_context(context)
90
+
91
+ while self.is_running():
92
+ # ------------ poll queued tasks ------------
93
+ queued_tasks = self._poll_queued_tasks()
94
+
95
+ # --------------- submit task ---------------
96
+ for task_id in queued_tasks:
97
+ # '_submit_ray_task' submit a Ray task based on the workflow task.
98
+ self._submit_ray_task(task_id, job_id=job_id)
99
+ # '_post_process_submit_task' updates the state related to task
100
+ # submission.
101
+ self._post_process_submit_task(task_id, wf_store)
102
+
103
+ self._garbage_collect()
104
+
105
+ # ------------ poll ready tasks ------------
106
+ ready_futures = await self._poll_ready_tasks()
107
+
108
+ # ----------- handle ready tasks -----------
109
+ await asyncio.gather(
110
+ *[
111
+ self._handle_ready_task(
112
+ fut, workflow_id=workflow_id, wf_store=wf_store
113
+ )
114
+ for fut in ready_futures
115
+ ]
116
+ )
117
+
118
+ # prevent leaking ObjectRefs into the next iteration
119
+ del ready_futures
120
+
121
+ wf_store.update_workflow_status(WorkflowStatus.SUCCESSFUL)
122
+ logger.info(f"Workflow '{workflow_id}' completes successfully.")
123
+
124
+ # set errors for pending workflow outputs
125
+ for task_id, futures in self._task_done_callbacks.items():
126
+ err = ValueError(
127
+ f"The workflow haven't yet produced output of task '{task_id}' "
128
+ f"after workflow execution completes."
129
+ )
130
+ for fut in futures:
131
+ if not fut.done():
132
+ fut.set_exception(err)
133
+
134
+ def cancel(self) -> None:
135
+ """Cancel the running workflow."""
136
+ for fut, workflow_ref in self._state.running_frontier.items():
137
+ fut.cancel()
138
+ try:
139
+ ray.cancel(workflow_ref.ref, force=True)
140
+ except Exception:
141
+ pass
142
+
143
+ def _poll_queued_tasks(self) -> List[TaskID]:
144
+ tasks = []
145
+ while True:
146
+ task_id = self._state.pop_frontier_to_run()
147
+ if task_id is None:
148
+ break
149
+ tasks.append(task_id)
150
+ return tasks
151
+
152
+ def _submit_ray_task(self, task_id: TaskID, job_id: str) -> None:
153
+ """Submit a workflow task as a Ray task."""
154
+ state = self._state
155
+ baked_inputs = _BakedWorkflowInputs(
156
+ args=state.task_input_args[task_id],
157
+ workflow_refs=[
158
+ state.get_input(d) for d in state.upstream_dependencies[task_id]
159
+ ],
160
+ )
161
+ task = state.tasks[task_id]
162
+ executor = get_task_executor(task.options)
163
+ metadata_ref, output_ref = executor(
164
+ task.func_body,
165
+ state.task_context[task_id],
166
+ job_id,
167
+ task_id,
168
+ baked_inputs,
169
+ task.options,
170
+ )
171
+ # The input workflow is not a reference to an executed workflow.
172
+ future = asyncio.wrap_future(metadata_ref.future())
173
+ future.add_done_callback(self._completion_queue.put_nowait)
174
+
175
+ state.insert_running_frontier(future, WorkflowRef(task_id, ref=output_ref))
176
+ state.task_execution_metadata[task_id] = TaskExecutionMetadata(
177
+ submit_time=time.time()
178
+ )
179
+
180
+ def _post_process_submit_task(
181
+ self, task_id: TaskID, store: "WorkflowStorage"
182
+ ) -> None:
183
+ """Update dependencies and reference count etc. after task submission."""
184
+ state = self._state
185
+ if task_id in state.continuation_root:
186
+ if state.tasks[task_id].options.checkpoint:
187
+ store.update_continuation_output_link(
188
+ state.continuation_root[task_id], task_id
189
+ )
190
+ else:
191
+ # update reference counting if the task is not a continuation
192
+ for c in state.upstream_dependencies[task_id]:
193
+ state.reference_set[c].remove(task_id)
194
+ if not state.reference_set[c]:
195
+ del state.reference_set[c]
196
+ state.free_outputs.add(c)
197
+
198
+ def _garbage_collect(self) -> None:
199
+ """Garbage collect the output refs of tasks.
200
+
201
+ Currently, this is done after task submission, because when a task
202
+ starts, we no longer needs its inputs (i.e. outputs from other tasks).
203
+
204
+ # TODO(suquark): We may need to improve garbage collection
205
+ # when taking more fault tolerant cases into consideration.
206
+ """
207
+ state = self._state
208
+ while state.free_outputs:
209
+ # garbage collect all free outputs immediately
210
+ gc_task_id = state.free_outputs.pop()
211
+ assert state.get_input(gc_task_id) is not None
212
+ state.output_map.pop(gc_task_id, None)
213
+
214
+ async def _poll_ready_tasks(self) -> List[asyncio.Future]:
215
+ cq = self._completion_queue
216
+ ready_futures = []
217
+ rf = await cq.get()
218
+ ready_futures.append(rf)
219
+ # get all remaining futures in the queue
220
+ while not cq.empty():
221
+ ready_futures.append(cq.get_nowait())
222
+ return ready_futures
223
+
224
+ def _iter_callstack(self, task_id: TaskID) -> Iterator[Tuple[TaskID, Task]]:
225
+ state = self._state
226
+ while task_id in state.task_context and task_id in state.tasks:
227
+ yield task_id, state.tasks[task_id]
228
+ task_id = state.task_context[task_id].creator_task_id
229
+
230
+ def _retry_failed_task(
231
+ self, workflow_id: str, failed_task_id: TaskID, exc: Exception
232
+ ) -> bool:
233
+ state = self._state
234
+ is_application_error = isinstance(exc, RayTaskError)
235
+ options = state.tasks[failed_task_id].options
236
+ if not is_application_error or options.retry_exceptions:
237
+ if state.task_retries[failed_task_id] < options.max_retries:
238
+ state.task_retries[failed_task_id] += 1
239
+ logger.info(
240
+ f"Retry [{workflow_id}@{failed_task_id}] "
241
+ f"({state.task_retries[failed_task_id]}/{options.max_retries})"
242
+ )
243
+ state.construct_scheduling_plan(failed_task_id)
244
+ return True
245
+ return False
246
+
247
+ async def _catch_failed_task(
248
+ self, workflow_id: str, failed_task_id: TaskID, exc: Exception
249
+ ) -> bool:
250
+ # lookup a creator task that catches the exception
251
+ is_application_error = isinstance(exc, RayTaskError)
252
+ exception_catcher = None
253
+ if is_application_error:
254
+ for t, task in self._iter_callstack(failed_task_id):
255
+ if task.options.catch_exceptions:
256
+ exception_catcher = t
257
+ break
258
+ if exception_catcher is not None:
259
+ logger.info(
260
+ f"Exception raised by '{workflow_id}@{failed_task_id}' is caught by "
261
+ f"'{workflow_id}@{exception_catcher}'"
262
+ )
263
+ # assign output to exception catching task;
264
+ # compose output with caught exception
265
+ await self._post_process_ready_task(
266
+ exception_catcher,
267
+ metadata=WorkflowExecutionMetadata(),
268
+ output_ref=WorkflowRef(failed_task_id, ray.put((None, exc))),
269
+ )
270
+ # TODO(suquark): cancel other running tasks?
271
+ return True
272
+ return False
273
+
274
+ async def _handle_ready_task(
275
+ self, fut: asyncio.Future, workflow_id: str, wf_store: "WorkflowStorage"
276
+ ) -> None:
277
+ """Handle ready task, especially about its exception."""
278
+ state = self._state
279
+ output_ref = state.pop_running_frontier(fut)
280
+ task_id = output_ref.task_id
281
+ try:
282
+ metadata: WorkflowExecutionMetadata = fut.result()
283
+ state.task_execution_metadata[task_id].finish_time = time.time()
284
+ logger.info(
285
+ f"Task status [{WorkflowStatus.SUCCESSFUL.value}]\t"
286
+ f"[{workflow_id}@{task_id}]"
287
+ )
288
+ await self._post_process_ready_task(task_id, metadata, output_ref)
289
+ except asyncio.CancelledError:
290
+ # NOTE: We must update the workflow status before broadcasting
291
+ # the exception. Otherwise, the workflow status would still be
292
+ # 'RUNNING' if check the status immediately after cancellation.
293
+ wf_store.update_workflow_status(WorkflowStatus.CANCELED)
294
+ logger.warning(f"Workflow '{workflow_id}' is cancelled.")
295
+ # broadcasting cancellation to all outputs
296
+ err = WorkflowCancellationError(workflow_id)
297
+ self._broadcast_exception(err)
298
+ raise err from None
299
+ except Exception as e:
300
+ if isinstance(e, RayTaskError):
301
+ reason = "an exception raised by the task"
302
+ elif isinstance(e, RayError):
303
+ reason = "a system error"
304
+ else:
305
+ reason = "an unknown error"
306
+ logger.error(
307
+ f"Task status [{WorkflowStatus.FAILED.value}] due to {reason}.\t"
308
+ f"[{workflow_id}@{task_id}]"
309
+ )
310
+
311
+ is_application_error = isinstance(e, RayTaskError)
312
+ options = state.tasks[task_id].options
313
+
314
+ # ---------------------- retry the task ----------------------
315
+ if not is_application_error or options.retry_exceptions:
316
+ if state.task_retries[task_id] < options.max_retries:
317
+ state.task_retries[task_id] += 1
318
+ logger.info(
319
+ f"Retry [{workflow_id}@{task_id}] "
320
+ f"({state.task_retries[task_id]}/{options.max_retries})"
321
+ )
322
+ state.construct_scheduling_plan(task_id)
323
+ return
324
+
325
+ # ----------- retry used up, handle the task error -----------
326
+ exception_catcher = None
327
+ if is_application_error:
328
+ for t, task in self._iter_callstack(task_id):
329
+ if task.options.catch_exceptions:
330
+ exception_catcher = t
331
+ break
332
+ if exception_catcher is not None:
333
+ logger.info(
334
+ f"Exception raised by '{workflow_id}@{task_id}' is caught by "
335
+ f"'{workflow_id}@{exception_catcher}'"
336
+ )
337
+ # assign output to exception catching task;
338
+ # compose output with caught exception
339
+ await self._post_process_ready_task(
340
+ exception_catcher,
341
+ metadata=WorkflowExecutionMetadata(),
342
+ output_ref=WorkflowRef(task_id, ray.put((None, e))),
343
+ )
344
+ # TODO(suquark): cancel other running tasks?
345
+ return
346
+
347
+ # ------------------- raise the task error -------------------
348
+ # NOTE: We must update the workflow status before broadcasting
349
+ # the exception. Otherwise, the workflow status would still be
350
+ # 'RUNNING' if check the status immediately after the exception.
351
+ wf_store.update_workflow_status(WorkflowStatus.FAILED)
352
+ logger.error(f"Workflow '{workflow_id}' failed due to {e}")
353
+ err = WorkflowExecutionError(workflow_id)
354
+ err.__cause__ = e # chain exceptions
355
+ self._broadcast_exception(err)
356
+ raise err
357
+
358
+ async def _post_process_ready_task(
359
+ self,
360
+ task_id: TaskID,
361
+ metadata: WorkflowExecutionMetadata,
362
+ output_ref: WorkflowRef,
363
+ ) -> None:
364
+ state = self._state
365
+ state.task_retries.pop(task_id, None)
366
+ if metadata.is_output_workflow: # The task returns a continuation
367
+ sub_workflow_state: WorkflowExecutionState = await output_ref.ref
368
+ # init the context just for "sub_workflow_state"
369
+ sub_workflow_state.init_context(state.task_context[task_id])
370
+ state.merge_state(sub_workflow_state)
371
+ # build up runtime dependency
372
+ continuation_task_id = sub_workflow_state.output_task_id
373
+ state.append_continuation(task_id, continuation_task_id)
374
+ # Migrate callbacks - all continuation callbacks are moved
375
+ # under the root of continuation, so when the continuation
376
+ # completes, all callbacks in the continuation can be triggered.
377
+ if continuation_task_id in self._task_done_callbacks:
378
+ self._task_done_callbacks[
379
+ state.continuation_root[continuation_task_id]
380
+ ].extend(self._task_done_callbacks.pop(continuation_task_id))
381
+ state.construct_scheduling_plan(sub_workflow_state.output_task_id)
382
+ else: # The task returns a normal object
383
+ target_task_id = state.continuation_root.get(task_id, task_id)
384
+ state.output_map[target_task_id] = output_ref
385
+ if state.tasks[task_id].options.checkpoint:
386
+ state.checkpoint_map[target_task_id] = WorkflowRef(task_id)
387
+ state.done_tasks.add(target_task_id)
388
+ # TODO(suquark): cleanup callbacks when a result is set?
389
+ if target_task_id in self._task_done_callbacks:
390
+ for callback in self._task_done_callbacks[target_task_id]:
391
+ callback.set_result(output_ref)
392
+ for m in state.reference_set[target_task_id]:
393
+ # we ensure that each reference corresponds to a pending input
394
+ state.pending_input_set[m].remove(target_task_id)
395
+ if not state.pending_input_set[m]:
396
+ state.append_frontier_to_run(m)
397
+
398
+ def _broadcast_exception(self, err: Exception):
399
+ for _, futures in self._task_done_callbacks.items():
400
+ for fut in futures:
401
+ if not fut.done():
402
+ fut.set_exception(err)
403
+
404
+ def get_task_output_async(self, task_id: Optional[TaskID]) -> asyncio.Future:
405
+ """Get the output of a task asynchronously.
406
+
407
+ Args:
408
+ task_id: The ID of task the callback associates with.
409
+
410
+ Returns:
411
+ A callback in the form of a future that associates with the task.
412
+ """
413
+ state = self._state
414
+ if self._task_done_callbacks[task_id]:
415
+ return self._task_done_callbacks[task_id][0]
416
+
417
+ fut = asyncio.Future()
418
+ task_id = state.continuation_root.get(task_id, task_id)
419
+ output = state.get_input(task_id)
420
+ if output is not None:
421
+ fut.set_result(output)
422
+ elif task_id in state.done_tasks:
423
+ fut.set_exception(
424
+ ValueError(
425
+ f"Task '{task_id}' is done but neither in memory or in storage "
426
+ "could we find its output. It could because its in memory "
427
+ "output has been garbage collected and the task did not"
428
+ "checkpoint its output."
429
+ )
430
+ )
431
+ else:
432
+ self._task_done_callbacks[task_id].append(fut)
433
+ return fut
.venv/lib/python3.11/site-packages/ray/workflow/workflow_state.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from collections import deque, defaultdict
4
+ import dataclasses
5
+ from dataclasses import field
6
+ import logging
7
+ from typing import List, Dict, Optional, Set, Deque, Callable
8
+
9
+ import ray
10
+ from ray.workflow.common import (
11
+ TaskID,
12
+ WorkflowRef,
13
+ WorkflowTaskRuntimeOptions,
14
+ )
15
+ from ray.workflow.workflow_context import WorkflowTaskContext
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class TaskExecutionMetadata:
22
+ submit_time: Optional[float] = None
23
+ finish_time: Optional[float] = None
24
+ output_size: Optional[int] = None
25
+
26
+ @property
27
+ def duration(self):
28
+ return self.finish_time - self.submit_time
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class Task:
33
+ """Data class for a workflow task."""
34
+
35
+ task_id: str
36
+ options: WorkflowTaskRuntimeOptions
37
+ user_metadata: Dict
38
+ func_body: Optional[Callable]
39
+
40
+ def to_dict(self) -> Dict:
41
+ return {
42
+ "task_id": self.task_id,
43
+ "task_options": self.options.to_dict(),
44
+ "user_metadata": self.user_metadata,
45
+ }
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class WorkflowExecutionState:
50
+ """The execution state of a workflow. This dataclass helps with observation
51
+ and debugging."""
52
+
53
+ # -------------------------------- dependencies -------------------------------- #
54
+
55
+ # The mapping from all tasks to immediately upstream tasks.
56
+ upstream_dependencies: Dict[TaskID, List[TaskID]] = field(default_factory=dict)
57
+ # A reverse mapping of the above. The dependency mapping from tasks to
58
+ # immediately downstream tasks.
59
+ downstream_dependencies: Dict[TaskID, List[TaskID]] = field(
60
+ default_factory=lambda: defaultdict(list)
61
+ )
62
+ # The mapping from a task to its immediate continuation.
63
+ next_continuation: Dict[TaskID, TaskID] = field(default_factory=dict)
64
+ # The reversed mapping from continuation to its immediate task.
65
+ prev_continuation: Dict[TaskID, TaskID] = field(default_factory=dict)
66
+ # The mapping from a task to its latest continuation. The latest continuation is
67
+ # a task that returns a value instead of a continuation.
68
+ latest_continuation: Dict[TaskID, TaskID] = field(default_factory=dict)
69
+ # The mapping from a task to the root of the continuation, i.e. the initial task
70
+ # that generates the lineage of continuation.
71
+ continuation_root: Dict[TaskID, TaskID] = field(default_factory=dict)
72
+
73
+ # ------------------------------- task properties ------------------------------- #
74
+
75
+ # Workflow tasks.
76
+ tasks: Dict[TaskID, Task] = field(default_factory=dict)
77
+
78
+ # The arguments for the task.
79
+ task_input_args: Dict[TaskID, ray.ObjectRef] = field(default_factory=dict)
80
+ # The context of the task.
81
+ task_context: Dict[TaskID, WorkflowTaskContext] = field(default_factory=dict)
82
+ # The execution metadata of a task.
83
+ task_execution_metadata: Dict[TaskID, TaskExecutionMetadata] = field(
84
+ default_factory=dict
85
+ )
86
+ task_retries: Dict[TaskID, int] = field(default_factory=lambda: defaultdict(int))
87
+
88
+ # ------------------------------ object management ------------------------------ #
89
+
90
+ # Set of references to upstream outputs.
91
+ reference_set: Dict[TaskID, Set[TaskID]] = field(
92
+ default_factory=lambda: defaultdict(set)
93
+ )
94
+ # The set of pending inputs of a task. We are able to run the task
95
+ # when it becomes empty.
96
+ pending_input_set: Dict[TaskID, Set[TaskID]] = field(default_factory=dict)
97
+ # The map from a task to its in-memory outputs. Normally it is the ObjectRef
98
+ # returned by the underlying Ray task. Things are different for continuation:
99
+ # because the true output of a continuation is created by the last task in
100
+ # the continuation lineage, so all other tasks in the continuation points
101
+ # to the output of the last task instead of the output of themselves.
102
+ output_map: Dict[TaskID, WorkflowRef] = field(default_factory=dict)
103
+ # The map from a task to its in-storage checkpoints. Normally it is the checkpoint
104
+ # created by the underlying Ray task. For continuations, the semantics is similar
105
+ # to 'output_map'.
106
+ checkpoint_map: Dict[TaskID, WorkflowRef] = field(default_factory=dict)
107
+ # Outputs that are free (no reference to this output in the workflow) and
108
+ # can be garbage collected.
109
+ free_outputs: Set[TaskID] = field(default_factory=set)
110
+
111
+ # -------------------------------- scheduling -------------------------------- #
112
+
113
+ # The frontier that is ready to run.
114
+ frontier_to_run: Deque[TaskID] = field(default_factory=deque)
115
+ # The set of frontier tasks to run. This field helps deduplicate tasks or
116
+ # look up task quickly. It contains the same elements as 'frontier_to_run',
117
+ # they act like a 'DequeSet' when combined.
118
+ frontier_to_run_set: Set[TaskID] = field(default_factory=set)
119
+ # The frontier that is running.
120
+ running_frontier: Dict[asyncio.Future, WorkflowRef] = field(default_factory=dict)
121
+ # The set of running frontier. This field helps deduplicate tasks or
122
+ # look up task quickly. It contains the same elements as 'running_frontier',
123
+ # they act like a dict but its values are in a set when combined.
124
+ running_frontier_set: Set[TaskID] = field(default_factory=set)
125
+ # The set of completed tasks. They are tasks are actually executed with the state,
126
+ # so inspected during recovery does not count.
127
+ #
128
+ # Normally, a task will be added in 'done_tasks' immediately after its completion.
129
+ # However, a task that is the root of continuations (i.e. it returns a continuation
130
+ # but itself is not a continuation) is only added to 'done_tasks' when all its
131
+ # continuation completes. We do not add its continuations in 'done_tasks' because
132
+ # we indicate their completion from the continuation structure - if a continuation
133
+ # is appended to a previous continuation, then the previous continuation must
134
+ # already complete; if the task that is the root of all continuation completes,
135
+ # then all its continuations would complete.
136
+ done_tasks: Set[TaskID] = field(default_factory=set)
137
+
138
+ # -------------------------------- external -------------------------------- #
139
+
140
+ # The ID of the output task.
141
+ output_task_id: Optional[TaskID] = None
142
+
143
+ def get_input(self, task_id: TaskID) -> Optional[WorkflowRef]:
144
+ """Get the input. It checks memory first and storage later. It returns None if
145
+ the input does not exist.
146
+ """
147
+ return self.output_map.get(task_id, self.checkpoint_map.get(task_id))
148
+
149
+ def pop_frontier_to_run(self) -> Optional[TaskID]:
150
+ """Pop one task to run from the frontier queue."""
151
+ try:
152
+ t = self.frontier_to_run.popleft()
153
+ self.frontier_to_run_set.remove(t)
154
+ return t
155
+ except IndexError:
156
+ return None
157
+
158
+ def append_frontier_to_run(self, task_id: TaskID) -> None:
159
+ """Insert one task to the frontier queue."""
160
+ if (
161
+ task_id not in self.frontier_to_run_set
162
+ and task_id not in self.running_frontier_set
163
+ ):
164
+ self.frontier_to_run.append(task_id)
165
+ self.frontier_to_run_set.add(task_id)
166
+
167
+ def add_dependencies(self, task_id: TaskID, in_dependencies: List[TaskID]) -> None:
168
+ """Add dependencies between a task and it input dependencies."""
169
+ self.upstream_dependencies[task_id] = in_dependencies
170
+ for in_task_id in in_dependencies:
171
+ self.downstream_dependencies[in_task_id].append(task_id)
172
+
173
+ def pop_running_frontier(self, fut: asyncio.Future) -> WorkflowRef:
174
+ """Pop a task from the running frontier."""
175
+ ref = self.running_frontier.pop(fut)
176
+ self.running_frontier_set.remove(ref.task_id)
177
+ return ref
178
+
179
+ def insert_running_frontier(self, fut: asyncio.Future, ref: WorkflowRef) -> None:
180
+ """Insert a task to the running frontier."""
181
+ self.running_frontier[fut] = ref
182
+ self.running_frontier_set.add(ref.task_id)
183
+
184
+ def append_continuation(
185
+ self, task_id: TaskID, continuation_task_id: TaskID
186
+ ) -> None:
187
+ """Append continuation to a task."""
188
+ continuation_root = self.continuation_root.get(task_id, task_id)
189
+ self.prev_continuation[continuation_task_id] = task_id
190
+ self.next_continuation[task_id] = continuation_task_id
191
+ self.continuation_root[continuation_task_id] = continuation_root
192
+ self.latest_continuation[continuation_root] = continuation_task_id
193
+
194
+ def merge_state(self, state: "WorkflowExecutionState") -> None:
195
+ """Merge with another execution state."""
196
+ self.upstream_dependencies.update(state.upstream_dependencies)
197
+ self.downstream_dependencies.update(state.downstream_dependencies)
198
+ self.task_input_args.update(state.task_input_args)
199
+ self.tasks.update(state.tasks)
200
+ self.task_context.update(state.task_context)
201
+ self.output_map.update(state.output_map)
202
+ self.checkpoint_map.update(state.checkpoint_map)
203
+
204
+ def construct_scheduling_plan(self, task_id: TaskID) -> None:
205
+ """Analyze upstream dependencies of a task to construct the scheduling plan."""
206
+ if self.get_input(task_id) is not None:
207
+ # This case corresponds to the scenario that the task is a
208
+ # checkpoint or ref.
209
+ return
210
+
211
+ visited_nodes = set()
212
+ dag_visit_queue = deque([task_id])
213
+ while dag_visit_queue:
214
+ tid = dag_visit_queue.popleft()
215
+ if tid in visited_nodes:
216
+ continue
217
+ visited_nodes.add(tid)
218
+ self.pending_input_set[tid] = set()
219
+ for in_task_id in self.upstream_dependencies[tid]:
220
+ self.reference_set[in_task_id].add(tid)
221
+ # All upstream deps should already complete here,
222
+ # so we just check their checkpoints.
223
+ task_input = self.get_input(in_task_id)
224
+ if task_input is None:
225
+ self.pending_input_set[tid].add(in_task_id)
226
+ dag_visit_queue.append(in_task_id)
227
+ if tid in self.latest_continuation:
228
+ if self.pending_input_set[tid]:
229
+ raise ValueError(
230
+ "A task that already returns a continuation cannot be pending."
231
+ )
232
+ # construct continuations, as they are not directly connected to
233
+ # the DAG dependency
234
+ self.construct_scheduling_plan(self.latest_continuation[tid])
235
+ elif not self.pending_input_set[tid]:
236
+ self.append_frontier_to_run(tid)
237
+
238
+ def init_context(self, context: WorkflowTaskContext) -> None:
239
+ """Initialize the context of all tasks."""
240
+ for task_id, task in self.tasks.items():
241
+ options = task.options
242
+ self.task_context.setdefault(
243
+ task_id,
244
+ dataclasses.replace(
245
+ context,
246
+ task_id=task_id,
247
+ creator_task_id=context.task_id,
248
+ checkpoint=options.checkpoint,
249
+ catch_exceptions=options.catch_exceptions,
250
+ ),
251
+ )
.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_dag.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+ import re
3
+ import unicodedata
4
+
5
+ import ray
6
+ from ray.workflow.common import WORKFLOW_OPTIONS
7
+ from ray.dag import DAGNode, FunctionNode, InputNode
8
+ from ray.dag.input_node import InputAttributeNode, DAGInputData
9
+ from ray import cloudpickle
10
+ from ray._private import signature
11
+ from ray._private.client_mode_hook import client_mode_should_convert
12
+ from ray.workflow import serialization_context
13
+ from ray.workflow.common import (
14
+ TaskType,
15
+ WorkflowTaskRuntimeOptions,
16
+ WorkflowRef,
17
+ validate_user_metadata,
18
+ )
19
+ from ray.workflow import workflow_context
20
+ from ray.workflow.workflow_state import WorkflowExecutionState, Task
21
+
22
+
23
+ def get_module(f):
24
+ return f.__module__ if hasattr(f, "__module__") else "__anonymous_module__"
25
+
26
+
27
+ def get_qualname(f):
28
+ return f.__qualname__ if hasattr(f, "__qualname__") else "__anonymous_func__"
29
+
30
+
31
+ def slugify(value: str, allow_unicode=False) -> str:
32
+ """Adopted from
33
+ https://github.com/django/django/blob/master/django/utils/text.py
34
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
35
+ dashes to single dashes. Remove characters that aren't alphanumerics,
36
+ underscores, dots or hyphens. Also strip leading and
37
+ trailing whitespace.
38
+ """
39
+ if allow_unicode:
40
+ value = unicodedata.normalize("NFKC", value)
41
+ else:
42
+ value = (
43
+ unicodedata.normalize("NFKD", value)
44
+ .encode("ascii", "ignore")
45
+ .decode("ascii")
46
+ )
47
+ value = re.sub(r"[^\w.\-]", "", value).strip()
48
+ return re.sub(r"[-\s]+", "-", value)
49
+
50
+
51
+ class _DelayedDeserialization:
52
+ def __init__(self, serialized: bytes):
53
+ self._serialized = serialized
54
+
55
+ def __reduce__(self):
56
+ return cloudpickle.loads, (self._serialized,)
57
+
58
+
59
+ class _SerializationContextPreservingWrapper:
60
+ """This class is a workaround for preserving serialization context
61
+ in client mode."""
62
+
63
+ def __init__(self, obj: Any):
64
+ self._serialized = cloudpickle.dumps(obj)
65
+
66
+ def __reduce__(self):
67
+ # This delays the deserialization to the actual worker
68
+ # instead of the Ray client server.
69
+ return _DelayedDeserialization, (self._serialized,)
70
+
71
+
72
+ def workflow_state_from_dag(
73
+ dag_node: DAGNode, input_context: Optional[DAGInputData], workflow_id: str
74
+ ):
75
+ """
76
+ Transform a Ray DAG to a workflow. Map FunctionNode to workflow task with
77
+ the workflow decorator.
78
+
79
+ Args:
80
+ dag_node: The DAG to be converted to a workflow.
81
+ input_context: The input data that wraps varibles for the input node of the DAG.
82
+ workflow_id: The ID of the workflow.
83
+ """
84
+ if not isinstance(dag_node, FunctionNode):
85
+ raise TypeError("Currently workflow does not support classes as DAG inputs.")
86
+
87
+ state = WorkflowExecutionState()
88
+
89
+ # TODO(suquark): remove this cyclic importing later by changing the way of
90
+ # task ID assignment.
91
+ from ray.workflow.workflow_access import get_management_actor
92
+
93
+ mgr = get_management_actor()
94
+ context = workflow_context.get_workflow_task_context()
95
+
96
+ def _node_visitor(node: Any) -> Any:
97
+ if isinstance(node, FunctionNode):
98
+ bound_options = node._bound_options.copy()
99
+ num_returns = bound_options.get("num_returns", 1)
100
+ if num_returns is None: # ray could use `None` as default value
101
+ num_returns = 1
102
+ if num_returns > 1:
103
+ raise ValueError("Workflow task can only have one return.")
104
+
105
+ workflow_options = bound_options.get("_metadata", {}).get(
106
+ WORKFLOW_OPTIONS, {}
107
+ )
108
+
109
+ # If checkpoint option is not specified, inherit checkpoint
110
+ # options from context (i.e. checkpoint options of the outer
111
+ # task). If it is still not specified, it's True by default.
112
+ checkpoint = workflow_options.get("checkpoint", None)
113
+ if checkpoint is None:
114
+ checkpoint = context.checkpoint if context is not None else True
115
+ # When it returns a nested workflow, catch_exception
116
+ # should be passed recursively.
117
+ catch_exceptions = workflow_options.get("catch_exceptions", None)
118
+ if catch_exceptions is None:
119
+ if node.get_stable_uuid() == dag_node.get_stable_uuid():
120
+ # 'catch_exception' context should be passed down to
121
+ # its direct continuation task.
122
+ # In this case, the direct continuation is the output node.
123
+ catch_exceptions = (
124
+ context.catch_exceptions if context is not None else False
125
+ )
126
+ else:
127
+ catch_exceptions = False
128
+
129
+ # We do not need to check the validness of bound options, because
130
+ # Ray option has already checked them for us.
131
+ max_retries = bound_options.get("max_retries", 3)
132
+ retry_exceptions = bound_options.get("retry_exceptions", False)
133
+
134
+ task_options = WorkflowTaskRuntimeOptions(
135
+ task_type=TaskType.FUNCTION,
136
+ catch_exceptions=catch_exceptions,
137
+ retry_exceptions=retry_exceptions,
138
+ max_retries=max_retries,
139
+ checkpoint=checkpoint,
140
+ ray_options=bound_options,
141
+ )
142
+
143
+ workflow_refs: List[WorkflowRef] = []
144
+ with serialization_context.workflow_args_serialization_context(
145
+ workflow_refs
146
+ ):
147
+ _func_signature = signature.extract_signature(node._body)
148
+ flattened_args = signature.flatten_args(
149
+ _func_signature, node._bound_args, node._bound_kwargs
150
+ )
151
+ # NOTE: When calling 'ray.put', we trigger python object
152
+ # serialization. Under our serialization context,
153
+ # Workflows are separated from the arguments,
154
+ # leaving a placeholder object with all other python objects.
155
+ # Then we put the placeholder object to object store,
156
+ # so it won't be mutated later. This guarantees correct
157
+ # semantics. See "tests/test_variable_mutable.py" as
158
+ # an example.
159
+ if client_mode_should_convert():
160
+ # Handle client mode. The Ray client would serialize and
161
+ # then deserialize objects in the Ray client server. When
162
+ # the object is being deserialized, the serialization context
163
+ # will be missing, resulting in failures. Here we protect the
164
+ # object from deserialization in client server, and we make sure
165
+ # the 'real' deserialization happens under the serialization
166
+ # context later.
167
+ flattened_args = _SerializationContextPreservingWrapper(
168
+ flattened_args
169
+ )
170
+ # Set the owner of the objects to the actor so that even the driver
171
+ # exits, these objects are still available.
172
+ input_placeholder: ray.ObjectRef = ray.put(flattened_args, _owner=mgr)
173
+
174
+ orig_task_id = workflow_options.get("task_id", None)
175
+ if orig_task_id is None:
176
+ orig_task_id = (
177
+ f"{get_module(node._body)}.{slugify(get_qualname(node._body))}"
178
+ )
179
+
180
+ task_id = ray.get(mgr.gen_task_id.remote(workflow_id, orig_task_id))
181
+ state.add_dependencies(task_id, [s.task_id for s in workflow_refs])
182
+ state.task_input_args[task_id] = input_placeholder
183
+
184
+ user_metadata = workflow_options.get("metadata", {})
185
+
186
+ validate_user_metadata(user_metadata)
187
+ state.tasks[task_id] = Task(
188
+ task_id=task_id,
189
+ options=task_options,
190
+ user_metadata=user_metadata,
191
+ func_body=node._body,
192
+ )
193
+ return WorkflowRef(task_id)
194
+
195
+ if isinstance(node, InputAttributeNode):
196
+ return node._execute_impl() # get data from input node
197
+ if isinstance(node, InputNode):
198
+ return input_context # replace input node with input data
199
+ if not isinstance(node, DAGNode):
200
+ return node # return normal objects
201
+ raise TypeError(f"Unsupported DAG node: {node}")
202
+
203
+ output_workflow_ref = dag_node.apply_recursive(_node_visitor)
204
+ state.output_task_id = output_workflow_ref.task_id
205
+ return state
.venv/lib/python3.11/site-packages/ray/workflow/workflow_state_from_storage.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from collections import deque
3
+
4
+ from ray.workflow import serialization
5
+ from ray.workflow.common import TaskID, WorkflowRef
6
+ from ray.workflow.exceptions import WorkflowTaskNotRecoverableError
7
+ from ray.workflow import workflow_storage
8
+ from ray.workflow.workflow_state import WorkflowExecutionState, Task
9
+
10
+
11
+ def workflow_state_from_storage(
12
+ workflow_id: str, task_id: Optional[TaskID]
13
+ ) -> WorkflowExecutionState:
14
+ """Try to construct a workflow (task) that recovers the workflow task.
15
+ If the workflow task already has an output checkpointing file, we return
16
+ the workflow task id instead.
17
+
18
+ Args:
19
+ workflow_id: The ID of the workflow.
20
+ task_id: The ID of the output task. If None, it will be the entrypoint of
21
+ the workflow.
22
+
23
+ Returns:
24
+ A workflow that recovers the task, or the output of the task
25
+ if it has been checkpointed.
26
+ """
27
+ reader = workflow_storage.WorkflowStorage(workflow_id)
28
+ if task_id is None:
29
+ task_id = reader.get_entrypoint_task_id()
30
+
31
+ # Construct the workflow execution state.
32
+ state = WorkflowExecutionState(output_task_id=task_id)
33
+ state.output_task_id = task_id
34
+
35
+ visited_tasks = set()
36
+ dag_visit_queue = deque([task_id])
37
+ with serialization.objectref_cache():
38
+ while dag_visit_queue:
39
+ task_id: TaskID = dag_visit_queue.popleft()
40
+ if task_id in visited_tasks:
41
+ continue
42
+ visited_tasks.add(task_id)
43
+ r = reader.inspect_task(task_id)
44
+ if not r.is_recoverable():
45
+ raise WorkflowTaskNotRecoverableError(task_id)
46
+ if r.output_object_valid:
47
+ target = state.continuation_root.get(task_id, task_id)
48
+ state.checkpoint_map[target] = WorkflowRef(task_id)
49
+ continue
50
+ if isinstance(r.output_task_id, str):
51
+ # no input dependencies here because the task has already
52
+ # returned a continuation
53
+ state.upstream_dependencies[task_id] = []
54
+ state.append_continuation(task_id, r.output_task_id)
55
+ dag_visit_queue.append(r.output_task_id)
56
+ continue
57
+ # transfer task info to state
58
+ state.add_dependencies(task_id, r.workflow_refs)
59
+ state.task_input_args[task_id] = reader.load_task_args(task_id)
60
+ # TODO(suquark): although not necessary, but for completeness,
61
+ # we may also load name and metadata.
62
+ state.tasks[task_id] = Task(
63
+ task_id="",
64
+ options=r.task_options,
65
+ user_metadata={},
66
+ func_body=reader.load_task_func_body(task_id),
67
+ )
68
+
69
+ dag_visit_queue.extend(r.workflow_refs)
70
+
71
+ return state
.venv/lib/python3.11/site-packages/ray/workflow/workflow_storage.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module is higher-level abstraction of storage directly used by
3
+ workflows.
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
12
+
13
+ import ray
14
+ from ray import cloudpickle
15
+ from ray._private import storage
16
+ from ray.types import ObjectRef
17
+ from ray.workflow.common import (
18
+ TaskID,
19
+ WorkflowStatus,
20
+ WorkflowTaskRuntimeOptions,
21
+ )
22
+ from ray.workflow.exceptions import WorkflowNotFoundError
23
+ from ray.workflow import workflow_context
24
+ from ray.workflow import serialization
25
+ from ray.workflow import serialization_context
26
+ from ray.workflow.workflow_state import WorkflowExecutionState
27
+ from ray.workflow.storage import DataLoadError, DataSaveError, KeyNotFoundError
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ ArgsType = Tuple[List[Any], Dict[str, Any]] # args and kwargs
32
+
33
+ # constants used for keys
34
+ WORKFLOW_ROOT = "workflows" # The workflow root directory under global Ray storage.
35
+ OBJECTS_DIR = "objects"
36
+ STEPS_DIR = "tasks"
37
+ STEP_INPUTS_METADATA = "inputs.json"
38
+ STEP_USER_METADATA = "user_task_metadata.json"
39
+ STEP_PRERUN_METADATA = "pre_task_metadata.json"
40
+ STEP_POSTRUN_METADATA = "post_task_metadata.json"
41
+ STEP_OUTPUTS_METADATA = "outputs.json"
42
+ STEP_ARGS = "args.pkl"
43
+ STEP_OUTPUT = "output.pkl"
44
+ STEP_EXCEPTION = "exception.pkl"
45
+ STEP_FUNC_BODY = "func_body.pkl"
46
+ CLASS_BODY = "class_body.pkl"
47
+ WORKFLOW_META = "workflow_meta.json"
48
+ WORKFLOW_USER_METADATA = "user_run_metadata.json"
49
+ WORKFLOW_PRERUN_METADATA = "pre_run_metadata.json"
50
+ WORKFLOW_POSTRUN_METADATA = "post_run_metadata.json"
51
+ WORKFLOW_PROGRESS = "progress.json"
52
+ WORKFLOW_STATUS_DIR = "__status__"
53
+ WORKFLOW_STATUS_DIRTY_DIR = "dirty"
54
+ # Without this counter, we're going to scan all tasks to get the number of
55
+ # tasks with a given name. This can be very expensive if there are too
56
+ # many duplicates.
57
+ DUPLICATE_NAME_COUNTER = "duplicate_name_counter"
58
+
59
+
60
+ @dataclass
61
+ class TaskInspectResult:
62
+ # The task output checkpoint exists and valid. If this field
63
+ # is set, we do not set all other fields below.
64
+ output_object_valid: bool = False
65
+ # The ID of the task that could contain the output checkpoint of this
66
+ # task. If this field is set, we do not set all other fields below.
67
+ output_task_id: Optional[TaskID] = None
68
+ # The task input arguments checkpoint exists and valid.
69
+ args_valid: bool = False
70
+ # The task function body checkpoint exists and valid.
71
+ func_body_valid: bool = False
72
+ # The dynamically referenced workflows in the input of the workflow.
73
+ workflow_refs: Optional[List[str]] = None
74
+ # The options of the workflow task.
75
+ task_options: Optional[WorkflowTaskRuntimeOptions] = None
76
+ # task throw exception
77
+ task_raised_exception: bool = False
78
+
79
+ def is_recoverable(self) -> bool:
80
+ return (
81
+ self.output_object_valid
82
+ or self.output_task_id
83
+ or (
84
+ self.args_valid
85
+ and self.workflow_refs is not None
86
+ and self.func_body_valid
87
+ )
88
+ )
89
+
90
+
91
+ class WorkflowIndexingStorage:
92
+ """Access and maintenance the indexing of workflow status.
93
+
94
+ It runs a protocol that guarantees we can recover from any interrupted
95
+ status updating. This protocol is **not thread-safe** for updating the
96
+ status of the same workflow, currently it is executed by workflow management
97
+ actor with a single thread.
98
+
99
+ Here is how the protocol works:
100
+
101
+ Update the status of a workflow
102
+ 1. Load workflow status from workflow data. If it is the same as the new status,
103
+ return.
104
+ 2. Check if the workflow status updating is dirty. If it is, fix the
105
+ workflow status; otherwise, mark the workflow status updating dirty.
106
+ 3. Update status in the workflow metadata.
107
+ 4. Insert the workflow ID key in the status indexing directory of the new status.
108
+ 5. Delete the workflow ID key in the status indexing directory of
109
+ the previous status.
110
+ 6. Remove the workflow status updating dirty mark.
111
+
112
+ Load a status of a workflow
113
+ 1. Read the status of the workflow from the workflow metadata.
114
+ 2. Return the status.
115
+
116
+ List the status of all workflows
117
+ 1. Get status of all workflows by listing workflow ID keys in each workflow
118
+ status indexing directory.
119
+ 2. List all workflows with dirty updating status. Get their status from
120
+ workflow data. Override the status of the corresponding workflow.
121
+ 3. Return all the status.
122
+ """
123
+
124
+ def __init__(self):
125
+ self._storage = storage.get_client(WORKFLOW_ROOT)
126
+
127
+ def update_workflow_status(self, workflow_id: str, status: WorkflowStatus):
128
+ """Update the status of the workflow.
129
+ Try fixing indexing if workflow status updating was marked dirty.
130
+
131
+ This method is NOT thread-safe. It is handled by the workflow management actor.
132
+ """
133
+ prev_status = self.load_workflow_status(workflow_id)
134
+ if prev_status != status:
135
+ # Try fixing indexing if workflow status updating was marked dirty.
136
+ if (
137
+ self._storage.get_info(self._key_workflow_status_dirty(workflow_id))
138
+ is not None
139
+ ):
140
+ # This means the previous status update failed. Fix it.
141
+ self._storage.put(
142
+ self._key_workflow_with_status(workflow_id, prev_status), b""
143
+ )
144
+ for s in WorkflowStatus:
145
+ if s != prev_status:
146
+ self._storage.delete(
147
+ self._key_workflow_with_status(workflow_id, s)
148
+ )
149
+ else:
150
+ self._storage.put(self._key_workflow_status_dirty(workflow_id), b"")
151
+ # Transactional update of workflow status
152
+ self._storage.put(
153
+ self._key_workflow_metadata(workflow_id),
154
+ json.dumps({"status": status.value}).encode(),
155
+ )
156
+ self._storage.put(self._key_workflow_with_status(workflow_id, status), b"")
157
+ if prev_status is not WorkflowStatus.NONE:
158
+ self._storage.delete(
159
+ self._key_workflow_with_status(workflow_id, prev_status)
160
+ )
161
+ self._storage.delete(self._key_workflow_status_dirty(workflow_id))
162
+
163
+ def load_workflow_status(self, workflow_id: str):
164
+ """Load the committed workflow status."""
165
+ raw_data = self._storage.get(self._key_workflow_metadata(workflow_id))
166
+ if raw_data is not None:
167
+ metadata = json.loads(raw_data)
168
+ return WorkflowStatus(metadata["status"])
169
+ return WorkflowStatus.NONE
170
+
171
+ def list_workflow(
172
+ self, status_filter: Optional[Set[WorkflowStatus]] = None
173
+ ) -> List[Tuple[str, WorkflowStatus]]:
174
+ """List workflow status. Override status of the workflows whose status updating
175
+ were marked dirty with the workflow status from workflow metadata.
176
+
177
+ Args:
178
+ status_filter: If given, only returns workflow with that status. This can
179
+ be a single status or set of statuses.
180
+ """
181
+ if status_filter is None:
182
+ status_filter = set(WorkflowStatus)
183
+ status_filter.discard(WorkflowStatus.NONE)
184
+ elif not isinstance(status_filter, set):
185
+ raise TypeError("'status_filter' should either be 'None' or a set.")
186
+ elif WorkflowStatus.NONE in status_filter:
187
+ raise ValueError("'WorkflowStatus.NONE' is not a valid filter value.")
188
+
189
+ results = {}
190
+ for status in status_filter:
191
+ try:
192
+ # empty string points the key to the dir
193
+ for p in self._storage.list(self._key_workflow_with_status("", status)):
194
+ workflow_id = p.base_name
195
+ results[workflow_id] = status
196
+ except FileNotFoundError:
197
+ pass
198
+ # Get "correct" status of workflows
199
+ try:
200
+ for p in self._storage.list(self._key_workflow_status_dirty("")):
201
+ workflow_id = p.base_name
202
+ # overwrite status
203
+ results.pop(workflow_id, None)
204
+ status = self.load_workflow_status(workflow_id)
205
+ if status in status_filter:
206
+ results[workflow_id] = status
207
+ except FileNotFoundError:
208
+ pass
209
+ return list(results.items())
210
+
211
+ def delete_workflow_status(self, workflow_id: str):
212
+ """Delete status indexing for the workflow."""
213
+ for status in WorkflowStatus:
214
+ self._storage.delete(self._key_workflow_with_status(workflow_id, status))
215
+ self._storage.delete(self._key_workflow_status_dirty(workflow_id))
216
+
217
+ def _key_workflow_with_status(self, workflow_id: str, status: WorkflowStatus):
218
+ """A key whose existence marks the status of the workflow."""
219
+ return os.path.join(WORKFLOW_STATUS_DIR, status.value, workflow_id)
220
+
221
+ def _key_workflow_status_dirty(self, workflow_id: str):
222
+ """A key marks the workflow status dirty, because it is under change."""
223
+ return os.path.join(WORKFLOW_STATUS_DIR, WORKFLOW_STATUS_DIRTY_DIR, workflow_id)
224
+
225
+ def _key_workflow_metadata(self, workflow_id: str):
226
+ return os.path.join(workflow_id, WORKFLOW_META)
227
+
228
+
229
+ class WorkflowStorage:
230
+ """Access workflow in storage. This is a higher-level abstraction,
231
+ which does not care about the underlining storage implementation."""
232
+
233
+ def __init__(self, workflow_id: str):
234
+ self._storage = storage.get_client(os.path.join(WORKFLOW_ROOT, workflow_id))
235
+ self._status_storage = WorkflowIndexingStorage()
236
+ self._workflow_id = workflow_id
237
+
238
+ def load_task_output(self, task_id: TaskID) -> Any:
239
+ """Load the output of the workflow task from checkpoint.
240
+
241
+ Args:
242
+ task_id: ID of the workflow task.
243
+
244
+ Returns:
245
+ Output of the workflow task.
246
+ """
247
+
248
+ tasks = [
249
+ self._get(self._key_task_output(task_id), no_exception=True),
250
+ self._get(self._key_task_exception(task_id), no_exception=True),
251
+ ]
252
+ (output_ret, output_err), (exception_ret, exception_err) = tasks
253
+ # When we have output, always return output first
254
+ if output_err is None:
255
+ return output_ret
256
+
257
+ # When we don't have output, check exception
258
+ if exception_err is None:
259
+ raise exception_ret
260
+
261
+ # In this case, there is no such task
262
+ raise output_err
263
+
264
+ def save_workflow_execution_state(
265
+ self, creator_task_id: TaskID, state: WorkflowExecutionState
266
+ ) -> None:
267
+ """Save a workflow execution state.
268
+ Typically, the state is translated from a Ray DAG.
269
+
270
+ Args:
271
+ creator_task_id: The ID of the task that creates the state.
272
+ state: The state converted from the DAG.
273
+ """
274
+ assert creator_task_id != state.output_task_id
275
+
276
+ for task_id, task in state.tasks.items():
277
+ # TODO (Alex): Handle the json case better?
278
+ metadata = {
279
+ **task.to_dict(),
280
+ "workflow_refs": state.upstream_dependencies[task_id],
281
+ }
282
+ self._put(self._key_task_input_metadata(task_id), metadata, True)
283
+ # TODO(suquark): The task user metadata duplicates.
284
+ self._put(
285
+ self._key_task_user_metadata(task_id),
286
+ task.user_metadata,
287
+ True,
288
+ )
289
+ workflow_id = self._workflow_id
290
+ serialization.dump_to_storage(
291
+ self._key_task_function_body(task_id),
292
+ task.func_body,
293
+ workflow_id,
294
+ self,
295
+ )
296
+ with serialization_context.workflow_args_keeping_context():
297
+ # TODO(suquark): in the future we should write to storage directly
298
+ # with plasma store object in memory.
299
+ args_obj = ray.get(state.task_input_args[task_id])
300
+ serialization.dump_to_storage(
301
+ self._key_task_args(task_id),
302
+ args_obj,
303
+ workflow_id,
304
+ self,
305
+ )
306
+
307
+ # Finally, point to the output ID of the DAG. The DAG is a continuation
308
+ # of the creator task.
309
+ self._put(
310
+ self._key_task_output_metadata(creator_task_id),
311
+ {"output_task_id": state.output_task_id},
312
+ True,
313
+ )
314
+
315
+ def save_task_output(
316
+ self,
317
+ task_id: TaskID,
318
+ ret: Any,
319
+ *,
320
+ exception: Optional[Exception],
321
+ ) -> None:
322
+ """When a workflow task returns,
323
+ 1. If the returned object is a workflow, this means we are a nested
324
+ workflow. We save the output metadata that points to the workflow.
325
+ 2. Otherwise, checkpoint the output.
326
+
327
+ Args:
328
+ task_id: The ID of the workflow task. If it is an empty string,
329
+ it means we are in the workflow job driver process.
330
+ ret: The returned object from a workflow task.
331
+ exception: This task should throw exception.
332
+ """
333
+ if exception is None:
334
+ # This workflow task returns a object.
335
+ ret = ray.get(ret) if isinstance(ret, ray.ObjectRef) else ret
336
+ serialization.dump_to_storage(
337
+ self._key_task_output(task_id),
338
+ ret,
339
+ self._workflow_id,
340
+ storage=self,
341
+ )
342
+ # tasks.append(self._put(self._key_task_output(task_id), ret))
343
+ # TODO (yic): Delete exception file
344
+ else:
345
+ assert ret is None
346
+ serialization.dump_to_storage(
347
+ self._key_task_exception(task_id),
348
+ exception,
349
+ self._workflow_id,
350
+ storage=self,
351
+ )
352
+ # tasks.append(
353
+ # self._put(self._key_task_exception(task_id), exception))
354
+
355
+ # Finish checkpointing.
356
+ # TODO(suquark): batching all tasks above.
357
+
358
+ def load_task_func_body(self, task_id: TaskID) -> Callable:
359
+ """Load the function body of the workflow task.
360
+
361
+ Args:
362
+ task_id: ID of the workflow task.
363
+
364
+ Returns:
365
+ A callable function.
366
+ """
367
+ return self._get(self._key_task_function_body(task_id))
368
+
369
+ def gen_task_id(self, task_name: str) -> int:
370
+ def _gen_task_id():
371
+ key = self._key_num_tasks_with_name(task_name)
372
+ try:
373
+ val = self._get(key, True)
374
+ self._put(key, val + 1, True)
375
+ return val + 1
376
+ except KeyNotFoundError:
377
+ self._put(key, 0, True)
378
+ return 0
379
+
380
+ return _gen_task_id()
381
+
382
+ def load_task_args(self, task_id: TaskID) -> ray.ObjectRef:
383
+ """Load the input arguments of the workflow task. This must be
384
+ done under a serialization context, otherwise the arguments would
385
+ not be reconstructed successfully.
386
+
387
+ Args:
388
+ task_id: ID of the workflow task.
389
+
390
+ Returns:
391
+ An object ref of the input args.
392
+ """
393
+ with serialization_context.workflow_args_keeping_context():
394
+ x = self._get(self._key_task_args(task_id))
395
+ return ray.put(x)
396
+
397
+ def save_object_ref(self, obj_ref: ray.ObjectRef) -> None:
398
+ """Save the object ref.
399
+
400
+ Args:
401
+ obj_ref: The object reference
402
+
403
+ Returns:
404
+ None
405
+ """
406
+ return self._save_object_ref(obj_ref)
407
+
408
+ def load_object_ref(self, object_id: str) -> ray.ObjectRef:
409
+ """Load the input object ref.
410
+
411
+ Args:
412
+ object_id: The hex ObjectID.
413
+
414
+ Returns:
415
+ The object ref.
416
+ """
417
+
418
+ def _load_obj_ref() -> ray.ObjectRef:
419
+ data = self._get(self._key_obj_id(object_id))
420
+ ref = _put_obj_ref.remote((data,))
421
+ return ref
422
+
423
+ return _load_obj_ref()
424
+
425
+ def update_continuation_output_link(
426
+ self, continuation_root_id: TaskID, latest_continuation_task_id: TaskID
427
+ ) -> None:
428
+ """Update the link of the continuation output. The link points
429
+ to the ID of the latest finished continuation task.
430
+
431
+ Args:
432
+ continuation_root_id: The ID of the task that returns all later
433
+ continuations.
434
+ latest_continuation_task_id: The ID of the latest finished
435
+ continuation task.
436
+ """
437
+ try:
438
+ metadata = self._get(
439
+ self._key_task_output_metadata(continuation_root_id), True
440
+ )
441
+ except KeyNotFoundError:
442
+ # This is because we skipped checkpointing of the
443
+ # task [id=continuation_root_id]. Return a dummy
444
+ # metadata instead.
445
+ metadata = {}
446
+ if latest_continuation_task_id != metadata.get(
447
+ "output_task_id"
448
+ ) and latest_continuation_task_id != metadata.get("dynamic_output_task_id"):
449
+ metadata["dynamic_output_task_id"] = latest_continuation_task_id
450
+ self._put(
451
+ self._key_task_output_metadata(continuation_root_id), metadata, True
452
+ )
453
+
454
+ def _locate_output_task_id(self, task_id: TaskID) -> str:
455
+ metadata = self._get(self._key_task_output_metadata(task_id), True)
456
+ return metadata.get("dynamic_output_task_id") or metadata["output_task_id"]
457
+
458
+ def get_entrypoint_task_id(self) -> TaskID:
459
+ """Load the entrypoint task ID of the workflow.
460
+
461
+ Returns:
462
+ The ID of the entrypoint task.
463
+ """
464
+ # empty TaskID represents the workflow driver
465
+ try:
466
+ return self._locate_output_task_id("")
467
+ except Exception as e:
468
+ raise ValueError(
469
+ "Fail to get entrypoint task ID from workflow"
470
+ f"[id={self._workflow_id}]"
471
+ ) from e
472
+
473
+ def _locate_output_in_storage(self, task_id: TaskID) -> Optional[TaskID]:
474
+ result = self.inspect_task(task_id)
475
+ while isinstance(result.output_task_id, str):
476
+ task_id = result.output_task_id
477
+ result = self.inspect_task(result.output_task_id)
478
+ if result.output_object_valid:
479
+ return task_id
480
+ return None
481
+
482
+ def inspect_output(self, task_id: TaskID) -> Optional[TaskID]:
483
+ """Get the actual checkpointed output for a task, represented by the ID of
484
+ the task that actually keeps the checkpoint.
485
+
486
+ Raises:
487
+ ValueError: The workflow does not exist or the workflow state is not valid.
488
+
489
+ Args:
490
+ task_id: The ID of the task we are looking for its checkpoint.
491
+
492
+ Returns:
493
+ The ID of the task that actually keeps the checkpoint.
494
+ 'None' if the checkpoint does not exist.
495
+ """
496
+ status = self.load_workflow_status()
497
+ if status == WorkflowStatus.NONE:
498
+ raise ValueError(f"No such workflow '{self._workflow_id}'")
499
+ if status == WorkflowStatus.CANCELED:
500
+ raise ValueError(f"Workflow {self._workflow_id} is canceled")
501
+ # For resumable workflow, the workflow result is not ready.
502
+ # It has to be resumed first.
503
+ if status == WorkflowStatus.RESUMABLE:
504
+ raise ValueError(
505
+ f"Workflow {self._workflow_id} is in resumable status, please resume it"
506
+ )
507
+ if task_id is None:
508
+ task_id = self.get_entrypoint_task_id()
509
+ return self._locate_output_in_storage(task_id)
510
+
511
+ def inspect_task(self, task_id: TaskID) -> TaskInspectResult:
512
+ """
513
+ Get the status of a workflow task. The status indicates whether
514
+ the workflow task can be recovered etc.
515
+
516
+ Args:
517
+ task_id: The ID of a workflow task
518
+
519
+ Returns:
520
+ The status of the task.
521
+ """
522
+ return self._inspect_task(task_id)
523
+
524
+ def _inspect_task(self, task_id: TaskID) -> TaskInspectResult:
525
+ items = self._scan(self._key_task_prefix(task_id), ignore_errors=True)
526
+ keys = set(items)
527
+ # does this task contains output checkpoint file?
528
+ if STEP_OUTPUT in keys:
529
+ return TaskInspectResult(output_object_valid=True)
530
+ # do we know where the output comes from?
531
+ if STEP_OUTPUTS_METADATA in keys:
532
+ output_task_id = self._locate_output_task_id(task_id)
533
+ return TaskInspectResult(output_task_id=output_task_id)
534
+
535
+ # read inputs metadata
536
+ try:
537
+ metadata = self._get(self._key_task_input_metadata(task_id), True)
538
+ return TaskInspectResult(
539
+ args_valid=(STEP_ARGS in keys),
540
+ func_body_valid=(STEP_FUNC_BODY in keys),
541
+ workflow_refs=metadata["workflow_refs"],
542
+ task_options=WorkflowTaskRuntimeOptions.from_dict(
543
+ metadata["task_options"]
544
+ ),
545
+ task_raised_exception=(STEP_EXCEPTION in keys),
546
+ )
547
+ except Exception:
548
+ return TaskInspectResult(
549
+ args_valid=(STEP_ARGS in keys),
550
+ func_body_valid=(STEP_FUNC_BODY in keys),
551
+ task_raised_exception=(STEP_EXCEPTION in keys),
552
+ )
553
+
554
+ def _save_object_ref(self, identifier: str, obj_ref: ray.ObjectRef):
555
+ data = ray.get(obj_ref)
556
+ self._put(self._key_obj_id(identifier), data)
557
+
558
+ def load_actor_class_body(self) -> type:
559
+ """Load the class body of the virtual actor.
560
+
561
+ Raises:
562
+ DataLoadError: if we fail to load the class body.
563
+ """
564
+ return self._get(self._key_class_body())
565
+
566
+ def save_actor_class_body(self, cls: type) -> None:
567
+ """Save the class body of the virtual actor.
568
+
569
+ Args:
570
+ cls: The class body used by the virtual actor.
571
+
572
+ Raises:
573
+ DataSaveError: if we fail to save the class body.
574
+ """
575
+ self._put(self._key_class_body(), cls)
576
+
577
+ def save_task_prerun_metadata(self, task_id: TaskID, metadata: Dict[str, Any]):
578
+ """Save pre-run metadata of the current task.
579
+
580
+ Args:
581
+ task_id: ID of the workflow task.
582
+ metadata: pre-run metadata of the current task.
583
+
584
+ Raises:
585
+ DataSaveError: if we fail to save the pre-run metadata.
586
+ """
587
+
588
+ self._put(self._key_task_prerun_metadata(task_id), metadata, True)
589
+
590
+ def save_task_postrun_metadata(self, task_id: TaskID, metadata: Dict[str, Any]):
591
+ """Save post-run metadata of the current task.
592
+
593
+ Args:
594
+ task_id: ID of the workflow task.
595
+ metadata: post-run metadata of the current task.
596
+
597
+ Raises:
598
+ DataSaveError: if we fail to save the post-run metadata.
599
+ """
600
+
601
+ self._put(self._key_task_postrun_metadata(task_id), metadata, True)
602
+
603
+ def save_workflow_user_metadata(self, metadata: Dict[str, Any]):
604
+ """Save user metadata of the current workflow.
605
+
606
+ Args:
607
+ metadata: user metadata of the current workflow.
608
+
609
+ Raises:
610
+ DataSaveError: if we fail to save the user metadata.
611
+ """
612
+
613
+ self._put(self._key_workflow_user_metadata(), metadata, True)
614
+
615
+ def load_task_metadata(self, task_id: TaskID) -> Dict[str, Any]:
616
+ """Load the metadata of the given task.
617
+
618
+ Returns:
619
+ The metadata of the given task.
620
+ """
621
+
622
+ def _load_task_metadata():
623
+ if not self._scan(self._key_task_prefix(task_id), ignore_errors=True):
624
+ if not self._scan("", ignore_errors=True):
625
+ raise ValueError(
626
+ "No such workflow_id '{}'".format(self._workflow_id)
627
+ )
628
+ else:
629
+ raise ValueError(
630
+ "No such task_id '{}' in workflow '{}'".format(
631
+ task_id, self._workflow_id
632
+ )
633
+ )
634
+
635
+ tasks = [
636
+ self._get(self._key_task_input_metadata(task_id), True, True),
637
+ self._get(self._key_task_prerun_metadata(task_id), True, True),
638
+ self._get(self._key_task_postrun_metadata(task_id), True, True),
639
+ ]
640
+
641
+ (
642
+ (input_metadata, _),
643
+ (prerun_metadata, _),
644
+ (postrun_metadata, _),
645
+ ) = tasks
646
+
647
+ input_metadata = input_metadata or {}
648
+ prerun_metadata = prerun_metadata or {}
649
+ postrun_metadata = postrun_metadata or {}
650
+
651
+ metadata = input_metadata
652
+ metadata["stats"] = {**prerun_metadata, **postrun_metadata}
653
+
654
+ return metadata
655
+
656
+ return _load_task_metadata()
657
+
658
+ def load_workflow_metadata(self) -> Dict[str, Any]:
659
+ """Load the metadata of the current workflow.
660
+
661
+ Returns:
662
+ The metadata of the current workflow.
663
+ """
664
+
665
+ def _load_workflow_metadata():
666
+ if not self._scan("", ignore_errors=True):
667
+ raise ValueError("No such workflow_id '{}'".format(self._workflow_id))
668
+
669
+ tasks = [
670
+ self._get(self._key_workflow_metadata(), True, True),
671
+ self._get(self._key_workflow_user_metadata(), True, True),
672
+ self._get(self._key_workflow_prerun_metadata(), True, True),
673
+ self._get(self._key_workflow_postrun_metadata(), True, True),
674
+ ]
675
+
676
+ (
677
+ (status_metadata, _),
678
+ (user_metadata, _),
679
+ (prerun_metadata, _),
680
+ (postrun_metadata, _),
681
+ ) = tasks
682
+
683
+ status_metadata = status_metadata or {}
684
+ user_metadata = user_metadata or {}
685
+ prerun_metadata = prerun_metadata or {}
686
+ postrun_metadata = postrun_metadata or {}
687
+
688
+ metadata = status_metadata
689
+ metadata["user_metadata"] = user_metadata
690
+ metadata["stats"] = {**prerun_metadata, **postrun_metadata}
691
+
692
+ return metadata
693
+
694
+ return _load_workflow_metadata()
695
+
696
+ def list_workflow(
697
+ self, status_filter: Optional[Set[WorkflowStatus]] = None
698
+ ) -> List[Tuple[str, WorkflowStatus]]:
699
+ """List all workflows matching a given status filter.
700
+
701
+ Args:
702
+ status_filter: If given, only returns workflow with that status. This can
703
+ be a single status or set of statuses.
704
+ """
705
+ return self._status_storage.list_workflow(status_filter)
706
+
707
+ def delete_workflow(self) -> None:
708
+ # TODO (Alex): There's a race condition here if someone tries to
709
+ # start the workflow between these ops.
710
+ self._status_storage.delete_workflow_status(self._workflow_id)
711
+ found = self._storage.delete_dir("")
712
+ # TODO (Alex): Different file systems seem to have different
713
+ # behavior when deleting a prefix that doesn't exist, so we may
714
+ # need to catch a broader class of exceptions.
715
+
716
+ if not found:
717
+ raise WorkflowNotFoundError(self._workflow_id)
718
+
719
+ def update_workflow_status(self, status: WorkflowStatus):
720
+ """Update the status of the workflow.
721
+ This method is NOT thread-safe. It is handled by the workflow management actor.
722
+ """
723
+ self._status_storage.update_workflow_status(self._workflow_id, status)
724
+ if status == WorkflowStatus.RUNNING:
725
+ self._put(
726
+ self._key_workflow_prerun_metadata(), {"start_time": time.time()}, True
727
+ )
728
+ elif status in (WorkflowStatus.SUCCESSFUL, WorkflowStatus.FAILED):
729
+ self._put(
730
+ self._key_workflow_postrun_metadata(), {"end_time": time.time()}, True
731
+ )
732
+
733
+ def load_workflow_status(self):
734
+ """Load workflow status. If we find the previous status updating failed,
735
+ fix it with redo-log transaction recovery."""
736
+ return self._status_storage.load_workflow_status(self._workflow_id)
737
+
738
+ def _put(self, key: str, data: Any, is_json: bool = False) -> str:
739
+ """Serialize and put an object in the object store.
740
+
741
+ Args:
742
+ key: The key of the object.
743
+ data: The data to be stored.
744
+ is_json: If true, json encode the data, otherwise pickle it.
745
+ """
746
+ # TODO(suquark): Currently put to file is not atomic -- you can get a partial
747
+ # file. This could fail workflow recovery.
748
+ try:
749
+ if not is_json:
750
+ serialization.dump_to_storage(
751
+ key, data, self._workflow_id, storage=self
752
+ )
753
+ else:
754
+ serialized_data = json.dumps(data).encode()
755
+ self._storage.put(key, serialized_data)
756
+ except Exception as e:
757
+ raise DataSaveError from e
758
+
759
+ return key
760
+
761
+ def _get(self, key: str, is_json: bool = False, no_exception: bool = False) -> Any:
762
+ err = None
763
+ ret = None
764
+ try:
765
+ unmarshaled = self._storage.get(key)
766
+ if unmarshaled is None:
767
+ raise KeyNotFoundError
768
+ if is_json:
769
+ ret = json.loads(unmarshaled.decode())
770
+ else:
771
+ ret = cloudpickle.loads(unmarshaled)
772
+ except KeyNotFoundError as e:
773
+ err = e
774
+ except Exception as e:
775
+ err = DataLoadError()
776
+ err.__cause__ = e
777
+
778
+ if no_exception:
779
+ return (ret, err)
780
+ elif err is None:
781
+ return ret
782
+ else:
783
+ raise err
784
+
785
+ def _scan(self, prefix: str, ignore_errors: bool = False) -> List[str]:
786
+ try:
787
+ return [p.base_name for p in self._storage.list(prefix)]
788
+ except Exception as e:
789
+ if ignore_errors:
790
+ return []
791
+ raise e
792
+
793
+ def _exists(self, key: str) -> bool:
794
+ return self._storage.get_info(key) is not None
795
+
796
+ # The following functions are helper functions to get the key
797
+ # for a specific fields
798
+
799
+ def _key_task_input_metadata(self, task_id):
800
+ return os.path.join(STEPS_DIR, task_id, STEP_INPUTS_METADATA)
801
+
802
+ def _key_task_user_metadata(self, task_id):
803
+ return os.path.join(STEPS_DIR, task_id, STEP_USER_METADATA)
804
+
805
+ def _key_task_prerun_metadata(self, task_id):
806
+ return os.path.join(STEPS_DIR, task_id, STEP_PRERUN_METADATA)
807
+
808
+ def _key_task_postrun_metadata(self, task_id):
809
+ return os.path.join(STEPS_DIR, task_id, STEP_POSTRUN_METADATA)
810
+
811
+ def _key_task_output(self, task_id):
812
+ return os.path.join(STEPS_DIR, task_id, STEP_OUTPUT)
813
+
814
+ def _key_task_exception(self, task_id):
815
+ return os.path.join(STEPS_DIR, task_id, STEP_EXCEPTION)
816
+
817
+ def _key_task_output_metadata(self, task_id):
818
+ return os.path.join(STEPS_DIR, task_id, STEP_OUTPUTS_METADATA)
819
+
820
+ def _key_task_function_body(self, task_id):
821
+ return os.path.join(STEPS_DIR, task_id, STEP_FUNC_BODY)
822
+
823
+ def _key_task_args(self, task_id):
824
+ return os.path.join(STEPS_DIR, task_id, STEP_ARGS)
825
+
826
+ def _key_obj_id(self, object_id):
827
+ return os.path.join(OBJECTS_DIR, object_id)
828
+
829
+ def _key_task_prefix(self, task_id):
830
+ return os.path.join(STEPS_DIR, task_id, "")
831
+
832
+ def _key_class_body(self):
833
+ return os.path.join(CLASS_BODY)
834
+
835
+ def _key_workflow_metadata(self):
836
+ return os.path.join(WORKFLOW_META)
837
+
838
+ def _key_workflow_user_metadata(self):
839
+ return os.path.join(WORKFLOW_USER_METADATA)
840
+
841
+ def _key_workflow_prerun_metadata(self):
842
+ return os.path.join(WORKFLOW_PRERUN_METADATA)
843
+
844
+ def _key_workflow_postrun_metadata(self):
845
+ return os.path.join(WORKFLOW_POSTRUN_METADATA)
846
+
847
+ def _key_num_tasks_with_name(self, task_name):
848
+ return os.path.join(DUPLICATE_NAME_COUNTER, task_name)
849
+
850
+
851
+ def get_workflow_storage(workflow_id: Optional[str] = None) -> WorkflowStorage:
852
+ """Get the storage for the workflow.
853
+
854
+ Args:
855
+ workflow_id: The ID of the storage.
856
+
857
+ Returns:
858
+ A workflow storage.
859
+ """
860
+ if workflow_id is None:
861
+ workflow_id = workflow_context.get_workflow_task_context().workflow_id
862
+ return WorkflowStorage(workflow_id)
863
+
864
+
865
+ def _load_object_ref(paths: List[str], wf_storage: WorkflowStorage) -> ObjectRef:
866
+ @ray.remote(num_cpus=0)
867
+ def load_ref(paths: List[str], wf_storage: WorkflowStorage):
868
+ return wf_storage._get(paths)
869
+
870
+ return load_ref.remote(paths, wf_storage)
871
+
872
+
873
+ @ray.remote(num_cpus=0)
874
+ def _put_obj_ref(ref: Tuple[ObjectRef]):
875
+ """
876
+ Return a ref to an object ref. (This can't be done with
877
+ `ray.put(obj_ref)`).
878
+
879
+ """
880
+ return ref[0]
.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37a50a2f6a1b9d5981382db3935480f1eb48babfd975e7cc5008363f657ed26e
3
+ size 120846
.venv/lib/python3.11/site-packages/torchgen/api/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchgen/api/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file