diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/autoscaler_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/autoscaler_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4dafdc5301757fefe34f265d14886cd26758dfd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/autoscaler_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/common_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/common_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1e0d103b5a0b4ba3b3ad1ef3cf70f6d80d6f4cb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/common_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/core_worker_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/core_worker_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffcfc7a44a07368c8d5e11d91296f265a60dddc2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/core_worker_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/dependency_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/dependency_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa6369d5dba1bba851fad126c9ec4010348a8f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/dependency_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/event_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/event_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ef5b7e67f15a6d7f0588003e152cb014065fe9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/event_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_actor_data_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_actor_data_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b2ee418738805e88eaa6269fa8941428965e2d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_actor_data_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ae4d2e300e85e8c5fefac7673c2e4f6447c0bad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68555daa5b40102f60166f91705a361808be7a6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_driver_job_event_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_event_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_event_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..077f3dae8330174989517f99c4809d767d970ee2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_event_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b6d48c015a34821585ae09a41e8e88d221aafb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adf316dd0b580a09262611f520a7760f05acb15a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_node_data_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da7c3c13b71ed5aabacacb772fb20522df2ec42 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6e1d10f51357e4ed36d2a8fd3c666854971f22 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_runtime_env_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_submission_job_event_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_submission_job_event_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d70b04027b9fbafbeff7c8df9f9c984444f15806 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_submission_job_event_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb2b501aad0a192da7a71cd1c3d74528aeae249c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37f370f7427abd064b50cad110634ff645163878 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/export_task_event_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/gcs_service_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/gcs_service_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..097dc3b76cc8c5f921b80b670c7c032cd9ebe104 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/gcs_service_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/instance_manager_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/instance_manager_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..175bd0be60ec06da350a9aeb4ec08b9276bc64df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/instance_manager_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfae203ca45d09f03bae7d61b65e7444e94bbf8f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8d20bf4103aace6ffb6994b06d439bd01c5ac8f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/logging_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adcc5b23faf758df7bb6b37add558a0c787d9618 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1048bef69a2a954278c439cfe9f60d1298766b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/metrics_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/node_manager_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/node_manager_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb0a6fa477fffb077d6302371efbe1c7f9fa4467 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/node_manager_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b64439202599302b93ce5ad2ffd50f9c712a80e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93c7edda360274c4e02ba5d1d2e66309d122b3c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/ray_client_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/reporter_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/reporter_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4be38fd7f077bbfa418d4ea3a9eeda55add00592 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/reporter_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0d9f3b6c682fe3e3cba1feed8b6f30a851672d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd687a3ca8f5fd068823eaa97a25fc9ae3328dad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_agent_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e10924e5548296156504f9182f317b47189d876 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2_grpc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2_grpc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..498d3df71db70921805c157a9657387711e6843a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/runtime_env_common_pb2_grpc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/usage_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/usage_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbaa9b5d4ecdcccf0c425d4943d9c0f51dc03ac8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/generated/__pycache__/usage_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d2eb16108a9e22da9ff8c3ef1dfd6713a2851b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a28acfb533c6ac5b7b73f0a8d8fd1b946f24bab1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/linalg.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/linalg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acce6f532fe0007aeafa7fd27fa7b59963843c86 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/linalg.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/random.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a960b84d886757fdf88a6ac9626b07a0f9ec292a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/distributed/__pycache__/random.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6f731798ec76a0d4557ac4e8f725e9990213b4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__init__.py @@ -0,0 +1,44 @@ +from . import linalg, random +from .core import ( + add, + copy, + diag, + dot, + eye, + hstack, + ones, + shape, + subarray, + subtract, + sum, + sum_list, + transpose, + tril, + triu, + vstack, + zeros, + zeros_like, +) + +__all__ = [ + "random", + "linalg", + "zeros", + "zeros_like", + "ones", + "eye", + "dot", + "vstack", + "hstack", + "subarray", + "copy", + "tril", + "triu", + "diag", + "transpose", + "add", + "subtract", + "sum", + "shape", + "sum_list", +] diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfed240bee27c8f16f49007229d7a78145d64bbc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ffc59feb64ccc540af776f96c8a21e8d2f7c846 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/linalg.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/linalg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..840c4a19ec33001e0bcdbcb4068ef74c52f57888 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/linalg.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/random.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46fb658cdd1417786c538a396fbe9fa75e5d31b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/__pycache__/random.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/core.py b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/core.py new file mode 100644 index 0000000000000000000000000000000000000000..bf22131ed067cb6cb96bb760a2762bea07b546d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/core.py @@ -0,0 +1,99 @@ +import numpy as np + +import ray + + +@ray.remote +def zeros(shape, dtype_name="float", order="C"): + return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) + + +@ray.remote +def zeros_like(a, dtype_name="None", order="K", subok=True): + dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) + return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) + + +@ray.remote +def ones(shape, dtype_name="float", order="C"): + return np.ones(shape, dtype=np.dtype(dtype_name), order=order) + + +@ray.remote +def eye(N, M=-1, k=0, dtype_name="float"): + M = N if M == -1 else M + return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) + + +@ray.remote +def dot(a, b): + return np.dot(a, b) + + +@ray.remote +def vstack(*xs): + return np.vstack(xs) + + +@ray.remote +def hstack(*xs): + return np.hstack(xs) + + +# TODO(rkn): Instead of this, consider implementing slicing. +# TODO(rkn): Be consistent about using "index" versus "indices". +@ray.remote +def subarray(a, lower_indices, upper_indices): + idx = tuple(slice(l, u) for (l, u) in zip(lower_indices, upper_indices)) + return a[idx] + + +@ray.remote +def copy(a, order="K"): + return np.copy(a, order=order) + + +@ray.remote +def tril(m, k=0): + return np.tril(m, k=k) + + +@ray.remote +def triu(m, k=0): + return np.triu(m, k=k) + + +@ray.remote +def diag(v, k=0): + return np.diag(v, k=k) + + +@ray.remote +def transpose(a, axes=None): + axes = None if (axes == [] or axes is None) else axes + return np.transpose(a, axes=axes) + + +@ray.remote +def add(x1, x2): + return np.add(x1, x2) + + +@ray.remote +def subtract(x1, x2): + return np.subtract(x1, x2) + + +@ray.remote +def sum(x, axis=-1): + return np.sum(x, axis=axis if axis != -1 else None) + + +@ray.remote +def shape(a): + return np.shape(a) + + +@ray.remote +def sum_list(*xs): + return np.sum(xs, axis=0) diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/linalg.py b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..d3104e92f70d80afd44e8edc413b01e50e927d24 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/linalg.py @@ -0,0 +1,126 @@ +import numpy as np + +import ray + +__all__ = [ + "matrix_power", + "solve", + "tensorsolve", + "tensorinv", + "inv", + "cholesky", + "eigvals", + "eigvalsh", + "pinv", + "slogdet", + "det", + "svd", + "eig", + "eigh", + "lstsq", + "norm", + "qr", + "cond", + "matrix_rank", + "multi_dot", +] + + +@ray.remote +def matrix_power(M, n): + return np.linalg.matrix_power(M, n) + + +@ray.remote +def solve(a, b): + return np.linalg.solve(a, b) + + +@ray.remote(num_returns=2) +def tensorsolve(a): + raise NotImplementedError + + +@ray.remote(num_returns=2) +def tensorinv(a): + raise NotImplementedError + + +@ray.remote +def inv(a): + return np.linalg.inv(a) + + +@ray.remote +def cholesky(a): + return np.linalg.cholesky(a) + + +@ray.remote +def eigvals(a): + return np.linalg.eigvals(a) + + +@ray.remote +def eigvalsh(a): + raise NotImplementedError + + +@ray.remote +def pinv(a): + return np.linalg.pinv(a) + + +@ray.remote +def slogdet(a): + raise NotImplementedError + + +@ray.remote +def det(a): + return np.linalg.det(a) + + +@ray.remote(num_returns=3) +def svd(a): + return np.linalg.svd(a) + + +@ray.remote(num_returns=2) +def eig(a): + return np.linalg.eig(a) + + +@ray.remote(num_returns=2) +def eigh(a): + return np.linalg.eigh(a) + + +@ray.remote(num_returns=4) +def lstsq(a, b): + return np.linalg.lstsq(a) + + +@ray.remote +def norm(x): + return np.linalg.norm(x) + + +@ray.remote(num_returns=2) +def qr(a): + return np.linalg.qr(a) + + +@ray.remote +def cond(x): + return np.linalg.cond(x) + + +@ray.remote +def matrix_rank(M): + return np.linalg.matrix_rank(M) + + +@ray.remote +def multi_dot(*a): + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/random.py b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/random.py new file mode 100644 index 0000000000000000000000000000000000000000..892b4cc4a87819555086fdaee91a63c5a2c2cf43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/array/remote/random.py @@ -0,0 +1,8 @@ +import numpy as np + +import ray + + +@ray.remote +def normal(shape): + return np.random.normal(size=shape) diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/dynamic_resources.py b/.venv/lib/python3.11/site-packages/ray/experimental/dynamic_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7bd3948095e911d4a8d1bbb82e13db305e0c4a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/dynamic_resources.py @@ -0,0 +1,7 @@ +def set_resource(resource_name, capacity, node_id=None): + raise DeprecationWarning( + "Dynamic custom resources are deprecated. Consider using placement " + "groups instead (docs.ray.io/en/master/placement-group.html). You " + "can also specify resources at Ray start time with the 'resources' " + "field in the cluster autoscaler." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/job/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/job/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/job/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/job/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f20a8b827ecfbdd02babc95e3b65b374be1390ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/job/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/job/example_job/demo_script.py b/.venv/lib/python3.11/site-packages/ray/experimental/job/example_job/demo_script.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcc79611e26d404a7f7d0dbd86b22f1424420c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/job/example_job/demo_script.py @@ -0,0 +1,77 @@ +# Regular ray application that user wrote and runs on local cluster. +# intermediate status are dumped to GCS +import argparse +import time + +import ray +import ray.experimental.internal_kv as ray_kv + + +@ray.remote +class StepActor: + def __init__(self, interval_s=1, total_steps=3): + self.interval_s = interval_s + self.stopped = False + self.current_step = 1 + self.total_steps = total_steps + + worker = ray._private.worker.global_worker + worker_id = worker.core_worker.get_actor_id() + ray_kv._internal_kv_put(f"JOB:{worker_id}", self.current_step, overwrite=True) + + def run(self): + worker = ray._private.worker.global_worker + worker_id = worker.core_worker.get_actor_id() + + while self.current_step <= self.total_steps: + if not self.stopped: + print( + f"Sleeping {self.interval_s} secs to executing " + f"step {self.current_step}" + ) + time.sleep(self.interval_s) + self.current_step += 1 + ray_kv._internal_kv_put( + f"JOB:{worker_id}", self.current_step, overwrite=True + ) + else: + print("Stop called or reached final step.") + break + + self.stopped = True + ray_kv._internal_kv_put(f"JOB:{worker_id}", "DONE", overwrite=True) + return "DONE" + + def get_step(self): + return self.current_step + + def stop(self): + self.stopped = True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--interval-s", + required=False, + type=int, + default=1, + help="Address to use to connect to Ray", + ) + parser.add_argument( + "--total-steps", + required=False, + type=int, + default=3, + help="Password for connecting to Redis", + ) + args, _ = parser.parse_known_args() + + ray.init() + step_actor = StepActor.remote( + interval_s=args.interval_s, total_steps=args.total_steps + ) + ref = step_actor.run.remote() + print(ray.get([ref])) + job_key = ray_kv._internal_kv_list("JOB:")[0] + print(f"{job_key}, {ray_kv._internal_kv_get(job_key)}") diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b741f1ebfc72812a214fe4c52b2894bd965602 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__init__.py @@ -0,0 +1,5 @@ +from multiprocessing import TimeoutError + +from .pool import Pool + +__all__ = ["Pool", "TimeoutError"] diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bca209a17dc05f552c44e4de170f2b8db030879 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/pool.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44aad3021a48a30af1d99417dbe94f51c33dc534 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/__pycache__/pool.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/pool.py b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0e442ff04b11a66baa1ac496d1acf4f6d55299 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/multiprocessing/pool.py @@ -0,0 +1,5 @@ +from ray.util import multiprocessing + + +class Pool(multiprocessing.Pool): + pass # moved to util package diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__init__.py b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca26e060e21043bf0d3ce53ea87ffc887085adc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/load_package.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/load_package.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1a37f80321d726854b844904cc2a771db102ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/__pycache__/load_package.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/experimental/packaging/load_package.py b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/load_package.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d883043b3398b3fb9c24a98a994ddf0b7fe994 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/experimental/packaging/load_package.py @@ -0,0 +1,235 @@ +"""Support for loading code packages into Ray at runtime. + +Ray packages allow developers to define self-contained code modules that can +be imported reproducibly into any Ray cluster. Each package can define its own +runtime environment, which can include: + - Different versions of code (e.g., from different git commits). + - Different Python libraries (e.g., conda environments, pip dependencies). + - Different Docker container images. + +You can run this file for an example of loading a "hello world" package. +""" + +import hashlib +import importlib.util +import os +import re +import subprocess +import tempfile + +import yaml + +import ray + + +def load_package(config_path: str) -> "_RuntimePackage": + """Load the code package given its config path. + + Args: + config_path: The path to the configuration YAML that defines + the package. For documentation on the packaging format, see the + example YAML in ``example_pkg/ray_pkg.yaml``. + + Examples: + + .. code-block :: python + + # Load from local. + my_pkg = load_package("~/path/to/my_pkg.yaml") + + # Load from GitHub. + my_pkg = ray.util.load_package( + "https://raw.githubusercontent.com/user/repo/refspec" + "/path/to/package/my_pkg.yaml") + + # Inspect the package runtime env. + print(my_pkg._runtime_env) + {"conda": {...}, + "docker": "anyscale-ml/ray-ml:nightly-py38-cpu", + "working_dir": "https://github.com/demo/foo/blob/v3.0/project/"} + + # Run remote functions from the package. + my_pkg.my_func.remote(1, 2) + + # Create actors from the package. + actor = my_pkg.MyActor.remote(3, 4) + + # Create new remote funcs in the same env as a package. + @ray.remote(runtime_env=my_pkg._runtime_env) + def f(): ... + """ + + from ray._private.runtime_env.packaging import ( + get_uri_for_directory, + upload_package_if_needed, + ) + + config_path = _download_from_github_if_needed(config_path) + + if not os.path.exists(config_path): + raise ValueError("Config file does not exist: {}".format(config_path)) + + # TODO(ekl) validate schema? + config = yaml.safe_load(open(config_path).read()) + base_dir = os.path.abspath(os.path.dirname(config_path)) + runtime_env = config["runtime_env"] + + # Autofill working directory by uploading to GCS storage. + if "working_dir" not in runtime_env: + pkg_uri = get_uri_for_directory(base_dir, excludes=[]) + + def do_register_package(): + # TODO(ekl) does this get garbage collected correctly with the + # current job id? + upload_package_if_needed(pkg_uri, _pkg_tmp(), base_dir) + + if ray.is_initialized(): + do_register_package() + else: + ray._private.worker._post_init_hooks.append(do_register_package) + runtime_env["working_dir"] = pkg_uri + + # Autofill conda config. + conda_yaml = os.path.join(base_dir, "conda.yaml") + if os.path.exists(conda_yaml): + if "conda" in runtime_env: + raise ValueError("Both conda.yaml and conda: section found in package") + runtime_env["conda"] = yaml.safe_load(open(conda_yaml).read()) + + pkg = _RuntimePackage( + name=config["name"], + desc=config["description"], + interface_file=os.path.join(base_dir, config["interface_file"]), + runtime_env=runtime_env, + ) + return pkg + + +def _download_from_github_if_needed(config_path: str) -> str: + """Resolve a GitHub raw link to the config file to a local path. + + If the user specifies a GitHub raw URL, download the repo specified at + that particular URL locally. This lets us treat YAMLs linked from GitHub + the same as local files. + """ + if config_path.startswith("http"): + if "github" not in config_path: + raise ValueError("Only GitHub URLs are supported by load_package().") + if "raw.githubusercontent.com" not in config_path: + raise ValueError("GitHub URL must start with raw.githubusercontent.com") + URL_FORMAT = ".*raw.githubusercontent.com/([^/]*)/([^/]*)/([^/]*)/(.*)" + match = re.match(URL_FORMAT, config_path) + if not match: + raise ValueError("GitHub URL must be of format {}".format(URL_FORMAT)) + gh_user = match.group(1) + gh_repo = match.group(2) + gh_branch = match.group(3) + gh_subdir = match.group(4) + + # Compute the cache key based on the URL. + hasher = hashlib.sha1() + hasher.update(config_path.encode("utf-8")) + config_key = hasher.hexdigest() + final_path = os.path.join(_pkg_tmp(), "github_snapshot_{}".format(config_key)) + + # Only download the repo if needed. + if not os.path.exists(final_path): + tmp = tempfile.mkdtemp(prefix="github_{}".format(gh_repo), dir=_pkg_tmp()) + subprocess.check_call( + [ + "curl", + "--fail", + "-L", + "https://github.com/{}/{}/tarball/{}".format( + gh_user, gh_repo, gh_branch + ), + "--output", + tmp + ".tar.gz", + ] + ) + subprocess.check_call( + ["tar", "xzf", tmp + ".tar.gz", "-C", tmp, "--strip-components=1"] + ) + os.rename(tmp, final_path) + return os.path.join(final_path, gh_subdir) + + return config_path + + +class _RuntimePackage: + """Represents a Ray package loaded via ``load_package()``. + + This class provides access to the symbols defined by the interface file of + the package (e.g., remote functions and actor definitions). You can also + access the raw runtime env defined by the package via ``pkg._runtime_env``. + """ + + def __init__(self, name: str, desc: str, interface_file: str, runtime_env: dict): + self._name = name + self._description = desc + self._interface_file = interface_file + self._runtime_env = runtime_env + _validate_interface_file(self._interface_file) + + spec = importlib.util.spec_from_file_location(self._name, self._interface_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._module = module + + for symbol in dir(self._module): + if not symbol.startswith("_"): + value = getattr(self._module, symbol) + if isinstance(value, ray.remote_function.RemoteFunction) or isinstance( + value, ray.actor.ActorClass + ): + setattr(self, symbol, value.options(runtime_env=runtime_env)) + + def __repr__(self): + return "ray._RuntimePackage(module={}, runtime_env={})".format( + self._module, self._runtime_env + ) + + +def _validate_interface_file(interface_file: str): + if not os.path.exists(interface_file): + raise ValueError("Interface file does not exist: {}".format(interface_file)) + for line in open(interface_file): + line = line.replace("\n", "") + if line.startswith("import ") or line.startswith("from "): + if line != "import ray" and "noqa" not in line: + raise ValueError( + "Interface files are only allowed to import `ray` " + "at top-level, found `{}`. Please either remove or " + "change this into a lazy import. To unsafely allow " + "this import, add `# noqa` to the line " + "in question.".format(line) + ) + + +def _pkg_tmp(): + tmp = "/tmp/ray/packaging" + os.makedirs(tmp, exist_ok=True) + return tmp + + +if __name__ == "__main__": + ray.init() + + print("-> Testing load local") + pkg = load_package("./example_pkg/ray_pkg.yaml") + print("-> Loaded package", pkg) + print("-> Package symbols", [x for x in dir(pkg) if not x.startswith("_")]) + print("-> Testing actor call") + a = pkg.MyActor.remote() + print(ray.get(a.f.remote())) + print("-> Testing method call") + print(ray.get(pkg.my_func.remote())) + + print("-> Testing load from github") + pkg2 = load_package( + "http://raw.githubusercontent.com/ray-project/ray/master/" + "python/ray/experimental/packaging/example_pkg/ray_pkg.yaml" + ) + print("-> Loaded package", pkg2) + print("-> Testing method call") + print(ray.get(pkg2.my_func.remote())) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/api.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/api.py new file mode 100644 index 0000000000000000000000000000000000000000..954d71cac10d49d4a33e1cfc8e26ed9bbb543cd7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/api.py @@ -0,0 +1,294 @@ +import inspect +import logging +from types import FunctionType +from typing import Any, Dict, Union + +import ray +from ray._private.pydantic_compat import is_subclass_of_base_model +from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME +from ray._private.usage import usage_lib +from ray.actor import ActorHandle +from ray.serve._private.client import ServeControllerClient +from ray.serve._private.constants import ( + CONTROLLER_MAX_CONCURRENCY, + HTTP_PROXY_TIMEOUT, + RAY_SERVE_ENABLE_TASK_EVENTS, + SERVE_CONTROLLER_NAME, + SERVE_LOGGER_NAME, + SERVE_NAMESPACE, +) +from ray.serve._private.controller import ServeController +from ray.serve.config import HTTPOptions, gRPCOptions +from ray.serve.context import _get_global_client, _set_global_client +from ray.serve.deployment import Application +from ray.serve.exceptions import RayServeException +from ray.serve.schema import LoggingConfig + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +def _check_http_options( + client: ServeControllerClient, http_options: Union[dict, HTTPOptions] +) -> None: + if http_options: + client_http_options = client.http_config + new_http_options = ( + http_options + if isinstance(http_options, HTTPOptions) + else HTTPOptions.parse_obj(http_options) + ) + different_fields = [] + all_http_option_fields = new_http_options.__dict__ + for field in all_http_option_fields: + if getattr(new_http_options, field) != getattr(client_http_options, field): + different_fields.append(field) + + if len(different_fields): + logger.warning( + "The new client HTTP config differs from the existing one " + f"in the following fields: {different_fields}. " + "The new HTTP config is ignored." + ) + + +def _start_controller( + http_options: Union[None, dict, HTTPOptions] = None, + grpc_options: Union[None, dict, gRPCOptions] = None, + global_logging_config: Union[None, dict, LoggingConfig] = None, + **kwargs, +) -> ActorHandle: + """Start Ray Serve controller. + + The function makes sure controller is ready to start deploying apps + after it returns. + + Parameters are same as ray.serve._private.api.serve_start(). + + Returns: controller actor handle. + """ + + # Initialize ray if needed. + ray._private.worker.global_worker._filter_logs_by_job = False + if not ray.is_initialized(): + ray.init(namespace=SERVE_NAMESPACE) + + # Legacy http proxy actor check + http_deprecated_args = ["http_host", "http_port", "http_middlewares"] + for key in http_deprecated_args: + if key in kwargs: + raise ValueError( + f"{key} is deprecated, please use serve.start(http_options=" + f'{{"{key}": {kwargs[key]}}}) instead.' + ) + + if isinstance(http_options, dict): + http_options = HTTPOptions.parse_obj(http_options) + if http_options is None: + http_options = HTTPOptions() + + if isinstance(grpc_options, dict): + grpc_options = gRPCOptions(**grpc_options) + + if global_logging_config is None: + global_logging_config = LoggingConfig() + elif isinstance(global_logging_config, dict): + global_logging_config = LoggingConfig(**global_logging_config) + + controller = ServeController.options( + num_cpus=0, + name=SERVE_CONTROLLER_NAME, + lifetime="detached", + max_restarts=-1, + max_task_retries=-1, + resources={HEAD_NODE_RESOURCE_NAME: 0.001}, + namespace=SERVE_NAMESPACE, + max_concurrency=CONTROLLER_MAX_CONCURRENCY, + enable_task_events=RAY_SERVE_ENABLE_TASK_EVENTS, + ).remote( + http_config=http_options, + grpc_options=grpc_options, + global_logging_config=global_logging_config, + ) + + proxy_handles = ray.get(controller.get_proxies.remote()) + if len(proxy_handles) > 0: + try: + ray.get( + [handle.ready.remote() for handle in proxy_handles.values()], + timeout=HTTP_PROXY_TIMEOUT, + ) + except ray.exceptions.GetTimeoutError: + raise TimeoutError( + f"HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s." + ) + return controller + + +async def serve_start_async( + http_options: Union[None, dict, HTTPOptions] = None, + grpc_options: Union[None, dict, gRPCOptions] = None, + global_logging_config: Union[None, dict, LoggingConfig] = None, + **kwargs, +) -> ServeControllerClient: + """Initialize a serve instance asynchronously. + + This function is not thread-safe. The caller should maintain the async lock in order + to start the serve instance asynchronously. + + This function has the same functionality as ray.serve._private.api.serve_start(). + + Parameters & Returns are same as ray.serve._private.api.serve_start(). + """ + + usage_lib.record_library_usage("serve") + + try: + client = _get_global_client(_health_check_controller=True) + logger.info( + f'Connecting to existing Serve app in namespace "{SERVE_NAMESPACE}".' + " New http options will not be applied." + ) + if http_options: + _check_http_options(client, http_options) + return client + except RayServeException: + pass + + controller = ( + await ray.remote(_start_controller) + .options(num_cpus=0) + .remote(http_options, grpc_options, global_logging_config, **kwargs) + ) + + client = ServeControllerClient( + controller, + ) + _set_global_client(client) + logger.info(f'Started Serve in namespace "{SERVE_NAMESPACE}".') + return client + + +def serve_start( + http_options: Union[None, dict, HTTPOptions] = None, + grpc_options: Union[None, dict, gRPCOptions] = None, + global_logging_config: Union[None, dict, LoggingConfig] = None, + **kwargs, +) -> ServeControllerClient: + """Initialize a serve instance. + + By default, the instance will be scoped to the lifetime of the returned + Client object (or when the script exits). This is + only relevant if connecting to a long-running Ray cluster (e.g., with + ray.init(address="auto") or ray.init("ray://")). + + Args: + http_options (Optional[Dict, serve.HTTPOptions]): Configuration options + for HTTP proxy. You can pass in a dictionary or HTTPOptions object + with fields: + + - host(str, None): Host for HTTP servers to listen on. Defaults to + "127.0.0.1". To expose Serve publicly, you probably want to set + this to "0.0.0.0". + - port(int): Port for HTTP server. Defaults to 8000. + - root_path(str): Root path to mount the serve application + (for example, "/serve"). All deployment routes will be prefixed + with this path. Defaults to "". + - middlewares(list): A list of Starlette middlewares that will be + applied to the HTTP servers in the cluster. Defaults to []. + - location(str, serve.config.DeploymentMode): The deployment + location of HTTP servers: + + - "HeadOnly": start one HTTP server on the head node. Serve + assumes the head node is the node you executed serve.start + on. This is the default. + - "EveryNode": start one HTTP server per node. + - "NoServer" or None: disable HTTP server. + - num_cpus (int): The number of CPU cores to reserve for each + internal Serve HTTP proxy actor. Defaults to 0. + grpc_options: [Experimental] Configuration options for gRPC proxy. + You can pass in a gRPCOptions object with fields: + + - port(int): Port for gRPC server. Defaults to 9000. + - grpc_servicer_functions(list): List of import paths for gRPC + `add_servicer_to_server` functions to add to Serve's gRPC proxy. + Default empty list, meaning not to start the gRPC server. + """ + + usage_lib.record_library_usage("serve") + + try: + client = _get_global_client(_health_check_controller=True) + logger.info( + f'Connecting to existing Serve app in namespace "{SERVE_NAMESPACE}".' + " New http options will not be applied." + ) + if http_options: + _check_http_options(client, http_options) + return client + except RayServeException: + pass + + controller = _start_controller( + http_options, grpc_options, global_logging_config, **kwargs + ) + + client = ServeControllerClient( + controller, + ) + _set_global_client(client) + logger.info(f'Started Serve in namespace "{SERVE_NAMESPACE}".') + return client + + +def call_user_app_builder_with_args_if_necessary( + builder: Union[Application, FunctionType], + args: Dict[str, Any], +) -> Application: + """Calls a user-provided function that returns Serve application. + + If an Application object is passed, this is a no-op. + + Else, we validate the signature of the function, convert the args dictionary to + the user-annotated Pydantic model if provided, and call the function. + + The output of the function is returned (must be an Application). + """ + if isinstance(builder, Application): + if len(args) > 0: + raise ValueError( + "Arguments can only be passed to an application builder function, " + "not an already built application." + ) + return builder + elif not isinstance(builder, FunctionType): + raise TypeError( + "Expected a built Serve application or an application builder function " + f"but got: {type(builder)}." + ) + + # Check that the builder only takes a single argument. + # TODO(edoakes): we may want to loosen this to allow optional kwargs in the future. + signature = inspect.signature(builder) + if len(signature.parameters) != 1: + raise TypeError( + "Application builder functions should take exactly one parameter, " + "a dictionary containing the passed arguments." + ) + + # If the sole argument to the builder is a pydantic model, convert the args dict to + # that model. This will perform standard pydantic validation (e.g., raise an + # exception if required fields are missing). + param = signature.parameters[list(signature.parameters.keys())[0]] + if inspect.isclass(param.annotation) and is_subclass_of_base_model( + param.annotation + ): + args = param.annotation.parse_obj(args) + + app = builder(args) + if not isinstance(app, Application): + raise TypeError( + "Application builder functions must return an `Application` returned " + f"`from `Deployment.bind()`, but got: {type(app)}." + ) + + return app diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/application_state.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/application_state.py new file mode 100644 index 0000000000000000000000000000000000000000..490ad98a987fa5a454ea2ad54a3c78e5e33998e9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/application_state.py @@ -0,0 +1,1329 @@ +import json +import logging +import os +import time +import traceback +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Callable, Dict, List, Optional, Tuple + +import ray +from ray import cloudpickle +from ray._private.utils import import_attr +from ray.exceptions import RuntimeEnvSetupError +from ray.serve._private.build_app import BuiltApplication, build_app +from ray.serve._private.common import ( + DeploymentID, + DeploymentStatus, + DeploymentStatusInfo, + DeploymentStatusTrigger, + EndpointInfo, + TargetCapacityDirection, +) +from ray.serve._private.config import DeploymentConfig +from ray.serve._private.constants import RAY_SERVE_ENABLE_TASK_EVENTS, SERVE_LOGGER_NAME +from ray.serve._private.deploy_utils import ( + deploy_args_to_deployment_info, + get_app_code_version, + get_deploy_args, +) +from ray.serve._private.deployment_info import DeploymentInfo +from ray.serve._private.deployment_state import DeploymentStateManager +from ray.serve._private.endpoint_state import EndpointState +from ray.serve._private.logging_utils import configure_component_logger +from ray.serve._private.storage.kv_store import KVStoreBase +from ray.serve._private.usage import ServeUsageTag +from ray.serve._private.utils import ( + DEFAULT, + check_obj_ref_ready_nowait, + override_runtime_envs_except_env_vars, + validate_route_prefix, +) +from ray.serve.config import AutoscalingConfig +from ray.serve.exceptions import RayServeException +from ray.serve.generated.serve_pb2 import ApplicationStatus as ApplicationStatusProto +from ray.serve.generated.serve_pb2 import ( + ApplicationStatusInfo as ApplicationStatusInfoProto, +) +from ray.serve.generated.serve_pb2 import DeploymentLanguage +from ray.serve.generated.serve_pb2 import ( + DeploymentStatusInfoList as DeploymentStatusInfoListProto, +) +from ray.serve.generated.serve_pb2 import StatusOverview as StatusOverviewProto +from ray.serve.schema import ( + APIType, + ApplicationStatus, + DeploymentDetails, + LoggingConfig, + ServeApplicationSchema, +) +from ray.types import ObjectRef + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +CHECKPOINT_KEY = "serve-application-state-checkpoint" + + +class BuildAppStatus(Enum): + """Status of the build application task.""" + + NO_TASK_IN_PROGRESS = 1 + IN_PROGRESS = 2 + SUCCEEDED = 3 + FAILED = 4 + + +@dataclass +class BuildAppTaskInfo: + """Stores info on the current in-progress build app task. + + We use a class instead of only storing the task object ref because + when a new config is deployed, there can be an outdated in-progress + build app task. We attach the code version to the task info to + distinguish outdated build app tasks. + """ + + obj_ref: ObjectRef + code_version: str + config: ServeApplicationSchema + target_capacity: Optional[float] + target_capacity_direction: Optional[TargetCapacityDirection] + finished: bool + + +@dataclass(eq=True) +class ApplicationStatusInfo: + status: ApplicationStatus + message: str = "" + deployment_timestamp: float = 0 + + def debug_string(self): + return json.dumps(asdict(self), indent=4) + + def to_proto(self): + return ApplicationStatusInfoProto( + status=f"APPLICATION_STATUS_{self.status.name}", + message=self.message, + deployment_timestamp=self.deployment_timestamp, + ) + + @classmethod + def from_proto(cls, proto: ApplicationStatusInfoProto): + status = ApplicationStatusProto.Name(proto.status)[len("APPLICATION_STATUS_") :] + return cls( + status=ApplicationStatus(status), + message=proto.message, + deployment_timestamp=proto.deployment_timestamp, + ) + + +@dataclass(eq=True) +class StatusOverview: + app_status: ApplicationStatusInfo + name: str = "" + deployment_statuses: List[DeploymentStatusInfo] = field(default_factory=list) + + def debug_string(self): + return json.dumps(asdict(self), indent=4) + + def get_deployment_status(self, name: str) -> Optional[DeploymentStatusInfo]: + """Get a deployment's status by name. + + Args: + name: Deployment's name. + + Return (Optional[DeploymentStatusInfo]): Status with a name matching + the argument, if one exists. Otherwise, returns None. + """ + + for deployment_status in self.deployment_statuses: + if name == deployment_status.name: + return deployment_status + + return None + + def to_proto(self): + # Create a protobuf for the Serve Application info + app_status_proto = self.app_status.to_proto() + + # Create protobufs for all individual deployment statuses + deployment_status_protos = map( + lambda status: status.to_proto(), self.deployment_statuses + ) + + # Create a protobuf list containing all the deployment status protobufs + deployment_status_proto_list = DeploymentStatusInfoListProto() + deployment_status_proto_list.deployment_status_infos.extend( + deployment_status_protos + ) + + # Return protobuf encapsulating application and deployment protos + return StatusOverviewProto( + name=self.name, + app_status=app_status_proto, + deployment_statuses=deployment_status_proto_list, + ) + + @classmethod + def from_proto(cls, proto: StatusOverviewProto) -> "StatusOverview": + # Recreate Serve Application info + app_status = ApplicationStatusInfo.from_proto(proto.app_status) + + # Recreate deployment statuses + deployment_statuses = [] + for info_proto in proto.deployment_statuses.deployment_status_infos: + deployment_statuses.append(DeploymentStatusInfo.from_proto(info_proto)) + + # Recreate StatusInfo + return cls( + app_status=app_status, + deployment_statuses=deployment_statuses, + name=proto.name, + ) + + +@dataclass +class ApplicationTargetState: + """Defines target state of application. + + Target state can become inconsistent if the code version doesn't + match that of the config. When that happens, a new build app task + should be kicked off to reconcile the inconsistency. + + deployment_infos: map of deployment name to deployment info. This is + - None if a config was deployed but the app hasn't finished + building yet, + - An empty dict if the app is deleting. + code_version: Code version of all deployments in target state. None + if application was deployed through serve.run. + config: application config deployed by user. None if application was + deployed through serve.run. + target_capacity: the target_capacity to use when adjusting num_replicas. + target_capacity_direction: the scale direction to use when + running the Serve autoscaler. + deleting: whether the application is being deleted. + """ + + deployment_infos: Optional[Dict[str, DeploymentInfo]] + code_version: Optional[str] + config: Optional[ServeApplicationSchema] + target_capacity: Optional[float] + target_capacity_direction: Optional[TargetCapacityDirection] + deleting: bool + api_type: APIType + + +class ApplicationState: + """Manage single application states with all operations""" + + def __init__( + self, + name: str, + deployment_state_manager: DeploymentStateManager, + endpoint_state: EndpointState, + save_checkpoint_func: Callable, + logging_config: LoggingConfig, + ): + """ + Args: + name: Application name. + deployment_state_manager: State manager for all deployments + in the cluster. + endpoint_state: State manager for endpoints in the system. + save_checkpoint_func: Function that can be called to write + a checkpoint of the application state. This should be + called in self._set_target_state() before actually + setting the target state so that the controller can + properly recover application states if it crashes. + """ + + self._name = name + self._status_msg = "" + self._deployment_state_manager = deployment_state_manager + self._endpoint_state = endpoint_state + self._route_prefix: Optional[str] = None + self._docs_path: Optional[str] = None + self._ingress_deployment_name: Optional[str] = None + + self._status: ApplicationStatus = ApplicationStatus.DEPLOYING + self._deployment_timestamp = time.time() + + self._build_app_task_info: Optional[BuildAppTaskInfo] = None + # Before a deploy app task finishes, we don't know what the + # target deployments are, so set deployment_infos=None + self._target_state: ApplicationTargetState = ApplicationTargetState( + deployment_infos=None, + code_version=None, + config=None, + target_capacity=None, + target_capacity_direction=None, + deleting=False, + api_type=APIType.UNKNOWN, + ) + self._save_checkpoint_func = save_checkpoint_func + self._logging_config = logging_config + + @property + def route_prefix(self) -> Optional[str]: + return self._route_prefix + + @property + def docs_path(self) -> Optional[str]: + return self._docs_path + + @property + def status(self) -> ApplicationStatus: + """Status of the application. + + DEPLOYING: The build task is still running, or the deployments + have started deploying but aren't healthy yet. + RUNNING: All deployments are healthy. + DEPLOY_FAILED: The build task failed or one or more deployments + became unhealthy in the process of deploying + UNHEALTHY: While the application was running, one or more + deployments transition from healthy to unhealthy. + DELETING: Application and its deployments are being deleted. + """ + return self._status + + @property + def deployment_timestamp(self) -> float: + return self._deployment_timestamp + + @property + def target_deployments(self) -> List[str]: + """List of target deployment names in application.""" + if self._target_state.deployment_infos is None: + return [] + return list(self._target_state.deployment_infos.keys()) + + @property + def ingress_deployment(self) -> Optional[str]: + return self._ingress_deployment_name + + @property + def api_type(self) -> APIType: + return self._target_state.api_type + + def recover_target_state_from_checkpoint( + self, checkpoint_data: ApplicationTargetState + ): + logger.info( + f"Recovering target state for application '{self._name}' from checkpoint." + ) + self._set_target_state( + checkpoint_data.deployment_infos, + api_type=checkpoint_data.api_type, + code_version=checkpoint_data.code_version, + target_config=checkpoint_data.config, + target_capacity=checkpoint_data.target_capacity, + target_capacity_direction=checkpoint_data.target_capacity_direction, + deleting=checkpoint_data.deleting, + ) + + def _set_target_state( + self, + deployment_infos: Optional[Dict[str, DeploymentInfo]], + *, + api_type: APIType, + code_version: Optional[str], + target_config: Optional[ServeApplicationSchema], + target_capacity: Optional[float] = None, + target_capacity_direction: Optional[TargetCapacityDirection] = None, + deleting: bool = False, + ): + """Set application target state. + + While waiting for build task to finish, this should be + (None, False) + When build task has finished and during normal operation, this should be + (target_deployments, False) + When a request to delete the application has been received, this should be + ({}, True) + """ + if deleting: + self._update_status(ApplicationStatus.DELETING) + else: + self._update_status(ApplicationStatus.DEPLOYING) + + if deployment_infos is None: + self._ingress_deployment_name = None + else: + for name, info in deployment_infos.items(): + if info.ingress: + self._ingress_deployment_name = name + + target_state = ApplicationTargetState( + deployment_infos, + code_version, + target_config, + target_capacity, + target_capacity_direction, + deleting, + api_type=api_type, + ) + + # Checkpoint ahead, so that if the controller crashes before we + # write to the target state, the target state will be recovered + # after the controller recovers + self._save_checkpoint_func(writeahead_checkpoints={self._name: target_state}) + # Set target state + self._target_state = target_state + + def _set_target_state_deleting(self): + """Set target state to deleting. + + Wipes the target deployment infos, code version, and config. + """ + self._set_target_state( + deployment_infos=dict(), + api_type=self._target_state.api_type, + code_version=None, + target_config=None, + deleting=True, + ) + + def _clear_target_state_and_store_config( + self, + target_config: Optional[ServeApplicationSchema], + ): + """Clears the target state and stores the config. + + NOTE: this currently assumes that this method is *only* called when managing + apps deployed with the declarative API. + """ + self._set_target_state( + deployment_infos=None, + api_type=APIType.DECLARATIVE, + code_version=None, + target_config=target_config, + deleting=False, + ) + + def _delete_deployment(self, name): + id = DeploymentID(name=name, app_name=self._name) + self._endpoint_state.delete_endpoint(id) + self._deployment_state_manager.delete_deployment(id) + + def delete(self): + """Delete the application""" + if self._status != ApplicationStatus.DELETING: + logger.info( + f"Deleting app '{self._name}'.", + extra={"log_to_stderr": False}, + ) + self._set_target_state_deleting() + + def is_deleted(self) -> bool: + """Check whether the application is already deleted. + + For an application to be considered deleted, the target state has to be set to + deleting and all deployments have to be deleted. + """ + return self._target_state.deleting and len(self._get_live_deployments()) == 0 + + def apply_deployment_info( + self, + deployment_name: str, + deployment_info: DeploymentInfo, + ) -> None: + """Deploys a deployment in the application.""" + route_prefix = deployment_info.route_prefix + if route_prefix is not None and not route_prefix.startswith("/"): + raise RayServeException( + f'Invalid route prefix "{route_prefix}", it must start with "/"' + ) + + deployment_id = DeploymentID(name=deployment_name, app_name=self._name) + + self._deployment_state_manager.deploy(deployment_id, deployment_info) + + if deployment_info.route_prefix is not None: + config = deployment_info.deployment_config + self._endpoint_state.update_endpoint( + deployment_id, + # The current meaning of the "is_cross_language" field is ambiguous. + # We will work on optimizing and removing this field in the future. + # Instead of using the "is_cross_language" field, we will directly + # compare if the replica is Python, which will assist the Python + # router in determining if the replica invocation is a cross-language + # operation. + EndpointInfo( + route=deployment_info.route_prefix, + app_is_cross_language=config.deployment_language + != DeploymentLanguage.PYTHON, + ), + ) + else: + self._endpoint_state.delete_endpoint(deployment_id) + + def deploy_app(self, deployment_infos: Dict[str, DeploymentInfo]): + """(Re-)deploy the application from list of deployment infos. + + This function should only be called to deploy an app from an + imperative API (i.e., `serve.run` or Java API). + + Raises: RayServeException if there is more than one route prefix + or docs path. + """ + + # Check routes are unique in deployment infos + self._route_prefix, self._docs_path = self._check_routes(deployment_infos) + + self._set_target_state( + deployment_infos=deployment_infos, + api_type=APIType.IMPERATIVE, + code_version=None, + target_config=None, + target_capacity=None, + target_capacity_direction=None, + ) + + def apply_app_config( + self, + config: ServeApplicationSchema, + target_capacity: Optional[float], + target_capacity_direction: Optional[TargetCapacityDirection], + deployment_time: float, + ) -> None: + """Apply the config to the application. + + If the code version matches that of the current live deployments + then it only applies the updated config to the deployment state + manager. If the code version doesn't match, this will re-build + the application. + + This function should only be called to (re-)deploy an app from + the declarative API (i.e., through the REST API). + """ + + self._deployment_timestamp = deployment_time + + config_version = get_app_code_version(config) + if config_version == self._target_state.code_version: + try: + overrided_infos = override_deployment_info( + self._target_state.deployment_infos, + config, + ) + self._check_routes(overrided_infos) + self._set_target_state( + # Code version doesn't change. + code_version=self._target_state.code_version, + api_type=APIType.DECLARATIVE, + # Everything else must reflect the new config. + deployment_infos=overrided_infos, + target_config=config, + target_capacity=target_capacity, + target_capacity_direction=target_capacity_direction, + ) + except (TypeError, ValueError, RayServeException): + self._clear_target_state_and_store_config(config) + self._update_status( + ApplicationStatus.DEPLOY_FAILED, traceback.format_exc() + ) + except Exception: + self._clear_target_state_and_store_config(config) + self._update_status( + ApplicationStatus.DEPLOY_FAILED, + ( + f"Unexpected error occured while applying config for " + f"application '{self._name}': \n{traceback.format_exc()}" + ), + ) + else: + # If there is an in progress build task, cancel it. + if self._build_app_task_info and not self._build_app_task_info.finished: + logger.info( + f"Received new config for application '{self._name}'. " + "Cancelling previous request." + ) + ray.cancel(self._build_app_task_info.obj_ref) + + # Halt reconciliation of target deployments. A new target state + # will be set once the new app has finished building. + self._clear_target_state_and_store_config(config) + + # Record telemetry for container runtime env feature + if self._target_state.config.runtime_env.get( + "container" + ) or self._target_state.config.runtime_env.get("image_uri"): + ServeUsageTag.APP_CONTAINER_RUNTIME_ENV_USED.record("1") + + # Kick off new build app task + logger.info(f"Importing and building app '{self._name}'.") + build_app_obj_ref = build_serve_application.options( + runtime_env=config.runtime_env, + enable_task_events=RAY_SERVE_ENABLE_TASK_EVENTS, + ).remote( + config.import_path, + config_version, + config.name, + config.args, + self._logging_config, + ) + self._build_app_task_info = BuildAppTaskInfo( + obj_ref=build_app_obj_ref, + code_version=config_version, + config=config, + target_capacity=target_capacity, + target_capacity_direction=target_capacity_direction, + finished=False, + ) + + def _get_live_deployments(self) -> List[str]: + return self._deployment_state_manager.get_deployments_in_application(self._name) + + def _determine_app_status(self) -> Tuple[ApplicationStatus, str]: + """Check deployment statuses and target state, and determine the + corresponding application status. + + Returns: + Status (ApplicationStatus): + RUNNING: all deployments are healthy or autoscaling. + DEPLOYING: there is one or more updating deployments, + and there are no unhealthy deployments. + DEPLOY_FAILED: one or more deployments became unhealthy + while the application was deploying. + UNHEALTHY: one or more deployments became unhealthy + while the application was running. + DELETING: the application is being deleted. + Error message (str): + Non-empty string if status is DEPLOY_FAILED or UNHEALTHY + """ + + if self._target_state.deleting: + return ApplicationStatus.DELETING, "" + + # Get the lowest rank, i.e. highest priority, deployment status info object + # The deployment status info with highest priority determines the corresponding + # application status to set. + deployment_statuses = self.get_deployments_statuses() + lowest_rank_status = min(deployment_statuses, key=lambda info: info.rank) + if lowest_rank_status.status == DeploymentStatus.DEPLOY_FAILED: + failed_deployments = [ + s.name + for s in deployment_statuses + if s.status == DeploymentStatus.DEPLOY_FAILED + ] + return ( + ApplicationStatus.DEPLOY_FAILED, + f"Failed to update the deployments {failed_deployments}.", + ) + elif lowest_rank_status.status == DeploymentStatus.UNHEALTHY: + unhealthy_deployment_names = [ + s.name + for s in deployment_statuses + if s.status == DeploymentStatus.UNHEALTHY + ] + return ( + ApplicationStatus.UNHEALTHY, + f"The deployments {unhealthy_deployment_names} are UNHEALTHY.", + ) + elif lowest_rank_status.status == DeploymentStatus.UPDATING: + return ApplicationStatus.DEPLOYING, "" + elif ( + lowest_rank_status.status + in [DeploymentStatus.UPSCALING, DeploymentStatus.DOWNSCALING] + and lowest_rank_status.status_trigger + == DeploymentStatusTrigger.CONFIG_UPDATE_STARTED + ): + return ApplicationStatus.DEPLOYING, "" + else: + return ApplicationStatus.RUNNING, "" + + def _reconcile_build_app_task(self) -> Tuple[Optional[Dict], BuildAppStatus, str]: + """If necessary, reconcile the in-progress build task. + + Returns: + Deploy arguments (Dict[str, DeploymentInfo]): + The deploy arguments returned from the build app task + and their code version. + Status (BuildAppStatus): + NO_TASK_IN_PROGRESS: There is no build task to reconcile. + SUCCEEDED: Task finished successfully. + FAILED: An error occurred during execution of build app task + IN_PROGRESS: Task hasn't finished yet. + Error message (str): + Non-empty string if status is DEPLOY_FAILED or UNHEALTHY + """ + if self._build_app_task_info is None or self._build_app_task_info.finished: + return None, BuildAppStatus.NO_TASK_IN_PROGRESS, "" + + if not check_obj_ref_ready_nowait(self._build_app_task_info.obj_ref): + return None, BuildAppStatus.IN_PROGRESS, "" + + # Retrieve build app task result + self._build_app_task_info.finished = True + try: + args, err = ray.get(self._build_app_task_info.obj_ref) + if err is None: + logger.info(f"Imported and built app '{self._name}' successfully.") + else: + return ( + None, + BuildAppStatus.FAILED, + f"Deploying app '{self._name}' failed with exception:\n{err}", + ) + except RuntimeEnvSetupError: + error_msg = ( + f"Runtime env setup for app '{self._name}' failed:\n" + + traceback.format_exc() + ) + return None, BuildAppStatus.FAILED, error_msg + except Exception: + error_msg = ( + f"Unexpected error occured while deploying application " + f"'{self._name}': \n{traceback.format_exc()}" + ) + return None, BuildAppStatus.FAILED, error_msg + + # Convert serialized deployment args (returned by build app task) + # to deployment infos and apply option overrides from config + try: + deployment_infos = { + params["deployment_name"]: deploy_args_to_deployment_info( + **params, app_name=self._name + ) + for params in args + } + overrided_infos = override_deployment_info( + deployment_infos, self._build_app_task_info.config + ) + self._route_prefix, self._docs_path = self._check_routes(overrided_infos) + return overrided_infos, BuildAppStatus.SUCCEEDED, "" + except (TypeError, ValueError, RayServeException): + return None, BuildAppStatus.FAILED, traceback.format_exc() + except Exception: + error_msg = ( + f"Unexpected error occured while applying config for application " + f"'{self._name}': \n{traceback.format_exc()}" + ) + return None, BuildAppStatus.FAILED, error_msg + + def _check_routes( + self, deployment_infos: Dict[str, DeploymentInfo] + ) -> Tuple[str, str]: + """Check route prefixes and docs paths of deployments in app. + + There should only be one non-null route prefix. If there is one, + set it as the application route prefix. This function must be + run every control loop iteration because the target config could + be updated without kicking off a new task. + + Returns: tuple of route prefix, docs path. + Raises: RayServeException if more than one route prefix or docs + path is found among deployments. + """ + num_route_prefixes = 0 + num_docs_paths = 0 + route_prefix = None + docs_path = None + for info in deployment_infos.values(): + # Update route prefix of application, which may be updated + # through a redeployed config. + if info.route_prefix is not None: + route_prefix = info.route_prefix + num_route_prefixes += 1 + if info.docs_path is not None: + docs_path = info.docs_path + num_docs_paths += 1 + + if num_route_prefixes > 1: + raise RayServeException( + f'Found multiple route prefixes from application "{self._name}",' + " Please specify only one route prefix for the application " + "to avoid this issue." + ) + # NOTE(zcin) This will not catch multiple FastAPI deployments in the application + # if user sets the docs path to None in their FastAPI app. + if num_docs_paths > 1: + raise RayServeException( + f'Found multiple deployments in application "{self._name}" that have ' + "a docs path. This may be due to using multiple FastAPI deployments " + "in your application. Please only include one deployment with a docs " + "path in your application to avoid this issue." + ) + + return route_prefix, docs_path + + def _reconcile_target_deployments(self) -> None: + """Reconcile target deployments in application target state. + + Ensure each deployment is running on up-to-date info, and + remove outdated deployments from the application. + """ + + # Set target state for each deployment + for deployment_name, info in self._target_state.deployment_infos.items(): + deploy_info = deepcopy(info) + + # Apply the target capacity information to the deployment info. + deploy_info.set_target_capacity( + new_target_capacity=self._target_state.target_capacity, + new_target_capacity_direction=( + self._target_state.target_capacity_direction + ), + ) + + # Apply the application logging config to the deployment logging config + # if it is not set. + if ( + self._target_state.config + and self._target_state.config.logging_config + and deploy_info.deployment_config.logging_config is None + ): + deploy_info.deployment_config.logging_config = ( + self._target_state.config.logging_config + ) + self.apply_deployment_info(deployment_name, deploy_info) + + # Delete outdated deployments + for deployment_name in self._get_live_deployments(): + if deployment_name not in self.target_deployments: + self._delete_deployment(deployment_name) + + def update(self) -> bool: + """Attempts to reconcile this application to match its target state. + + Updates the application status and status message based on the + current state of the system. + + Returns: + A boolean indicating whether the application is ready to be + deleted. + """ + + infos, task_status, msg = self._reconcile_build_app_task() + if task_status == BuildAppStatus.SUCCEEDED: + self._set_target_state( + deployment_infos=infos, + code_version=self._build_app_task_info.code_version, + api_type=self._target_state.api_type, + target_config=self._build_app_task_info.config, + target_capacity=self._build_app_task_info.target_capacity, + target_capacity_direction=( + self._build_app_task_info.target_capacity_direction + ), + ) + elif task_status == BuildAppStatus.FAILED: + self._update_status(ApplicationStatus.DEPLOY_FAILED, msg) + + # Only reconcile deployments when the build app task is finished. If + # it's not finished, we don't know what the target list of deployments + # is, so we don't perform any reconciliation. + if self._target_state.deployment_infos is not None: + self._reconcile_target_deployments() + status, status_msg = self._determine_app_status() + self._update_status(status, status_msg) + + # Check if app is ready to be deleted + if self._target_state.deleting: + return self.is_deleted() + return False + + def get_checkpoint_data(self) -> ApplicationTargetState: + return self._target_state + + def get_deployments_statuses(self) -> List[DeploymentStatusInfo]: + """Return all deployment status information""" + deployments = [ + DeploymentID(name=deployment, app_name=self._name) + for deployment in self.target_deployments + ] + return self._deployment_state_manager.get_deployment_statuses(deployments) + + def get_application_status_info(self) -> ApplicationStatusInfo: + """Return the application status information""" + return ApplicationStatusInfo( + self._status, + message=self._status_msg, + deployment_timestamp=self._deployment_timestamp, + ) + + def list_deployment_details(self) -> Dict[str, DeploymentDetails]: + """Gets detailed info on all live deployments in this application. + (Does not include deleted deployments.) + + Returns: + A dictionary of deployment infos. The set of deployment info returned + may not be the full list of deployments that are part of the application. + This can happen when the application is still deploying and bringing up + deployments, or when the application is deleting and some deployments have + been deleted. + """ + details = { + deployment_name: self._deployment_state_manager.get_deployment_details( + DeploymentID(name=deployment_name, app_name=self._name) + ) + for deployment_name in self.target_deployments + } + return {k: v for k, v in details.items() if v is not None} + + def _update_status(self, status: ApplicationStatus, status_msg: str = "") -> None: + if ( + status_msg + and status + in [ + ApplicationStatus.DEPLOY_FAILED, + ApplicationStatus.UNHEALTHY, + ] + and status_msg != self._status_msg + ): + logger.error(status_msg) + + self._status = status + self._status_msg = status_msg + + +class ApplicationStateManager: + def __init__( + self, + deployment_state_manager: DeploymentStateManager, + endpoint_state: EndpointState, + kv_store: KVStoreBase, + logging_config: LoggingConfig, + ): + self._deployment_state_manager = deployment_state_manager + self._endpoint_state = endpoint_state + self._kv_store = kv_store + self._logging_config = logging_config + self._application_states: Dict[str, ApplicationState] = dict() + self._recover_from_checkpoint() + + def _recover_from_checkpoint(self): + checkpoint = self._kv_store.get(CHECKPOINT_KEY) + if checkpoint is not None: + application_state_info = cloudpickle.loads(checkpoint) + + for app_name, checkpoint_data in application_state_info.items(): + app_state = ApplicationState( + app_name, + self._deployment_state_manager, + self._endpoint_state, + self._save_checkpoint_func, + self._logging_config, + ) + app_state.recover_target_state_from_checkpoint(checkpoint_data) + self._application_states[app_name] = app_state + + def delete_app(self, name: str) -> None: + """Delete application by name""" + if name not in self._application_states: + return + self._application_states[name].delete() + + def deploy_app(self, name: str, deployment_args: List[Dict]) -> None: + """Deploy the specified app to the list of deployment arguments. + + This function should only be called if the app is being deployed + through serve.run instead of from a config. + + Args: + name: application name + deployment_args_list: arguments for deploying a list of deployments. + + Raises: + RayServeException: If the list of deployments is trying to + use a route prefix that is already used by another application + """ + + # Make sure route_prefix is not being used by other application. + live_route_prefixes: Dict[str, str] = { + app_state.route_prefix: app_name + for app_name, app_state in self._application_states.items() + if app_state.route_prefix is not None + and not app_state.status == ApplicationStatus.DELETING + and name != app_name + } + + for deploy_param in deployment_args: + deploy_app_prefix = deploy_param.get("route_prefix", None) + if deploy_app_prefix is None: + continue + + app_name = live_route_prefixes.get(deploy_app_prefix) + if app_name is not None: + raise RayServeException( + f"Prefix {deploy_app_prefix} is being used by application " + f'"{app_name}". Failed to deploy application "{name}".' + ) + + if name not in self._application_states: + self._application_states[name] = ApplicationState( + name, + self._deployment_state_manager, + self._endpoint_state, + self._save_checkpoint_func, + self._logging_config, + ) + ServeUsageTag.NUM_APPS.record(str(len(self._application_states))) + + deployment_infos = { + params["deployment_name"]: deploy_args_to_deployment_info( + **params, app_name=name + ) + for params in deployment_args + } + self._application_states[name].deploy_app(deployment_infos) + + def apply_app_configs( + self, + app_configs: List[ServeApplicationSchema], + *, + deployment_time: float = 0, + target_capacity: Optional[float] = None, + target_capacity_direction: Optional[TargetCapacityDirection] = None, + ): + """Declaratively apply the list of application configs. + + The applications will be reconciled to match the target state of the config. + + Any applications previously deployed declaratively that are *not* present in + the list will be deleted. + """ + for app_config in app_configs: + if app_config.name not in self._application_states: + logger.info(f"Deploying new app '{app_config.name}'.") + self._application_states[app_config.name] = ApplicationState( + app_config.name, + self._deployment_state_manager, + endpoint_state=self._endpoint_state, + save_checkpoint_func=self._save_checkpoint_func, + logging_config=self._logging_config, + ) + + self._application_states[app_config.name].apply_app_config( + app_config, + target_capacity, + target_capacity_direction, + deployment_time=deployment_time, + ) + + # Delete all apps that were previously deployed via the declarative API + # but are not in the config being applied. + existing_apps = { + name + for name, app_state in self._application_states.items() + if app_state.api_type == APIType.DECLARATIVE + } + apps_in_config = {app_config.name for app_config in app_configs} + for app_to_delete in existing_apps - apps_in_config: + self.delete_app(app_to_delete) + + ServeUsageTag.NUM_APPS.record(str(len(self._application_states))) + + def get_deployments(self, app_name: str) -> List[str]: + """Return all deployment names by app name""" + if app_name not in self._application_states: + return [] + return self._application_states[app_name].target_deployments + + def get_deployments_statuses(self, app_name: str) -> List[DeploymentStatusInfo]: + """Return all deployment statuses by app name""" + if app_name not in self._application_states: + return [] + return self._application_states[app_name].get_deployments_statuses() + + def get_app_status(self, name: str) -> ApplicationStatus: + if name not in self._application_states: + return ApplicationStatus.NOT_STARTED + + return self._application_states[name].status + + def get_app_status_info(self, name: str) -> ApplicationStatusInfo: + if name not in self._application_states: + return ApplicationStatusInfo( + ApplicationStatus.NOT_STARTED, + message=f"Application {name} doesn't exist", + deployment_timestamp=0, + ) + return self._application_states[name].get_application_status_info() + + def get_docs_path(self, app_name: str) -> Optional[str]: + return self._application_states[app_name].docs_path + + def get_route_prefix(self, name: str) -> Optional[str]: + return self._application_states[name].route_prefix + + def get_ingress_deployment_name(self, name: str) -> Optional[str]: + if name not in self._application_states: + return None + + return self._application_states[name].ingress_deployment + + def get_app_source(self, name: str) -> APIType: + return self._application_states[name].api_type + + def list_app_statuses(self) -> Dict[str, ApplicationStatusInfo]: + """Return a dictionary with {app name: application info}""" + return { + name: self._application_states[name].get_application_status_info() + for name in self._application_states + } + + def list_deployment_details(self, name: str) -> Dict[str, DeploymentDetails]: + """Gets detailed info on all deployments in specified application.""" + if name not in self._application_states: + return {} + return self._application_states[name].list_deployment_details() + + def update(self): + """Update each application state""" + apps_to_be_deleted = [] + for name, app in self._application_states.items(): + ready_to_be_deleted = app.update() + if ready_to_be_deleted: + apps_to_be_deleted.append(name) + logger.debug(f"Application '{name}' deleted successfully.") + + if len(apps_to_be_deleted) > 0: + for app_name in apps_to_be_deleted: + del self._application_states[app_name] + ServeUsageTag.NUM_APPS.record(str(len(self._application_states))) + + def shutdown(self) -> None: + for app_state in self._application_states.values(): + app_state.delete() + + self._kv_store.delete(CHECKPOINT_KEY) + + def is_ready_for_shutdown(self) -> bool: + """Return whether all applications have shut down. + + Iterate through all application states and check if all their applications + are deleted. + """ + return all( + app_state.is_deleted() for app_state in self._application_states.values() + ) + + def _save_checkpoint_func( + self, *, writeahead_checkpoints: Optional[Dict[str, ApplicationTargetState]] + ) -> None: + """Write a checkpoint of all application states.""" + + application_state_info = { + app_name: app_state.get_checkpoint_data() + for app_name, app_state in self._application_states.items() + } + + if writeahead_checkpoints is not None: + application_state_info.update(writeahead_checkpoints) + + self._kv_store.put( + CHECKPOINT_KEY, + cloudpickle.dumps(application_state_info), + ) + + +@ray.remote(num_cpus=0, max_calls=1) +def build_serve_application( + import_path: str, + code_version: str, + name: str, + args: Dict, + logging_config: LoggingConfig, +) -> Tuple[Optional[List[Dict]], Optional[str]]: + """Import and build a Serve application. + + Args: + import_path: import path to top-level bound deployment. + code_version: code version inferred from app config. All + deployment versions are set to this code version. + name: application name. If specified, application will be deployed + without removing existing applications. + args: Arguments to be passed to the application builder. + logging_config: the logging config for the build app task. + Returns: + Deploy arguments: a list of deployment arguments if application + was built successfully, otherwise None. + Error message: a string if an error was raised, otherwise None. + """ + configure_component_logger( + component_name="controller", + component_id=f"build_{name}_{os.getpid()}", + logging_config=logging_config, + ) + + try: + from ray.serve._private.api import call_user_app_builder_with_args_if_necessary + + # Import and build the application. + args_info_str = f" with arguments {args}" if args else "" + logger.info(f"Importing application '{name}'{args_info_str}.") + + app = call_user_app_builder_with_args_if_necessary( + import_attr(import_path), args + ) + + deploy_args_list = [] + built_app: BuiltApplication = build_app( + app, + name=name, + default_runtime_env=ray.get_runtime_context().runtime_env, + ) + for deployment in built_app.deployments: + is_ingress = deployment.name == built_app.ingress_deployment_name + deploy_args_list.append( + get_deploy_args( + name=deployment._name, + replica_config=deployment._replica_config, + ingress=is_ingress, + deployment_config=deployment._deployment_config, + version=code_version, + route_prefix="/" if is_ingress else None, + docs_path=deployment._docs_path, + ) + ) + return deploy_args_list, None + except KeyboardInterrupt: + # Error is raised when this task is canceled with ray.cancel(), which + # happens when deploy_apps() is called. + logger.info( + "Existing config deployment request terminated because of keyboard " + "interrupt." + ) + return None, None + except Exception: + logger.error( + f"Exception importing application '{name}'.\n{traceback.format_exc()}" + ) + return None, traceback.format_exc() + + +def override_deployment_info( + deployment_infos: Dict[str, DeploymentInfo], + override_config: Optional[ServeApplicationSchema], +) -> Dict[str, DeploymentInfo]: + """Override deployment infos with options from app config. + + Args: + app_name: application name + deployment_infos: deployment info loaded from code + override_config: application config deployed by user with + options to override those loaded from code. + + Returns: the updated deployment infos. + + Raises: + ValueError: If config options have invalid values. + TypeError: If config options have invalid types. + """ + + deployment_infos = deepcopy(deployment_infos) + if override_config is None: + return deployment_infos + + config_dict = override_config.dict(exclude_unset=True) + deployment_override_options = config_dict.get("deployments", []) + + # Override options for each deployment listed in the config. + for options in deployment_override_options: + if "max_ongoing_requests" in options: + options["max_ongoing_requests"] = options.get("max_ongoing_requests") + + deployment_name = options["name"] + if deployment_name not in deployment_infos: + raise ValueError( + f"Got config override for nonexistent deployment '{deployment_name}'" + ) + + info = deployment_infos[deployment_name] + original_options = info.deployment_config.dict() + original_options["user_configured_option_names"].update(set(options)) + + # Override `max_ongoing_requests` and `autoscaling_config` if + # `num_replicas="auto"` + if options.get("num_replicas") == "auto": + options["num_replicas"] = None + + new_config = AutoscalingConfig.default().dict() + # If `autoscaling_config` is specified, its values override + # the default `num_replicas="auto"` configuration + autoscaling_config = ( + options.get("autoscaling_config") + or info.deployment_config.autoscaling_config + ) + if autoscaling_config: + new_config.update(autoscaling_config) + + options["autoscaling_config"] = AutoscalingConfig(**new_config) + + ServeUsageTag.AUTO_NUM_REPLICAS_USED.record("1") + + # What to pass to info.update + override_options = dict() + + # Merge app-level and deployment-level runtime_envs. + replica_config = info.replica_config + app_runtime_env = override_config.runtime_env + if "ray_actor_options" in options: + # If specified, get ray_actor_options from config + override_actor_options = options.pop("ray_actor_options", {}) + else: + # Otherwise, get options from application code (and default to {} + # if the code sets options to None). + override_actor_options = replica_config.ray_actor_options or {} + + override_placement_group_bundles = options.pop( + "placement_group_bundles", replica_config.placement_group_bundles + ) + override_placement_group_strategy = options.pop( + "placement_group_strategy", replica_config.placement_group_strategy + ) + + override_max_replicas_per_node = options.pop( + "max_replicas_per_node", replica_config.max_replicas_per_node + ) + + # Record telemetry for container runtime env feature at deployment level + if override_actor_options.get("runtime_env") and ( + override_actor_options["runtime_env"].get("container") + or override_actor_options["runtime_env"].get("image_uri") + ): + ServeUsageTag.DEPLOYMENT_CONTAINER_RUNTIME_ENV_USED.record("1") + + merged_env = override_runtime_envs_except_env_vars( + app_runtime_env, override_actor_options.get("runtime_env", {}) + ) + override_actor_options.update({"runtime_env": merged_env}) + + replica_config.update( + ray_actor_options=override_actor_options, + placement_group_bundles=override_placement_group_bundles, + placement_group_strategy=override_placement_group_strategy, + max_replicas_per_node=override_max_replicas_per_node, + ) + override_options["replica_config"] = replica_config + + # Override deployment config options + options.pop("name", None) + original_options.update(options) + override_options["deployment_config"] = DeploymentConfig(**original_options) + deployment_infos[deployment_name] = info.update(**override_options) + + deployment_config = deployment_infos[deployment_name].deployment_config + if ( + deployment_config.autoscaling_config is not None + and deployment_config.max_ongoing_requests + < deployment_config.autoscaling_config.get_target_ongoing_requests() + ): + logger.warning( + "Autoscaling will never happen, " + "because 'max_ongoing_requests' is less than " + "'target_ongoing_requests' now." + ) + + # Overwrite ingress route prefix + app_route_prefix = config_dict.get("route_prefix", DEFAULT.VALUE) + validate_route_prefix(app_route_prefix) + for deployment in list(deployment_infos.values()): + if ( + app_route_prefix is not DEFAULT.VALUE + and deployment.route_prefix is not None + ): + deployment.route_prefix = app_route_prefix + + return deployment_infos diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/build_app.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/build_app.py new file mode 100644 index 0000000000000000000000000000000000000000..5064cd95539791762d1863f4bdb0764918742eac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/build_app.py @@ -0,0 +1,230 @@ +import logging +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar + +from ray.dag.py_obj_scanner import _PyObjScanner +from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve.deployment import Application, Deployment +from ray.serve.handle import DeploymentHandle + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +K = TypeVar("K") +V = TypeVar("V") + + +class IDDict(dict, Generic[K, V]): + """Dictionary that uses id() for keys instead of hash(). + + This is necessary because Application objects aren't hashable and we want each + instance to map to a unique key. + """ + + def __getitem__(self, key: K) -> V: + if not isinstance(key, int): + key = id(key) + return super().__getitem__(key) + + def __setitem__(self, key: K, value: V): + if not isinstance(key, int): + key = id(key) + return super().__setitem__(key, value) + + def __delitem__(self, key: K): + if not isinstance(key, int): + key = id(key) + return super().__delitem__(key) + + def __contains__(self, key: object): + if not isinstance(key, int): + key = id(key) + return super().__contains__(key) + + +@dataclass(frozen=True) +class BuiltApplication: + # Name of the application. + name: str + # Name of the application's 'ingress' deployment + # (the one exposed over gRPC/HTTP/handle). + ingress_deployment_name: str + # List of unique deployments comprising the app. + deployments: List[Deployment] + # Dict[name, DeploymentHandle] mapping deployment names to the handles that replaced + # them in other deployments' init args/kwargs. + deployment_handles: Dict[str, DeploymentHandle] + + +def _make_deployment_handle_default( + deployment: Deployment, app_name: str +) -> DeploymentHandle: + return DeploymentHandle( + deployment.name, + app_name=app_name, + ) + + +def build_app( + app: Application, + *, + name: str, + default_runtime_env: Optional[Dict[str, Any]] = None, + make_deployment_handle: Optional[ + Callable[[Deployment, str], DeploymentHandle] + ] = None, +) -> BuiltApplication: + """Builds the application into a list of finalized deployments. + + The following transformations are made: + - Application objects in constructor args/kwargs are converted to + DeploymentHandles for injection at runtime. + - Name conflicts from deployments that use the same class are handled + by appending a monotonically increasing suffix (e.g., SomeClass_1). + + Returns: BuiltApplication + """ + if make_deployment_handle is None: + make_deployment_handle = _make_deployment_handle_default + + handles = IDDict() + deployment_names = IDDict() + deployments = _build_app_recursive( + app, + app_name=name, + handles=handles, + deployment_names=deployment_names, + default_runtime_env=default_runtime_env, + make_deployment_handle=make_deployment_handle, + ) + return BuiltApplication( + name=name, + ingress_deployment_name=app._bound_deployment.name, + deployments=deployments, + deployment_handles={ + deployment_names[app]: handle for app, handle in handles.items() + }, + ) + + +def _build_app_recursive( + app: Application, + *, + app_name: str, + deployment_names: IDDict[Application, str], + handles: IDDict[Application, DeploymentHandle], + default_runtime_env: Optional[Dict[str, Any]] = None, + make_deployment_handle: Callable[[Deployment, str], DeploymentHandle], +) -> List[Deployment]: + """Recursively traverses the graph of Application objects. + + Each Application will have an associated DeploymentHandle created that will replace + it in any occurrences in other Applications' args or kwargs. + + Also collects a list of the unique Applications encountered and returns them as + deployable Deployment objects. + """ + # This application has already been encountered. + # There's no need to recurse into its child args and we don't want to create + # a duplicate entry for it in the list of deployments. + if app in handles: + return [] + + deployments = [] + scanner = _PyObjScanner(source_type=Application) + try: + # Recursively traverse any Application objects bound to init args/kwargs. + child_apps = scanner.find_nodes( + (app._bound_deployment.init_args, app._bound_deployment.init_kwargs) + ) + for child_app in child_apps: + deployments.extend( + _build_app_recursive( + child_app, + app_name=app_name, + handles=handles, + deployment_names=deployment_names, + make_deployment_handle=make_deployment_handle, + default_runtime_env=default_runtime_env, + ) + ) + + # Replace Application objects with their corresponding DeploymentHandles. + new_init_args, new_init_kwargs = scanner.replace_nodes(handles) + final_deployment = app._bound_deployment.options( + name=_get_unique_deployment_name_memoized(app, deployment_names), + _init_args=new_init_args, + _init_kwargs=new_init_kwargs, + ) + final_deployment = _set_default_runtime_env( + final_deployment, default_runtime_env + ) + + # Create the DeploymentHandle that will be used to replace this application + # in the arguments of its parent(s). + handles[app] = make_deployment_handle( + final_deployment, + app_name, + ) + + return deployments + [final_deployment] + finally: + scanner.clear() + + +def _set_default_runtime_env( + d: Deployment, default_runtime_env: Optional[Dict[str, Any]] +) -> Deployment: + """Configures the deployment with the provided default runtime_env. + + If the deployment does not have a runtime_env configured, the default will be set. + + If it does have a runtime_env configured but that runtime_env does not have a + working_dir, only the working_dir field will be set. + + Else the deployment's runtime_env will be left untouched. + """ + if not default_runtime_env: + return d + + ray_actor_options = deepcopy(d.ray_actor_options or {}) + default_working_dir = default_runtime_env.get("working_dir", None) + if "runtime_env" not in ray_actor_options: + ray_actor_options["runtime_env"] = default_runtime_env + elif default_working_dir is not None: + ray_actor_options["runtime_env"].setdefault("working_dir", default_working_dir) + + return d.options(ray_actor_options=ray_actor_options) + + +def _get_unique_deployment_name_memoized( + app: Application, deployment_names: IDDict[Application, str] +) -> str: + """Generates a name for the deployment. + + This is used to handle collisions when the user does not specify a name + explicitly, so typically we'd use the class name as the default. + + In that case, we append a monotonically increasing suffix to the name, e.g., + Deployment, then Deployment_1, then Deployment_2, ... + + Names are memoized in the `deployment_names` dict, which should be passed to + subsequent calls to this function. + """ + if app in deployment_names: + return deployment_names[app] + + idx = 1 + name = app._bound_deployment.name + while name in deployment_names.values(): + name = f"{app._bound_deployment.name}_{idx}" + idx += 1 + + if idx != 1: + logger.warning( + "There are multiple deployments with the same name " + f"'{app._bound_deployment.name}'. Renaming one to '{name}'." + ) + + deployment_names[app] = name + return name diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/cluster_node_info_cache.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/cluster_node_info_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7e09ddbab9c08818249a7c19ee6315da76e61d9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/cluster_node_info_cache.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Set, Tuple, Union + +import ray +from ray._raylet import GcsClient +from ray.serve._private.constants import RAY_GCS_RPC_TIMEOUT_S + + +class ClusterNodeInfoCache(ABC): + """Provide access to cached node information in the cluster.""" + + def __init__(self, gcs_client: GcsClient): + self._gcs_client = gcs_client + self._cached_alive_nodes = None + self._cached_node_labels = dict() + self._cached_total_resources_per_node = dict() + self._cached_available_resources_per_node = dict() + + def update(self): + """Update the cache by fetching latest node information from GCS. + + This should be called once in each update cycle. + Within an update cycle, everyone will see the same + cached node info avoiding any potential issues + caused by inconsistent node info seen by different components. + """ + nodes = self._gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S) + alive_nodes = [ + (node_id.hex(), node.node_name) + for (node_id, node) in nodes.items() + if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE + ] + + # Sort on NodeID to ensure the ordering is deterministic across the cluster. + sorted(alive_nodes) + self._cached_alive_nodes = alive_nodes + self._cached_node_labels = { + node_id.hex(): dict(node.labels) for (node_id, node) in nodes.items() + } + + # Node resources + self._cached_total_resources_per_node = { + node_id.hex(): dict(node.resources_total) + for (node_id, node) in nodes.items() + } + + self._cached_available_resources_per_node = ( + ray._private.state.available_resources_per_node() + ) + + def get_alive_nodes(self) -> List[Tuple[str, str]]: + """Get IDs and info for all live nodes in the cluster. + + Returns a list of (node_id: str, node_info: Dict). The node_id can be + passed into the Ray SchedulingPolicy API. + """ + return self._cached_alive_nodes + + def get_total_resources_per_node(self) -> Dict[str, Dict]: + """Get total resources for alive nodes.""" + return self._cached_total_resources_per_node + + def get_alive_node_ids(self) -> Set[str]: + """Get IDs of all live nodes in the cluster.""" + return {node_id for node_id, _ in self.get_alive_nodes()} + + @abstractmethod + def get_draining_nodes(self) -> Dict[str, int]: + """Get draining nodes in the cluster and their deadlines.""" + raise NotImplementedError + + @abstractmethod + def get_node_az(self, node_id: str) -> Optional[str]: + """Get availability zone of a node.""" + raise NotImplementedError + + def get_active_node_ids(self) -> Set[str]: + """Get IDs of all active nodes in the cluster. + + A node is active if it's schedulable for new tasks and actors. + """ + return self.get_alive_node_ids() - set(self.get_draining_nodes()) + + def get_available_resources_per_node(self) -> Dict[str, Union[float, Dict]]: + """Get available resources per node. + + Returns a map from (node_id -> Dict of resources). + """ + + return self._cached_available_resources_per_node + + +class DefaultClusterNodeInfoCache(ClusterNodeInfoCache): + def __init__(self, gcs_client: GcsClient): + super().__init__(gcs_client) + + def get_draining_nodes(self) -> Dict[str, int]: + return dict() + + def get_node_az(self, node_id: str) -> Optional[str]: + """Get availability zone of a node.""" + return None diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/config.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/config.py new file mode 100644 index 0000000000000000000000000000000000000000..37a6b3662ab723a44b426f0f9784f84a8823cc62 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/config.py @@ -0,0 +1,722 @@ +import inspect +import json +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message + +from ray import cloudpickle +from ray._private import ray_option_utils +from ray._private.pydantic_compat import ( + BaseModel, + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + validator, +) +from ray._private.serialization import pickle_dumps +from ray._private.utils import resources_from_ray_options +from ray.serve._private.constants import ( + DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_S, + DEFAULT_GRACEFUL_SHUTDOWN_WAIT_LOOP_S, + DEFAULT_HEALTH_CHECK_PERIOD_S, + DEFAULT_HEALTH_CHECK_TIMEOUT_S, + DEFAULT_MAX_ONGOING_REQUESTS, + MAX_REPLICAS_PER_NODE_MAX_VALUE, +) +from ray.serve._private.utils import DEFAULT, DeploymentOptionUpdateType +from ray.serve.config import AutoscalingConfig +from ray.serve.generated.serve_pb2 import AutoscalingConfig as AutoscalingConfigProto +from ray.serve.generated.serve_pb2 import DeploymentConfig as DeploymentConfigProto +from ray.serve.generated.serve_pb2 import DeploymentLanguage +from ray.serve.generated.serve_pb2 import EncodingType as EncodingTypeProto +from ray.serve.generated.serve_pb2 import LoggingConfig as LoggingConfigProto +from ray.serve.generated.serve_pb2 import ReplicaConfig as ReplicaConfigProto +from ray.util.placement_group import validate_placement_group + + +def _needs_pickle(deployment_language: DeploymentLanguage, is_cross_language: bool): + """From Serve client API's perspective, decide whether pickling is needed.""" + if deployment_language == DeploymentLanguage.PYTHON and not is_cross_language: + # Python client deploying Python replicas. + return True + elif deployment_language == DeploymentLanguage.JAVA and is_cross_language: + # Python client deploying Java replicas, + # using xlang serialization via cloudpickle. + return True + else: + return False + + +def _proto_to_dict(proto: Message) -> Dict: + """Recursively convert a protobuf into a Python dictionary. + + This is an alternative to protobuf's `MessageToDict`. Unlike + `MessageToDict`, this function doesn't add an extra base64 + encoding to bytes when constructing a json response. + """ + data = {} + # Fill data with non-empty fields. + for field, value in proto.ListFields(): + # Recursively call if the field is another protobuf. + if field.type == FieldDescriptor.TYPE_MESSAGE: + data[field.name] = _proto_to_dict(value) + else: + data[field.name] = value + + # Fill data default values. + for field in proto.DESCRIPTOR.fields: + if ( + field.name not in data # skip the fields that are already set + and field.type != FieldDescriptor.TYPE_MESSAGE # skip nested messages + and not field.containing_oneof # skip optional fields + ): + data[field.name] = field.default_value + + return data + + +class DeploymentConfig(BaseModel): + """Internal datastructure wrapping config options for a deployment. + + Args: + num_replicas: The number of processes to start up that + handles requests to this deployment. Defaults to 1. + max_ongoing_requests: The maximum number of queries + that is sent to a replica of this deployment without receiving + a response. Defaults to 5. + max_queued_requests: Maximum number of requests to this deployment that will be + queued at each *caller* (proxy or DeploymentHandle). Once this limit is + reached, subsequent requests will raise a BackPressureError (for handles) or + return an HTTP 503 status code (for HTTP requests). Defaults to -1 (no + limit). + user_config: Arguments to pass to the reconfigure + method of the deployment. The reconfigure method is called if + user_config is not None. Must be JSON-serializable. + graceful_shutdown_wait_loop_s: Duration + that deployment replicas wait until there is no more work to + be done before shutting down. + graceful_shutdown_timeout_s: Controller waits for this duration + to forcefully kill the replica for shutdown. + health_check_period_s: Frequency at which the controller health + checks replicas. + health_check_timeout_s: Timeout that the controller waits for a + response from the replica's health check before marking it + unhealthy. + autoscaling_config: Autoscaling configuration. + logging_config: Configuration for deployment logs. + user_configured_option_names: The names of options manually + configured by the user. + """ + + num_replicas: Optional[NonNegativeInt] = Field( + default=1, update_type=DeploymentOptionUpdateType.LightWeight + ) + max_ongoing_requests: PositiveInt = Field( + default=DEFAULT_MAX_ONGOING_REQUESTS, + update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, + ) + max_queued_requests: int = Field( + default=-1, + update_type=DeploymentOptionUpdateType.LightWeight, + ) + user_config: Any = Field( + default=None, update_type=DeploymentOptionUpdateType.NeedsActorReconfigure + ) + + graceful_shutdown_timeout_s: NonNegativeFloat = Field( + default=DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_S, + update_type=DeploymentOptionUpdateType.NeedsReconfigure, + ) + graceful_shutdown_wait_loop_s: NonNegativeFloat = Field( + default=DEFAULT_GRACEFUL_SHUTDOWN_WAIT_LOOP_S, + update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, + ) + + health_check_period_s: PositiveFloat = Field( + default=DEFAULT_HEALTH_CHECK_PERIOD_S, + update_type=DeploymentOptionUpdateType.NeedsReconfigure, + ) + health_check_timeout_s: PositiveFloat = Field( + default=DEFAULT_HEALTH_CHECK_TIMEOUT_S, + update_type=DeploymentOptionUpdateType.NeedsReconfigure, + ) + + autoscaling_config: Optional[AutoscalingConfig] = Field( + default=None, update_type=DeploymentOptionUpdateType.NeedsActorReconfigure + ) + + # This flag is used to let replica know they are deployed from + # a different language. + is_cross_language: bool = False + + # This flag is used to let controller know which language does + # the deploymnent use. + deployment_language: Any = DeploymentLanguage.PYTHON + + version: Optional[str] = Field( + default=None, + update_type=DeploymentOptionUpdateType.HeavyWeight, + ) + + logging_config: Optional[dict] = Field( + default=None, + update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, + ) + + # Contains the names of deployment options manually set by the user + user_configured_option_names: Set[str] = set() + + class Config: + validate_assignment = True + arbitrary_types_allowed = True + + @validator("user_config", always=True) + def user_config_json_serializable(cls, v): + if isinstance(v, bytes): + return v + if v is not None: + try: + json.dumps(v) + except TypeError as e: + raise ValueError(f"user_config is not JSON-serializable: {str(e)}.") + + return v + + @validator("logging_config", always=True) + def logging_config_valid(cls, v): + if v is None: + return v + if not isinstance(v, dict): + raise TypeError( + f"Got invalid type '{type(v)}' for logging_config. " + "Expected a dictionary." + ) + # Handle default value + from ray.serve.schema import LoggingConfig + + v = LoggingConfig(**v).dict() + + return v + + @validator("max_queued_requests", always=True) + def validate_max_queued_requests(cls, v): + if not isinstance(v, int): + raise TypeError("max_queued_requests must be an integer.") + + if v < 1 and v != -1: + raise ValueError( + "max_queued_requests must be -1 (no limit) or a positive integer." + ) + + return v + + def needs_pickle(self): + return _needs_pickle(self.deployment_language, self.is_cross_language) + + def to_proto(self): + data = self.dict() + if data.get("user_config") is not None: + if self.needs_pickle(): + data["user_config"] = cloudpickle.dumps(data["user_config"]) + if data.get("autoscaling_config"): + data["autoscaling_config"] = AutoscalingConfigProto( + **data["autoscaling_config"] + ) + if data.get("logging_config"): + if "encoding" in data["logging_config"]: + data["logging_config"]["encoding"] = EncodingTypeProto.Value( + data["logging_config"]["encoding"] + ) + + data["logging_config"] = LoggingConfigProto(**data["logging_config"]) + data["user_configured_option_names"] = list( + data["user_configured_option_names"] + ) + return DeploymentConfigProto(**data) + + def to_proto_bytes(self): + return self.to_proto().SerializeToString() + + @classmethod + def from_proto(cls, proto: DeploymentConfigProto): + data = _proto_to_dict(proto) + if "user_config" in data: + if data["user_config"] != b"": + deployment_language = ( + data["deployment_language"] + if "deployment_language" in data + else DeploymentLanguage.PYTHON + ) + is_cross_language = ( + data["is_cross_language"] if "is_cross_language" in data else False + ) + needs_pickle = _needs_pickle(deployment_language, is_cross_language) + if needs_pickle: + data["user_config"] = cloudpickle.loads(proto.user_config) + else: + data["user_config"] = proto.user_config + else: + data["user_config"] = None + if "autoscaling_config" in data: + if not data["autoscaling_config"].get("upscale_smoothing_factor"): + data["autoscaling_config"]["upscale_smoothing_factor"] = None + if not data["autoscaling_config"].get("downscale_smoothing_factor"): + data["autoscaling_config"]["downscale_smoothing_factor"] = None + if not data["autoscaling_config"].get("upscaling_factor"): + data["autoscaling_config"]["upscaling_factor"] = None + if not data["autoscaling_config"].get("downscaling_factor"): + data["autoscaling_config"]["downscaling_factor"] = None + if not data["autoscaling_config"].get("target_ongoing_requests"): + data["autoscaling_config"]["target_ongoing_requests"] = None + data["autoscaling_config"] = AutoscalingConfig(**data["autoscaling_config"]) + if "version" in data: + if data["version"] == "": + data["version"] = None + if "user_configured_option_names" in data: + data["user_configured_option_names"] = set( + data["user_configured_option_names"] + ) + if "logging_config" in data: + if "encoding" in data["logging_config"]: + data["logging_config"]["encoding"] = EncodingTypeProto.Name( + data["logging_config"]["encoding"] + ) + + return cls(**data) + + @classmethod + def from_proto_bytes(cls, proto_bytes: bytes): + proto = DeploymentConfigProto.FromString(proto_bytes) + return cls.from_proto(proto) + + @classmethod + def from_default(cls, **kwargs): + """Creates a default DeploymentConfig and overrides it with kwargs. + + Ignores any kwargs set to DEFAULT.VALUE. + + Raises: + TypeError: when a keyword that's not an argument to the class is + passed in. + """ + + config = cls() + valid_config_options = set(config.dict().keys()) + + # Friendly error if a non-DeploymentConfig kwarg was passed in + for key, val in kwargs.items(): + if key not in valid_config_options: + raise TypeError( + f'Got invalid Deployment config option "{key}" ' + f"(with value {val}) as keyword argument. All Deployment " + "config options must come from this list: " + f"{list(valid_config_options)}." + ) + + kwargs = {key: val for key, val in kwargs.items() if val != DEFAULT.VALUE} + + for key, val in kwargs.items(): + config.__setattr__(key, val) + + return config + + +def handle_num_replicas_auto( + max_ongoing_requests: Union[int, DEFAULT], + autoscaling_config: Optional[Union[Dict, AutoscalingConfig, DEFAULT]], +): + """Return modified `max_ongoing_requests` and `autoscaling_config` + for when num_replicas="auto". + + If `autoscaling_config` is unspecified, returns the modified value + AutoscalingConfig.default(). + If it is specified, the specified fields in `autoscaling_config` + override that of AutoscalingConfig.default(). + """ + + if autoscaling_config in [DEFAULT.VALUE, None]: + # If autoscaling config wasn't specified, use default + # configuration + autoscaling_config = AutoscalingConfig.default() + else: + # If autoscaling config was specified, values specified in + # autoscaling config overrides the default configuration + default_config = AutoscalingConfig.default().dict(exclude_unset=True) + autoscaling_config = ( + autoscaling_config + if isinstance(autoscaling_config, dict) + else autoscaling_config.dict(exclude_unset=True) + ) + default_config.update(autoscaling_config) + autoscaling_config = AutoscalingConfig(**default_config) + + return max_ongoing_requests, autoscaling_config + + +class ReplicaConfig: + """Internal datastructure wrapping config options for a deployment's replicas. + + Provides five main properties (see property docstrings for more info): + deployment_def: the code, or a reference to the code, that this + replica should run. + init_args: the deployment_def's init_args. + init_kwargs: the deployment_def's init_kwargs. + ray_actor_options: the Ray actor options to pass into the replica's + actor. + resource_dict: contains info on this replica's actor's resource needs. + + Offers a serialized equivalent (e.g. serialized_deployment_def) for + deployment_def, init_args, and init_kwargs. Deserializes these properties + when they're first accessed, if they were not passed in directly through + create(). + + Use the classmethod create() to make a ReplicaConfig with the deserialized + properties. + + Note: overwriting or setting any property after the ReplicaConfig has been + constructed is currently undefined behavior. The config's fields should not + be modified externally after it is created. + """ + + def __init__( + self, + deployment_def_name: str, + serialized_deployment_def: bytes, + serialized_init_args: bytes, + serialized_init_kwargs: bytes, + ray_actor_options: Dict, + placement_group_bundles: Optional[List[Dict[str, float]]] = None, + placement_group_strategy: Optional[str] = None, + max_replicas_per_node: Optional[int] = None, + needs_pickle: bool = True, + ): + """Construct a ReplicaConfig with serialized properties. + + All parameters are required. See classmethod create() for defaults. + """ + self.deployment_def_name = deployment_def_name + + # Store serialized versions of code properties. + self.serialized_deployment_def = serialized_deployment_def + self.serialized_init_args = serialized_init_args + self.serialized_init_kwargs = serialized_init_kwargs + + # Deserialize properties when first accessed. See @property methods. + self._deployment_def = None + self._init_args = None + self._init_kwargs = None + + # Configure ray_actor_options. These are the Ray options ultimately + # passed into the replica's actor when it's created. + self.ray_actor_options = ray_actor_options + + self.placement_group_bundles = placement_group_bundles + self.placement_group_strategy = placement_group_strategy + + self.max_replicas_per_node = max_replicas_per_node + + self._validate() + + # Create resource_dict. This contains info about the replica's resource + # needs. It does NOT set the replica's resource usage. That's done by + # the ray_actor_options. + self.resource_dict = resources_from_ray_options(self.ray_actor_options) + self.needs_pickle = needs_pickle + + def _validate(self): + self._validate_ray_actor_options() + self._validate_placement_group_options() + self._validate_max_replicas_per_node() + + if ( + self.max_replicas_per_node is not None + and self.placement_group_bundles is not None + ): + raise ValueError( + "Setting max_replicas_per_node is not allowed when " + "placement_group_bundles is provided." + ) + + def update( + self, + ray_actor_options: dict, + placement_group_bundles: Optional[List[Dict[str, float]]] = None, + placement_group_strategy: Optional[str] = None, + max_replicas_per_node: Optional[int] = None, + ): + self.ray_actor_options = ray_actor_options + + self.placement_group_bundles = placement_group_bundles + self.placement_group_strategy = placement_group_strategy + + self.max_replicas_per_node = max_replicas_per_node + + self._validate() + + self.resource_dict = resources_from_ray_options(self.ray_actor_options) + + @classmethod + def create( + cls, + deployment_def: Union[Callable, str], + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + ray_actor_options: Optional[Dict] = None, + placement_group_bundles: Optional[List[Dict[str, float]]] = None, + placement_group_strategy: Optional[str] = None, + max_replicas_per_node: Optional[int] = None, + deployment_def_name: Optional[str] = None, + ): + """Create a ReplicaConfig from deserialized parameters.""" + + if not callable(deployment_def) and not isinstance(deployment_def, str): + raise TypeError("@serve.deployment must be called on a class or function.") + + if not (init_args is None or isinstance(init_args, (tuple, list))): + raise TypeError("init_args must be a tuple.") + + if not (init_kwargs is None or isinstance(init_kwargs, dict)): + raise TypeError("init_kwargs must be a dict.") + + if inspect.isfunction(deployment_def): + if init_args: + raise ValueError("init_args not supported for function deployments.") + elif init_kwargs: + raise ValueError("init_kwargs not supported for function deployments.") + + if not isinstance(deployment_def, (Callable, str)): + raise TypeError( + f'Got invalid type "{type(deployment_def)}" for ' + "deployment_def. Expected deployment_def to be a " + "class, function, or string." + ) + # Set defaults + if init_args is None: + init_args = () + if init_kwargs is None: + init_kwargs = {} + if ray_actor_options is None: + ray_actor_options = {} + if deployment_def_name is None: + if isinstance(deployment_def, str): + deployment_def_name = deployment_def + else: + deployment_def_name = deployment_def.__name__ + + config = cls( + deployment_def_name, + pickle_dumps( + deployment_def, + f"Could not serialize the deployment {repr(deployment_def)}", + ), + pickle_dumps(init_args, "Could not serialize the deployment init args"), + pickle_dumps(init_kwargs, "Could not serialize the deployment init kwargs"), + ray_actor_options, + placement_group_bundles, + placement_group_strategy, + max_replicas_per_node, + ) + + config._deployment_def = deployment_def + config._init_args = init_args + config._init_kwargs = init_kwargs + + return config + + def _validate_ray_actor_options(self): + if not isinstance(self.ray_actor_options, dict): + raise TypeError( + f'Got invalid type "{type(self.ray_actor_options)}" for ' + "ray_actor_options. Expected a dictionary." + ) + # Please keep this in sync with the docstring for the ray_actor_options + # kwarg in api.py. + allowed_ray_actor_options = { + # Resource options + "accelerator_type", + "memory", + "num_cpus", + "num_gpus", + "resources", + # Other options + "runtime_env", + } + + for option in self.ray_actor_options: + if option not in allowed_ray_actor_options: + raise ValueError( + f"Specifying '{option}' in ray_actor_options is not allowed. " + f"Allowed options: {allowed_ray_actor_options}" + ) + ray_option_utils.validate_actor_options(self.ray_actor_options, in_options=True) + + # Set Serve replica defaults + if self.ray_actor_options.get("num_cpus") is None: + self.ray_actor_options["num_cpus"] = 1 + + def _validate_max_replicas_per_node(self) -> None: + if self.max_replicas_per_node is None: + return + if not isinstance(self.max_replicas_per_node, int): + raise TypeError( + f"Get invalid type '{type(self.max_replicas_per_node)}' for " + "max_replicas_per_node. Expected None or an integer " + f"in the range of [1, {MAX_REPLICAS_PER_NODE_MAX_VALUE}]." + ) + if ( + self.max_replicas_per_node < 1 + or self.max_replicas_per_node > MAX_REPLICAS_PER_NODE_MAX_VALUE + ): + raise ValueError( + f"Invalid max_replicas_per_node {self.max_replicas_per_node}. " + "Valid values are None or an integer " + f"in the range of [1, {MAX_REPLICAS_PER_NODE_MAX_VALUE}]." + ) + + def _validate_placement_group_options(self) -> None: + if self.placement_group_strategy is not None: + if self.placement_group_bundles is None: + raise ValueError( + "If `placement_group_strategy` is provided, " + "`placement_group_bundles` must also be provided." + ) + + if self.placement_group_bundles is not None: + validate_placement_group( + bundles=self.placement_group_bundles, + strategy=self.placement_group_strategy or "PACK", + lifetime="detached", + ) + + resource_error_prefix = ( + "When using `placement_group_bundles`, the replica actor " + "will be placed in the first bundle, so the resource " + "requirements for the actor must be a subset of the first " + "bundle." + ) + + first_bundle = self.placement_group_bundles[0] + + # Validate that the replica actor fits in the first bundle. + bundle_cpu = first_bundle.get("CPU", 0) + replica_actor_num_cpus = self.ray_actor_options.get("num_cpus", 0) + if bundle_cpu < replica_actor_num_cpus: + raise ValueError( + f"{resource_error_prefix} `num_cpus` for the actor is " + f"{replica_actor_num_cpus}, but the bundle only has " + f"{bundle_cpu} `CPU` specified." + ) + + bundle_gpu = first_bundle.get("GPU", 0) + replica_actor_num_gpus = self.ray_actor_options.get("num_gpus", 0) + if bundle_gpu < replica_actor_num_gpus: + raise ValueError( + f"{resource_error_prefix} `num_gpus` for the actor is " + f"{replica_actor_num_gpus}, but the bundle only has " + f"{bundle_gpu} `GPU` specified." + ) + + replica_actor_resources = self.ray_actor_options.get("resources", {}) + for actor_resource, actor_value in replica_actor_resources.items(): + bundle_value = first_bundle.get(actor_resource, 0) + if bundle_value < actor_value: + raise ValueError( + f"{resource_error_prefix} `{actor_resource}` requirement " + f"for the actor is {actor_value}, but the bundle only " + f"has {bundle_value} `{actor_resource}` specified." + ) + + @property + def deployment_def(self) -> Union[Callable, str]: + """The code, or a reference to the code, that this replica runs. + + For Python replicas, this can be one of the following: + - Function (Callable) + - Class (Callable) + - Import path (str) + + For Java replicas, this can be one of the following: + - Class path (str) + """ + if self._deployment_def is None: + if self.needs_pickle: + self._deployment_def = cloudpickle.loads(self.serialized_deployment_def) + else: + self._deployment_def = self.serialized_deployment_def.decode( + encoding="utf-8" + ) + + return self._deployment_def + + @property + def init_args(self) -> Optional[Union[Tuple[Any], bytes]]: + """The init_args for a Python class. + + This property is only meaningful if deployment_def is a Python class. + Otherwise, it is None. + """ + if self._init_args is None: + if self.needs_pickle: + self._init_args = cloudpickle.loads(self.serialized_init_args) + else: + self._init_args = self.serialized_init_args + + return self._init_args + + @property + def init_kwargs(self) -> Optional[Tuple[Any]]: + """The init_kwargs for a Python class. + + This property is only meaningful if deployment_def is a Python class. + Otherwise, it is None. + """ + + if self._init_kwargs is None: + self._init_kwargs = cloudpickle.loads(self.serialized_init_kwargs) + + return self._init_kwargs + + @classmethod + def from_proto(cls, proto: ReplicaConfigProto, needs_pickle: bool = True): + return ReplicaConfig( + proto.deployment_def_name, + proto.deployment_def, + proto.init_args if proto.init_args != b"" else None, + proto.init_kwargs if proto.init_kwargs != b"" else None, + json.loads(proto.ray_actor_options), + json.loads(proto.placement_group_bundles) + if proto.placement_group_bundles + else None, + proto.placement_group_strategy + if proto.placement_group_strategy != "" + else None, + proto.max_replicas_per_node if proto.max_replicas_per_node else None, + needs_pickle, + ) + + @classmethod + def from_proto_bytes(cls, proto_bytes: bytes, needs_pickle: bool = True): + proto = ReplicaConfigProto.FromString(proto_bytes) + return cls.from_proto(proto, needs_pickle) + + def to_proto(self): + return ReplicaConfigProto( + deployment_def_name=self.deployment_def_name, + deployment_def=self.serialized_deployment_def, + init_args=self.serialized_init_args, + init_kwargs=self.serialized_init_kwargs, + ray_actor_options=json.dumps(self.ray_actor_options), + placement_group_bundles=json.dumps(self.placement_group_bundles) + if self.placement_group_bundles is not None + else "", + placement_group_strategy=self.placement_group_strategy, + max_replicas_per_node=self.max_replicas_per_node + if self.max_replicas_per_node is not None + else 0, + ) + + def to_proto_bytes(self): + return self.to_proto().SerializeToString() diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/constants.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7c292a5d9450bc6c3214ed4dad14af639ab83f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/constants.py @@ -0,0 +1,390 @@ +import os + +#: Logger used by serve components +SERVE_LOGGER_NAME = "ray.serve" + +#: Actor name used to register controller +SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR" + +#: Actor name used to register HTTP proxy actor +SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR" + +#: Ray namespace used for all Serve actors +SERVE_NAMESPACE = "serve" + +#: HTTP Host +DEFAULT_HTTP_HOST = os.environ.get("RAY_SERVE_DEFAULT_HTTP_HOST", "127.0.0.1") + +#: HTTP Port +DEFAULT_HTTP_PORT = int(os.environ.get("RAY_SERVE_DEFAULT_HTTP_PORT", 8000)) + +#: Uvicorn timeout_keep_alive Config +DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S = 5 + +#: gRPC Port +DEFAULT_GRPC_PORT = int(os.environ.get("RAY_SERVE_DEFAULT_GRPC_PORT", 9000)) + +#: Default Serve application name +SERVE_DEFAULT_APP_NAME = "default" + +#: Max concurrency +ASYNC_CONCURRENCY = int(1e6) + +# How long to sleep between control loop cycles on the controller. +CONTROL_LOOP_INTERVAL_S = float(os.getenv("RAY_SERVE_CONTROL_LOOP_INTERVAL_S", 0.1)) +assert CONTROL_LOOP_INTERVAL_S >= 0, ( + f"Got unexpected value {CONTROL_LOOP_INTERVAL_S} for " + "RAY_SERVE_CONTROL_LOOP_INTERVAL_S environment variable. " + "RAY_SERVE_CONTROL_LOOP_INTERVAL_S cannot be negative." +) + +#: Max time to wait for HTTP proxy in `serve.start()`. +HTTP_PROXY_TIMEOUT = 60 + +#: Max retry count for allowing failures in replica constructor. +#: If no replicas at target version is running by the time we're at +#: max construtor retry count, deploy() is considered failed. +#: By default we set threshold as min(num_replicas * 3, this value) +MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT = int( + os.environ.get("MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT", "20") +) + +#: Default histogram buckets for latency tracker. +DEFAULT_LATENCY_BUCKET_MS = [ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 300, + 400, + 500, + 1000, + 2000, + # 5 seconds + 5000, + # 10 seconds + 10000, + # 60 seconds + 60000, + # 2min + 120000, + # 5 min + 300000, + # 10 min + 600000, +] + +#: Name of deployment health check method implemented by user. +HEALTH_CHECK_METHOD = "check_health" + +#: Name of deployment reconfiguration method implemented by user. +RECONFIGURE_METHOD = "reconfigure" + +SERVE_ROOT_URL_ENV_KEY = "RAY_SERVE_ROOT_URL" + +#: Limit the number of cached handles because each handle has long poll +#: overhead. See https://github.com/ray-project/ray/issues/18980 +MAX_CACHED_HANDLES = int(os.getenv("MAX_CACHED_HANDLES", 100)) +assert MAX_CACHED_HANDLES > 0, ( + f"Got unexpected value {MAX_CACHED_HANDLES} for " + "MAX_CACHED_HANDLES environment variable. " + "MAX_CACHED_HANDLES must be positive." +) + +#: Because ServeController will accept one long poll request per handle, its +#: concurrency needs to scale as O(num_handles) +CONTROLLER_MAX_CONCURRENCY = int(os.getenv("CONTROLLER_MAX_CONCURRENCY", 15_000)) +assert CONTROLLER_MAX_CONCURRENCY > 0, ( + f"Got unexpected value {CONTROLLER_MAX_CONCURRENCY} for " + "CONTROLLER_MAX_CONCURRENCY environment variable. " + "CONTROLLER_MAX_CONCURRENCY must be positive." +) + +DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_S = 20 +DEFAULT_GRACEFUL_SHUTDOWN_WAIT_LOOP_S = 2 +DEFAULT_HEALTH_CHECK_PERIOD_S = 10 +DEFAULT_HEALTH_CHECK_TIMEOUT_S = 30 +DEFAULT_MAX_ONGOING_REQUESTS = 5 +DEFAULT_TARGET_ONGOING_REQUESTS = 2 + +# HTTP Proxy health check configs +PROXY_HEALTH_CHECK_TIMEOUT_S = ( + float(os.environ.get("RAY_SERVE_PROXY_HEALTH_CHECK_TIMEOUT_S", "10")) or 10 +) +PROXY_HEALTH_CHECK_PERIOD_S = ( + float(os.environ.get("RAY_SERVE_PROXY_HEALTH_CHECK_PERIOD_S", "10")) or 10 +) +PROXY_READY_CHECK_TIMEOUT_S = ( + float(os.environ.get("RAY_SERVE_PROXY_READY_CHECK_TIMEOUT_S", "5")) or 5 +) + +# Number of times in a row that a HTTP proxy must fail the health check before +# being marked unhealthy. +PROXY_HEALTH_CHECK_UNHEALTHY_THRESHOLD = 3 + +# The minimum drain period for a HTTP proxy. +PROXY_MIN_DRAINING_PERIOD_S = ( + float(os.environ.get("RAY_SERVE_PROXY_MIN_DRAINING_PERIOD_S", "30")) or 30 +) +# The time in seconds that the http proxy state waits before +# rechecking whether the proxy actor is drained or not. +PROXY_DRAIN_CHECK_PERIOD_S = 5 + +#: Number of times in a row that a replica must fail the health check before +#: being marked unhealthy. +REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD = 3 + +# The time in seconds that the Serve client waits before rechecking deployment state +CLIENT_POLLING_INTERVAL_S: float = 1 + +# The time in seconds that the Serve client waits before checking if +# deployment has been created +CLIENT_CHECK_CREATION_POLLING_INTERVAL_S: float = 0.1 + +# Handle metric push interval. (This interval will affect the cold start time period) +HANDLE_METRIC_PUSH_INTERVAL_S = float( + os.environ.get("RAY_SERVE_HANDLE_METRIC_PUSH_INTERVAL_S", "10") +) + +# Timeout for GCS internal KV service +RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None + +# Timeout for GCS RPC request +RAY_GCS_RPC_TIMEOUT_S = 3.0 + +# Maximum duration to wait until broadcasting a long poll update if there are +# still replicas in the RECOVERING state. +RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S = 10.0 + +# Minimum duration to wait until broadcasting model IDs. +PUSH_MULTIPLEXED_MODEL_IDS_INTERVAL_S = 0.1 + + +# Deprecation message for V1 migrations. +MIGRATION_MESSAGE = ( + "See https://docs.ray.io/en/latest/serve/index.html for more information." +) + +# Environment variable name for to specify the encoding of the log messages +RAY_SERVE_LOG_ENCODING = os.environ.get("RAY_SERVE_LOG_ENCODING", "TEXT") + +# Jsonify the log messages. This constant is deprecated and will be removed in the +# future. Use RAY_SERVE_LOG_ENCODING or 'LoggingConfig' to enable json format. +RAY_SERVE_ENABLE_JSON_LOGGING = os.environ.get("RAY_SERVE_ENABLE_JSON_LOGGING") == "1" + +# Setting RAY_SERVE_LOG_TO_STDERR=0 will disable logging to the stdout and stderr. +# Also, redirect them to serve's log files. +RAY_SERVE_LOG_TO_STDERR = os.environ.get("RAY_SERVE_LOG_TO_STDERR", "1") == "1" + +# Logging format attributes +SERVE_LOG_REQUEST_ID = "request_id" +SERVE_LOG_ROUTE = "route" +SERVE_LOG_APPLICATION = "application" +SERVE_LOG_DEPLOYMENT = "deployment" +SERVE_LOG_REPLICA = "replica" +SERVE_LOG_COMPONENT = "component_name" +SERVE_LOG_COMPONENT_ID = "component_id" +SERVE_LOG_MESSAGE = "message" +# This is a reserved for python logging module attribute, it should not be changed. +SERVE_LOG_LEVEL_NAME = "levelname" +SERVE_LOG_TIME = "asctime" + +# Logging format with record key to format string dict +SERVE_LOG_RECORD_FORMAT = { + SERVE_LOG_REQUEST_ID: "%(request_id)s", + SERVE_LOG_APPLICATION: "%(application)s", + SERVE_LOG_MESSAGE: "-- %(message)s", + SERVE_LOG_LEVEL_NAME: "%(levelname)s", + SERVE_LOG_TIME: "%(asctime)s", +} + +# There are some attributes that we only use internally or don't provide values to the +# users. Adding to this set will remove them from structured logs. +SERVE_LOG_UNWANTED_ATTRS = { + "serve_access_log", + "task_id", + "job_id", +} + +SERVE_LOG_EXTRA_FIELDS = "ray_serve_extra_fields" + +# Serve HTTP request header key for routing requests. +SERVE_MULTIPLEXED_MODEL_ID = "serve_multiplexed_model_id" + +# Feature flag to turn on node locality routing for proxies. On by default. +RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING = ( + os.environ.get("RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING", "1") == "1" +) + +# Feature flag to turn on AZ locality routing for proxies. On by default. +RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING = ( + os.environ.get("RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING", "1") == "1" +) + +# Serve HTTP proxy callback import path. +RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH = os.environ.get( + "RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH", None +) +# Serve controller callback import path. +RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH = os.environ.get( + "RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH", None +) + +# How often autoscaling metrics are recorded on Serve replicas. +RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S = 0.5 + +# How often autoscaling metrics are recorded on Serve handles. +RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S = 0.5 + +# Serve multiplexed matching timeout. +# This is the timeout for the matching process of multiplexed requests. To avoid +# thundering herd problem, the timeout value will be randomed between this value +# and this value * 2. The unit is second. +# If the matching process takes longer than the timeout, the request will be +# fallen to the default routing strategy. +RAY_SERVE_MULTIPLEXED_MODEL_ID_MATCHING_TIMEOUT_S = float( + os.environ.get("RAY_SERVE_MULTIPLEXED_MODEL_ID_MATCHING_TIMEOUT_S", "1") +) + +# Enable memray in all Serve actors. +RAY_SERVE_ENABLE_MEMORY_PROFILING = ( + os.environ.get("RAY_SERVE_ENABLE_MEMORY_PROFILING", "0") == "1" +) + +# Enable cProfile in all Serve actors. +RAY_SERVE_ENABLE_CPU_PROFILING = ( + os.environ.get("RAY_SERVE_ENABLE_CPU_PROFILING", "0") == "1" +) + +# Max value allowed for max_replicas_per_node option. +# TODO(jjyao) the <= 100 limitation is an artificial one +# and is due to the fact that Ray core only supports resource +# precision up to 0.0001. +# This limitation should be lifted in the long term. +MAX_REPLICAS_PER_NODE_MAX_VALUE = 100 + +# Argument name for passing in the gRPC context into a replica. +GRPC_CONTEXT_ARG_NAME = "grpc_context" + +# Whether or not to forcefully kill replicas that fail health checks. +RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS = ( + os.environ.get("RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS", "0") == "1" +) + +# Initial deadline for queue length responses in the router. +RAY_SERVE_QUEUE_LENGTH_RESPONSE_DEADLINE_S = float( + os.environ.get("RAY_SERVE_QUEUE_LENGTH_RESPONSE_DEADLINE_S", 0.1) +) + +# Maximum deadline for queue length responses in the router (in backoff). +RAY_SERVE_MAX_QUEUE_LENGTH_RESPONSE_DEADLINE_S = float( + os.environ.get("RAY_SERVE_MAX_QUEUE_LENGTH_RESPONSE_DEADLINE_S", 1.0) +) + +# Feature flag for caching queue lengths for faster routing in each handle. +RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE = ( + os.environ.get("RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE", "1") == "1" +) + +# Feature flag for strictly enforcing max_ongoing_requests (replicas will reject +# requests). +RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS = ( + os.environ.get("RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS", "0") == "1" + # Strict enforcement path must be enabled for the queue length cache. + or RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE +) + +# Length of time to respect entries in the queue length cache when scheduling requests. +RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S = float( + os.environ.get("RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S", 10.0) +) + +# The default autoscaling policy to use if none is specified. +DEFAULT_AUTOSCALING_POLICY = "ray.serve.autoscaling_policy:default_autoscaling_policy" + +# Feature flag to enable collecting all queued and ongoing request +# metrics at handles instead of replicas. ON by default. +RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE = ( + os.environ.get("RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE", "1") == "1" +) + +RAY_SERVE_MIN_HANDLE_METRICS_TIMEOUT_S = float( + os.environ.get("RAY_SERVE_MIN_HANDLE_METRICS_TIMEOUT_S", 10.0) +) + +# Feature flag to always run a proxy on the head node even if it has no replicas. +RAY_SERVE_ALWAYS_RUN_PROXY_ON_HEAD_NODE = ( + os.environ.get("RAY_SERVE_ALWAYS_RUN_PROXY_ON_HEAD_NODE", "1") == "1" +) + + +# Default is 2GiB, the max for a signed int. +RAY_SERVE_GRPC_MAX_MESSAGE_SIZE = int( + os.environ.get("RAY_SERVE_GRPC_MAX_MESSAGE_SIZE", (2 * 1024 * 1024 * 1024) - 1) +) + +# Default options passed when constructing gRPC servers. +DEFAULT_GRPC_SERVER_OPTIONS = [ + ("grpc.max_send_message_length", RAY_SERVE_GRPC_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", RAY_SERVE_GRPC_MAX_MESSAGE_SIZE), +] + +# Feature flag to eagerly start replacement replicas. This means new +# replicas will start before waiting for old replicas to fully stop. +RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS = ( + os.environ.get("RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS", "1") == "1" +) + +# Timeout for gracefully shutting down metrics pusher, e.g. in routers or replicas +METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S = 10 + +# Feature flag to set `enable_task_events=True` on Serve-managed actors. +RAY_SERVE_ENABLE_TASK_EVENTS = ( + os.environ.get("RAY_SERVE_ENABLE_TASK_EVENTS", "0") == "1" +) + +# Use compact instead of spread scheduling strategy +RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY = ( + os.environ.get("RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY", "0") == "1" +) + +# Feature flag to always override local_testing_mode to True in serve.run. +# This is used for internal testing to avoid passing the flag to every invocation. +RAY_SERVE_FORCE_LOCAL_TESTING_MODE = ( + os.environ.get("RAY_SERVE_FORCE_LOCAL_TESTING_MODE", "0") == "1" +) + +# Run sync methods defined in the replica in a thread pool by default. +RAY_SERVE_RUN_SYNC_IN_THREADPOOL = ( + os.environ.get("RAY_SERVE_RUN_SYNC_IN_THREADPOOL", "0") == "1" +) + +RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING = ( + "Calling sync method '{method_name}' directly on the " + "asyncio loop. In a future version, sync methods will be run in a " + "threadpool by default. Ensure your sync methods are thread safe " + "or keep the existing behavior by making them `async def`. Opt " + "into the new behavior by setting " + "RAY_SERVE_RUN_SYNC_IN_THREADPOOL=1." +) + +# Feature flag to turn off GC optimizations in the proxy (in case there is a +# memory leak or negative performance impact). +RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS = ( + os.environ.get("RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS", "1") == "1" +) + +# Used for gc.set_threshold() when proxy GC optimizations are enabled. +RAY_SERVE_PROXY_GC_THRESHOLD = int( + os.environ.get("RAY_SERVE_PROXY_GC_THRESHOLD", "10000") +) + +# Interval at which cached metrics will be exported using the Ray metric API. +# Set to `0` to disable caching entirely. +RAY_SERVE_METRICS_EXPORT_INTERVAL_MS = int( + os.environ.get("RAY_SERVE_METRICS_EXPORT_INTERVAL_MS", "100") +) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/default_impl.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/default_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4aa6318ffa9f60990266247aabf031c99eef9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/default_impl.py @@ -0,0 +1,213 @@ +from typing import Callable, Optional, Tuple + +import ray +from ray._raylet import GcsClient +from ray.serve._private.cluster_node_info_cache import ( + ClusterNodeInfoCache, + DefaultClusterNodeInfoCache, +) +from ray.serve._private.common import ( + CreatePlacementGroupRequest, + DeploymentHandleSource, + DeploymentID, + EndpointInfo, + RequestMetadata, + RequestProtocol, +) +from ray.serve._private.constants import ( + RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE, + RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS, + RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING, + RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, +) +from ray.serve._private.deployment_scheduler import ( + DefaultDeploymentScheduler, + DeploymentScheduler, +) +from ray.serve._private.grpc_util import gRPCGenericServer +from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions +from ray.serve._private.replica_scheduler import PowerOfTwoChoicesReplicaScheduler +from ray.serve._private.replica_scheduler.replica_wrapper import RunningReplica +from ray.serve._private.router import Router, SingletonThreadRouter +from ray.serve._private.utils import ( + generate_request_id, + get_current_actor_id, + get_head_node_id, + inside_ray_client_context, + resolve_deployment_response, +) +from ray.util.placement_group import PlacementGroup + +# NOTE: Please read carefully before changing! +# +# These methods are common extension points, therefore these should be +# changed as a Developer API, ie methods should not be renamed, have their +# API modified w/o substantial enough justification + + +def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCache: + return DefaultClusterNodeInfoCache(gcs_client) + + +CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup] + + +def _default_create_placement_group( + request: CreatePlacementGroupRequest, +) -> PlacementGroup: + return ray.util.placement_group( + request.bundles, + request.strategy, + _soft_target_node_id=request.target_node_id, + name=request.name, + lifetime="detached", + ) + + +def create_deployment_scheduler( + cluster_node_info_cache: ClusterNodeInfoCache, + head_node_id_override: Optional[str] = None, + create_placement_group_fn_override: Optional[CreatePlacementGroupFn] = None, +) -> DeploymentScheduler: + head_node_id = head_node_id_override or get_head_node_id() + return DefaultDeploymentScheduler( + cluster_node_info_cache, + head_node_id, + create_placement_group_fn=create_placement_group_fn_override + or _default_create_placement_group, + ) + + +def create_replica_impl(**kwargs): + from ray.serve._private.replica import Replica + + return Replica(**kwargs) + + +def create_dynamic_handle_options(**kwargs): + return DynamicHandleOptions(**kwargs) + + +def create_init_handle_options(**kwargs): + return InitHandleOptions.create(**kwargs) + + +def get_request_metadata(init_options, handle_options): + _request_context = ray.serve.context._get_serve_request_context() + + request_protocol = RequestProtocol.UNDEFINED + if init_options and init_options._source == DeploymentHandleSource.PROXY: + if _request_context.is_http_request: + request_protocol = RequestProtocol.HTTP + elif _request_context.grpc_context: + request_protocol = RequestProtocol.GRPC + + return RequestMetadata( + request_id=_request_context.request_id + if _request_context.request_id + else generate_request_id(), + internal_request_id=_request_context._internal_request_id + if _request_context._internal_request_id + else generate_request_id(), + call_method=handle_options.method_name, + route=_request_context.route, + app_name=_request_context.app_name, + multiplexed_model_id=handle_options.multiplexed_model_id, + is_streaming=handle_options.stream, + _request_protocol=request_protocol, + grpc_context=_request_context.grpc_context, + _by_reference=True, + ) + + +def _get_node_id_and_az() -> Tuple[str, Optional[str]]: + node_id = ray.get_runtime_context().get_node_id() + try: + cluster_node_info_cache = create_cluster_node_info_cache( + GcsClient(address=ray.get_runtime_context().gcs_address) + ) + cluster_node_info_cache.update() + az = cluster_node_info_cache.get_node_az(node_id) + except Exception: + az = None + + return node_id, az + + +# Interface definition for create_router. +CreateRouterCallable = Callable[[str, DeploymentID, InitHandleOptions], Router] + + +def create_router( + handle_id: str, + deployment_id: DeploymentID, + handle_options: InitHandleOptions, +) -> Router: + # NOTE(edoakes): this is lazy due to a nasty circular import that should be fixed. + from ray.serve.context import _get_global_client + + actor_id = get_current_actor_id() + node_id, availability_zone = _get_node_id_and_az() + controller_handle = _get_global_client()._controller + is_inside_ray_client_context = inside_ray_client_context() + + replica_scheduler = PowerOfTwoChoicesReplicaScheduler( + deployment_id, + handle_options._source, + handle_options._prefer_local_routing, + RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING, + node_id, + actor_id, + ray.get_runtime_context().current_actor + if ray.get_runtime_context().get_actor_id() + else None, + availability_zone, + # Streaming ObjectRefGenerators are not supported in Ray Client + use_replica_queue_len_cache=( + not is_inside_ray_client_context and RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE + ), + create_replica_wrapper_func=lambda r: RunningReplica(r), + ) + + return SingletonThreadRouter( + controller_handle=controller_handle, + deployment_id=deployment_id, + handle_id=handle_id, + self_actor_id=actor_id, + handle_source=handle_options._source, + replica_scheduler=replica_scheduler, + # Streaming ObjectRefGenerators are not supported in Ray Client + enable_strict_max_ongoing_requests=( + not is_inside_ray_client_context + and RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS + ), + resolve_request_arg_func=resolve_deployment_response, + ) + + +def add_grpc_address(grpc_server: gRPCGenericServer, server_address: str): + """Helper function to add an address to a gRPC server.""" + grpc_server.add_insecure_port(server_address) + + +def get_proxy_handle(endpoint: DeploymentID, info: EndpointInfo): + # NOTE(zcin): needs to be lazy import due to a circular dependency. + # We should not be importing from application_state in context. + from ray.serve.context import _get_global_client + + client = _get_global_client() + handle = client.get_handle(endpoint.name, endpoint.app_name, check_exists=True) + + # NOTE(zcin): It's possible that a handle is already initialized + # if a deployment with the same name and application name was + # deleted, then redeployed later. However this is not an issue since + # we initialize all handles with the same init options. + if not handle.is_initialized: + # NOTE(zcin): since the router is eagerly initialized here, the + # proxy will receive the replica set from the controller early. + handle._init( + _prefer_local_routing=RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, + _source=DeploymentHandleSource.PROXY, + ) + + return handle.options(stream=not info.app_is_cross_language) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_node.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_node.py new file mode 100644 index 0000000000000000000000000000000000000000..375c49e830c4788d973c555e0ede3d5f1c883730 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_node.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, List, Optional, Tuple + +from ray.dag import DAGNode +from ray.dag.format_utils import get_dag_node_str +from ray.serve.deployment import Deployment +from ray.serve.handle import DeploymentHandle + + +class DeploymentNode(DAGNode): + """Represents a deployment node in a DAG authored Ray DAG API.""" + + def __init__( + self, + # For serve structured deployment, deployment body can be import path + # to the class or function instead. + deployment: Deployment, + app_name: str, + deployment_init_args: Tuple[Any], + deployment_init_kwargs: Dict[str, Any], + ray_actor_options: Dict[str, Any], + other_args_to_resolve: Optional[Dict[str, Any]] = None, + ): + # Assign instance variables in base class constructor. + super().__init__( + deployment_init_args, + deployment_init_kwargs, + ray_actor_options, + other_args_to_resolve=other_args_to_resolve, + ) + self._app_name = app_name + self._deployment = deployment + self._deployment_handle = DeploymentHandle( + self._deployment.name, self._app_name + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return DeploymentNode( + self._deployment, + self._app_name, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._deployment)) + + def get_deployment_name(self): + return self._deployment.name diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_state.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4b354adb3c8aaa476268d253cee0d9b2379974 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/deployment_state.py @@ -0,0 +1,2837 @@ +import json +import logging +import math +import os +import random +import time +import traceback +from collections import defaultdict +from copy import copy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +import ray +from ray import ObjectRef, cloudpickle +from ray._private import ray_constants +from ray.actor import ActorHandle +from ray.exceptions import RayActorError, RayError, RayTaskError, RuntimeEnvSetupError +from ray.serve import metrics +from ray.serve._private import default_impl +from ray.serve._private.autoscaling_state import AutoscalingStateManager +from ray.serve._private.cluster_node_info_cache import ClusterNodeInfoCache +from ray.serve._private.common import ( + DeploymentID, + DeploymentStatus, + DeploymentStatusInfo, + DeploymentStatusInternalTrigger, + DeploymentStatusTrigger, + DeploymentTargetInfo, + Duration, + MultiplexedReplicaInfo, + ReplicaID, + ReplicaState, + RunningReplicaInfo, +) +from ray.serve._private.config import DeploymentConfig +from ray.serve._private.constants import ( + MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT, + RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS, + RAY_SERVE_ENABLE_TASK_EVENTS, + RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS, + RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY, + REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD, + SERVE_LOGGER_NAME, + SERVE_NAMESPACE, +) +from ray.serve._private.deployment_info import DeploymentInfo +from ray.serve._private.deployment_scheduler import ( + DeploymentDownscaleRequest, + DeploymentScheduler, + ReplicaSchedulingRequest, + ReplicaSchedulingRequestStatus, + SpreadDeploymentSchedulingPolicy, +) +from ray.serve._private.long_poll import LongPollHost, LongPollNamespace +from ray.serve._private.storage.kv_store import KVStoreBase +from ray.serve._private.usage import ServeUsageTag +from ray.serve._private.utils import ( + JavaActorHandleProxy, + check_obj_ref_ready_nowait, + get_capacity_adjusted_num_replicas, + get_random_string, + msgpack_deserialize, + msgpack_serialize, +) +from ray.serve._private.version import DeploymentVersion +from ray.serve.generated.serve_pb2 import DeploymentLanguage +from ray.serve.schema import ( + DeploymentDetails, + ReplicaDetails, + _deployment_info_to_schema, +) +from ray.util.placement_group import PlacementGroup + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +class ReplicaStartupStatus(Enum): + PENDING_ALLOCATION = 1 + PENDING_INITIALIZATION = 2 + SUCCEEDED = 3 + FAILED = 4 + + +class ReplicaHealthCheckResponse(Enum): + NONE = 1 + SUCCEEDED = 2 + APP_FAILURE = 3 + ACTOR_CRASHED = 4 + + +@dataclass +class DeploymentTargetState: + """The current goal state for a deployment. + + info: contains the information needed to initialize a replica. + target_num_replicas: the number of replicas to run. This should already + be adjusted by the target_capacity. + version: the goal version of the deployment. + deleting: whether the deployment is being deleted. + """ + + info: Optional[DeploymentInfo] + target_num_replicas: int + version: Optional[DeploymentVersion] + deleting: bool + + @classmethod + def default(cls) -> "DeploymentTargetState": + return cls(None, -1, None, False) + + @classmethod + def create( + cls, + info: DeploymentInfo, + target_num_replicas: int, + *, + deleting: bool = False, + ) -> "DeploymentTargetState": + if deleting: + if target_num_replicas != 0: + raise ValueError( + "target_num_replicas must be 0 when setting target state " + f"to deleting. Got {target_num_replicas} instead." + ) + + version = DeploymentVersion( + info.version, + deployment_config=info.deployment_config, + ray_actor_options=info.replica_config.ray_actor_options, + placement_group_bundles=info.replica_config.placement_group_bundles, + placement_group_strategy=info.replica_config.placement_group_strategy, + max_replicas_per_node=info.replica_config.max_replicas_per_node, + ) + + return cls(info, target_num_replicas, version, deleting) + + def is_scaled_copy_of(self, other_target_state: "DeploymentTargetState") -> bool: + """Checks if this target state is a scaled copy of another target state. + + A target state is a scaled copy of another target state if all + configurable info is identical, other than target_num_replicas. + + Returns: True if this target state contains a non-None DeploymentInfo + and is a scaled copy of the other target state. + """ + + if other_target_state.info is None: + return False + + return all( + [ + self.info.replica_config.ray_actor_options + == other_target_state.info.replica_config.ray_actor_options, + self.info.replica_config.placement_group_bundles + == other_target_state.info.replica_config.placement_group_bundles, + self.info.replica_config.placement_group_strategy + == other_target_state.info.replica_config.placement_group_strategy, + self.info.replica_config.max_replicas_per_node + == other_target_state.info.replica_config.max_replicas_per_node, + self.info.deployment_config.dict(exclude={"num_replicas"}) + == other_target_state.info.deployment_config.dict( + exclude={"num_replicas"} + ), + # TODO(zcin): version can be None, this is from an outdated codepath. + # We should remove outdated code, so version can never be None. + self.version, + self.version == other_target_state.version, + ] + ) + + +@dataclass +class DeploymentStateUpdateResult: + deleted: bool + any_replicas_recovering: bool + upscale: List[ReplicaSchedulingRequest] + downscale: Optional[DeploymentDownscaleRequest] + + +CHECKPOINT_KEY = "serve-deployment-state-checkpoint" +SLOW_STARTUP_WARNING_S = int(os.environ.get("SERVE_SLOW_STARTUP_WARNING_S", 30)) +SLOW_STARTUP_WARNING_PERIOD_S = int( + os.environ.get("SERVE_SLOW_STARTUP_WARNING_PERIOD_S", 30) +) + +ALL_REPLICA_STATES = list(ReplicaState) +_SCALING_LOG_ENABLED = os.environ.get("SERVE_ENABLE_SCALING_LOG", "0") != "0" + + +def print_verbose_scaling_log(): + assert _SCALING_LOG_ENABLED + + log_path = "/tmp/ray/session_latest/logs/monitor.log" + last_n_lines = 50 + autoscaler_log_last_n_lines = [] + if os.path.exists(log_path): + with open(log_path) as f: + autoscaler_log_last_n_lines = f.readlines()[-last_n_lines:] + + debug_info = { + "nodes": ray.nodes(), + "available_resources": ray.available_resources(), + "total_resources": ray.cluster_resources(), + "autoscaler_logs": autoscaler_log_last_n_lines, + } + logger.error(f"Scaling information\n{json.dumps(debug_info, indent=2)}") + + +class ActorReplicaWrapper: + """Wraps a Ray actor for a deployment replica. + + This is primarily defined so that we can mock out actual Ray operations + for unit testing. + + *All Ray API calls should be made here, not in DeploymentState.* + """ + + def __init__( + self, + replica_id: ReplicaID, + version: DeploymentVersion, + ): + self._replica_id = replica_id + self._deployment_id = replica_id.deployment_id + self._actor_name = replica_id.to_full_id_str() + + # Populated in either self.start() or self.recover() + self._allocated_obj_ref: ObjectRef = None + self._ready_obj_ref: ObjectRef = None + + self._actor_resources: Dict[str, float] = None + # If the replica is being started, this will be the true version + # If the replica is being recovered, this will be the target + # version, which may be inconsistent with the actual replica + # version. If so, the actual version will be updated later after + # recover() and check_ready() + self._version: DeploymentVersion = version + self._healthy: bool = True + self._health_check_ref: Optional[ObjectRef] = None + self._last_health_check_time: float = 0.0 + self._consecutive_health_check_failures = 0 + self._initialization_latency_s: Optional[float] = None + self._port: Optional[int] = None + + # Populated in `on_scheduled` or `recover`. + self._actor_handle: ActorHandle = None + self._placement_group: PlacementGroup = None + + # Populated after replica is allocated. + self._pid: int = None + self._actor_id: str = None + self._worker_id: str = None + self._node_id: str = None + self._node_ip: str = None + self._log_file_path: str = None + + # Populated in self.stop(). + self._graceful_shutdown_ref: ObjectRef = None + + # todo: will be confused with deployment_config.is_cross_language + self._is_cross_language = False + self._deployment_is_cross_language = False + + @property + def replica_id(self) -> str: + return self._replica_id + + @property + def deployment_name(self) -> str: + return self._deployment_id.name + + @property + def app_name(self) -> str: + return self._deployment_id.app_name + + @property + def is_cross_language(self) -> bool: + return self._is_cross_language + + @property + def actor_handle(self) -> Optional[ActorHandle]: + if not self._actor_handle: + try: + self._actor_handle = ray.get_actor( + self._actor_name, namespace=SERVE_NAMESPACE + ) + except ValueError: + self._actor_handle = None + + if self._is_cross_language: + assert isinstance(self._actor_handle, JavaActorHandleProxy) + return self._actor_handle.handle + + return self._actor_handle + + @property + def placement_group_bundles(self) -> Optional[List[Dict[str, float]]]: + if not self._placement_group: + return None + + return self._placement_group.bundle_specs + + @property + def version(self) -> DeploymentVersion: + """Replica version. This can be incorrect during state recovery. + + If the controller crashes and the deployment state is being + recovered, this will temporarily be the deployment-wide target + version, which may be inconsistent with the actual version + running on the replica actor. If so, the actual version will be + updated when the replica transitions from RECOVERING -> RUNNING + """ + return self._version + + @property + def deployment_config(self) -> DeploymentConfig: + """Deployment config. This can return an incorrect config during state recovery. + + If the controller hasn't yet recovered the up-to-date version + from the running replica actor, this property will return the + current target config for the deployment. + """ + return self._version.deployment_config + + @property + def max_ongoing_requests(self) -> int: + return self.deployment_config.max_ongoing_requests + + @property + def max_queued_requests(self) -> int: + return self.deployment_config.max_queued_requests + + @property + def graceful_shutdown_timeout_s(self) -> float: + return self.deployment_config.graceful_shutdown_timeout_s + + @property + def health_check_period_s(self) -> float: + return self.deployment_config.health_check_period_s + + @property + def health_check_timeout_s(self) -> float: + return self.deployment_config.health_check_timeout_s + + @property + def pid(self) -> Optional[int]: + """Returns the pid of the actor, None if not started.""" + return self._pid + + @property + def actor_id(self) -> Optional[str]: + """Returns the actor id, None if not started.""" + return self._actor_id + + @property + def worker_id(self) -> Optional[str]: + """Returns the worker id, None if not started.""" + return self._worker_id + + @property + def node_id(self) -> Optional[str]: + """Returns the node id of the actor, None if not placed.""" + return self._node_id + + @property + def node_ip(self) -> Optional[str]: + """Returns the node ip of the actor, None if not placed.""" + return self._node_ip + + @property + def log_file_path(self) -> Optional[str]: + """Returns the relative log file path of the actor, None if not placed.""" + return self._log_file_path + + @property + def initialization_latency_s(self) -> Optional[float]: + """Returns the initialization latency for the replica actor. + + Returns None if the replica hasn't started yet. + + Note: this value isn't checkpointed, so if the controller restarts, + this value goes back to None. + """ + + return self._initialization_latency_s + + def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: + """Start the current DeploymentReplica instance. + + The replica will be in the STARTING and PENDING_ALLOCATION states + until the deployment scheduler schedules the underlying actor. + """ + self._actor_resources = deployment_info.replica_config.resource_dict + # it is currently not possible to create a placement group + # with no resources (https://github.com/ray-project/ray/issues/20401) + self._deployment_is_cross_language = ( + deployment_info.deployment_config.is_cross_language + ) + + logger.info( + f"Starting {self.replica_id}.", + extra={"log_to_stderr": False}, + ) + + actor_def = deployment_info.actor_def + if ( + deployment_info.deployment_config.deployment_language + == DeploymentLanguage.PYTHON + ): + if deployment_info.replica_config.serialized_init_args is None: + serialized_init_args = cloudpickle.dumps(()) + else: + serialized_init_args = ( + cloudpickle.dumps( + msgpack_deserialize( + deployment_info.replica_config.serialized_init_args + ) + ) + if self._deployment_is_cross_language + else deployment_info.replica_config.serialized_init_args + ) + init_args = ( + self.replica_id, + cloudpickle.dumps(deployment_info.replica_config.deployment_def) + if self._deployment_is_cross_language + else deployment_info.replica_config.serialized_deployment_def, + serialized_init_args, + deployment_info.replica_config.serialized_init_kwargs + if deployment_info.replica_config.serialized_init_kwargs + else cloudpickle.dumps({}), + deployment_info.deployment_config.to_proto_bytes(), + self._version, + ) + # TODO(simon): unify the constructor arguments across language + elif ( + deployment_info.deployment_config.deployment_language + == DeploymentLanguage.JAVA + ): + self._is_cross_language = True + actor_def = ray.cross_language.java_actor_class( + "io.ray.serve.replica.RayServeWrappedReplica" + ) + init_args = ( + # String deploymentName, + self.deployment_name, + # String replicaID, + self.replica_id.to_full_id_str(), + # String deploymentDef + deployment_info.replica_config.deployment_def_name, + # byte[] initArgsbytes + msgpack_serialize( + cloudpickle.loads( + deployment_info.replica_config.serialized_init_args + ) + ) + if self._deployment_is_cross_language + else deployment_info.replica_config.serialized_init_args, + # byte[] deploymentConfigBytes, + deployment_info.deployment_config.to_proto_bytes(), + # byte[] deploymentVersionBytes, + self._version.to_proto().SerializeToString(), + # String controllerName + # String appName + self.app_name, + ) + + actor_options = { + "name": self._actor_name, + "namespace": SERVE_NAMESPACE, + "lifetime": "detached", + "enable_task_events": RAY_SERVE_ENABLE_TASK_EVENTS, + } + actor_options.update(deployment_info.replica_config.ray_actor_options) + + # A replica's default `max_concurrency` value can prevent it from + # respecting the configured `max_ongoing_requests`. To avoid this + # unintentional behavior, use `max_ongoing_requests` to override + # the Actor's `max_concurrency` if it is larger. + if ( + deployment_info.deployment_config.max_ongoing_requests + > ray_constants.DEFAULT_MAX_CONCURRENCY_ASYNC + ): + actor_options[ + "max_concurrency" + ] = deployment_info.deployment_config.max_ongoing_requests + + return ReplicaSchedulingRequest( + replica_id=self.replica_id, + actor_def=actor_def, + actor_resources=self._actor_resources, + actor_options=actor_options, + actor_init_args=init_args, + placement_group_bundles=( + deployment_info.replica_config.placement_group_bundles + ), + placement_group_strategy=( + deployment_info.replica_config.placement_group_strategy + ), + max_replicas_per_node=( + deployment_info.replica_config.max_replicas_per_node + ), + on_scheduled=self.on_scheduled, + ) + + def on_scheduled( + self, + actor_handle: ActorHandle, + placement_group: Optional[PlacementGroup] = None, + ): + self._actor_handle = actor_handle + self._placement_group = placement_group + + # Perform auto method name translation for java handles. + # See https://github.com/ray-project/ray/issues/21474 + deployment_config = copy(self._version.deployment_config) + deployment_config.user_config = self._format_user_config( + deployment_config.user_config + ) + if self._is_cross_language: + self._actor_handle = JavaActorHandleProxy(self._actor_handle) + self._allocated_obj_ref = self._actor_handle.is_allocated.remote() + self._ready_obj_ref = self._actor_handle.is_initialized.remote( + deployment_config.to_proto_bytes() + ) + else: + self._allocated_obj_ref = self._actor_handle.is_allocated.remote() + replica_ready_check_func = self._actor_handle.initialize_and_get_metadata + self._ready_obj_ref = replica_ready_check_func.remote( + deployment_config, + # Ensure that `is_allocated` will execute + # before `initialize_and_get_metadata`, + # because `initialize_and_get_metadata` runs + # user code that could block the replica + # asyncio loop. If that happens before `is_allocated` is executed, + # the `is_allocated` call won't be able to run. + self._allocated_obj_ref, + ) + + def _format_user_config(self, user_config: Any): + temp = copy(user_config) + if user_config is not None and self._deployment_is_cross_language: + if self._is_cross_language: + temp = msgpack_serialize(temp) + else: + temp = msgpack_deserialize(temp) + return temp + + def reconfigure(self, version: DeploymentVersion) -> bool: + """ + Update replica version. Also, updates the deployment config on the actor + behind this DeploymentReplica instance if necessary. + + Returns: whether the actor is being updated. + """ + updating = False + if self._version.requires_actor_reconfigure(version): + # Call into replica actor reconfigure() with updated user config and + # graceful_shutdown_wait_loop_s + updating = True + deployment_config = copy(version.deployment_config) + deployment_config.user_config = self._format_user_config( + deployment_config.user_config + ) + self._ready_obj_ref = self._actor_handle.reconfigure.remote( + deployment_config + ) + + self._version = version + return updating + + def recover(self) -> bool: + """Recover replica version from a live replica actor. + + When controller dies, the deployment state loses the info on the version that's + running on each individual replica actor, so as part of the recovery process, we + need to recover the version that is running on the replica actor. + + Also confirm that actor is allocated and initialized before marking as running. + + Returns: False if the replica actor is no longer alive; the + actor could have been killed in the time between when the + controller fetching all Serve actors in the cluster and when + the controller tries to recover it. Otherwise, return True. + """ + logger.info(f"Recovering {self.replica_id}.") + try: + self._actor_handle = ray.get_actor( + self._actor_name, namespace=SERVE_NAMESPACE + ) + except ValueError: + logger.warning( + f"Failed to get handle to replica {self._actor_name} " + "during controller recovery. Marking as dead." + ) + return False + + try: + self._placement_group = ray.util.get_placement_group( + self._actor_name, + ) + except ValueError: + # ValueError is raised if the placement group does not exist. + self._placement_group = None + + # Re-fetch initialization proof + self._allocated_obj_ref = self._actor_handle.is_allocated.remote() + + # Running actor handle already has all info needed, thus successful + # starting simply means retrieving replica version hash from actor + if self._is_cross_language: + self._ready_obj_ref = self._actor_handle.check_health.remote() + else: + self._ready_obj_ref = ( + self._actor_handle.initialize_and_get_metadata.remote() + ) + + return True + + def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: + """ + Check if current replica has started by making ray API calls on + relevant actor / object ref. + + Replica initialization calls __init__(), reconfigure(), and check_health(). + + Returns: + state (ReplicaStartupStatus): + PENDING_ALLOCATION: replica is waiting for a worker to start + PENDING_INITIALIZATION: replica initialization hasn't finished. + FAILED: replica initialization failed. + SUCCEEDED: replica initialization succeeded. + error_msg: + None: for PENDING_ALLOCATION, PENDING_INITIALIZATION or SUCCEEDED states + str: for FAILED state + """ + + # Check whether the replica has been allocated. + if self._allocated_obj_ref is None or not check_obj_ref_ready_nowait( + self._allocated_obj_ref + ): + return ReplicaStartupStatus.PENDING_ALLOCATION, None + + if not self._is_cross_language: + try: + ( + self._pid, + self._actor_id, + self._worker_id, + self._node_id, + self._node_ip, + self._log_file_path, + ) = ray.get(self._allocated_obj_ref) + except RayTaskError as e: + logger.exception( + f"Exception in {self._replica_id}, the replica will be stopped." + ) + return ReplicaStartupStatus.FAILED, str(e.as_instanceof_cause()) + except RuntimeEnvSetupError as e: + msg = f"Exception when allocating {self._replica_id}: {str(e)}" + logger.exception(msg) + return ReplicaStartupStatus.FAILED, msg + except Exception: + msg = ( + f"Exception when allocating {self._replica_id}:\n" + + traceback.format_exc() + ) + logger.exception(msg) + return ReplicaStartupStatus.FAILED, msg + + # Check whether replica initialization has completed. + replica_ready = check_obj_ref_ready_nowait(self._ready_obj_ref) + # In case of deployment constructor failure, ray.get will help to + # surface exception to each update() cycle. + if not replica_ready: + return ReplicaStartupStatus.PENDING_INITIALIZATION, None + else: + try: + # TODO(simon): fully implement reconfigure for Java replicas. + if self._is_cross_language: + return ReplicaStartupStatus.SUCCEEDED, None + + # todo: The replica's userconfig whitch java client created + # is different from the controller's userconfig + if not self._deployment_is_cross_language: + # This should only update version if the replica is being recovered. + # If this is checking on a replica that is newly started, this + # should return a version that is identical to what's already stored + ( + _, + self._version, + self._initialization_latency_s, + self._port, + ) = ray.get(self._ready_obj_ref) + except RayTaskError as e: + logger.exception( + f"Exception in {self._replica_id}, the replica will be stopped." + ) + # NOTE(zcin): we should use str(e) instead of traceback.format_exc() + # here because the full details of the error is not displayed properly + # with traceback.format_exc(). + return ReplicaStartupStatus.FAILED, str(e.as_instanceof_cause()) + except Exception as e: + logger.exception( + f"Exception in {self._replica_id}, the replica will be stopped." + ) + return ReplicaStartupStatus.FAILED, repr(e) + + return ReplicaStartupStatus.SUCCEEDED, None + + @property + def actor_resources(self) -> Optional[Dict[str, float]]: + return self._actor_resources + + @property + def available_resources(self) -> Dict[str, float]: + return ray.available_resources() + + def graceful_stop(self) -> Duration: + """Request the actor to exit gracefully. + + Returns the timeout after which to kill the actor. + """ + try: + handle = ray.get_actor(self._actor_name, namespace=SERVE_NAMESPACE) + if self._is_cross_language: + handle = JavaActorHandleProxy(handle) + self._graceful_shutdown_ref = handle.perform_graceful_shutdown.remote() + except ValueError: + # ValueError thrown from ray.get_actor means actor has already been deleted. + pass + + return self.graceful_shutdown_timeout_s + + def check_stopped(self) -> bool: + """Check if the actor has exited.""" + try: + handle = ray.get_actor(self._actor_name, namespace=SERVE_NAMESPACE) + stopped = check_obj_ref_ready_nowait(self._graceful_shutdown_ref) + if stopped: + try: + ray.get(self._graceful_shutdown_ref) + except Exception: + logger.exception( + "Exception when trying to gracefully shutdown replica:\n" + + traceback.format_exc() + ) + + ray.kill(handle, no_restart=True) + except ValueError: + # ValueError thrown from ray.get_actor means actor has already been deleted. + stopped = True + finally: + # Remove the placement group both if the actor has already been deleted or + # it was just killed above. + if stopped and self._placement_group is not None: + ray.util.remove_placement_group(self._placement_group) + + return stopped + + def _check_active_health_check(self) -> ReplicaHealthCheckResponse: + """Check the active health check (if any). + + self._health_check_ref will be reset to `None` when the active health + check is deemed to have succeeded or failed. This method *does not* + start a new health check, that's up to the caller. + + Returns: + - NONE if there's no active health check, or it hasn't returned + yet and the timeout is not up. + - SUCCEEDED if the active health check succeeded. + - APP_FAILURE if the active health check failed (or didn't return + before the timeout). + - ACTOR_CRASHED if the underlying actor crashed. + """ + if self._health_check_ref is None: + # There is no outstanding health check. + response = ReplicaHealthCheckResponse.NONE + elif check_obj_ref_ready_nowait(self._health_check_ref): + # Object ref is ready, ray.get it to check for exceptions. + try: + ray.get(self._health_check_ref) + # Health check succeeded without exception. + response = ReplicaHealthCheckResponse.SUCCEEDED + except RayActorError: + # Health check failed due to actor crashing. + response = ReplicaHealthCheckResponse.ACTOR_CRASHED + except RayError as e: + # Health check failed due to application-level exception. + logger.warning(f"Health check for {self._replica_id} failed: {e}") + response = ReplicaHealthCheckResponse.APP_FAILURE + elif time.time() - self._last_health_check_time > self.health_check_timeout_s: + # Health check hasn't returned and the timeout is up, consider it failed. + logger.warning( + "Didn't receive health check response for replica " + f"{self._replica_id} after " + f"{self.health_check_timeout_s}s, marking it unhealthy." + ) + response = ReplicaHealthCheckResponse.APP_FAILURE + else: + # Health check hasn't returned and the timeout isn't up yet. + response = ReplicaHealthCheckResponse.NONE + + if response is not ReplicaHealthCheckResponse.NONE: + self._health_check_ref = None + + return response + + def _should_start_new_health_check(self) -> bool: + """Determines if a new health check should be kicked off. + + A health check will be started if: + 1) There is not already an active health check. + 2) It has been more than health_check_period_s since the + previous health check was *started*. + + This assumes that self._health_check_ref is reset to `None` when an + active health check succeeds or fails (due to returning or timeout). + """ + if self._health_check_ref is not None: + # There's already an active health check. + return False + + # If there's no active health check, kick off another and reset + # the timer if it's been long enough since the last health + # check. Add some randomness to avoid synchronizing across all + # replicas. + time_since_last = time.time() - self._last_health_check_time + randomized_period = self.health_check_period_s * random.uniform(0.9, 1.1) + return time_since_last > randomized_period + + def check_health(self) -> bool: + """Check if the actor is healthy. + + self._healthy should *only* be modified in this method. + + This is responsible for: + 1) Checking the outstanding health check (if any). + 2) Determining the replica health based on the health check results. + 3) Kicking off a new health check if needed. + """ + response: ReplicaHealthCheckResponse = self._check_active_health_check() + if response is ReplicaHealthCheckResponse.NONE: + # No info; don't update replica health. + pass + elif response is ReplicaHealthCheckResponse.SUCCEEDED: + # Health check succeeded. Reset the consecutive failure counter + # and mark the replica healthy. + if self._consecutive_health_check_failures > 0: + logger.info( + f"{self._replica_id} passed the health check after " + f"{self._consecutive_health_check_failures} consecutive failures." + ) + self._consecutive_health_check_failures = 0 + self._healthy = True + elif response is ReplicaHealthCheckResponse.APP_FAILURE: + # Health check failed. If it has failed more than N times in a row, + # mark the replica unhealthy. + self._consecutive_health_check_failures += 1 + if ( + self._consecutive_health_check_failures + >= REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD + ): + logger.warning( + f"Replica {self._replica_id} failed the health " + f"check {self._consecutive_health_check_failures} " + "times in a row, marking it unhealthy." + ) + self._healthy = False + elif response is ReplicaHealthCheckResponse.ACTOR_CRASHED: + # Actor crashed, mark the replica unhealthy immediately. + logger.warning( + f"Actor for {self._replica_id} crashed, marking " + "it unhealthy immediately." + ) + self._healthy = False + else: + assert False, f"Unknown response type: {response}." + + if self._should_start_new_health_check(): + self._last_health_check_time = time.time() + self._health_check_ref = self._actor_handle.check_health.remote() + + return self._healthy + + def force_stop(self): + """Force the actor to exit without shutting down gracefully.""" + try: + ray.kill(ray.get_actor(self._actor_name, namespace=SERVE_NAMESPACE)) + except ValueError: + pass + + +class DeploymentReplica: + """Manages state transitions for deployment replicas. + + This is basically a checkpointable lightweight state machine. + """ + + def __init__( + self, + replica_id: ReplicaID, + version: DeploymentVersion, + ): + self._replica_id = replica_id + self._actor = ActorReplicaWrapper(replica_id, version) + self._start_time = None + self._actor_details = ReplicaDetails( + actor_name=replica_id.to_full_id_str(), + replica_id=self._replica_id.unique_id, + state=ReplicaState.STARTING, + start_time_s=0, + ) + self._multiplexed_model_ids: List = [] + + def get_running_replica_info( + self, cluster_node_info_cache: ClusterNodeInfoCache + ) -> RunningReplicaInfo: + return RunningReplicaInfo( + replica_id=self._replica_id, + node_id=self.actor_node_id, + node_ip=self._actor.node_ip, + availability_zone=cluster_node_info_cache.get_node_az(self.actor_node_id), + actor_handle=self._actor.actor_handle, + max_ongoing_requests=self._actor.max_ongoing_requests, + is_cross_language=self._actor.is_cross_language, + multiplexed_model_ids=self.multiplexed_model_ids, + port=self._actor._port, + ) + + def record_multiplexed_model_ids(self, multiplexed_model_ids: List[str]): + """Record the multiplexed model ids for this replica.""" + self._multiplexed_model_ids = multiplexed_model_ids + + @property + def multiplexed_model_ids(self) -> List[str]: + return self._multiplexed_model_ids + + @property + def actor_details(self) -> ReplicaDetails: + return self._actor_details + + @property + def replica_id(self) -> ReplicaID: + return self._replica_id + + @property + def deployment_name(self) -> str: + return self._replica_id.deployment_id.name + + @property + def app_name(self) -> str: + return self._replica_id.deployment_id.app_name + + @property + def version(self): + return self._actor.version + + @property + def actor_id(self) -> str: + return self._actor.actor_id + + @property + def actor_handle(self) -> ActorHandle: + return self._actor.actor_handle + + @property + def actor_node_id(self) -> Optional[str]: + """Returns the node id of the actor, None if not placed.""" + return self._actor.node_id + + @property + def actor_pid(self) -> Optional[int]: + """Returns the node id of the actor, None if not placed.""" + return self._actor.pid + + @property + def initialization_latency_s(self) -> Optional[float]: + """Returns how long the replica took to initialize.""" + + return self._actor.initialization_latency_s + + def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: + """ + Start a new actor for current DeploymentReplica instance. + """ + replica_scheduling_request = self._actor.start(deployment_info) + self._start_time = time.time() + self.update_actor_details(start_time_s=self._start_time) + return replica_scheduling_request + + def reconfigure(self, version: DeploymentVersion) -> bool: + """ + Update replica version. Also, updates the deployment config on the actor + behind this DeploymentReplica instance if necessary. + + Returns: whether the actor is being updated. + """ + return self._actor.reconfigure(version) + + def recover(self) -> bool: + """ + Recover states in DeploymentReplica instance by fetching running actor + status + + Returns: False if the replica is no longer alive at the time + when this method is called. + """ + # If replica is no longer alive + if not self._actor.recover(): + return False + + self._start_time = time.time() + self.update_actor_details(start_time_s=self._start_time) + return True + + def check_started( + self, + ) -> Tuple[ReplicaStartupStatus, Optional[str], Optional[float]]: + """Check if the replica has started. If so, transition to RUNNING. + + Should handle the case where the replica has already stopped. + + Returns: + status: Most recent state of replica by + querying actor obj ref + """ + is_ready = self._actor.check_ready() + self.update_actor_details( + pid=self._actor.pid, + node_id=self._actor.node_id, + node_ip=self._actor.node_ip, + actor_id=self._actor.actor_id, + worker_id=self._actor.worker_id, + log_file_path=self._actor.log_file_path, + ) + + return is_ready + + def stop(self, graceful: bool = True) -> None: + """Stop the replica. + + Should handle the case where the replica is already stopped. + """ + state = self._actor_details.state + logger.info( + f"Stopping {self.replica_id} (currently {state}).", + extra={"log_to_stderr": False}, + ) + timeout_s = self._actor.graceful_stop() + if not graceful: + timeout_s = 0 + self._shutdown_deadline = time.time() + timeout_s + + def check_stopped(self) -> bool: + """Check if the replica has finished stopping.""" + if self._actor.check_stopped(): + return True + + timeout_passed = time.time() >= self._shutdown_deadline + if timeout_passed: + # Graceful period passed, kill it forcefully. + # This will be called repeatedly until the replica shuts down. + logger.info( + f"{self.replica_id} did not shut down after grace " + "period, force-killing it. " + ) + + self._actor.force_stop() + return False + + def check_health(self) -> bool: + """Check if the replica is healthy. + + Returns `True` if the replica is healthy, else `False`. + """ + return self._actor.check_health() + + def update_state(self, state: ReplicaState) -> None: + """Updates state in actor details.""" + self.update_actor_details(state=state) + + def update_actor_details(self, **kwargs) -> None: + details_kwargs = self._actor_details.dict() + details_kwargs.update(kwargs) + self._actor_details = ReplicaDetails(**details_kwargs) + + def resource_requirements(self) -> Tuple[str, str]: + """Returns required and currently available resources. + + Only resources with nonzero requirements will be included in the + required dict and only resources in the required dict will be + included in the available dict (filtered for relevance). + """ + if self._actor.actor_resources is None: + return "UNKNOWN", "UNKNOWN" + + if self._actor.placement_group_bundles is not None: + required = self._actor.placement_group_bundles + else: + required = { + k: v + for k, v in self._actor.actor_resources.items() + if v is not None and v > 0 + } + + available = { + k: v for k, v in self._actor.available_resources.items() if k in required + } + + # Use json.dumps() instead of str() here to avoid double-quoting keys + # when dumping these objects. See + # https://github.com/ray-project/ray/issues/26210 for the issue. + return json.dumps(required), json.dumps(available) + + +class ReplicaStateContainer: + """Container for mapping ReplicaStates to lists of DeploymentReplicas.""" + + def __init__(self): + self._replicas: Dict[ReplicaState, List[DeploymentReplica]] = defaultdict(list) + + def add(self, state: ReplicaState, replica: DeploymentReplica): + """Add the provided replica under the provided state. + + Args: + state: state to add the replica under. + replica: replica to add. + """ + assert isinstance(state, ReplicaState), f"Type: {type(state)}" + replica.update_state(state) + self._replicas[state].append(replica) + + def get( + self, states: Optional[List[ReplicaState]] = None + ) -> List[DeploymentReplica]: + """Get all replicas of the given states. + + This does not remove them from the container. Replicas are returned + in order of state as passed in. + + Args: + states: states to consider. If not specified, all replicas + are considered. + """ + if states is None: + states = ALL_REPLICA_STATES + + assert isinstance(states, list) + + return sum((self._replicas[state] for state in states), []) + + def pop( + self, + exclude_version: Optional[DeploymentVersion] = None, + states: Optional[List[ReplicaState]] = None, + max_replicas: Optional[int] = math.inf, + ) -> List[DeploymentReplica]: + """Get and remove all replicas of the given states. + + This removes the replicas from the container. Replicas are returned + in order of state as passed in. + + Args: + exclude_version: if specified, replicas of the + provided version will *not* be removed. + states: states to consider. If not specified, all replicas + are considered. + max_replicas: max number of replicas to return. If not + specified, will pop all replicas matching the criteria. + """ + if states is None: + states = ALL_REPLICA_STATES + + assert exclude_version is None or isinstance(exclude_version, DeploymentVersion) + assert isinstance(states, list) + + replicas = [] + for state in states: + popped = [] + remaining = [] + + for replica in self._replicas[state]: + if len(replicas) + len(popped) == max_replicas: + remaining.append(replica) + elif exclude_version is not None and replica.version == exclude_version: + remaining.append(replica) + else: + popped.append(replica) + + self._replicas[state] = remaining + replicas.extend(popped) + + return replicas + + def count( + self, + exclude_version: Optional[DeploymentVersion] = None, + version: Optional[DeploymentVersion] = None, + states: Optional[List[ReplicaState]] = None, + ): + """Get the total count of replicas of the given states. + + Args: + exclude_version: version to exclude. If not + specified, all versions are considered. + version: version to filter to. If not specified, + all versions are considered. + states: states to consider. If not specified, all replicas + are considered. + """ + if states is None: + states = ALL_REPLICA_STATES + assert isinstance(states, list) + assert exclude_version is None or isinstance(exclude_version, DeploymentVersion) + assert version is None or isinstance(version, DeploymentVersion) + if exclude_version is None and version is None: + return sum(len(self._replicas[state]) for state in states) + elif exclude_version is None and version is not None: + return sum( + len(list(filter(lambda r: r.version == version, self._replicas[state]))) + for state in states + ) + elif exclude_version is not None and version is None: + return sum( + len( + list( + filter( + lambda r: r.version != exclude_version, + self._replicas[state], + ) + ) + ) + for state in states + ) + else: + raise ValueError( + "Only one of `version` or `exclude_version` may be provided." + ) + + def __str__(self): + return str(self._replicas) + + def __repr__(self): + return repr(self._replicas) + + +class DeploymentState: + """Manages the target state and replicas for a single deployment.""" + + FORCE_STOP_UNHEALTHY_REPLICAS = RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS + + def __init__( + self, + id: DeploymentID, + long_poll_host: LongPollHost, + deployment_scheduler: DeploymentScheduler, + cluster_node_info_cache: ClusterNodeInfoCache, + autoscaling_state_manager: AutoscalingStateManager, + _save_checkpoint_func: Callable, + ): + self._id = id + self._long_poll_host: LongPollHost = long_poll_host + self._deployment_scheduler = deployment_scheduler + self._cluster_node_info_cache = cluster_node_info_cache + self._autoscaling_state_manager = autoscaling_state_manager + self._save_checkpoint_func = _save_checkpoint_func + + # Each time we set a new deployment goal, we're trying to save new + # DeploymentInfo and bring current deployment to meet new status. + self._target_state: DeploymentTargetState = DeploymentTargetState.default() + + self._prev_startup_warning: float = time.time() + self._replica_constructor_retry_counter: int = 0 + self._replica_constructor_error_msg: Optional[str] = None + + self._replicas: ReplicaStateContainer = ReplicaStateContainer() + self._curr_status_info: DeploymentStatusInfo = DeploymentStatusInfo( + self._id.name, + DeploymentStatus.UPDATING, + DeploymentStatusTrigger.CONFIG_UPDATE_STARTED, + ) + + self.replica_average_ongoing_requests: Dict[str, float] = dict() + + self.health_check_gauge = metrics.Gauge( + "serve_deployment_replica_healthy", + description=( + "Tracks whether this deployment replica is healthy. 1 means " + "healthy, 0 means unhealthy." + ), + tag_keys=("deployment", "replica", "application"), + ) + + # Whether the multiplexed model ids have been updated since the last + # time we checked. + self._multiplexed_model_ids_updated = False + + self._last_broadcasted_running_replica_infos: List[RunningReplicaInfo] = [] + self._last_broadcasted_availability: bool = True + self._last_broadcasted_deployment_config = None + + def should_autoscale(self) -> bool: + """ + Check if the deployment is under autoscaling + """ + return self._id in self._autoscaling_state_manager._autoscaling_states + + def get_checkpoint_data(self) -> DeploymentTargetState: + """ + Return deployment's target state submitted by user's deployment call. + Should be persisted and outlive current ray cluster. + """ + return self._target_state + + def recover_target_state_from_checkpoint( + self, target_state_checkpoint: DeploymentTargetState + ): + logger.info(f"Recovering target state for {self._id} from checkpoint.") + self._target_state = target_state_checkpoint + self._deployment_scheduler.on_deployment_deployed( + self._id, self._target_state.info.replica_config + ) + if self._target_state.info.deployment_config.autoscaling_config: + self._autoscaling_state_manager.register_deployment( + self._id, + self._target_state.info, + self._target_state.target_num_replicas, + ) + + def recover_current_state_from_replica_actor_names( + self, replica_actor_names: List[str] + ): + """Recover deployment state from live replica actors found in the cluster.""" + + assert self._target_state is not None, ( + "Target state should be recovered successfully first before " + "recovering current state from replica actor names." + ) + logger.info( + f"Recovering current state for {self._id} " + f"from {len(replica_actor_names)} live actors." + ) + # All current states use default value, only attach running replicas. + for replica_actor_name in replica_actor_names: + replica_id = ReplicaID.from_full_id_str(replica_actor_name) + new_deployment_replica = DeploymentReplica( + replica_id, + self._target_state.version, + ) + # If replica is no longer alive, simply don't add it to the + # deployment state manager to track. + if not new_deployment_replica.recover(): + logger.warning(f"{replica_id} died before controller could recover it.") + continue + + self._replicas.add(ReplicaState.RECOVERING, new_deployment_replica) + self._deployment_scheduler.on_replica_recovering(replica_id) + logger.debug(f"RECOVERING {replica_id}.") + + # TODO(jiaodong): this currently halts all traffic in the cluster + # briefly because we will broadcast a replica update with everything in + # RECOVERING. We should have a grace period where we recover the state + # of the replicas before doing this update. + + @property + def target_info(self) -> DeploymentInfo: + return self._target_state.info + + @property + def target_version(self) -> DeploymentVersion: + return self._target_state.version + + @property + def target_num_replicas(self) -> int: + return self._target_state.target_num_replicas + + @property + def curr_status_info(self) -> DeploymentStatusInfo: + return self._curr_status_info + + @property + def deployment_name(self) -> str: + return self._id.name + + @property + def app_name(self) -> str: + return self._id.app_name + + @property + def _failed_to_start_threshold(self) -> int: + return min( + MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT, + self._target_state.target_num_replicas * 3, + ) + + @property + def is_failed(self) -> bool: + """Whether the deployment failed to deploy.""" + return self._curr_status_info.status == DeploymentStatus.DEPLOY_FAILED + + def get_alive_replica_actor_ids(self) -> Set[str]: + return {replica.actor_id for replica in self._replicas.get()} + + def get_running_replica_ids(self) -> List[ReplicaID]: + return [ + replica.replica_id + for replica in self._replicas.get( + [ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION] + ) + ] + + def get_running_replica_infos(self) -> List[RunningReplicaInfo]: + return [ + replica.get_running_replica_info(self._cluster_node_info_cache) + for replica in self._replicas.get( + [ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION] + ) + ] + + def get_num_running_replicas(self, version: DeploymentVersion = None) -> int: + return self._replicas.count(states=[ReplicaState.RUNNING], version=version) + + def get_active_node_ids(self) -> Set[str]: + """Get the node ids of all running replicas in this deployment. + + This is used to determine which node has replicas. Only nodes with replicas and + head node should have active proxies. + """ + active_states = [ + ReplicaState.STARTING, + ReplicaState.UPDATING, + ReplicaState.RECOVERING, + ReplicaState.RUNNING, + # NOTE(zcin): We still want a proxy to run on a draining + # node before all the replicas are migrated. + ReplicaState.PENDING_MIGRATION, + ] + return { + replica.actor_node_id + for replica in self._replicas.get(active_states) + if replica.actor_node_id is not None + } + + def list_replica_details(self) -> List[ReplicaDetails]: + return [replica.actor_details for replica in self._replicas.get()] + + def broadcast_running_replicas_if_changed(self) -> None: + """Broadcasts the set of running replicas over long poll if it has changed. + + Keeps an in-memory record of the last set of running replicas that was broadcast + to determine if it has changed. + + The set will also be broadcast if any replicas have an updated set of + multiplexed model IDs. + """ + running_replica_infos = self.get_running_replica_infos() + is_available = not self.is_failed + + running_replicas_changed = ( + set(self._last_broadcasted_running_replica_infos) + != set(running_replica_infos) + or self._multiplexed_model_ids_updated + ) + availability_changed = is_available != self._last_broadcasted_availability + if not running_replicas_changed and not availability_changed: + return + + deployment_metadata = DeploymentTargetInfo( + is_available=is_available, + running_replicas=running_replica_infos, + ) + self._long_poll_host.notify_changed( + { + ( + LongPollNamespace.DEPLOYMENT_TARGETS, + self._id, + ): deployment_metadata, + # NOTE(zcin): notify changed for Java routers. Since Java only + # supports 1.x API, there is no concept of applications in Java, + # so the key should remain a string describing the deployment + # name. If there are no Java routers, this is a no-op. + ( + LongPollNamespace.DEPLOYMENT_TARGETS, + self._id.name, + ): deployment_metadata, + } + ) + self._last_broadcasted_running_replica_infos = running_replica_infos + self._last_broadcasted_availability = is_available + self._multiplexed_model_ids_updated = False + + def broadcast_deployment_config_if_changed(self) -> None: + """Broadcasts the deployment config over long poll if it has changed. + + Keeps an in-memory record of the last config that was broadcast to determine + if it has changed. + """ + current_deployment_config = self._target_state.info.deployment_config + if self._last_broadcasted_deployment_config == current_deployment_config: + return + + self._long_poll_host.notify_changed( + {(LongPollNamespace.DEPLOYMENT_CONFIG, self._id): current_deployment_config} + ) + + self._last_broadcasted_deployment_config = current_deployment_config + + def _set_target_state_deleting(self) -> None: + """Set the target state for the deployment to be deleted.""" + + # We must write ahead the target state in case of GCS failure (we don't + # want to set the target state, then fail because we can't checkpoint it). + target_state = DeploymentTargetState.create( + info=self._target_state.info, + target_num_replicas=0, + deleting=True, + ) + self._save_checkpoint_func(writeahead_checkpoints={self._id: target_state}) + + self._target_state = target_state + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.DELETE + ) + logger.info( + f"Deleting {self._id}", + extra={"log_to_stderr": False}, + ) + + def _set_target_state( + self, + target_info: DeploymentInfo, + target_num_replicas: int, + ) -> None: + """Set the target state for the deployment to the provided info. + + Args: + target_info: The info with which to set the target state. + target_num_replicas: The number of replicas that this deployment + should attempt to run. + status_trigger: The driver that triggered this change of state. + """ + + # We must write ahead the target state in case of GCS failure (we don't + # want to set the target state, then fail because we can't checkpoint it). + new_target_state = DeploymentTargetState.create( + target_info, target_num_replicas, deleting=False + ) + self._save_checkpoint_func(writeahead_checkpoints={self._id: new_target_state}) + + if self._target_state.version == new_target_state.version: + # Record either num replica or autoscaling config lightweight update + if ( + self._target_state.version.deployment_config.autoscaling_config + != new_target_state.version.deployment_config.autoscaling_config + ): + ServeUsageTag.AUTOSCALING_CONFIG_LIGHTWEIGHT_UPDATED.record("True") + elif ( + self._target_state.version.deployment_config.num_replicas + != new_target_state.version.deployment_config.num_replicas + ): + ServeUsageTag.NUM_REPLICAS_LIGHTWEIGHT_UPDATED.record("True") + + self._target_state = new_target_state + + def deploy(self, deployment_info: DeploymentInfo) -> bool: + """Deploy the deployment. + + If the deployment already exists with the same version, config, + target_capacity, and target_capacity_direction, + this method returns False. + + Returns: + bool: Whether or not the deployment is being updated. + """ + + curr_deployment_info = self._target_state.info + if curr_deployment_info is not None: + # Redeploying should not reset the deployment's start time. + if not self._target_state.deleting: + deployment_info.start_time_ms = curr_deployment_info.start_time_ms + + deployment_settings_changed = ( + self._target_state.deleting + or curr_deployment_info.deployment_config + != deployment_info.deployment_config + or curr_deployment_info.replica_config.ray_actor_options + != deployment_info.replica_config.ray_actor_options + or deployment_info.version is None + or curr_deployment_info.version != deployment_info.version + ) + target_capacity_changed = ( + curr_deployment_info.target_capacity != deployment_info.target_capacity + or curr_deployment_info.target_capacity_direction + != deployment_info.target_capacity_direction + ) + else: + deployment_settings_changed = True + target_capacity_changed = True + + # Exit early if the deployment info hasn't changed. Ensures this method + # is idempotent. + if not deployment_settings_changed and not target_capacity_changed: + return False + + if deployment_info.deployment_config.autoscaling_config: + target_num_replicas = self._autoscaling_state_manager.register_deployment( + self._id, deployment_info, self._target_state.target_num_replicas + ) + else: + self._autoscaling_state_manager.deregister_deployment(self._id) + target_num_replicas = get_capacity_adjusted_num_replicas( + deployment_info.deployment_config.num_replicas, + deployment_info.target_capacity, + ) + + old_target_state = self._target_state + self._set_target_state(deployment_info, target_num_replicas=target_num_replicas) + self._deployment_scheduler.on_deployment_deployed( + self._id, deployment_info.replica_config + ) + + # Determine if the updated target state simply scales the current state. + if self._target_state.is_scaled_copy_of(old_target_state): + old_num = old_target_state.target_num_replicas + new_num = self._target_state.target_num_replicas + + if new_num > old_num: + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.MANUALLY_INCREASE_NUM_REPLICAS, # noqa: E501 + message=f"Upscaling from {old_num} to {new_num} replicas.", + ) + elif new_num < old_num: + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.MANUALLY_DECREASE_NUM_REPLICAS, # noqa: E501 + message=f"Downscaling from {old_num} to {new_num} replicas.", + ) + else: + # Otherwise, the deployment configuration has actually been updated. + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.CONFIG_UPDATE + ) + + logger.info( + f"Deploying new version of {self._id} " + f"(initial target replicas: {target_num_replicas})." + ) + self._replica_constructor_retry_counter = 0 + return True + + def autoscale(self) -> int: + """Autoscale the deployment based on metrics.""" + + if self._target_state.deleting: + return + + decision_num_replicas = self._autoscaling_state_manager.get_target_num_replicas( + deployment_id=self._id, + curr_target_num_replicas=self._target_state.target_num_replicas, + ) + + if ( + decision_num_replicas is None + or decision_num_replicas == self._target_state.target_num_replicas + ): + return + + new_info = copy(self._target_state.info) + new_info.version = self._target_state.version.code_version + + old_num = self._target_state.target_num_replicas + self._set_target_state(new_info, decision_num_replicas) + + # The deployment should only transition to UPSCALING/DOWNSCALING + # if it's within the autoscaling bounds + if not self._autoscaling_state_manager.is_within_bounds( + self._id, + self._replicas.count( + states=[ReplicaState.RUNNING], version=self._target_state.version + ), + ): + return + + curr_stats_str = ( + f"Current ongoing requests: " + f"{self._autoscaling_state_manager.get_total_num_requests(self._id):.2f}, " + f"current running replicas: " + f"{self._replicas.count(states=[ReplicaState.RUNNING])}." + ) + new_num = self._target_state.target_num_replicas + if new_num > old_num: + logger.info( + f"Upscaling {self._id} from {old_num} to {new_num} replicas. " + f"{curr_stats_str}" + ) + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.AUTOSCALE_UP, + message=f"Upscaling from {old_num} to {new_num} replicas.", + ) + elif new_num < old_num: + logger.info( + f"Downscaling {self._id} from {old_num} to {new_num} replicas. " + f"{curr_stats_str}" + ) + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.AUTOSCALE_DOWN, + message=f"Downscaling from {old_num} to {new_num} replicas.", + ) + + def delete(self) -> None: + if not self._target_state.deleting: + self._set_target_state_deleting() + + def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> bool: + """Stop or update replicas with outdated versions. + + Stop replicas with versions that require the actor to be restarted, and + reconfigure replicas that require refreshing deployment config values. + + Args: + max_to_stop: max number of replicas to stop, by default, + it will stop all replicas with outdated version. + """ + replicas_to_update = self._replicas.pop( + exclude_version=self._target_state.version, + states=[ + ReplicaState.STARTING, + ReplicaState.PENDING_MIGRATION, + ReplicaState.RUNNING, + ], + ) + replicas_changed = False + code_version_changes = 0 + reconfigure_changes = 0 + for replica in replicas_to_update: + if (code_version_changes + reconfigure_changes) >= max_to_stop: + self._replicas.add(replica.actor_details.state, replica) + # If the new version requires the actors to be restarted, stop the replica. + # A new one with the correct version will be started later as part of the + # normal scale-up process. + elif replica.version.requires_actor_restart(self._target_state.version): + code_version_changes += 1 + # If the replica is still `STARTING`, we don't need to go through the + # graceful stop period. + graceful_stop = replica.actor_details.state == ReplicaState.RUNNING + self._stop_replica(replica, graceful_stop=graceful_stop) + replicas_changed = True + # Otherwise, only lightweight options in deployment config is a mismatch, so + # we update it dynamically without restarting the replica. + elif replica.actor_details.state == ReplicaState.RUNNING: + reconfigure_changes += 1 + if replica.version.requires_long_poll_broadcast( + self._target_state.version + ): + replicas_changed = True + actor_updating = replica.reconfigure(self._target_state.version) + if actor_updating: + self._replicas.add(ReplicaState.UPDATING, replica) + else: + self._replicas.add(ReplicaState.RUNNING, replica) + # We don't allow going from STARTING, PENDING_MIGRATION to UPDATING. + else: + self._replicas.add(replica.actor_details.state, replica) + + if code_version_changes > 0: + logger.info( + f"Stopping {code_version_changes} replicas of {self._id} " + "with outdated versions." + ) + + if reconfigure_changes > 0: + logger.info( + f"Updating {reconfigure_changes} replicas of {self._id} " + "with outdated deployment configs." + ) + # Record user config lightweight update + ServeUsageTag.USER_CONFIG_LIGHTWEIGHT_UPDATED.record("True") + + return replicas_changed + + def _check_and_stop_outdated_version_replicas(self) -> bool: + """Stops replicas with outdated versions to implement rolling updates. + + This includes both explicit code version updates and changes to the + user_config. + + Returns whether any replicas were stopped. + """ + # Short circuit if target replicas is 0 (the deployment is being + # deleted) because this will be handled in the main loop. + if self._target_state.target_num_replicas == 0: + return False + + # We include STARTING and UPDATING replicas here + # because if there are replicas still pending startup, we may as well + # terminate them and start new version replicas instead. + old_running_replicas = self._replicas.count( + exclude_version=self._target_state.version, + states=[ + ReplicaState.STARTING, + ReplicaState.UPDATING, + ReplicaState.RUNNING, + ], + ) + old_stopping_replicas = self._replicas.count( + exclude_version=self._target_state.version, states=[ReplicaState.STOPPING] + ) + new_running_replicas = self._replicas.count( + version=self._target_state.version, states=[ReplicaState.RUNNING] + ) + + # If the deployment is currently scaling down, let the scale down + # complete before doing a rolling update. + if ( + self._target_state.target_num_replicas + < old_running_replicas + old_stopping_replicas + ): + return False + + # The number of replicas that are currently in transition between + # an old version and the new version. Note that we cannot directly + # count the number of stopping replicas because once replicas finish + # stopping, they are removed from the data structure. + pending_replicas = ( + self._target_state.target_num_replicas + - new_running_replicas + - old_running_replicas + ) + + # Maximum number of replicas that can be updating at any given time. + # There should never be more than rollout_size old replicas stopping + # or rollout_size new replicas starting. + rollout_size = max(int(0.2 * self._target_state.target_num_replicas), 1) + max_to_stop = max(rollout_size - pending_replicas, 0) + + return self._stop_or_update_outdated_version_replicas(max_to_stop) + + def scale_deployment_replicas( + self, + ) -> Tuple[List[ReplicaSchedulingRequest], DeploymentDownscaleRequest]: + """Scale the given deployment to the number of replicas.""" + + assert ( + self._target_state.target_num_replicas >= 0 + ), "Target number of replicas must be greater than or equal to 0." + + upscale = [] + downscale = None + + self._check_and_stop_outdated_version_replicas() + + current_replicas = self._replicas.count( + states=[ReplicaState.STARTING, ReplicaState.UPDATING, ReplicaState.RUNNING] + ) + recovering_replicas = self._replicas.count(states=[ReplicaState.RECOVERING]) + + delta_replicas = ( + self._target_state.target_num_replicas + - current_replicas + - recovering_replicas + ) + if delta_replicas == 0: + return (upscale, downscale) + + elif delta_replicas > 0: + to_add = delta_replicas + if not RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS: + # Don't ever exceed target_num_replicas. + stopping_replicas = self._replicas.count(states=[ReplicaState.STOPPING]) + to_add = max(delta_replicas - stopping_replicas, 0) + + if ( + to_add > 0 + and self._replica_constructor_retry_counter + < self._failed_to_start_threshold + ): + logger.info(f"Adding {to_add} replica{'s' * (to_add>1)} to {self._id}.") + for _ in range(to_add): + replica_id = ReplicaID(get_random_string(), deployment_id=self._id) + new_deployment_replica = DeploymentReplica( + replica_id, + self._target_state.version, + ) + upscale.append( + new_deployment_replica.start(self._target_state.info) + ) + + self._replicas.add(ReplicaState.STARTING, new_deployment_replica) + + elif delta_replicas < 0: + to_remove = -delta_replicas + removed_replicas = f"{to_remove} replica{'s' if to_remove > 1 else ''}" + logger.info(f"Removing {removed_replicas} from {self._id}.") + downscale = DeploymentDownscaleRequest( + deployment_id=self._id, num_to_stop=to_remove + ) + + return upscale, downscale + + def check_curr_status(self) -> Tuple[bool, bool]: + """Check the current deployment status. + + Checks the difference between the target vs. running replica count for + the target version. + + This will update the current deployment status depending on the state + of the replicas. + + Returns (deleted, any_replicas_recovering). + """ + # TODO(edoakes): we could make this more efficient in steady-state by + # having a "healthy" flag that gets flipped if an update or replica + # failure happens. + + target_version = self._target_state.version + + any_replicas_recovering = ( + self._replicas.count(states=[ReplicaState.RECOVERING]) > 0 + ) + all_running_replica_cnt = self._replicas.count(states=[ReplicaState.RUNNING]) + running_at_target_version_replica_cnt = self._replicas.count( + states=[ReplicaState.RUNNING], version=target_version + ) + + failed_to_start_count = self._replica_constructor_retry_counter + + # Got to make a call to complete current deploy() goal after + # start failure threshold reached, while we might still have + # pending replicas in current goal. + if ( + failed_to_start_count >= self._failed_to_start_threshold + and self._failed_to_start_threshold != 0 + ): + if running_at_target_version_replica_cnt > 0: + # At least one RUNNING replica at target state, partial + # success; We can stop tracking constructor failures and + # leave it to the controller to fully scale to target + # number of replicas and only return as completed once + # reached target replica count + self._replica_constructor_retry_counter = -1 + else: + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.REPLICA_STARTUP_FAILED, + message=( + f"The deployment failed to start {failed_to_start_count} times " + "in a row. This may be due to a problem with its " + "constructor or initial health check failing. See " + "controller logs for details. Error:\n" + f"{self._replica_constructor_error_msg}" + ), + ) + return False, any_replicas_recovering + + # If we have pending ops, the current goal is *not* ready. + if ( + self._replicas.count( + states=[ + ReplicaState.STARTING, + ReplicaState.UPDATING, + ReplicaState.RECOVERING, + ReplicaState.STOPPING, + ] + ) + == 0 + ): + # Check for deleting and a non-zero number of deployments. + if self._target_state.deleting and all_running_replica_cnt == 0: + return True, any_replicas_recovering + + if ( + self._target_state.target_num_replicas + == running_at_target_version_replica_cnt + and running_at_target_version_replica_cnt == all_running_replica_cnt + ): + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.HEALTHY + ) + return False, any_replicas_recovering + + return False, any_replicas_recovering + + def _check_startup_replicas( + self, original_state: ReplicaState, stop_on_slow=False + ) -> List[Tuple[DeploymentReplica, ReplicaStartupStatus]]: + """ + Common helper function for startup actions tracking and status + transition: STARTING, UPDATING and RECOVERING. + + Args: + stop_on_slow: If we consider a replica failed upon observing it's + slow to reach running state. + """ + slow_replicas = [] + for replica in self._replicas.pop(states=[original_state]): + start_status, error_msg = replica.check_started() + if start_status == ReplicaStartupStatus.SUCCEEDED: + # This replica should be now be added to handle's replica + # set. + self._replicas.add(ReplicaState.RUNNING, replica) + self._deployment_scheduler.on_replica_running( + replica.replica_id, replica.actor_node_id + ) + + # Log the startup latency. + e2e_replica_start_latency = time.time() - replica._start_time + replica_startup_message = ( + f"{replica.replica_id} started successfully " + f"on node '{replica.actor_node_id}' after " + f"{e2e_replica_start_latency:.1f}s (PID: {replica.actor_pid})." + ) + if replica.initialization_latency_s is not None: + # This condition should always be True. The initialization + # latency is only None before the replica has initialized. + replica_startup_message += ( + " Replica constructor, " + "reconfigure method, and initial health check took " + f"{replica.initialization_latency_s:.1f}s." + ) + logger.info(replica_startup_message, extra={"log_to_stderr": False}) + + elif start_status == ReplicaStartupStatus.FAILED: + # Replica reconfigure (deploy / upgrade) failed + self.record_replica_startup_failure(error_msg) + self._stop_replica(replica) + elif start_status in [ + ReplicaStartupStatus.PENDING_ALLOCATION, + ReplicaStartupStatus.PENDING_INITIALIZATION, + ]: + is_slow = time.time() - replica._start_time > SLOW_STARTUP_WARNING_S + if is_slow: + slow_replicas.append((replica, start_status)) + + # Does it make sense to stop replicas in PENDING_ALLOCATION + # state? + if is_slow and stop_on_slow: + self._stop_replica(replica, graceful_stop=False) + else: + self._replicas.add(original_state, replica) + + return slow_replicas + + def record_replica_startup_failure(self, error_msg: str): + """Record that a replica failed to start.""" + + if self._replica_constructor_retry_counter >= 0: + # Increase startup failure counter if we're tracking it + self._replica_constructor_retry_counter += 1 + self._replica_constructor_error_msg = error_msg + + retrying_msg = "Retrying" + if self._failed_to_start_threshold != 0: + remaining_retries = max( + self._failed_to_start_threshold + - self._replica_constructor_retry_counter, + 0, + ) + retrying_msg += f" {remaining_retries} more time(s)" + + message = ( + f"A replica failed to start with exception. {retrying_msg}. Error:\n" + f"{error_msg}" + ) + self._curr_status_info = self._curr_status_info.update_message(message) + + def stop_replicas(self, replicas_to_stop) -> None: + for replica in self._replicas.pop(): + if replica.replica_id in replicas_to_stop: + self._stop_replica(replica) + else: + self._replicas.add(replica.actor_details.state, replica) + + def _stop_replica(self, replica: DeploymentReplica, graceful_stop=True): + """Stop replica + 1. Stop the replica. + 2. Change the replica into stopping state. + 3. Set the health replica stats to 0. + """ + logger.debug(f"Adding STOPPING to replica: {replica.replica_id}.") + replica.stop(graceful=graceful_stop) + self._replicas.add(ReplicaState.STOPPING, replica) + self._deployment_scheduler.on_replica_stopping(replica.replica_id) + self.health_check_gauge.set( + 0, + tags={ + "deployment": self.deployment_name, + "replica": replica.replica_id.unique_id, + "application": self.app_name, + }, + ) + + def check_and_update_replicas(self): + """ + Check current state of all DeploymentReplica being tracked, and compare + with state container from previous update() cycle to see if any state + transition happened. + """ + + for replica in self._replicas.pop( + states=[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION] + ): + if replica.check_health(): + self._replicas.add(replica.actor_details.state, replica) + self.health_check_gauge.set( + 1, + tags={ + "deployment": self.deployment_name, + "replica": replica.replica_id.unique_id, + "application": self.app_name, + }, + ) + else: + logger.warning( + f"Replica {replica.replica_id} failed health check, stopping it." + ) + self.health_check_gauge.set( + 0, + tags={ + "deployment": self.deployment_name, + "replica": replica.replica_id.unique_id, + "application": self.app_name, + }, + ) + self._stop_replica( + replica, graceful_stop=not self.FORCE_STOP_UNHEALTHY_REPLICAS + ) + # If this is a replica of the target version, the deployment + # enters the "UNHEALTHY" status until the replica is + # recovered or a new deploy happens. + if replica.version == self._target_state.version: + self._curr_status_info = self._curr_status_info.handle_transition( + trigger=DeploymentStatusInternalTrigger.HEALTH_CHECK_FAILED, + message="A replica's health check failed. This " + "deployment will be UNHEALTHY until the replica " + "recovers or a new deploy happens.", + ) + + slow_start_replicas = [] + slow_start = self._check_startup_replicas(ReplicaState.STARTING) + slow_update = self._check_startup_replicas(ReplicaState.UPDATING) + slow_recover = self._check_startup_replicas( + ReplicaState.RECOVERING, stop_on_slow=True + ) + + slow_start_replicas = slow_start + slow_update + slow_recover + + if ( + len(slow_start_replicas) + and time.time() - self._prev_startup_warning > SLOW_STARTUP_WARNING_PERIOD_S + ): + pending_allocation = [] + pending_initialization = [] + + for replica, startup_status in slow_start_replicas: + if startup_status == ReplicaStartupStatus.PENDING_ALLOCATION: + pending_allocation.append(replica) + if startup_status == ReplicaStartupStatus.PENDING_INITIALIZATION: + pending_initialization.append(replica) + + if len(pending_allocation) > 0: + required, available = pending_allocation[0].resource_requirements() + message = ( + f"Deployment '{self.deployment_name}' in application " + f"'{self.app_name}' has {len(pending_allocation)} replicas that " + f"have taken more than {SLOW_STARTUP_WARNING_S}s to be scheduled. " + "This may be due to waiting for the cluster to auto-scale or for a " + "runtime environment to be installed. " + f"Resources required for each replica: {required}, " + f"total resources available: {available}. " + "Use `ray status` for more details." + ) + logger.warning(message) + if _SCALING_LOG_ENABLED: + print_verbose_scaling_log() + # If status is UNHEALTHY, leave the status and message as is. + # The issue that caused the deployment to be unhealthy should be + # prioritized over this resource availability issue. + if self._curr_status_info.status not in [ + DeploymentStatus.UNHEALTHY, + DeploymentStatus.DEPLOY_FAILED, + ]: + self._curr_status_info = self._curr_status_info.update_message( + message + ) + + if len(pending_initialization) > 0: + message = ( + f"Deployment '{self.deployment_name}' in application " + f"'{self.app_name}' has {len(pending_initialization)} replicas " + f"that have taken more than {SLOW_STARTUP_WARNING_S}s to " + "initialize.\n" + "This may be caused by a slow __init__ or reconfigure method." + ) + logger.warning(message) + # If status is UNHEALTHY, leave the status and message as is. + # The issue that caused the deployment to be unhealthy should be + # prioritized over this resource availability issue. + if self._curr_status_info.status not in [ + DeploymentStatus.UNHEALTHY, + DeploymentStatus.DEPLOY_FAILED, + ]: + self._curr_status_info = self._curr_status_info.update_message( + message + ) + + self._prev_startup_warning = time.time() + + for replica in self._replicas.pop(states=[ReplicaState.STOPPING]): + stopped = replica.check_stopped() + if not stopped: + self._replicas.add(ReplicaState.STOPPING, replica) + else: + logger.info(f"{replica.replica_id} is stopped.") + self._autoscaling_state_manager.on_replica_stopped(replica.replica_id) + + def _choose_pending_migration_replicas_to_stop( + self, + replicas: List[DeploymentReplica], + deadlines: Dict[str, int], + min_replicas_to_stop: int, + ) -> Tuple[List[DeploymentReplica], List[DeploymentReplica]]: + """Returns a partition of replicas to stop and to keep. + + Args: + replicas: The current list of replicas pending migration. + deadlines: The current draining node deadlines. + min_replicas_to_stop: The minimum number of replicas to stop. + """ + to_stop = list() + remaining = list() + + # Stop replicas whose deadline is up + for replica in replicas: + assert replica.actor_node_id in deadlines + + curr_timestamp_ms = time.time() * 1000 + timeout_ms = replica._actor.graceful_shutdown_timeout_s * 1000 + if curr_timestamp_ms >= deadlines[replica.actor_node_id] - timeout_ms: + to_stop.append(replica) + else: + remaining.append(replica) + + # Stop excess PENDING_MIGRATION replicas when new "replacement" + # replicas have transitioned to RUNNING. The replicas with the + # earliest deadlines should be chosen greedily. + remaining.sort(key=lambda r: deadlines[r.actor_node_id]) + num_excess = min_replicas_to_stop - len(to_stop) + + if num_excess > 0: + to_stop.extend(remaining[:num_excess]) + remaining = remaining[num_excess:] + + return to_stop, remaining + + def migrate_replicas_on_draining_nodes(self, draining_nodes: Dict[str, int]): + # Move replicas back to running if they are no longer on a draining node. + # If this causes the number of replicas to exceed the target state, + # they will be scaled down because `scale_deployment_replicas` is called on + # each deployment after this + for replica in self._replicas.pop(states=[ReplicaState.PENDING_MIGRATION]): + if replica.actor_node_id not in draining_nodes: + self._replicas.add(ReplicaState.RUNNING, replica) + else: + self._replicas.add(ReplicaState.PENDING_MIGRATION, replica) + + # Migrate replicas on draining nodes + for replica in self._replicas.pop( + states=[ReplicaState.UPDATING, ReplicaState.RUNNING, ReplicaState.STARTING] + ): + if replica.actor_node_id in draining_nodes: + # For RUNNING replicas, migrate them safely by starting + # a replacement replica first. + if replica.actor_details.state == ReplicaState.RUNNING: + logger.info( + f"Migrating {replica.replica_id} from draining node " + f"'{replica.actor_node_id}'. A new replica will be created on " + "another node." + ) + self._replicas.add(ReplicaState.PENDING_MIGRATION, replica) + # For replicas that are STARTING or UPDATING, might as + # well terminate them immediately to allow replacement + # replicas to start. Otherwise we need to wait for them + # to transition to RUNNING before starting migration. + else: + self._stop_replica(replica, graceful_stop=True) + else: + self._replicas.add(replica.actor_details.state, replica) + + num_running = self._replicas.count(states=[ReplicaState.RUNNING]) + num_draining = self._replicas.count(states=[ReplicaState.PENDING_MIGRATION]) + num_pending_migration_replicas_to_stop = ( + num_running + num_draining - self._target_state.target_num_replicas + ) + + ( + replicas_to_stop, + replicas_to_keep, + ) = self._choose_pending_migration_replicas_to_stop( + self._replicas.pop(states=[ReplicaState.PENDING_MIGRATION]), + draining_nodes, + num_pending_migration_replicas_to_stop, + ) + for replica in replicas_to_stop: + logger.info( + f"Stopping {replica.replica_id} " + f"on draining node {replica.actor_node_id}." + ) + self._stop_replica(replica, graceful_stop=True) + + for replica in replicas_to_keep: + self._replicas.add(ReplicaState.PENDING_MIGRATION, replica) + + def record_multiplexed_model_ids( + self, replica_id: ReplicaID, multiplexed_model_ids: List[str] + ) -> None: + """Records the multiplexed model IDs of a replica. + + Args: + replica_name: Name of the replica. + multiplexed_model_ids: List of model IDs that replica is serving. + """ + # Find the replica + for replica in self._replicas.get(): + if replica.replica_id == replica_id: + replica.record_multiplexed_model_ids(multiplexed_model_ids) + self._multiplexed_model_ids_updated = True + return + + logger.warning(f"{replica_id} not found.") + + def _stop_one_running_replica_for_testing(self): + running_replicas = self._replicas.pop(states=[ReplicaState.RUNNING]) + replica_to_stop = running_replicas.pop() + replica_to_stop.stop(graceful=False) + self._replicas.add(ReplicaState.STOPPING, replica_to_stop) + for replica in running_replicas: + self._replicas.add(ReplicaState.RUNNING, replica) + + +class DeploymentStateManager: + """Manages all state for deployments in the system. + + This class is *not* thread safe, so any state-modifying methods should be + called with a lock held. + """ + + def __init__( + self, + kv_store: KVStoreBase, + long_poll_host: LongPollHost, + all_current_actor_names: List[str], + all_current_placement_group_names: List[str], + cluster_node_info_cache: ClusterNodeInfoCache, + autoscaling_state_manager: AutoscalingStateManager, + head_node_id_override: Optional[str] = None, + create_placement_group_fn_override: Optional[Callable] = None, + ): + self._kv_store = kv_store + self._long_poll_host = long_poll_host + self._cluster_node_info_cache = cluster_node_info_cache + self._deployment_scheduler = default_impl.create_deployment_scheduler( + cluster_node_info_cache, + head_node_id_override, + create_placement_group_fn_override, + ) + self._autoscaling_state_manager = autoscaling_state_manager + + self._deployment_states: Dict[DeploymentID, DeploymentState] = dict() + + self._recover_from_checkpoint( + all_current_actor_names, all_current_placement_group_names + ) + + def _create_deployment_state(self, deployment_id): + self._deployment_scheduler.on_deployment_created( + deployment_id, SpreadDeploymentSchedulingPolicy() + ) + + return DeploymentState( + deployment_id, + self._long_poll_host, + self._deployment_scheduler, + self._cluster_node_info_cache, + self._autoscaling_state_manager, + self._save_checkpoint_func, + ) + + def _map_actor_names_to_deployment( + self, all_current_actor_names: List[str] + ) -> Dict[str, List[str]]: + """ + Given a list of all actor names queried from current ray cluster, + map them to corresponding deployments. + + Example: + Args: + [A#zxc123, B#xcv234, A#qwe234] + Returns: + { + A: [A#zxc123, A#qwe234] + B: [B#xcv234] + } + """ + all_replica_names = [ + actor_name + for actor_name in all_current_actor_names + if ReplicaID.is_full_id_str(actor_name) + ] + deployment_to_current_replicas = defaultdict(list) + if len(all_replica_names) > 0: + for replica_name in all_replica_names: + replica_id = ReplicaID.from_full_id_str(replica_name) + deployment_to_current_replicas[replica_id.deployment_id].append( + replica_name + ) + + return deployment_to_current_replicas + + def _detect_and_remove_leaked_placement_groups( + self, + all_current_actor_names: List[str], + all_current_placement_group_names: List[str], + ): + """Detect and remove any placement groups not associated with a replica. + + This can happen under certain rare circumstances: + - The controller creates a placement group then crashes before creating + the associated replica actor. + - While the controller is down, a replica actor crashes but its placement + group still exists. + + In both of these (or any other unknown cases), we simply need to remove the + leaked placement groups. + """ + leaked_pg_names = [] + for pg_name in all_current_placement_group_names: + if ( + ReplicaID.is_full_id_str(pg_name) + and pg_name not in all_current_actor_names + ): + leaked_pg_names.append(pg_name) + + if len(leaked_pg_names) > 0: + logger.warning( + f"Detected leaked placement groups: {leaked_pg_names}. " + "The placement groups will be removed. This can happen in rare " + "circumstances when the controller crashes and should not cause any " + "issues. If this happens repeatedly, please file an issue on GitHub." + ) + + for leaked_pg_name in leaked_pg_names: + try: + pg = ray.util.get_placement_group(leaked_pg_name) + ray.util.remove_placement_group(pg) + except Exception: + logger.exception( + f"Failed to remove leaked placement group {leaked_pg_name}." + ) + + def _recover_from_checkpoint( + self, + all_current_actor_names: List[str], + all_current_placement_group_names: List[str], + ): + """ + Recover from checkpoint upon controller failure with all actor names + found in current cluster. + + Each deployment resumes target state from checkpoint if available. + + For current state it will prioritize reconstructing from current + actor names found that matches deployment tag if applicable. + """ + self._detect_and_remove_leaked_placement_groups( + all_current_actor_names, + all_current_placement_group_names, + ) + + deployment_to_current_replicas = self._map_actor_names_to_deployment( + all_current_actor_names + ) + checkpoint = self._kv_store.get(CHECKPOINT_KEY) + if checkpoint is not None: + deployment_state_info = cloudpickle.loads(checkpoint) + + for deployment_id, checkpoint_data in deployment_state_info.items(): + deployment_state = self._create_deployment_state(deployment_id) + deployment_state.recover_target_state_from_checkpoint(checkpoint_data) + if len(deployment_to_current_replicas[deployment_id]) > 0: + deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501 + deployment_to_current_replicas[deployment_id] + ) + self._deployment_states[deployment_id] = deployment_state + + def shutdown(self): + """ + Shutdown all running replicas by notifying the controller, and leave + it to the controller event loop to take actions afterwards. + + Once shutdown signal is received, it will also prevent any new + deployments or replicas from being created. + + One can send multiple shutdown signals but won't effectively make any + difference compare to calling it once. + """ + + for deployment_state in self._deployment_states.values(): + deployment_state.delete() + + # TODO(jiaodong): This might not be 100% safe since we deleted + # everything without ensuring all shutdown goals are completed + # yet. Need to address in follow-up PRs. + self._kv_store.delete(CHECKPOINT_KEY) + + # TODO(jiaodong): Need to add some logic to prevent new replicas + # from being created once shutdown signal is sent. + + def is_ready_for_shutdown(self) -> bool: + """Return whether all deployments are shutdown. + + Check there are no deployment states and no checkpoints. + """ + return ( + len(self._deployment_states) == 0 + and self._kv_store.get(CHECKPOINT_KEY) is None + ) + + def _save_checkpoint_func( + self, *, writeahead_checkpoints: Optional[Dict[str, Tuple]] + ) -> None: + """Write a checkpoint of all deployment states. + By default, this checkpoints the current in-memory state of each + deployment. However, these can be overwritten by passing + `writeahead_checkpoints` in order to checkpoint an update before + applying it to the in-memory state. + """ + + deployment_state_info = { + deployment_id: deployment_state.get_checkpoint_data() + for deployment_id, deployment_state in self._deployment_states.items() + } + + if writeahead_checkpoints is not None: + deployment_state_info.update(writeahead_checkpoints) + + self._kv_store.put( + CHECKPOINT_KEY, + cloudpickle.dumps(deployment_state_info), + ) + + def get_running_replica_infos( + self, + ) -> Dict[DeploymentID, List[RunningReplicaInfo]]: + return { + id: deployment_state.get_running_replica_infos() + for id, deployment_state in self._deployment_states.items() + } + + def get_deployment_infos(self) -> Dict[DeploymentID, DeploymentInfo]: + infos: Dict[DeploymentID, DeploymentInfo] = {} + for deployment_id, deployment_state in self._deployment_states.items(): + infos[deployment_id] = deployment_state.target_info + + return infos + + def get_deployment(self, deployment_id: DeploymentID) -> Optional[DeploymentInfo]: + if deployment_id in self._deployment_states: + return self._deployment_states[deployment_id].target_info + else: + return None + + def get_deployment_details(self, id: DeploymentID) -> Optional[DeploymentDetails]: + """Gets detailed info on a deployment. + + Returns: + DeploymentDetails: if the deployment is live. + None: if the deployment is deleted. + """ + statuses = self.get_deployment_statuses([id]) + if len(statuses) == 0: + return None + else: + status_info = statuses[0] + deployment_state = self._deployment_states[id] + return DeploymentDetails( + name=id.name, + status=status_info.status, + status_trigger=status_info.status_trigger, + message=status_info.message, + deployment_config=_deployment_info_to_schema( + id.name, self.get_deployment(id) + ), + target_num_replicas=deployment_state._target_state.target_num_replicas, + replicas=deployment_state.list_replica_details(), + ) + + def get_deployment_statuses( + self, ids: Optional[List[DeploymentID]] = None + ) -> List[DeploymentStatusInfo]: + """ + Return the statuses of the deployments with the given `ids`. + If `ids` is `None`, returns the status of all deployments. + """ + if ids is None: + # fast path for returning all deployments, + # avoids checking `if ids is None` in a loop + return [ + state.curr_status_info for state in self._deployment_states.values() + ] + else: + statuses = [] + for id in ids: + state = self._deployment_states.get(id) + if state is not None: + statuses.append(state.curr_status_info) + return statuses + + def get_alive_replica_actor_ids(self) -> Set[str]: + alive_replica_actor_ids = set() + for ds in self._deployment_states.values(): + alive_replica_actor_ids |= ds.get_alive_replica_actor_ids() + + return alive_replica_actor_ids + + def deploy( + self, + deployment_id: DeploymentID, + deployment_info: DeploymentInfo, + ) -> bool: + """Deploy the deployment. + + If the deployment already exists with the same version and config, + this is a no-op and returns False. + + Returns: + bool: Whether or not the deployment is being updated. + """ + if deployment_id not in self._deployment_states: + self._deployment_states[deployment_id] = self._create_deployment_state( + deployment_id + ) + self._record_deployment_usage() + + return self._deployment_states[deployment_id].deploy(deployment_info) + + def get_deployments_in_application(self, app_name: str) -> List[str]: + """Return list of deployment names in application.""" + + deployments = [] + for deployment_id in self._deployment_states: + if deployment_id.app_name == app_name: + deployments.append(deployment_id.name) + + return deployments + + def delete_deployment(self, id: DeploymentID): + # This method must be idempotent. We should validate that the + # specified deployment exists on the client. + if id in self._deployment_states: + self._deployment_states[id].delete() + + def update(self) -> bool: + """Updates the state of all deployments to match their goal state. + + Returns True if any of the deployments have replicas in the RECOVERING state. + """ + + deleted_ids = [] + any_recovering = False + upscales: Dict[DeploymentID, List[ReplicaSchedulingRequest]] = {} + downscales: Dict[DeploymentID, DeploymentDownscaleRequest] = {} + + # STEP 1: Update current state + for deployment_state in self._deployment_states.values(): + if deployment_state.should_autoscale(): + deployment_state.autoscale() + + deployment_state.check_and_update_replicas() + + # STEP 2: Check current status + for deployment_state in self._deployment_states.values(): + deployment_state.check_curr_status() + + # STEP 3: Drain nodes + draining_nodes = self._cluster_node_info_cache.get_draining_nodes() + allow_new_compaction = len(draining_nodes) == 0 and all( + ds.curr_status_info.status == DeploymentStatus.HEALTHY + # TODO(zcin): Make sure that status should never be healthy if + # the number of running replicas at target version is not at + # target number, so we can remove this defensive check. + and ds.get_num_running_replicas(ds.target_version) == ds.target_num_replicas + # To be extra conservative, only actively compact if there + # are no non-running replicas + and len(ds._replicas.get()) == ds.target_num_replicas + for ds in self._deployment_states.values() + ) + if RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY: + # Tuple of target node to compact, and its draining deadline + node_info: Optional[ + Tuple[str, float] + ] = self._deployment_scheduler.get_node_to_compact( + allow_new_compaction=allow_new_compaction + ) + if node_info: + target_node_id, deadline = node_info + draining_nodes = {target_node_id: deadline} + + for deployment_id, deployment_state in self._deployment_states.items(): + deployment_state.migrate_replicas_on_draining_nodes(draining_nodes) + + # STEP 4: Scale replicas + for deployment_id, deployment_state in self._deployment_states.items(): + upscale, downscale = deployment_state.scale_deployment_replicas() + + if upscale: + upscales[deployment_id] = upscale + if downscale: + downscales[deployment_id] = downscale + + # STEP 5: Update status + for deployment_id, deployment_state in self._deployment_states.items(): + deleted, any_replicas_recovering = deployment_state.check_curr_status() + + if deleted: + deleted_ids.append(deployment_id) + any_recovering |= any_replicas_recovering + + # STEP 6: Schedule all STARTING replicas and stop all STOPPING replicas + deployment_to_replicas_to_stop = self._deployment_scheduler.schedule( + upscales, downscales + ) + for deployment_id, replicas_to_stop in deployment_to_replicas_to_stop.items(): + self._deployment_states[deployment_id].stop_replicas(replicas_to_stop) + for deployment_id, scheduling_requests in upscales.items(): + self._handle_scheduling_request_failures(deployment_id, scheduling_requests) + + # STEP 7: Broadcast long poll information + for deployment_id, deployment_state in self._deployment_states.items(): + deployment_state.broadcast_running_replicas_if_changed() + deployment_state.broadcast_deployment_config_if_changed() + if deployment_state.should_autoscale(): + self._autoscaling_state_manager.update_running_replica_ids( + deployment_id=deployment_id, + running_replicas=deployment_state.get_running_replica_ids(), + ) + + # STEP 8: Cleanup + for deployment_id in deleted_ids: + self._deployment_scheduler.on_deployment_deleted(deployment_id) + self._autoscaling_state_manager.deregister_deployment(deployment_id) + del self._deployment_states[deployment_id] + + if len(deleted_ids): + self._record_deployment_usage() + + return any_recovering + + def _handle_scheduling_request_failures( + self, + deployment_id: DeploymentID, + scheduling_requests: List[ReplicaSchedulingRequest], + ): + """Updates internal datastructures when replicas fail to be scheduled.""" + failed_replicas: List[ReplicaID] = [] + for scheduling_request in scheduling_requests: + if ( + scheduling_request.status + == ReplicaSchedulingRequestStatus.PLACEMENT_GROUP_CREATION_FAILED + ): + failed_replicas.append(scheduling_request.replica_id) + self._deployment_states[deployment_id].record_replica_startup_failure( + "Replica scheduling failed. Failed to create a placement " + f"group for replica {scheduling_request.replica_id}. " + "See Serve controller logs for more details." + ) + elif ( + scheduling_request.status + == ReplicaSchedulingRequestStatus.ACTOR_CREATION_FAILED + ): + failed_replicas.append(scheduling_request.replica_id) + self._deployment_states[deployment_id].record_replica_startup_failure( + "Replica scheduling failed. Failed to create an actor " + f"for replica {scheduling_request.replica_id}. " + "See Serve controller logs for more details." + ) + if failed_replicas: + self._deployment_states[deployment_id].stop_replicas(failed_replicas) + + def _record_deployment_usage(self): + ServeUsageTag.NUM_DEPLOYMENTS.record(str(len(self._deployment_states))) + + num_gpu_deployments = 0 + for deployment_state in self._deployment_states.values(): + if ( + deployment_state.target_info is not None + and deployment_state.target_info.replica_config is not None + and deployment_state.target_info.replica_config.ray_actor_options + is not None + and ( + deployment_state.target_info.replica_config.ray_actor_options.get( + "num_gpus", 0 + ) + > 0 + ) + ): + num_gpu_deployments += 1 + ServeUsageTag.NUM_GPU_DEPLOYMENTS.record(str(num_gpu_deployments)) + + def record_multiplexed_replica_info(self, info: MultiplexedReplicaInfo): + """ + Record multiplexed model ids for a multiplexed replica. + + Args: + info: Multiplexed replica info including deployment name, + replica tag and model ids. + """ + deployment_id = info.replica_id.deployment_id + if deployment_id not in self._deployment_states: + app_msg = f" in application '{deployment_id.app_name}'" + logger.error( + f"Deployment '{deployment_id.name}'{app_msg} not found in state " + "manager." + ) + return + self._deployment_states[deployment_id].record_multiplexed_model_ids( + info.replica_id, info.model_ids + ) + + def get_active_node_ids(self) -> Set[str]: + """Return set of node ids with running replicas of any deployment. + + This is used to determine which node has replicas. Only nodes with replicas and + head node should have active proxies. + """ + node_ids = set() + for deployment_state in self._deployment_states.values(): + node_ids.update(deployment_state.get_active_node_ids()) + return node_ids diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/handle_options.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/handle_options.py new file mode 100644 index 0000000000000000000000000000000000000000..d0438f4ec3d7af2c57e143587868b3a057e51978 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/handle_options.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields + +import ray +from ray.serve._private.common import DeploymentHandleSource +from ray.serve._private.utils import DEFAULT + + +@dataclass(frozen=True) +class InitHandleOptionsBase(ABC): + """Init options for each ServeHandle instance. + + These fields can be set by calling `.init()` on a handle before + sending the first request. + """ + + _prefer_local_routing: bool = False + _source: DeploymentHandleSource = DeploymentHandleSource.UNKNOWN + + @classmethod + @abstractmethod + def create(cls, **kwargs) -> "InitHandleOptionsBase": + raise NotImplementedError + + +@dataclass(frozen=True) +class InitHandleOptions(InitHandleOptionsBase): + @classmethod + def create(cls, **kwargs) -> "InitHandleOptions": + for k in list(kwargs.keys()): + if kwargs[k] == DEFAULT.VALUE: + # Use default value + del kwargs[k] + + # Detect replica source for handles + if ( + "_source" not in kwargs + and ray.serve.context._get_internal_replica_context() is not None + ): + kwargs["_source"] = DeploymentHandleSource.REPLICA + + return cls(**kwargs) + + +@dataclass(frozen=True) +class DynamicHandleOptionsBase(ABC): + """Dynamic options for each ServeHandle instance. + + These fields can be changed by calling `.options()` on a handle. + """ + + method_name: str = "__call__" + multiplexed_model_id: str = "" + stream: bool = False + + @abstractmethod + def copy_and_update(self, **kwargs) -> "DynamicHandleOptionsBase": + pass + + +@dataclass(frozen=True) +class DynamicHandleOptions(DynamicHandleOptionsBase): + def copy_and_update(self, **kwargs) -> "DynamicHandleOptions": + new_kwargs = {} + + for f in fields(self): + if f.name not in kwargs or kwargs[f.name] == DEFAULT.VALUE: + new_kwargs[f.name] = getattr(self, f.name) + else: + new_kwargs[f.name] = kwargs[f.name] + + return DynamicHandleOptions(**new_kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/http_util.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/http_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd4a1ac84332e86998d5e65fa02ca726c78a936 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/http_util.py @@ -0,0 +1,540 @@ +import asyncio +import inspect +import json +import logging +import pickle +import socket +from collections import deque +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type + +import starlette +from fastapi.encoders import jsonable_encoder +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from uvicorn.config import Config +from uvicorn.lifespan.on import LifespanOn + +from ray._private.pydantic_compat import IS_PYDANTIC_2 +from ray.serve._private.common import RequestMetadata +from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve._private.utils import serve_encoders +from ray.serve.exceptions import RayServeException + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +@dataclass(frozen=True) +class ASGIArgs: + scope: Scope + receive: Receive + send: Send + + def to_args_tuple(self) -> Tuple[Scope, Receive, Send]: + return (self.scope, self.receive, self.send) + + def to_starlette_request(self) -> starlette.requests.Request: + return starlette.requests.Request( + *self.to_args_tuple(), + ) + + +def make_buffered_asgi_receive(serialized_body: bytes) -> Receive: + """Returns an ASGI receiver that returns the provided buffered body.""" + + # Simulates receiving HTTP body from TCP socket. In reality, the body has + # already been streamed in chunks and stored in serialized_body. + received = False + + async def mock_receive(): + nonlocal received + + # If the request has already been received, starlette will keep polling + # for HTTP disconnect. We will pause forever. The coroutine should be + # cancelled by starlette after the response has been sent. + if received: + block_forever = asyncio.Event() + await block_forever.wait() + + received = True + return {"body": serialized_body, "type": "http.request", "more_body": False} + + return mock_receive + + +def convert_object_to_asgi_messages( + obj: Optional[Any] = None, status_code: int = 200 +) -> List[Message]: + """Serializes the provided object and converts it to ASGI messages. + + These ASGI messages can be sent via an ASGI `send` interface to comprise an HTTP + response. + """ + body = None + content_type = None + if obj is None: + body = b"" + content_type = b"text/plain" + elif isinstance(obj, bytes): + body = obj + content_type = b"text/plain" + elif isinstance(obj, str): + body = obj.encode("utf-8") + content_type = b"text/plain; charset=utf-8" + else: + # `separators=(",", ":")` will remove all whitespaces between separators in the + # json string and return a minimized json string. This helps to reduce the size + # of the response similar to Starlette's JSONResponse. + body = json.dumps( + jsonable_encoder(obj, custom_encoder=serve_encoders), + separators=(",", ":"), + ).encode() + content_type = b"application/json" + + return [ + { + "type": "http.response.start", + "status": status_code, + "headers": [[b"content-type", content_type]], + }, + {"type": "http.response.body", "body": body}, + ] + + +class Response: + """ASGI compliant response class. + + It is expected to be called in async context and pass along + `scope, receive, send` as in ASGI spec. + + >>> from ray.serve.http_util import Response # doctest: +SKIP + >>> scope, receive = ... # doctest: +SKIP + >>> await Response({"k": "v"}).send(scope, receive, send) # doctest: +SKIP + """ + + def __init__(self, content=None, status_code=200): + """Construct a HTTP Response based on input type. + + Args: + content: Any JSON serializable object. + status_code (int, optional): Default status code is 200. + """ + self._messages = convert_object_to_asgi_messages( + obj=content, + status_code=status_code, + ) + + async def send(self, scope, receive, send): + for message in self._messages: + await send(message) + + +async def receive_http_body(scope, receive, send): + body_buffer = [] + more_body = True + while more_body: + message = await receive() + assert message["type"] == "http.request" + + more_body = message["more_body"] + body_buffer.append(message["body"]) + + return b"".join(body_buffer) + + +class MessageQueue(Send): + """Queue enables polling for received or sent messages. + + Implements the ASGI `Send` interface. + + This class: + - Is *NOT* thread safe and should only be accessed from a single asyncio + event loop. + - Assumes a single consumer of the queue (concurrent calls to + `get_messages_nowait` and `wait_for_message` is undefined behavior). + """ + + def __init__(self): + self._message_queue = deque() + self._new_message_event = asyncio.Event() + self._closed = False + self._error = None + + def close(self): + """Close the queue, rejecting new messages. + + Once the queue is closed, existing messages will be returned from + `get_messages_nowait` and subsequent calls to `wait_for_message` will + always return immediately. + """ + self._closed = True + self._new_message_event.set() + + def set_error(self, e: BaseException): + self._error = e + + def put_nowait(self, message: Message): + self._message_queue.append(message) + self._new_message_event.set() + + async def __call__(self, message: Message): + """Send a message, putting it on the queue. + + `RuntimeError` is raised if the queue has been closed using `.close()`. + """ + if self._closed: + raise RuntimeError("New messages cannot be sent after the queue is closed.") + + self.put_nowait(message) + + async def wait_for_message(self): + """Wait until at least one new message is available. + + If a message is available, this method will return immediately on each call + until `get_messages_nowait` is called. + + After the queue is closed using `.close()`, this will always return + immediately. + """ + if not self._closed: + await self._new_message_event.wait() + + def get_messages_nowait(self) -> List[Message]: + """Returns all messages that are currently available (non-blocking). + + At least one message will be present if `wait_for_message` had previously + returned and a subsequent call to `wait_for_message` blocks until at + least one new message is available. + """ + messages = [] + while len(self._message_queue) > 0: + messages.append(self._message_queue.popleft()) + + self._new_message_event.clear() + return messages + + async def get_one_message(self) -> Message: + """This blocks until a message is ready. + + This method should not be used together with get_messages_nowait. + Please use either `get_one_message` or `get_messages_nowait`. + + Raises: + StopAsyncIteration: if the queue is closed and there are no + more messages. + Exception (self._error): if there are no more messages in + the queue and an error has been set. + """ + + if self._error: + raise self._error + + await self._new_message_event.wait() + + if len(self._message_queue) > 0: + msg = self._message_queue.popleft() + + if len(self._message_queue) == 0 and not self._closed: + self._new_message_event.clear() + + return msg + elif len(self._message_queue) == 0 and self._error: + raise self._error + elif len(self._message_queue) == 0 and self._closed: + raise StopAsyncIteration + + +class ASGIReceiveProxy: + """Proxies ASGI receive from an actor. + + The `receive_asgi_messages` callback will be called repeatedly to fetch messages + until a disconnect message is received. + """ + + def __init__( + self, + scope: Scope, + request_metadata: RequestMetadata, + receive_asgi_messages: Callable[[RequestMetadata], Awaitable[bytes]], + ): + self._type = scope["type"] # Either 'http' or 'websocket'. + self._queue = asyncio.Queue() + self._request_metadata = request_metadata + self._receive_asgi_messages = receive_asgi_messages + self._disconnect_message = None + + def _get_default_disconnect_message(self) -> Message: + """Return the appropriate disconnect message based on the connection type. + + HTTP ASGI spec: + https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event + + WS ASGI spec: + https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event-ws + """ + if self._type == "websocket": + return { + "type": "websocket.disconnect", + # 1005 is the default disconnect code according to the ASGI spec. + "code": 1005, + } + else: + return {"type": "http.disconnect"} + + async def fetch_until_disconnect(self): + """Fetch messages repeatedly until a disconnect message is received. + + If a disconnect message is received, this function exits and returns it. + + If an exception occurs, it will be raised on the next __call__ and no more + messages will be received. + """ + while True: + try: + pickled_messages = await self._receive_asgi_messages( + self._request_metadata + ) + for message in pickle.loads(pickled_messages): + self._queue.put_nowait(message) + + if message["type"] in {"http.disconnect", "websocket.disconnect"}: + self._disconnect_message = message + return + except KeyError: + # KeyError can be raised if the request is no longer active in the proxy + # (i.e., the user disconnects). This is expected behavior and we should + # not log an error: https://github.com/ray-project/ray/issues/43290. + message = self._get_default_disconnect_message() + self._queue.put_nowait(message) + self._disconnect_message = message + return + except Exception as e: + # Raise unexpected exceptions in the next `__call__`. + self._queue.put_nowait(e) + return + + async def __call__(self) -> Message: + """Return the next message once available. + + This will repeatedly return a disconnect message once it's been received. + """ + if self._queue.empty() and self._disconnect_message is not None: + return self._disconnect_message + + message = await self._queue.get() + if isinstance(message, Exception): + raise message + + return message + + +def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None: + """Transform the `cls`'s methods and class annotations to FastAPI routes. + + Modified from + https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py + + Usage: + >>> from fastapi import FastAPI + >>> app = FastAPI() # doctest: +SKIP + >>> class A: # doctest: +SKIP + ... @app.route("/{i}") # doctest: +SKIP + ... def func(self, i: int) -> str: # doctest: +SKIP + ... return self.dep + i # doctest: +SKIP + >>> # just running the app won't work, here. + >>> make_fastapi_class_based_view(app, A) # doctest: +SKIP + >>> # now app can be run properly + """ + # Delayed import to prevent ciruclar imports in workers. + from fastapi import APIRouter, Depends + from fastapi.routing import APIRoute, APIWebSocketRoute + + def get_current_servable_instance(): + from ray import serve + + return serve.get_replica_context().servable_object + + # Find all the class method routes + class_method_routes = [ + route + for route in fastapi_app.routes + if + # User defined routes must all be APIRoute or APIWebSocketRoute. + isinstance(route, (APIRoute, APIWebSocketRoute)) + # We want to find the route that's bound to the `cls`. + # NOTE(simon): we can't use `route.endpoint in inspect.getmembers(cls)` + # because the FastAPI supports different routes for the methods with + # same name. See #17559. + and (cls.__qualname__ in route.endpoint.__qualname__) + ] + + # Modify these routes and mount it to a new APIRouter. + # We need to to this (instead of modifying in place) because we want to use + # the laster fastapi_app.include_router to re-run the dependency analysis + # for each routes. + new_router = APIRouter() + for route in class_method_routes: + fastapi_app.routes.remove(route) + + # This block just adds a default values to the self parameters so that + # FastAPI knows to inject the object when calling the route. + # Before: def method(self, i): ... + # After: def method(self=Depends(...), *, i):... + old_endpoint = route.endpoint + old_signature = inspect.signature(old_endpoint) + old_parameters = list(old_signature.parameters.values()) + if len(old_parameters) == 0: + # TODO(simon): make it more flexible to support no arguments. + raise RayServeException( + "Methods in FastAPI class-based view must have ``self`` as " + "their first argument." + ) + old_self_parameter = old_parameters[0] + new_self_parameter = old_self_parameter.replace( + default=Depends(get_current_servable_instance) + ) + new_parameters = [new_self_parameter] + [ + # Make the rest of the parameters keyword only because + # the first argument is no longer positional. + parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for parameter in old_parameters[1:] + ] + new_signature = old_signature.replace(parameters=new_parameters) + route.endpoint.__signature__ = new_signature + route.endpoint._serve_cls = cls + new_router.routes.append(route) + fastapi_app.include_router(new_router) + + routes_to_remove = list() + for route in fastapi_app.routes: + if not isinstance(route, (APIRoute, APIWebSocketRoute)): + continue + + # If there is a response model, FastAPI creates a copy of the fields. + # But FastAPI creates the field incorrectly by missing the outer_type_. + if ( + # TODO(edoakes): I don't think this check is complete because we need + # to support v1 models in v2 (from pydantic.v1 import *). + not IS_PYDANTIC_2 + and isinstance(route, APIRoute) + and route.response_model + ): + route.secure_cloned_response_field.outer_type_ = ( + route.response_field.outer_type_ + ) + + # Remove endpoints that belong to other class based views. + serve_cls = getattr(route.endpoint, "_serve_cls", None) + if serve_cls is not None and serve_cls != cls: + routes_to_remove.append(route) + fastapi_app.routes[:] = [r for r in fastapi_app.routes if r not in routes_to_remove] + + +def set_socket_reuse_port(sock: socket.socket) -> bool: + """Mutate a socket object to allow multiple process listening on the same port. + + Returns: + success: whether the setting was successful. + """ + try: + # These two socket options will allow multiple process to bind the the + # same port. Kernel will evenly load balance among the port listeners. + # Note: this will only work on Linux. + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + # In some Python binary distribution (e.g., conda py3.6), this flag + # was not present at build time but available in runtime. But + # Python relies on compiler flag to include this in binary. + # Therefore, in the absence of socket.SO_REUSEPORT, we try + # to use `15` which is value in linux kernel. + # https://github.com/torvalds/linux/blob/master/tools/include/uapi/asm-generic/socket.h#L27 + else: + sock.setsockopt(socket.SOL_SOCKET, 15, 1) + return True + except Exception as e: + logger.debug( + f"Setting SO_REUSEPORT failed because of {e}. SO_REUSEPORT is disabled." + ) + return False + + +class ASGIAppReplicaWrapper: + """Provides a common wrapper for replicas running an ASGI app.""" + + def __init__(self, app: ASGIApp): + self._asgi_app = app + + # Use uvicorn's lifespan handling code to properly deal with + # startup and shutdown event. + self._serve_asgi_lifespan = LifespanOn(Config(self._asgi_app, lifespan="on")) + + # Replace uvicorn logger with our own. + self._serve_asgi_lifespan.logger = logger + + @property + def app(self) -> ASGIApp: + return self._asgi_app + + async def _run_asgi_lifespan_startup(self): + # LifespanOn's logger logs in INFO level thus becomes spammy + # Within this block we temporarily uplevel for cleaner logging + from ray.serve._private.logging_utils import LoggingContext + + with LoggingContext(self._serve_asgi_lifespan.logger, level=logging.WARNING): + await self._serve_asgi_lifespan.startup() + if self._serve_asgi_lifespan.should_exit: + raise RuntimeError( + "ASGI lifespan startup failed. Check replica logs for details." + ) + + async def __call__( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> Optional[ASGIApp]: + """Calls into the wrapped ASGI app.""" + await self._asgi_app( + scope, + receive, + send, + ) + + # NOTE: __del__ must be async so that we can run ASGI shutdown + # in the same event loop. + async def __del__(self): + # LifespanOn's logger logs in INFO level thus becomes spammy. + # Within this block we temporarily uplevel for cleaner logging. + from ray.serve._private.logging_utils import LoggingContext + + with LoggingContext(self._serve_asgi_lifespan.logger, level=logging.WARNING): + await self._serve_asgi_lifespan.shutdown() + + +def validate_http_proxy_callback_return( + middlewares: Any, +) -> [starlette.middleware.Middleware]: + """Validate the return value of HTTP proxy callback. + + Middlewares should be a list of Starlette middlewares. If it is None, we + will treat it as an empty list. If it is not a list, we will raise an + error. If it is a list, we will check if all the items in the list are + Starlette middlewares. + """ + + if middlewares is None: + middlewares = [] + if not isinstance(middlewares, list): + raise ValueError( + "HTTP proxy callback must return a list of Starlette middlewares." + ) + else: + # All middlewares must be Starlette middlewares. + # https://www.starlette.io/middleware/#using-pure-asgi-middleware + for middleware in middlewares: + if not issubclass(type(middleware), starlette.middleware.Middleware): + raise ValueError( + "HTTP proxy callback must return a list of Starlette middlewares, " + f"instead got {type(middleware)} type item in the list." + ) + return middlewares diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/logging_utils.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8058f9419c675b361367adfa98fd16fd94508c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/logging_utils.py @@ -0,0 +1,523 @@ +import builtins +import logging +import os +import sys +import traceback +from typing import Any, Optional, Tuple + +import ray +from ray._private.ray_constants import LOGGING_ROTATE_BACKUP_COUNT, LOGGING_ROTATE_BYTES +from ray._private.ray_logging.filters import CoreContextFilter +from ray._private.ray_logging.formatters import JSONFormatter +from ray.serve._private.common import ServeComponentType +from ray.serve._private.constants import ( + RAY_SERVE_ENABLE_CPU_PROFILING, + RAY_SERVE_ENABLE_JSON_LOGGING, + RAY_SERVE_ENABLE_MEMORY_PROFILING, + RAY_SERVE_LOG_TO_STDERR, + SERVE_LOG_APPLICATION, + SERVE_LOG_COMPONENT, + SERVE_LOG_COMPONENT_ID, + SERVE_LOG_DEPLOYMENT, + SERVE_LOG_LEVEL_NAME, + SERVE_LOG_MESSAGE, + SERVE_LOG_RECORD_FORMAT, + SERVE_LOG_REPLICA, + SERVE_LOG_REQUEST_ID, + SERVE_LOG_ROUTE, + SERVE_LOG_TIME, + SERVE_LOG_UNWANTED_ATTRS, + SERVE_LOGGER_NAME, +) +from ray.serve._private.utils import get_component_file_name +from ray.serve.schema import EncodingType, LoggingConfig + +try: + import cProfile +except ImportError: + pass + + +buildin_print = builtins.print + + +class ServeComponentFilter(logging.Filter): + """Serve component filter. + + The filter will add the component name, id, and type to the log record. + """ + + def __init__( + self, + component_name: str, + component_id: str, + component_type: Optional[ServeComponentType] = None, + ): + self.component_name = component_name + self.component_id = component_id + self.component_type = component_type + + def filter(self, record: logging.LogRecord) -> bool: + """Add component attributes to the log record. + + Note: the filter doesn't do any filtering, it only adds the component + attributes. + """ + if self.component_type and self.component_type == ServeComponentType.REPLICA: + setattr(record, SERVE_LOG_DEPLOYMENT, self.component_name) + setattr(record, SERVE_LOG_REPLICA, self.component_id) + setattr(record, SERVE_LOG_COMPONENT, self.component_type) + else: + setattr(record, SERVE_LOG_COMPONENT, self.component_name) + setattr(record, SERVE_LOG_COMPONENT_ID, self.component_id) + + return True + + +class ServeContextFilter(logging.Filter): + """Serve context filter. + + The filter will add the route, request id, app name to the log record. + + Note: the filter doesn't do any filtering, it only adds the serve request context + attributes. + """ + + def filter(self, record): + request_context = ray.serve.context._get_serve_request_context() + if request_context.route: + setattr(record, SERVE_LOG_ROUTE, request_context.route) + if request_context.request_id: + setattr(record, SERVE_LOG_REQUEST_ID, request_context.request_id) + if request_context.app_name: + setattr(record, SERVE_LOG_APPLICATION, request_context.app_name) + return True + + +class ServeLogAttributeRemovalFilter(logging.Filter): + """Serve log attribute removal filter. + + The filter will remove unwanted attributes on the log record so they won't be + included in the structured logs. + + Note: the filter doesn't do any filtering, it only removes unwanted attributes. + """ + + def filter(self, record): + for key in SERVE_LOG_UNWANTED_ATTRS: + if hasattr(record, key): + delattr(record, key) + + return True + + +class ServeFormatter(logging.Formatter): + """Serve Logging Formatter + + The formatter will generate the log format on the fly based on the field of record. + """ + + COMPONENT_LOG_FMT = f"%({SERVE_LOG_LEVEL_NAME})s %({SERVE_LOG_TIME})s {{{SERVE_LOG_COMPONENT}}} {{{SERVE_LOG_COMPONENT_ID}}} " # noqa:E501 + + def __init__( + self, + component_name: str, + component_id: str, + ): + self.component_log_fmt = ServeFormatter.COMPONENT_LOG_FMT.format( + component_name=component_name, component_id=component_id + ) + + def format(self, record: logging.LogRecord) -> str: + """Format the log record into the format string. + + Args: + record: The log record to be formatted. + + Returns: + The formatted log record in string format. + """ + record_format = self.component_log_fmt + record_formats_attrs = [] + if SERVE_LOG_REQUEST_ID in record.__dict__: + record_formats_attrs.append(SERVE_LOG_RECORD_FORMAT[SERVE_LOG_REQUEST_ID]) + record_formats_attrs.append(SERVE_LOG_RECORD_FORMAT[SERVE_LOG_MESSAGE]) + record_format += " ".join(record_formats_attrs) + + # create a formatter using the format string + formatter = logging.Formatter(record_format) + + # format the log record using the formatter + return formatter.format(record) + + +def access_log_msg(*, method: str, route: str, status: str, latency_ms: float): + """Returns a formatted message for an HTTP or ServeHandle access log.""" + return f"{method} {route} {status} {latency_ms:.1f}ms" + + +def log_to_stderr_filter(record: logging.LogRecord) -> bool: + """Filters log records based on a parameter in the `extra` dictionary.""" + if not hasattr(record, "log_to_stderr") or record.log_to_stderr is None: + return True + + return record.log_to_stderr + + +def log_access_log_filter(record: logging.LogRecord) -> bool: + """Filters ray serve access log based on 'serve_access_log' key in `extra` dict.""" + if not hasattr(record, "serve_access_log") or record.serve_access_log is None: + return True + + return not record.serve_access_log + + +def get_component_logger_file_path() -> Optional[str]: + """Returns the relative file path for the Serve logger, if it exists. + + If a logger was configured through configure_component_logger() for the Serve + component that's calling this function, this returns the location of the log file + relative to the ray logs directory. + """ + logger = logging.getLogger(SERVE_LOGGER_NAME) + for handler in logger.handlers: + if isinstance(handler, logging.handlers.RotatingFileHandler): + absolute_path = handler.baseFilename + ray_logs_dir = ray._private.worker._global_node.get_logs_dir_path() + if absolute_path.startswith(ray_logs_dir): + return absolute_path[len(ray_logs_dir) :] + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + + This comes from https://stackoverflow.com/a/36296215 directly. + """ + + def __init__(self, logger: logging.Logger, log_level: int, original_object: Any): + self._logger = logger + self._log_level = log_level + self._original_object = original_object + self._linebuf = "" + + def __getattr__(self, attr: str) -> Any: + # getting attributes from the original object + return getattr(self._original_object, attr) + + @staticmethod + def get_stacklevel() -> int: + """Rewind stack to get the stacklevel for the user code. + + Going from the back of the traceback and traverse until it's no longer in + logging_utils.py or site-packages. + """ + reverse_traces = traceback.extract_stack()[::-1] + for index, trace in enumerate(reverse_traces): + if ( + "logging_utils.py" not in trace.filename + and "site-packages" not in trace.filename + ): + return index + return 1 + + def write(self, buf: str): + temp_linebuf = self._linebuf + buf + self._linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross-platform. + if line[-1] == "\n": + self._logger.log( + self._log_level, + line.rstrip(), + stacklevel=self.get_stacklevel(), + ) + else: + self._linebuf += line + + def flush(self): + if self._linebuf != "": + self._logger.log( + self._log_level, + self._linebuf.rstrip(), + stacklevel=self.get_stacklevel(), + ) + self._linebuf = "" + + def isatty(self) -> bool: + return True + + +def redirected_print(*objects, sep=" ", end="\n", file=None, flush=False): + """Implement python's print function to redirect logs to Serve's logger. + + If the file is set to anything other than stdout, stderr, or None, call the + builtin print. Else, construct the message and redirect to Serve's logger. + + See https://docs.python.org/3/library/functions.html#print + """ + if file not in [sys.stdout, sys.stderr, None]: + return buildin_print(objects, sep=sep, end=end, file=file, flush=flush) + + serve_logger = logging.getLogger(SERVE_LOGGER_NAME) + message = sep.join(map(str, objects)) + end + # We monkey patched print function, so this is always at stack level 2. + serve_logger.log(logging.INFO, message, stacklevel=2) + + +def configure_component_logger( + *, + component_name: str, + component_id: str, + logging_config: LoggingConfig, + component_type: Optional[ServeComponentType] = None, + max_bytes: Optional[int] = None, + backup_count: Optional[int] = None, + stream_handler_only: bool = False, +): + """Configure a logger to be used by a Serve component. + + The logger will log using a standard format to make components identifiable + using the provided name and unique ID for this instance (e.g., replica ID). + + This logger will *not* propagate its log messages to the parent logger(s). + """ + logger = logging.getLogger(SERVE_LOGGER_NAME) + logger.propagate = False + logger.setLevel(logging_config.log_level) + logger.handlers.clear() + + # Only add stream handler if RAY_SERVE_LOG_TO_STDERR is True or if + # `stream_handler_only` is set to True. + if RAY_SERVE_LOG_TO_STDERR or stream_handler_only: + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(ServeFormatter(component_name, component_id)) + stream_handler.addFilter(log_to_stderr_filter) + stream_handler.addFilter(ServeContextFilter()) + logger.addHandler(stream_handler) + + # Skip setting up file handler and stdout/stderr redirect if `stream_handler_only` + # is set to True. Logger such as default serve logger can be configured outside the + # context of a Serve component, we don't want those logs to redirect into serve's + # logger and log files. + if stream_handler_only: + return + + if logging_config.logs_dir: + logs_dir = logging_config.logs_dir + else: + logs_dir = get_serve_logs_dir() + os.makedirs(logs_dir, exist_ok=True) + + if max_bytes is None: + max_bytes = ray._private.worker._global_node.max_bytes + if backup_count is None: + backup_count = ray._private.worker._global_node.backup_count + + log_file_name = get_component_file_name( + component_name=component_name, + component_id=component_id, + component_type=component_type, + suffix=".log", + ) + + file_handler = logging.handlers.RotatingFileHandler( + os.path.join(logs_dir, log_file_name), + maxBytes=max_bytes, + backupCount=backup_count, + ) + if RAY_SERVE_ENABLE_JSON_LOGGING: + logger.warning( + "'RAY_SERVE_ENABLE_JSON_LOGGING' is deprecated, please use " + "'LoggingConfig' to enable json format." + ) + if RAY_SERVE_ENABLE_JSON_LOGGING or logging_config.encoding == EncodingType.JSON: + file_handler.addFilter(CoreContextFilter()) + file_handler.addFilter(ServeContextFilter()) + file_handler.addFilter( + ServeComponentFilter(component_name, component_id, component_type) + ) + file_handler.setFormatter(JSONFormatter()) + else: + file_handler.setFormatter(ServeFormatter(component_name, component_id)) + + if logging_config.enable_access_log is False: + file_handler.addFilter(log_access_log_filter) + + # Remove unwanted attributes from the log record. + file_handler.addFilter(ServeLogAttributeRemovalFilter()) + + # Redirect print, stdout, and stderr to Serve logger, only when it's on the replica. + if not RAY_SERVE_LOG_TO_STDERR and component_type == ServeComponentType.REPLICA: + builtins.print = redirected_print + sys.stdout = StreamToLogger(logger, logging.INFO, sys.stdout) + sys.stderr = StreamToLogger(logger, logging.INFO, sys.stderr) + + logger.addHandler(file_handler) + + +def configure_default_serve_logger(): + """Helper function to configure the default Serve logger that's used outside of + individual Serve components.""" + configure_component_logger( + component_name="serve", + component_id=str(os.getpid()), + logging_config=LoggingConfig(), + max_bytes=LOGGING_ROTATE_BYTES, + backup_count=LOGGING_ROTATE_BACKUP_COUNT, + stream_handler_only=True, + ) + + +def configure_component_memory_profiler( + component_name: str, + component_id: str, + component_type: Optional[ServeComponentType] = None, +): + """Configures the memory logger for this component. + + Does nothing if RAY_SERVE_ENABLE_MEMORY_PROFILING is disabled. + """ + + if RAY_SERVE_ENABLE_MEMORY_PROFILING: + logger = logging.getLogger(SERVE_LOGGER_NAME) + + try: + import memray + + logs_dir = get_serve_logs_dir() + memray_file_name = get_component_file_name( + component_name=component_name, + component_id=component_id, + component_type=component_type, + suffix="_memray_0.bin", + ) + memray_file_path = os.path.join(logs_dir, memray_file_name) + + # If the actor restarted, memray requires a new file to start + # tracking memory. + restart_counter = 1 + while os.path.exists(memray_file_path): + memray_file_name = get_component_file_name( + component_name=component_name, + component_id=component_id, + component_type=component_type, + suffix=f"_memray_{restart_counter}.bin", + ) + memray_file_path = os.path.join(logs_dir, memray_file_name) + restart_counter += 1 + + # Memray usually tracks the memory usage of only a block of code + # within a context manager. We explicitly call __enter__ here + # instead of using a context manager to track memory usage across + # all of the caller's code instead. + memray.Tracker(memray_file_path, native_traces=True).__enter__() + + logger.info( + "RAY_SERVE_ENABLE_MEMORY_PROFILING is enabled. Started a " + "memray tracker on this actor. Tracker file located at " + f'"{memray_file_path}"' + ) + + except ImportError: + logger.warning( + "RAY_SERVE_ENABLE_MEMORY_PROFILING is enabled, but memray " + "is not installed. No memory profiling is happening. " + "`pip install memray` to enable memory profiling." + ) + + +def configure_component_cpu_profiler( + component_name: str, + component_id: str, + component_type: Optional[ServeComponentType] = None, +) -> Tuple[Optional[cProfile.Profile], Optional[str]]: + """Configures the CPU profiler for this component. + + Does nothing if RAY_SERVE_ENABLE_CPU_PROFILING is disabled. + + Returns: + 2-tuple containing profiler object and log file name for profile stats. + """ + + if RAY_SERVE_ENABLE_CPU_PROFILING: + logger = logging.getLogger(SERVE_LOGGER_NAME) + + try: + import cProfile + except ImportError: + logger.warning( + "RAY_SERVE_ENABLE_CPU_PROFILING is enabled, but cProfile " + "is not installed. No CPU profiling is happening." + ) + return None, None + try: + # Need marshal to dump data. Check if marshal is installed before + # starting the profiler. + import marshal # noqa: F401 + except ImportError: + logger.warning( + "RAY_SERVE_ENABLE_CPU_PROFILING is enabled, but marshal " + "is not installed. No CPU profiling is happening." + ) + return None, None + + logs_dir = get_serve_logs_dir() + cpu_profiler_file_name = get_component_file_name( + component_name=component_name, + component_id=component_id, + component_type=component_type, + suffix="_cprofile.prof", + ) + cpu_profiler_file_path = os.path.join(logs_dir, cpu_profiler_file_name) + + profile = cProfile.Profile() + profile.enable() + logger.info( + "RAY_SERVE_ENABLE_CPU_PROFILING is enabled. Started cProfile " + "on this actor." + ) + return profile, cpu_profiler_file_path + else: + return None, None + + +def get_serve_logs_dir() -> str: + """Get the directory that stores Serve log files. + + If `ray._private.worker._global_node` is None (running outside the context of Ray), + then the current working directory with subdirectory of serve is used as the logs + directory. Otherwise, the logs directory is determined by the global node's logs + directory path. + """ + if ray._private.worker._global_node is None: + return os.path.join(os.getcwd(), "serve") + + return os.path.join(ray._private.worker._global_node.get_logs_dir_path(), "serve") + + +class LoggingContext: + """ + Context manager to manage logging behaviors within a particular block, such as: + 1) Overriding logging level + + Source (python3 official documentation) + https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging # noqa: E501 + """ + + def __init__(self, logger, level=None): + self.logger = logger + self.level = level + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/long_poll.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/long_poll.py new file mode 100644 index 0000000000000000000000000000000000000000..a47511207e5e67b1729950c3b1f0341d703b9efd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/long_poll.py @@ -0,0 +1,464 @@ +import asyncio +import logging +import os +import random +from asyncio import sleep +from asyncio.events import AbstractEventLoop +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Callable, DefaultDict, Dict, Optional, Set, Tuple, Union + +import ray +from ray._private.utils import get_or_create_event_loop +from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve.generated.serve_pb2 import DeploymentTargetInfo +from ray.serve.generated.serve_pb2 import EndpointInfo as EndpointInfoProto +from ray.serve.generated.serve_pb2 import EndpointSet, LongPollRequest, LongPollResult +from ray.serve.generated.serve_pb2 import UpdatedObject as UpdatedObjectProto +from ray.util import metrics + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +# Each LongPollClient will send requests to LongPollHost to poll changes +# as blocking awaitable. This doesn't scale if we have many client instances +# that will slow down, or even block controller actor's event loop if near +# its max_concurrency limit. Therefore we timeout a polling request after +# a few seconds and let each client retry on their end. +# We randomly select a timeout within this range to avoid a "thundering herd" +# when there are many clients subscribing at the same time. +LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S = ( + float(os.environ.get("LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_LOWER_BOUND", "30")), + float(os.environ.get("LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_UPPER_BOUND", "60")), +) + + +class LongPollNamespace(Enum): + def __repr__(self): + return f"{self.__class__.__name__}.{self.name}" + + DEPLOYMENT_TARGETS = auto() + ROUTE_TABLE = auto() + GLOBAL_LOGGING_CONFIG = auto() + DEPLOYMENT_CONFIG = auto() + + +@dataclass +class UpdatedObject: + object_snapshot: Any + # The identifier for the object's version. There is not sequential relation + # among different object's snapshot_ids. + snapshot_id: int + + +# Type signature for the update state callbacks. E.g. +# async def update_state(updated_object: Any): +# do_something(updated_object) +UpdateStateCallable = Callable[[Any], None] +KeyType = Union[str, LongPollNamespace, Tuple[LongPollNamespace, str]] + + +class LongPollState(Enum): + TIME_OUT = auto() + + +class LongPollClient: + """The asynchronous long polling client. + + Args: + host_actor: handle to actor embedding LongPollHost. + key_listeners: a dictionary mapping keys to + callbacks to be called on state update for the corresponding keys. + call_in_event_loop: an asyncio event loop + to post the callback into. + """ + + def __init__( + self, + host_actor, + key_listeners: Dict[KeyType, UpdateStateCallable], + call_in_event_loop: AbstractEventLoop, + ) -> None: + # We used to allow this to be optional, but due to Ray Client issue + # we now enforce all long poll client to post callback to event loop + # See https://github.com/ray-project/ray/issues/20971 + assert call_in_event_loop is not None + + self.host_actor = host_actor + self.key_listeners = key_listeners + self.event_loop = call_in_event_loop + self.snapshot_ids: Dict[KeyType, int] = { + # The initial snapshot id for each key is < 0, + # but real snapshot keys in the long poll host are always >= 0, + # so this will always trigger an initial update. + key: -1 + for key in self.key_listeners.keys() + } + self.is_running = True + + self._poll_next() + + def stop(self) -> None: + """Stop the long poll client after the next RPC returns.""" + self.is_running = False + + def add_key_listeners( + self, key_listeners: Dict[KeyType, UpdateStateCallable] + ) -> None: + """Add more key listeners to the client. + The new listeners will only be included in the *next* long poll request; + the current request will continue with the existing listeners. + + If a key is already in the client, the new listener will replace the old one, + but the snapshot ID will be preserved, so the new listener will only be called + on the *next* update to that key. + """ + # Only initialize snapshot ids for *new* keys. + self.snapshot_ids.update( + {key: -1 for key in key_listeners.keys() if key not in self.key_listeners} + ) + self.key_listeners.update(key_listeners) + + def _on_callback_completed(self, trigger_at: int): + """Called after a single callback is completed. + + When the total number of callback completed equals to trigger_at, + _poll_next() will be called. This is designed to make sure we only + _poll_next() after all the state callbacks completed. This is a + way to serialize the callback invocations between object versions. + """ + self._callbacks_processed_count += 1 + if self._callbacks_processed_count == trigger_at: + self._poll_next() + + def _poll_next(self): + """Poll the update. The callback is expected to scheduler another + _poll_next call. + """ + if not self.is_running: + return + + self._callbacks_processed_count = 0 + self._current_ref = self.host_actor.listen_for_change.remote(self.snapshot_ids) + self._current_ref._on_completed(lambda update: self._process_update(update)) + + def _schedule_to_event_loop(self, callback): + # Schedule the next iteration only if the loop is running. + # The event loop might not be running if users used a cached + # version across loops. + if self.event_loop.is_running(): + self.event_loop.call_soon_threadsafe(callback) + else: + logger.error("The event loop is closed, shutting down long poll client.") + self.is_running = False + + def _process_update(self, updates: Dict[str, UpdatedObject]): + if isinstance(updates, (ray.exceptions.RayActorError)): + # This can happen during shutdown where the controller is + # intentionally killed, the client should just gracefully + # exit. + logger.debug("LongPollClient failed to connect to host. Shutting down.") + self.is_running = False + return + + if isinstance(updates, ConnectionError): + logger.warning("LongPollClient connection failed, shutting down.") + self.is_running = False + return + + if isinstance(updates, (ray.exceptions.RayTaskError)): + # Some error happened in the controller. It could be a bug or + # some undesired state. + logger.error("LongPollHost errored\n" + updates.traceback_str) + # We must call this in event loop so it works in Ray Client. + # See https://github.com/ray-project/ray/issues/20971 + self._schedule_to_event_loop(self._poll_next) + return + + if updates == LongPollState.TIME_OUT: + logger.debug("LongPollClient polling timed out. Retrying.") + self._schedule_to_event_loop(self._poll_next) + return + + logger.debug( + f"LongPollClient {self} received updates for keys: " + f"{list(updates.keys())}.", + extra={"log_to_stderr": False}, + ) + if not updates: # no updates, no callbacks to run, just poll again + self._schedule_to_event_loop(self._poll_next) + for key, update in updates.items(): + self.snapshot_ids[key] = update.snapshot_id + callback = self.key_listeners[key] + + # Bind the parameters because closures are late-binding. + # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501 + def chained(callback=callback, arg=update.object_snapshot): + callback(arg) + self._on_callback_completed(trigger_at=len(updates)) + + self._schedule_to_event_loop(chained) + + +class LongPollHost: + """The server side object that manages long pulling requests. + + The desired use case is to embed this in an Ray actor. Client will be + expected to call actor.listen_for_change.remote(...). On the host side, + you can call host.notify_changed({key: object}) to update the state and + potentially notify whoever is polling for these values. + + Internally, we use snapshot_ids for each object to identify client with + outdated object and immediately return the result. If the client has the + up-to-date version, then the listen_for_change call will only return when + the object is updated. + """ + + def __init__( + self, + listen_for_change_request_timeout_s: Tuple[ + int, int + ] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S, + ): + # Map object_key -> int + self.snapshot_ids: Dict[KeyType, int] = {} + # Map object_key -> object + self.object_snapshots: Dict[KeyType, Any] = {} + # Map object_key -> set(asyncio.Event waiting for updates) + self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict( + set + ) + + self._listen_for_change_request_timeout_s = listen_for_change_request_timeout_s + self.transmission_counter = metrics.Counter( + "serve_long_poll_host_transmission_counter", + description="The number of times the long poll host transmits data.", + tag_keys=("namespace_or_state",), + ) + + def _get_num_notifier_events(self, key: Optional[KeyType] = None): + """Used for testing.""" + if key is not None: + return len(self.notifier_events[key]) + else: + return sum(len(events) for events in self.notifier_events.values()) + + def _count_send( + self, timeout_or_data: Union[LongPollState, Dict[KeyType, UpdatedObject]] + ): + """Helper method that tracks the data sent by listen_for_change. + + Records number of times long poll host sends data in the + ray_serve_long_poll_host_send_counter metric. + """ + + if isinstance(timeout_or_data, LongPollState): + # The only LongPollState is TIME_OUT– the long poll + # connection has timed out. + self.transmission_counter.inc( + value=1, tags={"namespace_or_state": "TIMEOUT"} + ) + else: + data = timeout_or_data + for key in data.keys(): + self.transmission_counter.inc( + value=1, tags={"namespace_or_state": str(key)} + ) + + async def listen_for_change( + self, + keys_to_snapshot_ids: Dict[KeyType, int], + ) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]: + """Listen for changed objects. + + This method will return a dictionary of updated objects. It returns + immediately if any of the snapshot_ids are outdated, + otherwise it will block until there's an update. + """ + # If there are no keys to listen for, + # just wait for a short time to provide backpressure, + # then return an empty update. + if not keys_to_snapshot_ids: + await sleep(1) + + updated_objects = {} + self._count_send(updated_objects) + return updated_objects + + # If there are any keys with outdated snapshot ids, + # return their updated values immediately. + updated_objects = {} + for key, client_snapshot_id in keys_to_snapshot_ids.items(): + try: + existing_id = self.snapshot_ids[key] + except KeyError: + # The caller may ask for keys that we don't know about (yet), + # just ignore them. + # This can happen when, for example, + # a deployment handle is manually created for an app + # that hasn't been deployed yet (by bypassing the safety checks). + continue + + if existing_id != client_snapshot_id: + updated_objects[key] = UpdatedObject( + self.object_snapshots[key], existing_id + ) + if len(updated_objects) > 0: + self._count_send(updated_objects) + return updated_objects + + # Otherwise, register asyncio events to be waited. + async_task_to_events = {} + async_task_to_watched_keys = {} + for key in keys_to_snapshot_ids.keys(): + # Create a new asyncio event for this key. + event = asyncio.Event() + + # Make sure future caller of notify_changed will unblock this + # asyncio Event. + self.notifier_events[key].add(event) + + task = get_or_create_event_loop().create_task(event.wait()) + async_task_to_events[task] = event + async_task_to_watched_keys[task] = key + + done, not_done = await asyncio.wait( + async_task_to_watched_keys.keys(), + return_when=asyncio.FIRST_COMPLETED, + timeout=random.uniform(*self._listen_for_change_request_timeout_s), + ) + + for task in not_done: + task.cancel() + try: + event = async_task_to_events[task] + self.notifier_events[async_task_to_watched_keys[task]].remove(event) + except KeyError: + # Because we use `FIRST_COMPLETED` above, a task in `not_done` may + # actually have had its event removed in `notify_changed`. + pass + + if len(done) == 0: + self._count_send(LongPollState.TIME_OUT) + return LongPollState.TIME_OUT + else: + updated_objects = {} + for task in done: + updated_object_key = async_task_to_watched_keys[task] + updated_objects[updated_object_key] = UpdatedObject( + self.object_snapshots[updated_object_key], + self.snapshot_ids[updated_object_key], + ) + self._count_send(updated_objects) + return updated_objects + + async def listen_for_change_java( + self, + keys_to_snapshot_ids_bytes: bytes, + ) -> bytes: + """Listen for changed objects. only call by java proxy/router now. + Args: + keys_to_snapshot_ids_bytes (Dict[str, int]): the protobuf bytes of + keys_to_snapshot_ids (Dict[str, int]). + """ + request_proto = LongPollRequest.FromString(keys_to_snapshot_ids_bytes) + keys_to_snapshot_ids = { + self._parse_xlang_key(xlang_key): snapshot_id + for xlang_key, snapshot_id in request_proto.keys_to_snapshot_ids.items() + } + keys_to_updated_objects = await self.listen_for_change(keys_to_snapshot_ids) + return self._listen_result_to_proto_bytes(keys_to_updated_objects) + + def _parse_poll_namespace(self, name: str): + if name == LongPollNamespace.ROUTE_TABLE.name: + return LongPollNamespace.ROUTE_TABLE + elif name == LongPollNamespace.DEPLOYMENT_TARGETS.name: + return LongPollNamespace.DEPLOYMENT_TARGETS + else: + return name + + def _parse_xlang_key(self, xlang_key: str) -> KeyType: + if xlang_key is None: + raise ValueError("func _parse_xlang_key: xlang_key is None") + if xlang_key.startswith("(") and xlang_key.endswith(")"): + fields = xlang_key[1:-1].split(",") + if len(fields) == 2: + enum_field = self._parse_poll_namespace(fields[0].strip()) + if isinstance(enum_field, LongPollNamespace): + return enum_field, fields[1].strip() + else: + return self._parse_poll_namespace(xlang_key) + raise ValueError("can not parse key type from xlang_key {}".format(xlang_key)) + + def _build_xlang_key(self, key: KeyType) -> str: + if isinstance(key, tuple): + return "(" + key[0].name + "," + key[1] + ")" + elif isinstance(key, LongPollNamespace): + return key.name + else: + return key + + def _object_snapshot_to_proto_bytes( + self, key: KeyType, object_snapshot: Any + ) -> bytes: + if key == LongPollNamespace.ROUTE_TABLE: + # object_snapshot is Dict[DeploymentID, EndpointInfo] + # NOTE(zcin): the endpoint dictionary broadcasted to Java + # HTTP proxies should use string as key because Java does + # not yet support 2.x or applications + xlang_endpoints = { + str(endpoint_tag): EndpointInfoProto(route=endpoint_info.route) + for endpoint_tag, endpoint_info in object_snapshot.items() + } + return EndpointSet(endpoints=xlang_endpoints).SerializeToString() + elif isinstance(key, tuple) and key[0] == LongPollNamespace.DEPLOYMENT_TARGETS: + # object_snapshot.running_replicas is List[RunningReplicaInfo] + actor_name_list = [ + replica_info.replica_id.to_full_id_str() + for replica_info in object_snapshot.running_replicas + ] + return DeploymentTargetInfo( + replica_names=actor_name_list, + is_available=object_snapshot.is_available, + ).SerializeToString() + else: + return str.encode(str(object_snapshot)) + + def _listen_result_to_proto_bytes( + self, keys_to_updated_objects: Dict[KeyType, UpdatedObject] + ) -> bytes: + xlang_keys_to_updated_objects = { + self._build_xlang_key(key): UpdatedObjectProto( + snapshot_id=updated_object.snapshot_id, + object_snapshot=self._object_snapshot_to_proto_bytes( + key, updated_object.object_snapshot + ), + ) + for key, updated_object in keys_to_updated_objects.items() + } + data = { + "updated_objects": xlang_keys_to_updated_objects, + } + proto = LongPollResult(**data) + return proto.SerializeToString() + + def notify_changed(self, updates: Mapping[KeyType, Any]) -> None: + """ + Update the current snapshot of some objects + and notify any long poll clients. + """ + for object_key, updated_object in updates.items(): + try: + self.snapshot_ids[object_key] += 1 + except KeyError: + # Initial snapshot id must be >= 0, so that the long poll client + # can send a negative initial snapshot id to get a fast update. + # They should also be randomized; see + # https://github.com/ray-project/ray/pull/45881#discussion_r1645243485 + self.snapshot_ids[object_key] = random.randint(0, 1_000_000) + self.object_snapshots[object_key] = updated_object + logger.debug(f"LongPollHost: Notify change for key {object_key}.") + + for event in self.notifier_events.pop(object_key, set()): + event.set() diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/metrics_utils.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14efb553ca0981fb9a487800ab47857c2e7cc749 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/metrics_utils.py @@ -0,0 +1,216 @@ +import asyncio +import bisect +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Callable, DefaultDict, Dict, Hashable, List, Optional + +from ray.serve._private.constants import ( + METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S, + SERVE_LOGGER_NAME, +) + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +@dataclass +class _MetricsTask: + task_func: Callable + interval_s: float + + +class MetricsPusher: + """Periodically runs registered asyncio tasks.""" + + def __init__( + self, + *, + async_sleep: Optional[Callable[[int], None]] = None, + ): + self._async_sleep = async_sleep or asyncio.sleep + self._tasks: Dict[str, _MetricsTask] = dict() + self._async_tasks: Dict[str, asyncio.Task] = dict() + + # The event needs to be lazily initialized because this class may be constructed + # on the main thread but its methods called on a separate asyncio loop. + self._stop_event: Optional[asyncio.Event] = None + + @property + def stop_event(self) -> asyncio.Event: + if self._stop_event is None: + self._stop_event = asyncio.Event() + + return self._stop_event + + def start(self): + self.stop_event.clear() + + async def metrics_task(self, name: str): + """Periodically runs `task_func` every `interval_s` until `stop_event` is set. + + If `task_func` raises an error, an exception will be logged. + """ + + wait_for_stop_event = asyncio.create_task(self.stop_event.wait()) + while True: + if wait_for_stop_event.done(): + return + + try: + self._tasks[name].task_func() + except Exception as e: + logger.exception(f"Failed to run metrics task '{name}': {e}") + + sleep_task = asyncio.create_task( + self._async_sleep(self._tasks[name].interval_s) + ) + await asyncio.wait( + [sleep_task, wait_for_stop_event], + return_when=asyncio.FIRST_COMPLETED, + ) + + if not sleep_task.done(): + sleep_task.cancel() + + def register_or_update_task( + self, + name: str, + task_func: Callable, + interval_s: int, + ) -> None: + """Register a task under the provided name, or update it. + + This method is idempotent - if a task is already registered with + the specified name, it will update it with the most recent info. + """ + + self._tasks[name] = _MetricsTask(task_func, interval_s) + if name not in self._async_tasks or self._async_tasks[name].done(): + self._async_tasks[name] = asyncio.create_task(self.metrics_task(name)) + + def stop_tasks(self): + self.stop_event.set() + self._tasks.clear() + self._async_tasks.clear() + + async def graceful_shutdown(self): + """Shutdown metrics pusher gracefully. + + This method will ensure idempotency of shutdown call. + """ + + self.stop_event.set() + if self._async_tasks: + await asyncio.wait( + list(self._async_tasks.values()), + timeout=METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S, + ) + + self._tasks.clear() + self._async_tasks.clear() + + +@dataclass(order=True) +class TimeStampedValue: + timestamp: float + value: float = field(compare=False) + + +class InMemoryMetricsStore: + """A very simple, in memory time series database""" + + def __init__(self): + self.data: DefaultDict[Hashable, List[TimeStampedValue]] = defaultdict(list) + + def add_metrics_point(self, data_points: Dict[Hashable, float], timestamp: float): + """Push new data points to the store. + + Args: + data_points: dictionary containing the metrics values. The + key should uniquely identify this time series + and to be used to perform aggregation. + timestamp: the unix epoch timestamp the metrics are + collected at. + """ + for name, value in data_points.items(): + # Using in-sort to insert while maintaining sorted ordering. + bisect.insort(a=self.data[name], x=TimeStampedValue(timestamp, value)) + + def prune_keys_and_compact_data(self, start_timestamp_s: float): + """Prune keys and compact data that are outdated. + + For keys that haven't had new data recorded after the timestamp, + remove them from the database. + For keys that have, compact the datapoints that were recorded + before the timestamp. + """ + for key, datapoints in list(self.data.items()): + if len(datapoints) == 0 or datapoints[-1].timestamp < start_timestamp_s: + del self.data[key] + else: + self.data[key] = self._get_datapoints(key, start_timestamp_s) + + def _get_datapoints( + self, key: Hashable, window_start_timestamp_s: float + ) -> List[float]: + """Get all data points given key after window_start_timestamp_s""" + + datapoints = self.data[key] + + idx = bisect.bisect( + a=datapoints, + x=TimeStampedValue( + timestamp=window_start_timestamp_s, value=0 # dummy value + ), + ) + return datapoints[idx:] + + def window_average( + self, key: Hashable, window_start_timestamp_s: float, do_compact: bool = True + ) -> Optional[float]: + """Perform a window average operation for metric `key` + + Args: + key: the metric name. + window_start_timestamp_s: the unix epoch timestamp for the + start of the window. The computed average will use all datapoints + from this timestamp until now. + do_compact: whether or not to delete the datapoints that's + before `window_start_timestamp_s` to save memory. Default is + true. + Returns: + The average of all the datapoints for the key on and after time + window_start_timestamp_s, or None if there are no such points. + """ + points_after_idx = self._get_datapoints(key, window_start_timestamp_s) + + if do_compact: + self.data[key] = points_after_idx + + if len(points_after_idx) == 0: + return + return sum(point.value for point in points_after_idx) / len(points_after_idx) + + def max( + self, key: Hashable, window_start_timestamp_s: float, do_compact: bool = True + ): + """Perform a max operation for metric `key`. + + Args: + key: the metric name. + window_start_timestamp_s: the unix epoch timestamp for the + start of the window. The computed average will use all datapoints + from this timestamp until now. + do_compact: whether or not to delete the datapoints that's + before `window_start_timestamp_s` to save memory. Default is + true. + Returns: + Max value of the data points for the key on and after time + window_start_timestamp_s, or None if there are no such points. + """ + points_after_idx = self._get_datapoints(key, window_start_timestamp_s) + + if do_compact: + self.data[key] = points_after_idx + + return max((point.value for point in points_after_idx), default=None) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..2e445a36f16f31cf7c3118ac719e7f657125aaf2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy.py @@ -0,0 +1,1538 @@ +import asyncio +import gc +import json +import logging +import os +import pickle +import socket +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple + +import grpc +import starlette +import starlette.routing +import uvicorn +from packaging import version +from starlette.datastructures import MutableHeaders +from starlette.middleware import Middleware +from starlette.types import Receive + +import ray +from ray._private.utils import get_or_create_event_loop +from ray.exceptions import RayActorError, RayTaskError +from ray.serve._private.common import ( + DeploymentID, + EndpointInfo, + NodeId, + ReplicaID, + RequestMetadata, + RequestProtocol, +) +from ray.serve._private.constants import ( + DEFAULT_LATENCY_BUCKET_MS, + DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, + PROXY_MIN_DRAINING_PERIOD_S, + RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS, + RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH, + RAY_SERVE_PROXY_GC_THRESHOLD, + SERVE_CONTROLLER_NAME, + SERVE_LOGGER_NAME, + SERVE_MULTIPLEXED_MODEL_ID, + SERVE_NAMESPACE, +) +from ray.serve._private.default_impl import add_grpc_address, get_proxy_handle +from ray.serve._private.grpc_util import DummyServicer, gRPCGenericServer +from ray.serve._private.http_util import ( + MessageQueue, + convert_object_to_asgi_messages, + receive_http_body, + set_socket_reuse_port, + validate_http_proxy_callback_return, +) +from ray.serve._private.logging_utils import ( + access_log_msg, + configure_component_cpu_profiler, + configure_component_logger, + configure_component_memory_profiler, + get_component_logger_file_path, +) +from ray.serve._private.long_poll import LongPollClient, LongPollNamespace +from ray.serve._private.proxy_request_response import ( + ASGIProxyRequest, + HandlerMetadata, + ProxyRequest, + ResponseGenerator, + ResponseHandlerInfo, + ResponseStatus, + gRPCProxyRequest, +) +from ray.serve._private.proxy_response_generator import ProxyResponseGenerator +from ray.serve._private.proxy_router import ProxyRouter +from ray.serve._private.usage import ServeUsageTag +from ray.serve._private.utils import ( + call_function_from_import_path, + generate_request_id, + get_head_node_id, +) +from ray.serve.config import gRPCOptions +from ray.serve.exceptions import BackPressureError, DeploymentUnavailableError +from ray.serve.generated.serve_pb2 import HealthzResponse, ListApplicationsResponse +from ray.serve.generated.serve_pb2_grpc import add_RayServeAPIServiceServicer_to_server +from ray.serve.handle import DeploymentHandle +from ray.serve.schema import LoggingConfig +from ray.util import metrics + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +HTTP_REQUEST_MAX_RETRIES = int(os.environ.get("RAY_SERVE_HTTP_REQUEST_MAX_RETRIES", 10)) +assert HTTP_REQUEST_MAX_RETRIES >= 0, ( + f"Got unexpected value {HTTP_REQUEST_MAX_RETRIES} for " + "RAY_SERVE_HTTP_REQUEST_MAX_RETRIES environment variable. " + "RAY_SERVE_HTTP_REQUEST_MAX_RETRIES cannot be negative." +) + +TIMEOUT_ERROR_CODE = "timeout" +DISCONNECT_ERROR_CODE = "disconnection" +SOCKET_REUSE_PORT_ENABLED = ( + os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") == "1" +) + +RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S = int( + os.environ.get("RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S", 0) +) +# TODO (shrekris-anyscale): Deprecate SERVE_REQUEST_PROCESSING_TIMEOUT_S env var +RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S = ( + float(os.environ.get("RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0)) + or float(os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0)) + or None +) +# Controls whether Ray Serve is operating in debug-mode switching off some +# of the performance optimizations to make troubleshooting easier +RAY_SERVE_DEBUG_MODE = bool(os.environ.get("RAY_SERVE_DEBUG_MODE", 0)) + +if os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S") is not None: + logger.warning( + "The `SERVE_REQUEST_PROCESSING_TIMEOUT_S` environment variable has " + "been deprecated. Please set `request_timeout_s` in your Serve config's " + "`http_options` field instead. `SERVE_REQUEST_PROCESSING_TIMEOUT_S` will be " + "ignored in future versions. See: https://docs.ray.io/en/releases-2.5.1/serve/a" + "pi/doc/ray.serve.schema.HTTPOptionsSchema.html#ray.serve.schema.HTTPOptionsSch" + "ema.request_timeout_s" + ) + + +INITIAL_BACKOFF_PERIOD_SEC = 0.05 +MAX_BACKOFF_PERIOD_SEC = 5 + +HEALTHY_MESSAGE = "success" +DRAINING_MESSAGE = "This node is being drained." + + +class GenericProxy(ABC): + """This class is served as the base class for different types of proxies. + It contains all the common setup and methods required for running a proxy. + + The proxy subclass need to implement the following methods: + - `protocol()` + - `not_found_response()` + - `routes_response()` + - `health_response()` + - `setup_request_context_and_handle()` + - `send_request_to_replica()` + """ + + def __init__( + self, + node_id: NodeId, + node_ip_address: str, + is_head: bool, + proxy_router: ProxyRouter, + request_timeout_s: Optional[float] = None, + ): + self.request_timeout_s = request_timeout_s + if self.request_timeout_s is not None and self.request_timeout_s < 0: + self.request_timeout_s = None + + self._node_id = node_id + self._is_head = is_head + + self.proxy_router = proxy_router + self.request_counter = metrics.Counter( + f"serve_num_{self.protocol.lower()}_requests", + description=f"The number of {self.protocol} requests processed.", + tag_keys=("route", "method", "application", "status_code"), + ) + + self.request_error_counter = metrics.Counter( + f"serve_num_{self.protocol.lower()}_error_requests", + description=f"The number of errored {self.protocol} responses.", + tag_keys=( + "route", + "error_code", + "method", + "application", + ), + ) + + self.deployment_request_error_counter = metrics.Counter( + f"serve_num_deployment_{self.protocol.lower()}_error_requests", + description=( + f"The number of errored {self.protocol} " + "responses returned by each deployment." + ), + tag_keys=( + "deployment", + "error_code", + "method", + "route", + "application", + ), + ) + + self.processing_latency_tracker = metrics.Histogram( + f"serve_{self.protocol.lower()}_request_latency_ms", + description=( + f"The end-to-end latency of {self.protocol} requests " + f"(measured from the Serve {self.protocol} proxy)." + ), + boundaries=DEFAULT_LATENCY_BUCKET_MS, + tag_keys=( + "method", + "route", + "application", + "status_code", + ), + ) + + self.num_ongoing_requests_gauge = metrics.Gauge( + name=f"serve_num_ongoing_{self.protocol.lower()}_requests", + description=f"The number of ongoing requests in this {self.protocol} " + "proxy.", + tag_keys=("node_id", "node_ip_address"), + ).set_default_tags( + { + "node_id": node_id, + "node_ip_address": node_ip_address, + } + ) + + # `self._ongoing_requests` is used to count the number of ongoing requests + self._ongoing_requests = 0 + # The time when the node starts to drain. + # The node is not draining if it's None. + self._draining_start_time: Optional[float] = None + + getattr(ServeUsageTag, f"{self.protocol.upper()}_PROXY_USED").record("1") + + @property + @abstractmethod + def protocol(self) -> RequestProtocol: + """Protocol used in the proxy. + + Each proxy needs to implement its own logic for setting up the protocol. + """ + raise NotImplementedError + + def _is_draining(self) -> bool: + """Whether is proxy actor is in the draining status or not.""" + return self._draining_start_time is not None + + def is_drained(self): + """Check whether the proxy actor is drained or not. + + A proxy actor is drained if it has no ongoing requests + AND it has been draining for more than + `PROXY_MIN_DRAINING_PERIOD_S` seconds. + """ + if not self._is_draining(): + return False + + return (not self._ongoing_requests) and ( + (time.time() - self._draining_start_time) > PROXY_MIN_DRAINING_PERIOD_S + ) + + def update_draining(self, draining: bool): + """Update the draining status of the proxy. + + This is called by the proxy state manager + to drain or un-drain the proxy actor. + """ + + if draining and (not self._is_draining()): + logger.info( + f"Start to drain the proxy actor on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + self._draining_start_time = time.time() + if (not draining) and self._is_draining(): + logger.info( + f"Stop draining the proxy actor on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + self._draining_start_time = None + + @abstractmethod + async def not_found_response( + self, proxy_request: ProxyRequest + ) -> ResponseGenerator: + raise NotImplementedError + + @abstractmethod + async def routes_response( + self, *, healthy: bool, message: str + ) -> ResponseGenerator: + raise NotImplementedError + + @abstractmethod + async def health_response( + self, *, healthy: bool, message: str + ) -> ResponseGenerator: + raise NotImplementedError + + def _ongoing_requests_start(self): + """Ongoing requests start. + + The current autoscale logic can downscale nodes with ongoing requests if the + node doesn't have replicas and has no primary copies of objects in the object + store. The counter and the dummy object reference will help to keep the node + alive while draining requests, so they are not dropped unintentionally. + """ + self._ongoing_requests += 1 + self.num_ongoing_requests_gauge.set(self._ongoing_requests) + + def _ongoing_requests_end(self): + """Ongoing requests end. + + Decrement the ongoing request counter and drop the dummy object reference + signaling that the node can be downscaled safely. + """ + self._ongoing_requests -= 1 + self.num_ongoing_requests_gauge.set(self._ongoing_requests) + + def _get_health_or_routes_reponse( + self, proxy_request: ProxyRequest + ) -> ResponseHandlerInfo: + """Get the response handler for system health and route endpoints. + + If the proxy is draining or has not yet received a route table update from the + controller, both will return a non-OK status. + """ + router_ready_for_traffic, router_msg = self.proxy_router.ready_for_traffic( + self._is_head + ) + if self._is_draining(): + healthy = False + message = DRAINING_MESSAGE + elif not router_ready_for_traffic: + healthy = False + message = router_msg + else: + healthy = True + message = HEALTHY_MESSAGE + + if proxy_request.is_health_request: + response_generator = self.health_response(healthy=healthy, message=message) + else: + assert proxy_request.is_route_request + response_generator = self.routes_response(healthy=healthy, message=message) + + return ResponseHandlerInfo( + response_generator=response_generator, + metadata=HandlerMetadata( + route=proxy_request.route_path, + ), + should_record_access_log=False, + should_increment_ongoing_requests=False, + ) + + def _get_response_handler_info( + self, proxy_request: ProxyRequest + ) -> ResponseHandlerInfo: + if proxy_request.is_health_request or proxy_request.is_route_request: + return self._get_health_or_routes_reponse(proxy_request) + + matched_route = None + if self.protocol == RequestProtocol.HTTP: + matched_route = self.proxy_router.match_route(proxy_request.route_path) + elif self.protocol == RequestProtocol.GRPC: + matched_route = self.proxy_router.get_handle_for_endpoint( + proxy_request.route_path + ) + + if matched_route is None: + return ResponseHandlerInfo( + response_generator=self.not_found_response(proxy_request), + metadata=HandlerMetadata( + # Don't include the invalid route prefix because it can blow up our + # metrics' cardinality. + # See: https://github.com/ray-project/ray/issues/47999 + route="", + ), + should_record_access_log=True, + should_increment_ongoing_requests=False, + ) + else: + route_prefix, handle, app_is_cross_language = matched_route + # Modify the path and root path so that reverse lookups and redirection + # work as expected. We do this here instead of in replicas so it can be + # changed without restarting the replicas. + route_path = proxy_request.route_path + if route_prefix != "/" and self.protocol == RequestProtocol.HTTP: + assert not route_prefix.endswith("/") + proxy_request.set_root_path(proxy_request.root_path + route_prefix) + # NOTE(edoakes): starlette<0.33.0 expected the ASGI 'root_prefix' + # to be stripped from the 'path', which wasn't technically following + # the standard. See https://github.com/encode/starlette/pull/2352. + if version.parse(starlette.__version__) < version.parse("0.33.0"): + proxy_request.set_path(route_path.replace(route_prefix, "", 1)) + + # NOTE(edoakes): we use the route_prefix instead of the full HTTP path + # for logs & metrics to avoid high cardinality. + # See: https://github.com/ray-project/ray/issues/47999 + logs_and_metrics_route = ( + route_prefix + if self.protocol == RequestProtocol.HTTP + else handle.deployment_id.app_name + ) + internal_request_id = generate_request_id() + handle, request_id = self.setup_request_context_and_handle( + app_name=handle.deployment_id.app_name, + handle=handle, + route=logs_and_metrics_route, + proxy_request=proxy_request, + internal_request_id=internal_request_id, + ) + + response_generator = self.send_request_to_replica( + request_id=request_id, + internal_request_id=internal_request_id, + handle=handle, + proxy_request=proxy_request, + app_is_cross_language=app_is_cross_language, + ) + + return ResponseHandlerInfo( + response_generator=response_generator, + metadata=HandlerMetadata( + application_name=handle.deployment_id.app_name, + deployment_name=handle.deployment_id.name, + route=logs_and_metrics_route, + ), + should_record_access_log=True, + should_increment_ongoing_requests=True, + ) + + async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator: + """Wrapper for proxy request. + + This method is served as common entry point by the proxy. It handles the + routing, including routes and health checks, ongoing request counter, + and metrics. + """ + assert proxy_request.request_type in {"http", "websocket", "grpc"} + + response_handler_info = self._get_response_handler_info(proxy_request) + + start_time = time.time() + if response_handler_info.should_increment_ongoing_requests: + self._ongoing_requests_start() + + try: + # The final message yielded must always be the `ResponseStatus`. + status: Optional[ResponseStatus] = None + async for message in response_handler_info.response_generator: + if isinstance(message, ResponseStatus): + status = message + + yield message + + assert status is not None and isinstance(status, ResponseStatus) + finally: + # If anything during the request failed, we still want to ensure the ongoing + # request counter is decremented. + if response_handler_info.should_increment_ongoing_requests: + self._ongoing_requests_end() + + latency_ms = (time.time() - start_time) * 1000.0 + if response_handler_info.should_record_access_log: + request_context = ray.serve.context._get_serve_request_context() + logger.info( + access_log_msg( + method=proxy_request.method, + route=request_context.route, + status=str(status.code), + latency_ms=latency_ms, + ), + extra={"log_to_stderr": False, "serve_access_log": True}, + ) + + self.request_counter.inc( + tags={ + "route": response_handler_info.metadata.route, + "method": proxy_request.method, + "application": response_handler_info.metadata.application_name, + "status_code": str(status.code), + } + ) + + self.processing_latency_tracker.observe( + latency_ms, + tags={ + "route": response_handler_info.metadata.route, + "method": proxy_request.method, + "application": response_handler_info.metadata.application_name, + "status_code": str(status.code), + }, + ) + if status.is_error: + self.request_error_counter.inc( + tags={ + "route": response_handler_info.metadata.route, + "method": proxy_request.method, + "application": response_handler_info.metadata.application_name, + "error_code": str(status.code), + } + ) + self.deployment_request_error_counter.inc( + tags={ + "route": response_handler_info.metadata.route, + "method": proxy_request.method, + "application": response_handler_info.metadata.application_name, + "error_code": str(status.code), + "deployment": response_handler_info.metadata.deployment_name, + } + ) + + @abstractmethod + def setup_request_context_and_handle( + self, + app_name: str, + handle: DeploymentHandle, + route: str, + proxy_request: ProxyRequest, + internal_request_id: str, + ) -> Tuple[DeploymentHandle, str]: + """Setup the request context and handle for the request. + + Each proxy needs to implement its own logic for setting up the request context + and handle. + """ + raise NotImplementedError + + @abstractmethod + async def send_request_to_replica( + self, + request_id: str, + internal_request_id: str, + handle: DeploymentHandle, + proxy_request: ProxyRequest, + app_is_cross_language: bool = False, + ) -> ResponseGenerator: + """Send the request to the replica and handle streaming response. + + Each proxy needs to implement its own logic for sending the request and + handling the streaming response. + """ + raise NotImplementedError + + +class gRPCProxy(GenericProxy): + """This class is meant to be instantiated and run by an gRPC server. + + This is the servicer class for the gRPC server. It implements `unary_unary` + as the entry point for unary gRPC request and `unary_stream` as the entry + point for streaming gRPC request. + """ + + @property + def protocol(self) -> RequestProtocol: + return RequestProtocol.GRPC + + async def not_found_response( + self, proxy_request: ProxyRequest + ) -> ResponseGenerator: + if not proxy_request.app_name: + application_message = "Application metadata not set." + else: + application_message = f"Application '{proxy_request.app_name}' not found." + not_found_message = ( + f"{application_message} Ping " + "/ray.serve.RayServeAPIService/ListApplications for available applications." + ) + + yield ResponseStatus( + code=grpc.StatusCode.NOT_FOUND, + message=not_found_message, + is_error=True, + ) + + async def routes_response( + self, *, healthy: bool, message: str + ) -> ResponseGenerator: + yield ListApplicationsResponse( + application_names=[ + endpoint.app_name for endpoint in self.proxy_router.endpoints + ], + ).SerializeToString() + + yield ResponseStatus( + code=grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE, + message=message, + is_error=not healthy, + ) + + async def health_response(self, *, healthy: bool, message) -> ResponseGenerator: + yield HealthzResponse(message=message).SerializeToString() + yield ResponseStatus( + code=grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE, + message=message, + is_error=not healthy, + ) + + def service_handler_factory(self, service_method: str, stream: bool) -> Callable: + def set_grpc_code_and_details( + context: grpc._cython.cygrpc._ServicerContext, status: ResponseStatus + ): + # Only the latest code and details will take effect. If the user already + # set them to a truthy value in the context, skip setting them with Serve's + # default values. By default, if nothing is set, the code is 0 and the + # details is "", which both are falsy. So if the user did not set them or + # if they're explicitly set to falsy values, such as None, Serve will + # continue to set them with our default values. + if not context.code(): + context.set_code(status.code) + if not context.details(): + context.set_details(status.message) + + async def unary_unary( + request_proto: Any, context: grpc._cython.cygrpc._ServicerContext + ) -> bytes: + """Entry point of the gRPC proxy unary request. + + This method is called by the gRPC server when a unary request is received. + It wraps the request in a ProxyRequest object and calls proxy_request. + The return value is serialized user defined protobuf bytes. + """ + proxy_request = gRPCProxyRequest( + request_proto=request_proto, + context=context, + service_method=service_method, + stream=False, + ) + + status = None + response = None + async for message in self.proxy_request(proxy_request=proxy_request): + if isinstance(message, ResponseStatus): + status = message + else: + response = message + + set_grpc_code_and_details(context, status) + + return response + + async def unary_stream( + request_proto: Any, context: grpc._cython.cygrpc._ServicerContext + ) -> Generator[bytes, None, None]: + """Entry point of the gRPC proxy streaming request. + + This method is called by the gRPC server when a streaming request is + received. It wraps the request in a ProxyRequest object and calls + proxy_request. The return value is a generator of serialized user defined + protobuf bytes. + """ + proxy_request = gRPCProxyRequest( + request_proto=request_proto, + context=context, + service_method=service_method, + stream=True, + ) + + status = None + async for message in self.proxy_request(proxy_request=proxy_request): + if isinstance(message, ResponseStatus): + status = message + else: + yield message + + set_grpc_code_and_details(context, status) + + return unary_stream if stream else unary_unary + + def setup_request_context_and_handle( + self, + app_name: str, + handle: DeploymentHandle, + route: str, + proxy_request: ProxyRequest, + internal_request_id: str, + ) -> Tuple[DeploymentHandle, str]: + """Setup request context and handle for the request. + + Unpack gRPC request metadata and extract info to set up request context and + handle. + """ + multiplexed_model_id = proxy_request.multiplexed_model_id + request_id = proxy_request.request_id + if not request_id: + request_id = generate_request_id() + proxy_request.request_id = request_id + + handle = handle.options( + stream=proxy_request.stream, + multiplexed_model_id=multiplexed_model_id, + method_name=proxy_request.method_name, + ) + + request_context_info = { + "route": route, + "request_id": request_id, + "_internal_request_id": internal_request_id, + "app_name": app_name, + "multiplexed_model_id": multiplexed_model_id, + "grpc_context": proxy_request.ray_serve_grpc_context, + } + ray.serve.context._serve_request_context.set( + ray.serve.context._RequestContext(**request_context_info) + ) + proxy_request.send_request_id(request_id=request_id) + return handle, request_id + + async def send_request_to_replica( + self, + request_id: str, + internal_request_id: str, + handle: DeploymentHandle, + proxy_request: ProxyRequest, + app_is_cross_language: bool = False, + ) -> ResponseGenerator: + response_generator = ProxyResponseGenerator( + handle.remote(proxy_request.serialized_replica_arg()), + timeout_s=self.request_timeout_s, + ) + + try: + async for context, result in response_generator: + context._set_on_grpc_context(proxy_request.context) + yield result + + status = ResponseStatus(code=grpc.StatusCode.OK) + except TimeoutError: + message = f"Request timed out after {self.request_timeout_s}s." + logger.warning(message) + status = ResponseStatus( + code=grpc.StatusCode.DEADLINE_EXCEEDED, + is_error=True, + message=message, + ) + except asyncio.CancelledError: + message = f"Client for request {request_id} disconnected." + logger.info(message) + status = ResponseStatus( + code=grpc.StatusCode.CANCELLED, + is_error=True, + message=message, + ) + except BackPressureError as e: + status = ResponseStatus( + code=grpc.StatusCode.UNAVAILABLE, + is_error=True, + message=e.message, + ) + except Exception as e: + if isinstance(e, (RayActorError, RayTaskError)): + logger.warning(f"Request failed: {e}", extra={"log_to_stderr": False}) + else: + logger.exception("Request failed due to unexpected error.") + status = ResponseStatus( + code=grpc.StatusCode.INTERNAL, + is_error=True, + message=str(e), + ) + + # The status code should always be set. + assert status is not None + yield status + + +class HTTPProxy(GenericProxy): + """This class is meant to be instantiated and run by an ASGI HTTP server.""" + + def __init__( + self, + node_id: NodeId, + node_ip_address: str, + is_head: bool, + proxy_router: ProxyRouter, + self_actor_name: str, + request_timeout_s: Optional[float] = None, + ): + super().__init__( + node_id, + node_ip_address, + is_head, + proxy_router, + request_timeout_s=request_timeout_s, + ) + self.self_actor_name = self_actor_name + self.asgi_receive_queues: Dict[str, MessageQueue] = dict() + + @property + def protocol(self) -> RequestProtocol: + return RequestProtocol.HTTP + + async def not_found_response( + self, proxy_request: ProxyRequest + ) -> ResponseGenerator: + status_code = 404 + for message in convert_object_to_asgi_messages( + f"Path '{proxy_request.path}' not found. " + "Ping http://.../-/routes for available routes.", + status_code=status_code, + ): + yield message + + yield ResponseStatus(code=status_code, is_error=True) + + async def routes_response( + self, *, healthy: bool, message: str + ) -> ResponseGenerator: + status_code = 200 if healthy else 503 + if healthy: + response = dict() + for endpoint, info in self.proxy_router.endpoints.items(): + # For 2.x deployments, return {route -> app name} + if endpoint.app_name: + response[info.route] = endpoint.app_name + # Keep compatibility with 1.x deployments. + else: + response[info.route] = endpoint.name + else: + response = message + + for asgi_message in convert_object_to_asgi_messages( + response, + status_code=status_code, + ): + yield asgi_message + + yield ResponseStatus( + code=status_code, + message=message, + is_error=not healthy, + ) + + async def health_response( + self, *, healthy: bool, message: str = "" + ) -> ResponseGenerator: + status_code = 200 if healthy else 503 + for asgi_message in convert_object_to_asgi_messages( + message, + status_code=status_code, + ): + yield asgi_message + + yield ResponseStatus( + code=status_code, + is_error=not healthy, + message=message, + ) + + async def receive_asgi_messages( + self, request_metadata: RequestMetadata + ) -> ResponseGenerator: + queue = self.asgi_receive_queues.get(request_metadata.internal_request_id, None) + if queue is None: + raise KeyError(f"Request ID {request_metadata.request_id} not found.") + + await queue.wait_for_message() + return queue.get_messages_nowait() + + async def __call__(self, scope, receive, send): + """Implements the ASGI protocol. + + See details at: + https://asgi.readthedocs.io/en/latest/specs/index.html. + """ + proxy_request = ASGIProxyRequest(scope=scope, receive=receive, send=send) + async for message in self.proxy_request(proxy_request): + if not isinstance(message, ResponseStatus): + await send(message) + + async def proxy_asgi_receive( + self, receive: Receive, queue: MessageQueue + ) -> Optional[int]: + """Proxies the `receive` interface, placing its messages into the queue. + + Once a disconnect message is received, the call exits and `receive` is no longer + called. + + For HTTP messages, `None` is always returned. + For websocket messages, the disconnect code is returned if a disconnect code is + received. + """ + try: + while True: + msg = await receive() + await queue(msg) + + if msg["type"] == "http.disconnect": + return None + + if msg["type"] == "websocket.disconnect": + return msg["code"] + finally: + # Close the queue so any subsequent calls to fetch messages return + # immediately: https://github.com/ray-project/ray/issues/38368. + queue.close() + + def setup_request_context_and_handle( + self, + app_name: str, + handle: DeploymentHandle, + route: str, + proxy_request: ProxyRequest, + internal_request_id: str, + ) -> Tuple[DeploymentHandle, str]: + """Setup request context and handle for the request. + + Unpack HTTP request headers and extract info to set up request context and + handle. + """ + request_context_info = { + "route": route, + "app_name": app_name, + "_internal_request_id": internal_request_id, + "is_http_request": True, + } + for key, value in proxy_request.headers: + if key.decode() == SERVE_MULTIPLEXED_MODEL_ID: + multiplexed_model_id = value.decode() + handle = handle.options(multiplexed_model_id=multiplexed_model_id) + request_context_info["multiplexed_model_id"] = multiplexed_model_id + if key.decode() == "x-request-id": + request_context_info["request_id"] = value.decode() + ray.serve.context._serve_request_context.set( + ray.serve.context._RequestContext(**request_context_info) + ) + return handle, request_context_info["request_id"] + + async def _format_handle_arg_for_java( + self, + proxy_request: ProxyRequest, + ) -> bytes: + """Convert an HTTP request to the Java-accepted format (single byte string).""" + query_string = proxy_request.scope.get("query_string") + http_body_bytes = await receive_http_body( + proxy_request.scope, proxy_request.receive, proxy_request.send + ) + if query_string: + arg = query_string.decode().split("=", 1)[1] + else: + arg = http_body_bytes.decode() + + return arg + + async def send_request_to_replica( + self, + request_id: str, + internal_request_id: str, + handle: DeploymentHandle, + proxy_request: ProxyRequest, + app_is_cross_language: bool = False, + ) -> ResponseGenerator: + """Send the request to the replica and yield its response messages. + + The yielded values will be ASGI messages until the final one, which will be + the status code. + """ + if app_is_cross_language: + handle_arg_bytes = await self._format_handle_arg_for_java(proxy_request) + # Response is returned as raw bytes, convert it to ASGI messages. + result_callback = convert_object_to_asgi_messages + else: + handle_arg_bytes = proxy_request.serialized_replica_arg( + proxy_actor_name=self.self_actor_name, + ) + # Messages are returned as pickled dictionaries. + result_callback = pickle.loads + + # Proxy the receive interface by placing the received messages on a queue. + # The downstream replica must call back into `receive_asgi_messages` on this + # actor to receive the messages. + receive_queue = MessageQueue() + self.asgi_receive_queues[internal_request_id] = receive_queue + proxy_asgi_receive_task = get_or_create_event_loop().create_task( + self.proxy_asgi_receive(proxy_request.receive, receive_queue) + ) + + response_generator = ProxyResponseGenerator( + handle.remote(handle_arg_bytes), + timeout_s=self.request_timeout_s, + disconnected_task=proxy_asgi_receive_task, + result_callback=result_callback, + ) + + status: Optional[ResponseStatus] = None + response_started = False + expecting_trailers = False + try: + async for asgi_message_batch in response_generator: + # See the ASGI spec for message details: + # https://asgi.readthedocs.io/en/latest/specs/www.html. + for asgi_message in asgi_message_batch: + if asgi_message["type"] == "http.response.start": + # HTTP responses begin with exactly one + # "http.response.start" message containing the "status" + # field. Other response types (e.g., WebSockets) may not. + status_code = str(asgi_message["status"]) + status = ResponseStatus( + code=status_code, + is_error=not status_code.startswith("2"), + ) + expecting_trailers = asgi_message.get("trailers", False) + elif asgi_message["type"] == "websocket.accept": + # Websocket code explicitly handles client disconnects, + # so let the ASGI disconnect message propagate instead of + # cancelling the handler. + response_generator.stop_checking_for_disconnect() + elif ( + asgi_message["type"] == "http.response.body" + and not asgi_message.get("more_body", False) + and not expecting_trailers + ): + # If the body is completed and we aren't expecting trailers, the + # response is done so we should stop listening for disconnects. + response_generator.stop_checking_for_disconnect() + elif asgi_message["type"] == "http.response.trailers": + # If we are expecting trailers, the response is only done when + # the trailers message has been sent. + if not asgi_message.get("more_trailers", False): + response_generator.stop_checking_for_disconnect() + elif asgi_message["type"] in [ + "websocket.close", + "websocket.disconnect", + ]: + status_code = str(asgi_message["code"]) + status = ResponseStatus( + code=status_code, + # All status codes are considered errors aside from: + # 1000 (CLOSE_NORMAL), 1001 (CLOSE_GOING_AWAY). + is_error=status_code not in ["1000", "1001"], + ) + response_generator.stop_checking_for_disconnect() + + yield asgi_message + response_started = True + except TimeoutError: + status = ResponseStatus( + code=TIMEOUT_ERROR_CODE, + is_error=True, + ) + logger.warning(f"Request timed out after {self.request_timeout_s}s.") + # We should only send timeout response if we have not sent + # any messages to the client yet. Header (including status code) + # messages can only be sent once. + if not response_started: + for message in convert_object_to_asgi_messages( + f"Request {request_id} timed out after {self.request_timeout_s}s.", + status_code=408, + ): + yield message + except asyncio.CancelledError: + status = ResponseStatus( + code=DISCONNECT_ERROR_CODE, + is_error=True, + ) + logger.info( + f"Client for request {request_id} disconnected, cancelling request." + ) + except (BackPressureError, DeploymentUnavailableError) as e: + status_code = 503 + status = ResponseStatus( + code=status_code, + is_error=True, + message=e.message, + ) + for message in convert_object_to_asgi_messages( + e.message, + status_code=status_code, + ): + yield message + except Exception as e: + if isinstance(e, (RayActorError, RayTaskError)): + logger.warning(f"Request failed: {e}", extra={"log_to_stderr": False}) + else: + logger.exception("Request failed due to unexpected error.") + status = ResponseStatus( + code="500", + is_error=True, + ) + + finally: + # For websocket connection, queue receive task is done when receiving + # disconnect message from client. + receive_client_disconnect_msg = False + if not proxy_asgi_receive_task.done(): + proxy_asgi_receive_task.cancel() + else: + receive_client_disconnect_msg = True + + # If the server disconnects, status_code can be set above from the + # disconnect message. + # If client disconnects, the disconnect code comes from + # a client message via the receive interface. + if status is None and proxy_request.request_type == "websocket": + if receive_client_disconnect_msg: + # The disconnect message is sent from the client. + status = ResponseStatus( + code=str(proxy_asgi_receive_task.result()), + is_error=True, + ) + else: + # The server disconnect without sending a disconnect message + # (otherwise the `status` would be set). + status = ResponseStatus( + code="1000", # [Sihan] is there a better code for this? + is_error=True, + ) + + del self.asgi_receive_queues[internal_request_id] + + # The status code should always be set. + assert status is not None + yield status + + +class RequestIdMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + headers = MutableHeaders(scope=scope) + if "x-request-id" not in headers: + # If X-Request-ID is not set, we + # generate a new request ID. + request_id = generate_request_id() + headers.append("x-request-id", request_id) + elif "x-request-id" in headers: + request_id = headers["x-request-id"] + + async def send_with_request_id(message: Dict): + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + headers.append("X-Request-ID", request_id) + if message["type"] == "websocket.accept": + message["X-Request-ID"] = request_id + await send(message) + + await self.app(scope, receive, send_with_request_id) + + +@ray.remote(num_cpus=0) +class ProxyActor: + def __init__( + self, + host: str, + port: int, + root_path: str, + node_ip_address: str, + node_id: NodeId, + logging_config: LoggingConfig, + request_timeout_s: Optional[float] = None, + http_middlewares: Optional[List["starlette.middleware.Middleware"]] = None, + keep_alive_timeout_s: int = DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, + grpc_options: Optional[gRPCOptions] = None, + long_poll_client: Optional[LongPollClient] = None, + ): # noqa: F821 + self.grpc_options = grpc_options or gRPCOptions() + self.host = host + self.port = port + self.grpc_port = self.grpc_options.port + self.root_path = root_path + self.keep_alive_timeout_s = ( + RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S or keep_alive_timeout_s + ) + self._uvicorn_server = None + self.node_ip_address = node_ip_address + + self.http_setup_complete = asyncio.Event() + self.grpc_setup_complete = asyncio.Event() + + self.long_poll_client = long_poll_client or LongPollClient( + ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE), + { + LongPollNamespace.GLOBAL_LOGGING_CONFIG: self._update_logging_config, + LongPollNamespace.ROUTE_TABLE: self._update_routes_in_proxies, + }, + call_in_event_loop=get_or_create_event_loop(), + ) + + configure_component_logger( + component_name="proxy", + component_id=node_ip_address, + logging_config=logging_config, + ) + + startup_msg = f"Proxy starting on node {node_id} (HTTP port: {port}" + if self.should_start_grpc_service(): + startup_msg += f", gRPC port: {self.grpc_options.port})." + else: + startup_msg += ")." + logger.info(startup_msg) + logger.debug( + f"Configure Proxy actor {ray.get_runtime_context().get_actor_id()} " + f"logger with logging config: {logging_config}" + ) + + configure_component_memory_profiler( + component_name="proxy", component_id=node_ip_address + ) + self.cpu_profiler, self.cpu_profiler_log = configure_component_cpu_profiler( + component_name="proxy", component_id=node_ip_address + ) + + if http_middlewares is None: + http_middlewares = [Middleware(RequestIdMiddleware)] + else: + http_middlewares.append(Middleware(RequestIdMiddleware)) + + if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH: + logger.info( + "Calling user-provided callback from import path " + f"'{RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH}'." + ) + middlewares = validate_http_proxy_callback_return( + call_function_from_import_path( + RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH + ) + ) + + http_middlewares.extend(middlewares) + + is_head = node_id == get_head_node_id() + self.proxy_router = ProxyRouter(get_proxy_handle) + self.http_proxy = HTTPProxy( + node_id=node_id, + node_ip_address=node_ip_address, + is_head=is_head, + self_actor_name=ray.get_runtime_context().get_actor_name(), + proxy_router=self.proxy_router, + request_timeout_s=( + request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S + ), + ) + self.grpc_proxy = ( + gRPCProxy( + node_id=node_id, + node_ip_address=node_ip_address, + is_head=is_head, + proxy_router=self.proxy_router, + request_timeout_s=( + request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S + ), + ) + if self.should_start_grpc_service() + else None + ) + + self.wrapped_http_proxy = self.http_proxy + + for middleware in http_middlewares: + if version.parse(starlette.__version__) < version.parse("0.35.0"): + self.wrapped_http_proxy = middleware.cls( + self.wrapped_http_proxy, **middleware.options + ) + else: + # In starlette >= 0.35.0, middleware.options does not exist: + # https://github.com/encode/starlette/pull/2381. + self.wrapped_http_proxy = middleware.cls( + self.wrapped_http_proxy, + *middleware.args, + **middleware.kwargs, + ) + + # Start running the HTTP server on the event loop. + # This task should be running forever. We track it in case of failure. + self.running_task_http = get_or_create_event_loop().create_task( + self.run_http_server() + ) + + # Start running the gRPC server on the event loop. + # This task should be running forever. We track it in case of failure. + self.running_task_grpc = get_or_create_event_loop().create_task( + self.run_grpc_server() + ) + + _configure_gc_options() + + def _update_routes_in_proxies(self, endpoints: Dict[DeploymentID, EndpointInfo]): + self.proxy_router.update_routes(endpoints) + + def _update_logging_config(self, logging_config: LoggingConfig): + configure_component_logger( + component_name="proxy", + component_id=self.node_ip_address, + logging_config=logging_config, + ) + + def _get_logging_config(self) -> Tuple: + """Get the logging configuration (for testing purposes).""" + log_file_path = None + for handler in logger.handlers: + if isinstance(handler, logging.handlers.RotatingFileHandler): + log_file_path = handler.baseFilename + return log_file_path + + def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]: + _, handle, _ = self.http_proxy.proxy_router.match_route(route) + return handle._router._asyncio_router._replica_scheduler._replica_id_set + + def should_start_grpc_service(self) -> bool: + """Determine whether gRPC service should be started. + + gRPC service will only be started if a valid port is provided and if the + servicer functions are passed. + """ + return self.grpc_port > 0 and len(self.grpc_options.grpc_servicer_functions) > 0 + + async def ready(self): + """Returns when both HTTP and gRPC proxies are ready to serve traffic. + Or throw exception when either proxy is not able to serve traffic. + """ + http_setup_complete_wait_task = get_or_create_event_loop().create_task( + self.http_setup_complete.wait() + ) + grpc_setup_complete_wait_task = get_or_create_event_loop().create_task( + self.grpc_setup_complete.wait() + ) + + waiting_tasks_http = [ + # Either the HTTP setup has completed. + # The event is set inside self.run_http_server. + http_setup_complete_wait_task, + # Or self.run_http_server errored. + self.running_task_http, + ] + done_set_http, _ = await asyncio.wait( + waiting_tasks_http, + return_when=asyncio.FIRST_COMPLETED, + ) + waiting_tasks_grpc = [ + # Either the gRPC setup has completed. + # The event is set inside self.run_grpc_server. + grpc_setup_complete_wait_task, + # Or self.run_grpc_server errored. + self.running_task_grpc, + ] + done_set_grpc, _ = await asyncio.wait( + waiting_tasks_grpc, + return_when=asyncio.FIRST_COMPLETED, + ) + + # Return metadata, or re-throw the exception from self.running_task_http and + # self.running_task_grpc. + if self.http_setup_complete.is_set() and self.grpc_setup_complete.is_set(): + # NOTE(zcin): We need to convert the metadata to a json string because + # of cross-language scenarios. Java can't deserialize a Python tuple. + return json.dumps( + [ + ray.get_runtime_context().get_worker_id(), + get_component_logger_file_path(), + ] + ) + else: + proxy_error = None + if not self.http_setup_complete.is_set(): + try: + await done_set_http.pop() + except Exception as e: + logger.exception(e) + proxy_error = e + if not self.grpc_setup_complete.is_set(): + try: + await done_set_grpc.pop() + except Exception as e: + logger.exception(e) + proxy_error = e + raise proxy_error + + async def run_http_server(self): + sock = socket.socket() + if SOCKET_REUSE_PORT_ENABLED: + set_socket_reuse_port(sock) + try: + sock.bind((self.host, self.port)) + except OSError: + # The OS failed to bind a socket to the given host and port. + raise ValueError( + f"Failed to bind Ray Serve HTTP proxy to '{self.host}:{self.port}'. " + "Please make sure your http-host and http-port are specified correctly." + ) + + # NOTE: We have to use lower level uvicorn Config and Server + # class because we want to run the server as a coroutine. The only + # alternative is to call uvicorn.run which is blocking. + config = uvicorn.Config( + self.wrapped_http_proxy, + host=self.host, + port=self.port, + loop=_determine_target_loop(), + root_path=self.root_path, + lifespan="off", + log_level="warning", + access_log=False, + timeout_keep_alive=self.keep_alive_timeout_s, + ) + self._uvicorn_server = uvicorn.Server(config=config) + # TODO(edoakes): we need to override install_signal_handlers here + # because the existing implementation fails if it isn't running in + # the main thread and uvicorn doesn't expose a way to configure it. + self._uvicorn_server.install_signal_handlers = lambda: None + + logger.debug( + "Starting HTTP server on node: " + f"{ray.get_runtime_context().get_node_id()} " + f"listening on port {self.port}" + ) + + self.http_setup_complete.set() + await self._uvicorn_server.serve(sockets=[sock]) + + async def run_grpc_server(self): + if not self.should_start_grpc_service(): + return self.grpc_setup_complete.set() + + grpc_server = gRPCGenericServer( + service_handler_factory=self.grpc_proxy.service_handler_factory, + ) + + add_grpc_address(grpc_server, f"[::]:{self.grpc_port}") + + # Dummy servicer is used to be callable for the gRPC server. Serve have a + # custom gRPC server implementation to redirect calls into gRPCProxy. + # See: ray/serve/_private/grpc_util.py + dummy_servicer = DummyServicer() + + # Add Ray Serve gRPC service and methods (e.g. ListApplications and Healthz). + add_RayServeAPIServiceServicer_to_server(dummy_servicer, grpc_server) + + # Iterate through each of user provided gRPC servicer functions and add user + # defined services and methods. + for grpc_servicer_function in self.grpc_options.grpc_servicer_func_callable: + grpc_servicer_function(dummy_servicer, grpc_server) + + await grpc_server.start() + logger.debug( + "Starting gRPC server on node: " + f"{ray.get_runtime_context().get_node_id()} " + f"listening on port {self.grpc_port}" + ) + self.grpc_setup_complete.set() + await grpc_server.wait_for_termination() + + async def update_draining(self, draining: bool, _after: Optional[Any] = None): + """Update the draining status of the HTTP and gRPC proxies. + + Unused `_after` argument is for scheduling: passing an ObjectRef + allows delaying this call until after the `_after` call has returned. + """ + + self.http_proxy.update_draining(draining) + if self.grpc_proxy: + self.grpc_proxy.update_draining(draining) + + async def is_drained(self, _after: Optional[Any] = None): + """Check whether both HTTP and gRPC proxies are drained or not. + + Unused `_after` argument is for scheduling: passing an ObjectRef + allows delaying this call until after the `_after` call has returned. + """ + + return self.http_proxy.is_drained() and ( + self.grpc_proxy is None or self.grpc_proxy.is_drained() + ) + + async def check_health(self): + """No-op method to check on the health of the HTTP Proxy. + + Make sure the async event loop is not blocked. + """ + logger.debug("Received health check.", extra={"log_to_stderr": False}) + + def pong(self): + """Called by the replica to initialize its handle to the proxy.""" + pass + + async def receive_asgi_messages(self, request_metadata: RequestMetadata) -> bytes: + """Get ASGI messages for the provided `request_metadata`. + + After the proxy has stopped receiving messages for this `request_metadata`, + this will always return immediately. + + Raises `KeyError` if this request ID is not found. This will happen when the + request is no longer being handled (e.g., the user disconnects). + """ + return pickle.dumps( + await self.http_proxy.receive_asgi_messages(request_metadata) + ) + + def _save_cpu_profile_data(self) -> str: + """Saves CPU profiling data, if CPU profiling is enabled. + + Logs a warning if CPU profiling is disabled. + """ + + if self.cpu_profiler is not None: + import marshal + + self.cpu_profiler.snapshot_stats() + with open(self.cpu_profiler_log, "wb") as f: + marshal.dump(self.cpu_profiler.stats, f) + logger.info(f'Saved CPU profile data to file "{self.cpu_profiler_log}"') + return self.cpu_profiler_log + else: + logger.error( + "Attempted to save CPU profile data, but failed because no " + "CPU profiler was running! Enable CPU profiling by enabling " + "the RAY_SERVE_ENABLE_CPU_PROFILING env var." + ) + + async def _uvicorn_keep_alive(self) -> Optional[int]: + """Get the keep alive timeout used for the running uvicorn server. + + Return the timeout_keep_alive config used on the uvicorn server if it's running. + If the server is not running, return None. + """ + if self._uvicorn_server: + return self._uvicorn_server.config.timeout_keep_alive + + +def _determine_target_loop(): + """We determine target loop based on whether RAY_SERVE_DEBUG_MODE is enabled: + + - RAY_SERVE_DEBUG_MODE=0 (default): we use "uvloop" (Cython) providing + high-performance, native implementation of the event-loop + + - RAY_SERVE_DEBUG_MODE=1: we fall back to "asyncio" (pure Python) event-loop + implementation that is considerably slower than "uvloop", + but provides for easy access to the source implementation + """ + if RAY_SERVE_DEBUG_MODE: + return "asyncio" + else: + return "uvloop" + + +def _configure_gc_options(): + if not RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS: + return + + # Collect any objects that exist already and exclude them from future GC. + gc.collect(2) + gc.freeze() + + # Tune the GC threshold to run less frequently (default is 700). + gc.set_threshold(RAY_SERVE_PROXY_GC_THRESHOLD) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy_request_response.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy_request_response.py new file mode 100644 index 0000000000000000000000000000000000000000..922621f6242487bf842da4a2c3109564a4edf8a6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/proxy_request_response.py @@ -0,0 +1,199 @@ +import logging +import pickle +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, AsyncIterator, List, Tuple, Union + +import grpc +from starlette.types import Receive, Scope, Send + +from ray.serve._private.common import StreamingHTTPRequest, gRPCRequest +from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve._private.utils import DEFAULT +from ray.serve.grpc_util import RayServegRPCContext + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +class ProxyRequest(ABC): + """Base ProxyRequest class to use in the common interface among proxies""" + + @property + @abstractmethod + def request_type(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def method(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def route_path(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def is_route_request(self) -> bool: + raise NotImplementedError + + @property + @abstractmethod + def is_health_request(self) -> bool: + raise NotImplementedError + + +class ASGIProxyRequest(ProxyRequest): + """ProxyRequest implementation to wrap ASGI scope, receive, and send.""" + + def __init__(self, scope: Scope, receive: Receive, send: Send): + self.scope = scope + self.receive = receive + self.send = send + + @property + def request_type(self) -> str: + return self.scope.get("type", "") + + @property + def method(self) -> str: + # WebSocket messages don't have a 'method' field. + return self.scope.get("method", "WS").upper() + + @property + def route_path(self) -> str: + return self.scope.get("path", "")[len(self.root_path) :] + + @property + def is_route_request(self) -> bool: + return self.route_path == "/-/routes" + + @property + def is_health_request(self) -> bool: + return self.route_path == "/-/healthz" + + @property + def client(self) -> str: + return self.scope.get("client", "") + + @property + def root_path(self) -> str: + return self.scope.get("root_path", "") + + @property + def path(self) -> str: + return self.scope.get("path", "") + + @property + def headers(self) -> List[Tuple[bytes, bytes]]: + return self.scope.get("headers", []) + + def set_path(self, path: str): + self.scope["path"] = path + + def set_root_path(self, root_path: str): + self.scope["root_path"] = root_path + + def serialized_replica_arg(self, proxy_actor_name: str) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps( + StreamingHTTPRequest( + asgi_scope=self.scope, + proxy_actor_name=proxy_actor_name, + ) + ) + + +class gRPCProxyRequest(ProxyRequest): + """ProxyRequest implementation to wrap gRPC request protobuf and metadata.""" + + def __init__( + self, + request_proto: Any, + context: grpc._cython.cygrpc._ServicerContext, + service_method: str, + stream: bool, + ): + self._request_proto = request_proto + self.context = context + self.service_method = service_method + self.stream = stream + self.app_name = "" + self.request_id = None + self.method_name = "__call__" + self.multiplexed_model_id = DEFAULT.VALUE + # ray_serve_grpc_context is a class implemented by us to be able to serialize + # the object and pass it into the deployment. + self.ray_serve_grpc_context = RayServegRPCContext(context) + self.setup_variables() + + def setup_variables(self): + if not self.is_route_request and not self.is_health_request: + service_method_split = self.service_method.split("/") + self.method_name = service_method_split[-1] + for key, value in self.context.invocation_metadata(): + if key == "application": + self.app_name = value + elif key == "request_id": + self.request_id = value + elif key == "multiplexed_model_id": + self.multiplexed_model_id = value + + @property + def request_type(self) -> str: + return "grpc" + + @property + def method(self) -> str: + return self.service_method + + @property + def route_path(self) -> str: + return self.app_name + + @property + def is_route_request(self) -> bool: + return self.service_method == "/ray.serve.RayServeAPIService/ListApplications" + + @property + def is_health_request(self) -> bool: + return self.service_method == "/ray.serve.RayServeAPIService/Healthz" + + def send_request_id(self, request_id: str): + # Setting the trailing metadata on the ray_serve_grpc_context object, so it's + # not overriding the ones set from the user and will be sent back to the + # client altogether. + self.ray_serve_grpc_context.set_trailing_metadata([("request_id", request_id)]) + + def serialized_replica_arg(self) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps(gRPCRequest(user_request_proto=self._request_proto)) + + +@dataclass(frozen=True) +class ResponseStatus: + code: Union[str, grpc.StatusCode] # Must be convertible to a string. + is_error: bool = False + message: str = "" + + +# Yields protocol-specific messages followed by a final `ResponseStatus`. +ResponseGenerator = AsyncIterator[Union[Any, ResponseStatus]] + + +@dataclass(frozen=True) +class HandlerMetadata: + application_name: str = "" + deployment_name: str = "" + route: str = "" + + +@dataclass(frozen=True) +class ResponseHandlerInfo: + response_generator: ResponseGenerator + metadata: HandlerMetadata + should_record_access_log: bool + should_increment_ongoing_requests: bool diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/replica_result.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/replica_result.py new file mode 100644 index 0000000000000000000000000000000000000000..c9627d5590ac175840c7b3b60508a4f22b707bb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/replica_result.py @@ -0,0 +1,188 @@ +import inspect +import threading +import time +from abc import ABC, abstractmethod +from functools import wraps +from typing import Callable, Coroutine, Optional, Union + +import ray +from ray.serve._private.common import RequestMetadata +from ray.serve._private.utils import calculate_remaining_timeout +from ray.serve.exceptions import RequestCancelledError + + +class ReplicaResult(ABC): + @abstractmethod + def get(self, timeout_s: Optional[float]): + raise NotImplementedError + + @abstractmethod + async def get_async(self): + raise NotImplementedError + + @abstractmethod + def __next__(self): + raise NotImplementedError + + @abstractmethod + async def __anext__(self): + raise NotImplementedError + + @abstractmethod + def add_done_callback(self, callback: Callable): + raise NotImplementedError + + @abstractmethod + def cancel(self): + raise NotImplementedError + + @abstractmethod + def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef: + raise NotImplementedError + + @abstractmethod + async def to_object_ref_async(self) -> ray.ObjectRef: + raise NotImplementedError + + @abstractmethod + def to_object_ref_gen(self) -> ray.ObjectRefGenerator: + # NOTE(edoakes): there is only a sync version of this method because it + # does not block like `to_object_ref` (so there's also no timeout argument). + raise NotImplementedError + + +class ActorReplicaResult(ReplicaResult): + def __init__( + self, + obj_ref_or_gen: Union[ray.ObjectRef, ray.ObjectRefGenerator], + metadata: RequestMetadata, + ): + self._obj_ref: Optional[ray.ObjectRef] = None + self._obj_ref_gen: Optional[ray.ObjectRefGenerator] = None + self._is_streaming: bool = metadata.is_streaming + self._request_id: str = metadata.request_id + self._object_ref_or_gen_sync_lock = threading.Lock() + + if isinstance(obj_ref_or_gen, ray.ObjectRefGenerator): + self._obj_ref_gen = obj_ref_or_gen + else: + self._obj_ref = obj_ref_or_gen + + if self._is_streaming: + assert ( + self._obj_ref_gen is not None + ), "An ObjectRefGenerator must be passed for streaming requests." + + def _process_response(f: Union[Callable, Coroutine]): + @wraps(f) + def wrapper(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except ray.exceptions.TaskCancelledError: + raise RequestCancelledError(self._request_id) + + @wraps(f) + async def async_wrapper(self, *args, **kwargs): + try: + return await f(self, *args, **kwargs) + except ray.exceptions.TaskCancelledError: + raise RequestCancelledError(self._request_id) + + if inspect.iscoroutinefunction(f): + return async_wrapper + else: + return wrapper + + @_process_response + def get(self, timeout_s: Optional[float]): + assert ( + not self._is_streaming + ), "get() can only be called on a unary ActorReplicaResult." + + start_time_s = time.time() + object_ref = self.to_object_ref(timeout_s=timeout_s) + remaining_timeout_s = calculate_remaining_timeout( + timeout_s=timeout_s, + start_time_s=start_time_s, + curr_time_s=time.time(), + ) + return ray.get(object_ref, timeout=remaining_timeout_s) + + @_process_response + async def get_async(self): + assert ( + not self._is_streaming + ), "get_async() can only be called on a unary ActorReplicaResult." + + return await (await self.to_object_ref_async()) + + @_process_response + def __next__(self): + assert ( + self._is_streaming + ), "next() can only be called on a streaming ActorReplicaResult." + + next_obj_ref = self._obj_ref_gen.__next__() + return ray.get(next_obj_ref) + + @_process_response + async def __anext__(self): + assert ( + self._is_streaming + ), "__anext__() can only be called on a streaming ActorReplicaResult." + + next_obj_ref = await self._obj_ref_gen.__anext__() + return await next_obj_ref + + def add_done_callback(self, callback: Callable): + if self._obj_ref_gen is not None: + self._obj_ref_gen.completed()._on_completed(callback) + else: + self._obj_ref._on_completed(callback) + + def cancel(self): + if self._obj_ref_gen is not None: + ray.cancel(self._obj_ref_gen) + else: + ray.cancel(self._obj_ref) + + def to_object_ref(self, *, timeout_s: Optional[float] = None) -> ray.ObjectRef: + assert ( + not self._is_streaming + ), "to_object_ref can only be called on a unary ReplicaActorResult." + + # NOTE(edoakes): this section needs to be guarded with a lock and the resulting + # object ref cached in order to avoid calling `__next__()` to + # resolve to the underlying object ref more than once. + # See: https://github.com/ray-project/ray/issues/43879. + with self._object_ref_or_gen_sync_lock: + if self._obj_ref is None: + obj_ref = self._obj_ref_gen._next_sync(timeout_s=timeout_s) + if obj_ref.is_nil(): + raise TimeoutError("Timed out resolving to ObjectRef.") + + self._obj_ref = obj_ref + + return self._obj_ref + + async def to_object_ref_async(self) -> ray.ObjectRef: + assert ( + not self._is_streaming + ), "to_object_ref_async can only be called on a unary ReplicaActorResult." + + # NOTE(edoakes): this section needs to be guarded with a lock and the resulting + # object ref cached in order to avoid calling `__anext__()` to + # resolve to the underlying object ref more than once. + # See: https://github.com/ray-project/ray/issues/43879. + with self._object_ref_or_gen_sync_lock: + if self._obj_ref is None: + self._obj_ref = await self._obj_ref_gen.__anext__() + + return self._obj_ref + + def to_object_ref_gen(self) -> ray.ObjectRefGenerator: + assert ( + self._is_streaming + ), "to_object_ref_gen can only be called on a streaming ReplicaActorResult." + + return self._obj_ref_gen diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/router.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2462c000b271f62526095212690a637b802febe8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/router.py @@ -0,0 +1,777 @@ +import asyncio +import concurrent.futures +import logging +import threading +import time +import weakref +from abc import ABC, abstractmethod +from asyncio import AbstractEventLoop +from collections import defaultdict +from collections.abc import MutableMapping +from contextlib import contextmanager +from functools import lru_cache, partial +from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union + +import ray +from ray.actor import ActorHandle +from ray.exceptions import ActorDiedError, ActorUnavailableError, RayError +from ray.serve._private.common import ( + DeploymentHandleSource, + DeploymentID, + DeploymentTargetInfo, + ReplicaID, + RequestMetadata, + RunningReplicaInfo, +) +from ray.serve._private.config import DeploymentConfig +from ray.serve._private.constants import ( + HANDLE_METRIC_PUSH_INTERVAL_S, + RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE, + RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S, + SERVE_LOGGER_NAME, +) +from ray.serve._private.long_poll import LongPollClient, LongPollNamespace +from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher +from ray.serve._private.replica_result import ReplicaResult +from ray.serve._private.replica_scheduler import PendingRequest, ReplicaScheduler +from ray.serve._private.utils import generate_request_id, resolve_deployment_response +from ray.serve.config import AutoscalingConfig +from ray.serve.exceptions import BackPressureError, DeploymentUnavailableError +from ray.util import metrics + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +QUEUED_REQUESTS_KEY = "queued" + + +class RouterMetricsManager: + """Manages metrics for the router.""" + + PUSH_METRICS_TO_CONTROLLER_TASK_NAME = "push_metrics_to_controller" + RECORD_METRICS_TASK_NAME = "record_metrics" + + def __init__( + self, + deployment_id: DeploymentID, + handle_id: str, + self_actor_id: str, + handle_source: DeploymentHandleSource, + controller_handle: ActorHandle, + router_requests_counter: metrics.Counter, + queued_requests_gauge: metrics.Gauge, + running_requests_gauge: metrics.Gauge, + ): + self._handle_id = handle_id + self._deployment_id = deployment_id + self._self_actor_id = self_actor_id + self._handle_source = handle_source + self._controller_handle = controller_handle + + # Exported metrics + self.num_router_requests = router_requests_counter + self.num_router_requests.set_default_tags( + { + "deployment": deployment_id.name, + "application": deployment_id.app_name, + "handle": self._handle_id, + "actor_id": self._self_actor_id, + } + ) + + self.num_queued_requests = 0 + self.num_queued_requests_gauge = queued_requests_gauge + self.num_queued_requests_gauge.set_default_tags( + { + "deployment": deployment_id.name, + "application": deployment_id.app_name, + "handle": self._handle_id, + "actor_id": self._self_actor_id, + } + ) + self.num_queued_requests_gauge.set(0) + + # Track queries sent to replicas for the autoscaling algorithm. + self.num_requests_sent_to_replicas: DefaultDict[ReplicaID, int] = defaultdict( + int + ) + self.num_running_requests_gauge = running_requests_gauge + self.num_running_requests_gauge.set_default_tags( + { + "deployment": deployment_id.name, + "application": deployment_id.app_name, + "handle": self._handle_id, + "actor_id": self._self_actor_id, + } + ) + # We use Ray object ref callbacks to update state when tracking + # number of requests running on replicas. The callbacks will be + # called from a C++ thread into the router's async event loop, + # so non-atomic read and write operations need to be guarded by + # this thread-safe lock. + self._queries_lock = threading.Lock() + # Regularly aggregate and push autoscaling metrics to controller + self.metrics_pusher = MetricsPusher() + self.metrics_store = InMemoryMetricsStore() + # The config for the deployment this router sends requests to will be broadcast + # by the controller. That means it is not available until we get the first + # update. This includes an optional autoscaling config. + self._deployment_config: Optional[DeploymentConfig] = None + # Track whether the metrics manager has been shutdown + self._shutdown: bool = False + + @contextmanager + def wrap_request_assignment(self, request_meta: RequestMetadata): + max_queued_requests = ( + self._deployment_config.max_queued_requests + if self._deployment_config is not None + else -1 + ) + if ( + max_queued_requests != -1 + and self.num_queued_requests >= max_queued_requests + ): + e = BackPressureError( + num_queued_requests=self.num_queued_requests, + max_queued_requests=max_queued_requests, + ) + logger.warning(e.message) + raise e + + try: + self.inc_num_total_requests(request_meta.route) + self.inc_num_queued_requests() + + yield + finally: + # If the request is disconnected before assignment, this coroutine + # gets cancelled by the caller and an asyncio.CancelledError is + # raised. The finally block ensures that num_queued_requests + # is correctly decremented in this case. + self.dec_num_queued_requests() + + def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): + """Prune list of replica ids in self.num_queries_sent_to_replicas. + + We want to avoid self.num_queries_sent_to_replicas from growing + in memory as the deployment upscales and downscales over time. + """ + + running_replica_set = {replica.replica_id for replica in running_replicas} + with self._queries_lock: + self.num_requests_sent_to_replicas = defaultdict( + int, + { + id: self.num_requests_sent_to_replicas[id] + for id, num_queries in self.num_requests_sent_to_replicas.items() + if num_queries or id in running_replica_set + }, + ) + + @property + def autoscaling_config(self) -> Optional[AutoscalingConfig]: + if self._deployment_config is None: + return None + + return self._deployment_config.autoscaling_config + + def update_deployment_config( + self, deployment_config: DeploymentConfig, curr_num_replicas: int + ): + """Update the config for the deployment this router sends requests to.""" + + if self._shutdown: + return + + self._deployment_config = deployment_config + + # Start the metrics pusher if autoscaling is enabled. + autoscaling_config = self.autoscaling_config + if autoscaling_config: + self.metrics_pusher.start() + # Optimization for autoscaling cold start time. If there are + # currently 0 replicas for the deployment, and there is at + # least one queued request on this router, then immediately + # push handle metric to the controller. + if self.should_send_scaled_to_zero_optimized_push(curr_num_replicas): + self.push_autoscaling_metrics_to_controller() + + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: + # Record number of queued + ongoing requests at regular + # intervals into the in-memory metrics store + self.metrics_pusher.register_or_update_task( + self.RECORD_METRICS_TASK_NAME, + self._add_autoscaling_metrics_point, + min( + RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S, + autoscaling_config.metrics_interval_s, + ), + ) + # Push metrics to the controller periodically. + self.metrics_pusher.register_or_update_task( + self.PUSH_METRICS_TO_CONTROLLER_TASK_NAME, + self.push_autoscaling_metrics_to_controller, + autoscaling_config.metrics_interval_s, + ) + else: + self.metrics_pusher.register_or_update_task( + self.PUSH_METRICS_TO_CONTROLLER_TASK_NAME, + self.push_autoscaling_metrics_to_controller, + HANDLE_METRIC_PUSH_INTERVAL_S, + ) + + else: + if self.metrics_pusher: + self.metrics_pusher.stop_tasks() + + def inc_num_total_requests(self, route: str): + self.num_router_requests.inc(tags={"route": route}) + + def inc_num_queued_requests(self): + self.num_queued_requests += 1 + self.num_queued_requests_gauge.set(self.num_queued_requests) + + def dec_num_queued_requests(self): + self.num_queued_requests -= 1 + self.num_queued_requests_gauge.set(self.num_queued_requests) + + def inc_num_running_requests_for_replica(self, replica_id: ReplicaID): + with self._queries_lock: + self.num_requests_sent_to_replicas[replica_id] += 1 + self.num_running_requests_gauge.set( + sum(self.num_requests_sent_to_replicas.values()) + ) + + def dec_num_running_requests_for_replica(self, replica_id: ReplicaID): + with self._queries_lock: + self.num_requests_sent_to_replicas[replica_id] -= 1 + self.num_running_requests_gauge.set( + sum(self.num_requests_sent_to_replicas.values()) + ) + + def should_send_scaled_to_zero_optimized_push(self, curr_num_replicas: int) -> bool: + return ( + self.autoscaling_config is not None + and curr_num_replicas == 0 + and self.num_queued_requests > 0 + ) + + def push_autoscaling_metrics_to_controller(self): + """Pushes queued and running request metrics to the controller. + + These metrics are used by the controller for autoscaling. + """ + + self._controller_handle.record_handle_metrics.remote( + send_timestamp=time.time(), + deployment_id=self._deployment_id, + handle_id=self._handle_id, + actor_id=self._self_actor_id, + handle_source=self._handle_source, + **self._get_aggregated_requests(), + ) + + def _add_autoscaling_metrics_point(self): + """Adds metrics point for queued and running requests at replicas. + + Also prunes keys in the in memory metrics store with outdated datapoints. + """ + + timestamp = time.time() + self.metrics_store.add_metrics_point( + {QUEUED_REQUESTS_KEY: self.num_queued_requests}, timestamp + ) + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: + self.metrics_store.add_metrics_point( + self.num_requests_sent_to_replicas, timestamp + ) + + # Prevent in memory metrics store memory from growing + start_timestamp = time.time() - self.autoscaling_config.look_back_period_s + self.metrics_store.prune_keys_and_compact_data(start_timestamp) + + def _get_aggregated_requests(self): + running_requests = dict() + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE and self.autoscaling_config: + look_back_period = self.autoscaling_config.look_back_period_s + running_requests = { + replica_id: self.metrics_store.window_average( + replica_id, time.time() - look_back_period + ) + # If data hasn't been recorded yet, return current + # number of queued and ongoing requests. + or num_requests + for replica_id, num_requests in self.num_requests_sent_to_replicas.items() # noqa: E501 + } + + return { + "queued_requests": self.num_queued_requests, + "running_requests": running_requests, + } + + async def shutdown(self): + """Shutdown metrics manager gracefully.""" + + if self.metrics_pusher: + await self.metrics_pusher.graceful_shutdown() + + self._shutdown = True + + +class Router(ABC): + @abstractmethod + def running_replicas_populated(self) -> bool: + pass + + @abstractmethod + def assign_request( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[ReplicaResult]: + pass + + @abstractmethod + def shutdown(self) -> concurrent.futures.Future: + pass + + +class AsyncioRouter: + def __init__( + self, + controller_handle: ActorHandle, + deployment_id: DeploymentID, + handle_id: str, + self_actor_id: str, + handle_source: DeploymentHandleSource, + event_loop: asyncio.BaseEventLoop, + replica_scheduler: Optional[ReplicaScheduler], + enable_strict_max_ongoing_requests: bool, + resolve_request_arg_func: Coroutine = resolve_deployment_response, + ): + """Used to assign requests to downstream replicas for a deployment. + + The scheduling behavior is delegated to a ReplicaScheduler; this is a thin + wrapper that adds metrics and logging. + """ + + self._event_loop = event_loop + self.deployment_id = deployment_id + self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests + # By default, deployment is available unless we receive news + # otherwise through a long poll broadcast from the controller. + self._deployment_available = True + + self._replica_scheduler: ReplicaScheduler = replica_scheduler + self._resolve_request_arg_func = resolve_request_arg_func + + # Flipped to `True` once the router has received a non-empty + # replica set at least once. + self._running_replicas_populated: bool = False + + # Initializing `self._metrics_manager` before `self.long_poll_client` is + # necessary to avoid race condition where `self.update_deployment_config()` + # might be called before `self._metrics_manager` instance is created. + self._metrics_manager = RouterMetricsManager( + deployment_id, + handle_id, + self_actor_id, + handle_source, + controller_handle, + metrics.Counter( + "serve_num_router_requests", + description="The number of requests processed by the router.", + tag_keys=("deployment", "route", "application", "handle", "actor_id"), + ), + metrics.Gauge( + "serve_deployment_queued_queries", + description=( + "The current number of queries to this deployment waiting" + " to be assigned to a replica." + ), + tag_keys=("deployment", "application", "handle", "actor_id"), + ), + metrics.Gauge( + "serve_num_ongoing_requests_at_replicas", + description=( + "The current number of requests to this deployment that " + "have been submitted to a replica." + ), + tag_keys=("deployment", "application", "handle", "actor_id"), + ), + ) + + # The Router needs to stay informed about changes to the target deployment's + # running replicas and deployment config. We do this via the long poll system. + # However, for efficiency, we don't want to create a LongPollClient for every + # DeploymentHandle, so we use a shared LongPollClient that all Routers + # register themselves with. But first, the router needs to get a fast initial + # update so that it can start serving requests, which we do with a dedicated + # LongPollClient that stops running once the shared client takes over. + + self.long_poll_client = LongPollClient( + controller_handle, + { + ( + LongPollNamespace.DEPLOYMENT_TARGETS, + deployment_id, + ): self.update_deployment_targets, + ( + LongPollNamespace.DEPLOYMENT_CONFIG, + deployment_id, + ): self.update_deployment_config, + }, + call_in_event_loop=self._event_loop, + ) + + shared = SharedRouterLongPollClient.get_or_create( + controller_handle, self._event_loop + ) + shared.register(self) + + def running_replicas_populated(self) -> bool: + return self._running_replicas_populated + + def update_deployment_targets(self, deployment_target_info: DeploymentTargetInfo): + self._deployment_available = deployment_target_info.is_available + + running_replicas = deployment_target_info.running_replicas + self._replica_scheduler.update_running_replicas(running_replicas) + self._metrics_manager.update_running_replicas(running_replicas) + + if running_replicas: + self._running_replicas_populated = True + + def update_deployment_config(self, deployment_config: DeploymentConfig): + self._metrics_manager.update_deployment_config( + deployment_config, + curr_num_replicas=len(self._replica_scheduler.curr_replicas), + ) + + async def _resolve_request_arguments( + self, + request_metadata: RequestMetadata, + request_args: Tuple[Any], + request_kwargs: Dict[str, Any], + ) -> Tuple[Tuple[Any], Dict[str, Any]]: + """Asynchronously resolve and replace top-level request args and kwargs.""" + new_args = list(request_args) + new_kwargs = request_kwargs.copy() + + # Map from index -> task for resolving positional arg + resolve_arg_tasks = {} + for i, obj in enumerate(request_args): + task = await self._resolve_request_arg_func(obj, request_metadata) + if task is not None: + resolve_arg_tasks[i] = task + + # Map from key -> task for resolving key-word arg + resolve_kwarg_tasks = {} + for k, obj in request_kwargs.items(): + task = await self._resolve_request_arg_func(obj, request_metadata) + if task is not None: + resolve_kwarg_tasks[k] = task + + # Gather all argument resolution tasks concurrently. + if resolve_arg_tasks or resolve_kwarg_tasks: + all_tasks = list(resolve_arg_tasks.values()) + list( + resolve_kwarg_tasks.values() + ) + await asyncio.wait(all_tasks) + + # Update new args and new kwargs with resolved arguments + for index, task in resolve_arg_tasks.items(): + new_args[index] = task.result() + for key, task in resolve_kwarg_tasks.items(): + new_kwargs[key] = task.result() + + # Return new args and new kwargs + return new_args, new_kwargs + + def _process_finished_request( + self, + replica_id: ReplicaID, + parent_request_id: str, + response_id: str, + result: Union[Any, RayError], + ): + self._metrics_manager.dec_num_running_requests_for_replica(replica_id) + if isinstance(result, ActorDiedError): + # Replica has died but controller hasn't notified the router yet. + # Don't consider this replica for requests in the future, and retry + # scheduling request. + self._replica_scheduler.on_replica_actor_died(replica_id) + logger.warning( + f"{replica_id} will not be considered for future " + "requests because it has died." + ) + elif isinstance(result, ActorUnavailableError): + # There are network issues, or replica has died but GCS is down so + # ActorUnavailableError will be raised until GCS recovers. For the + # time being, invalidate the cache entry so that we don't try to + # send requests to this replica without actively probing, and retry + # scheduling request. + self._replica_scheduler.on_replica_actor_unavailable(replica_id) + logger.warning( + f"Request failed because {replica_id} is temporarily unavailable." + ) + + async def schedule_and_send_request( + self, pr: PendingRequest + ) -> Tuple[ReplicaResult, ReplicaID]: + """Choose a replica for the request and send it. + + This will block indefinitely if no replicas are available to handle the + request, so it's up to the caller to time out or cancel the request. + """ + r = await self._replica_scheduler.choose_replica_for_request(pr) + + # If the queue len cache is disabled or we're sending a request to Java, + # then directly send the query and hand the response back. The replica will + # never reject requests in this code path. + if not self._enable_strict_max_ongoing_requests or r.is_cross_language: + result, _ = await r.send_request(pr, with_rejection=False) + return result, r.replica_id + + while True: + result = None + try: + result, queue_info = await r.send_request(pr, with_rejection=True) + self._replica_scheduler.on_new_queue_len_info(r.replica_id, queue_info) + if queue_info.accepted: + return result, r.replica_id + except asyncio.CancelledError: + # NOTE(edoakes): this is not strictly necessary because there are + # currently no `await` statements between getting the ref and returning, + # but I'm adding it defensively. + if result is not None: + result.cancel() + + raise + except ActorDiedError: + # Replica has died but controller hasn't notified the router yet. + # Don't consider this replica for requests in the future, and retry + # scheduling request. + self._replica_scheduler.on_replica_actor_died(r.replica_id) + logger.warning( + f"{r.replica_id} will not be considered for future " + "requests because it has died." + ) + except ActorUnavailableError: + # There are network issues, or replica has died but GCS is down so + # ActorUnavailableError will be raised until GCS recovers. For the + # time being, invalidate the cache entry so that we don't try to + # send requests to this replica without actively probing, and retry + # scheduling request. + self._replica_scheduler.on_replica_actor_unavailable(r.replica_id) + logger.warning(f"{r.replica_id} is temporarily unavailable.") + + # If the replica rejects the request, retry the scheduling process. The + # request will be placed on the front of the queue to avoid tail latencies. + # TODO(edoakes): this retry procedure is not perfect because it'll reset the + # process of choosing candidates replicas (i.e., for locality-awareness). + r = await self._replica_scheduler.choose_replica_for_request( + pr, is_retry=True + ) + + async def assign_request( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> ReplicaResult: + """Assign a request to a replica and return the resulting object_ref.""" + + if not self._deployment_available: + raise DeploymentUnavailableError(self.deployment_id) + + response_id = generate_request_id() + assign_request_task = asyncio.current_task() + ray.serve.context._add_request_pending_assignment( + request_meta.internal_request_id, response_id, assign_request_task + ) + assign_request_task.add_done_callback( + lambda _: ray.serve.context._remove_request_pending_assignment( + request_meta.internal_request_id, response_id + ) + ) + + with self._metrics_manager.wrap_request_assignment(request_meta): + # Optimization: if there are currently zero replicas for a deployment, + # push handle metric to controller to allow for fast cold start time. + if self._metrics_manager.should_send_scaled_to_zero_optimized_push( + curr_num_replicas=len(self._replica_scheduler.curr_replicas) + ): + self._metrics_manager.push_autoscaling_metrics_to_controller() + + replica_result = None + try: + request_args, request_kwargs = await self._resolve_request_arguments( + request_meta, request_args, request_kwargs + ) + replica_result, replica_id = await self.schedule_and_send_request( + PendingRequest( + args=list(request_args), + kwargs=request_kwargs, + metadata=request_meta, + ), + ) + + # Keep track of requests that have been sent out to replicas + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: + _request_context = ray.serve.context._get_serve_request_context() + request_id: str = _request_context.request_id + self._metrics_manager.inc_num_running_requests_for_replica( + replica_id + ) + callback = partial( + self._process_finished_request, + replica_id, + request_id, + response_id, + ) + replica_result.add_done_callback(callback) + + return replica_result + except asyncio.CancelledError: + # NOTE(edoakes): this is not strictly necessary because + # there are currently no `await` statements between + # getting the ref and returning, but I'm adding it defensively. + if replica_result is not None: + replica_result.cancel() + + raise + + async def shutdown(self): + await self._metrics_manager.shutdown() + + +class SingletonThreadRouter(Router): + """Wrapper class that runs an AsyncioRouter on a separate thread. + + The motivation for this is to avoid user code blocking the event loop and + preventing the router from making progress. + + Maintains a singleton event loop running in a daemon thread that is shared by + all AsyncioRouters. + """ + + _asyncio_loop: Optional[asyncio.AbstractEventLoop] = None + _asyncio_loop_creation_lock = threading.Lock() + + def __init__(self, **passthrough_kwargs): + assert ( + "event_loop" not in passthrough_kwargs + ), "SingletonThreadRouter manages the router event loop." + + self._asyncio_router = AsyncioRouter( + event_loop=self._get_singleton_asyncio_loop(), **passthrough_kwargs + ) + + @classmethod + def _get_singleton_asyncio_loop(cls) -> asyncio.AbstractEventLoop: + """Get singleton asyncio loop running in a daemon thread. + + This method is thread safe. + """ + with cls._asyncio_loop_creation_lock: + if cls._asyncio_loop is None: + cls._asyncio_loop = asyncio.new_event_loop() + thread = threading.Thread( + daemon=True, + target=cls._asyncio_loop.run_forever, + ) + thread.start() + + return cls._asyncio_loop + + def running_replicas_populated(self) -> bool: + return self._asyncio_router.running_replicas_populated() + + def assign_request( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[ReplicaResult]: + return asyncio.run_coroutine_threadsafe( + self._asyncio_router.assign_request( + request_meta, *request_args, **request_kwargs + ), + loop=self._asyncio_loop, + ) + + def shutdown(self) -> concurrent.futures.Future: + return asyncio.run_coroutine_threadsafe( + self._asyncio_router.shutdown(), loop=self._asyncio_loop + ) + + +class SharedRouterLongPollClient: + def __init__(self, controller_handle: ActorHandle, event_loop: AbstractEventLoop): + self.controller_handler = controller_handle + + # We use a WeakSet to store the Routers so that we don't prevent them + # from being garbage-collected. + self.routers: MutableMapping[ + DeploymentID, weakref.WeakSet[AsyncioRouter] + ] = defaultdict(weakref.WeakSet) + + # Creating the LongPollClient implicitly starts it + self.long_poll_client = LongPollClient( + controller_handle, + key_listeners={}, + call_in_event_loop=event_loop, + ) + + @classmethod + @lru_cache(maxsize=None) + def get_or_create( + cls, controller_handle: ActorHandle, event_loop: AbstractEventLoop + ) -> "SharedRouterLongPollClient": + shared = cls(controller_handle=controller_handle, event_loop=event_loop) + logger.info(f"Started {shared}.") + return shared + + def update_deployment_targets( + self, + deployment_target_info: DeploymentTargetInfo, + deployment_id: DeploymentID, + ) -> None: + for router in self.routers[deployment_id]: + router.update_deployment_targets(deployment_target_info) + router.long_poll_client.stop() + + def update_deployment_config( + self, deployment_config: DeploymentConfig, deployment_id: DeploymentID + ) -> None: + for router in self.routers[deployment_id]: + router.update_deployment_config(deployment_config) + router.long_poll_client.stop() + + def register(self, router: AsyncioRouter) -> None: + self.routers[router.deployment_id].add(router) + + # Remove the entries for any deployment ids that no longer have any routers. + # The WeakSets will automatically lose track of Routers that get GC'd, + # but the outer dict will keep the key around, so we need to clean up manually. + # Note the list(...) to avoid mutating self.routers while iterating over it. + for deployment_id, routers in list(self.routers.items()): + if not routers: + self.routers.pop(deployment_id) + + # Register the new listeners on the long poll client. + # Some of these listeners may already exist, but it's safe to add them again. + key_listeners = { + (LongPollNamespace.DEPLOYMENT_TARGETS, deployment_id): partial( + self.update_deployment_targets, deployment_id=deployment_id + ) + for deployment_id in self.routers.keys() + } | { + (LongPollNamespace.DEPLOYMENT_CONFIG, deployment_id): partial( + self.update_deployment_config, deployment_id=deployment_id + ) + for deployment_id in self.routers.keys() + } + self.long_poll_client.add_key_listeners(key_listeners) diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/test_utils.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..219ccd908e2b052165235f5a116d4c24de129b33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/test_utils.py @@ -0,0 +1,701 @@ +import asyncio +import datetime +import os +import threading +import time +from contextlib import asynccontextmanager +from copy import copy, deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import grpc +import requests +from starlette.requests import Request + +import ray +import ray.util.state as state_api +from ray import serve +from ray.actor import ActorHandle +from ray.serve._private.client import ServeControllerClient +from ray.serve._private.common import ( + CreatePlacementGroupRequest, + DeploymentID, + DeploymentStatus, + RequestProtocol, +) +from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME, SERVE_NAMESPACE +from ray.serve._private.deployment_state import ALL_REPLICA_STATES, ReplicaState +from ray.serve._private.proxy import DRAINING_MESSAGE +from ray.serve._private.usage import ServeUsageTag +from ray.serve._private.utils import TimerBase +from ray.serve.context import _get_global_client +from ray.serve.generated import serve_pb2, serve_pb2_grpc +from ray.serve.schema import ApplicationStatus + +TELEMETRY_ROUTE_PREFIX = "/telemetry" +STORAGE_ACTOR_NAME = "storage" + + +class MockTimer(TimerBase): + def __init__(self, start_time: Optional[float] = None): + self._lock = threading.Lock() + self.reset(start_time=start_time) + + def reset(self, start_time: Optional[float] = None): + if start_time is None: + start_time = time.time() + self._curr = start_time + + def time(self) -> float: + return self._curr + + def advance(self, by: float): + with self._lock: + self._curr += by + + def realistic_sleep(self, amt: float): + with self._lock: + self._curr += amt + 0.001 + + +class MockAsyncTimer: + def __init__(self, start_time: Optional[float] = 0): + self.reset(start_time=start_time) + self._num_sleepers = 0 + + def reset(self, start_time: 0): + self._curr = start_time + + def time(self) -> float: + return self._curr + + async def sleep(self, amt: float): + self._num_sleepers += 1 + end = self._curr + amt + + # Give up the event loop + while self._curr < end: + await asyncio.sleep(0) + + self._num_sleepers -= 1 + + def advance(self, amt: float): + self._curr += amt + + def num_sleepers(self): + return self._num_sleepers + + +class MockKVStore: + def __init__(self): + self.store = dict() + + def put(self, key: str, val: Any) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + self.store[key] = val + return True + + def get(self, key: str) -> Any: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + return self.store.get(key, None) + + def delete(self, key: str) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + + if key in self.store: + del self.store[key] + return True + + return False + + +class MockClusterNodeInfoCache: + def __init__(self): + self.alive_node_ids = set() + self.total_resources_per_node = dict() + self.available_resources_per_node = dict() + self.draining_nodes = dict() + self.node_labels = dict() + + def get_alive_node_ids(self): + return self.alive_node_ids + + def get_draining_nodes(self): + return self.draining_nodes + + def get_active_node_ids(self): + return self.alive_node_ids - set(self.draining_nodes) + + def get_node_az(self, node_id): + return None + + def get_available_resources_per_node(self): + return self.available_resources_per_node + + def get_total_resources_per_node(self): + return self.total_resources_per_node + + def add_node(self, node_id: str, resources: Dict = None, labels: Dict = None): + self.alive_node_ids.add(node_id) + self.total_resources_per_node[node_id] = deepcopy(resources) or {} + self.available_resources_per_node[node_id] = deepcopy(resources) or {} + self.node_labels[node_id] = labels or {} + + def set_available_resources_per_node(self, node_id: str, resources: Dict): + self.available_resources_per_node[node_id] = deepcopy(resources) + + +class FakeRemoteFunction: + def remote(self): + pass + + +class MockActorHandle: + def __init__(self, **kwargs): + self._options = kwargs + self._actor_id = "fake_id" + self.initialize_and_get_metadata_called = False + self.is_allocated_called = False + + @property + def initialize_and_get_metadata(self): + self.initialize_and_get_metadata_called = True + # return a mock object so that we can call `remote()` on it. + return FakeRemoteFunction() + + @property + def is_allocated(self): + self.is_allocated_called = True + return FakeRemoteFunction() + + +class MockActorClass: + def __init__(self): + self._init_args = () + self._options = dict() + + def options(self, **kwargs): + res = copy(self) + + for k, v in kwargs.items(): + res._options[k] = v + + return res + + def remote(self, *args) -> MockActorHandle: + return MockActorHandle(init_args=args, **self._options) + + +class MockPlacementGroup: + def __init__(self, request: CreatePlacementGroupRequest): + self._bundles = request.bundles + self._strategy = request.strategy + self._soft_target_node_id = request.target_node_id + self._name = request.name + self._lifetime = "detached" + + +class MockDeploymentHandle: + def __init__(self, deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME): + self._deployment_name = deployment_name + self._app_name = app_name + self._protocol = RequestProtocol.UNDEFINED + self._running_replicas_populated = False + self._initialized = False + + def is_initialized(self): + return self._initialized + + def _init(self): + if self._initialized: + raise RuntimeError("already initialized") + + self._initialized = True + + def options(self, *args, **kwargs): + return self + + def __eq__(self, dep: Tuple[str]): + other_deployment_name, other_app_name = dep + return ( + self._deployment_name == other_deployment_name + and self._app_name == other_app_name + ) + + def _set_request_protocol(self, protocol: RequestProtocol): + self._protocol = protocol + + def _get_or_create_router(self): + pass + + def running_replicas_populated(self) -> bool: + return self._running_replicas_populated + + def set_running_replicas_populated(self, val: bool): + self._running_replicas_populated = val + + +@serve.deployment +class GetPID: + def __call__(self): + return os.getpid() + + +get_pid_entrypoint = GetPID.bind() + + +def check_ray_stopped(): + try: + requests.get("http://localhost:52365/api/ray/version") + return False + except Exception: + return True + + +def check_ray_started(): + return requests.get("http://localhost:52365/api/ray/version").status_code == 200 + + +def check_deployment_status( + name: str, expected_status: DeploymentStatus, app_name=SERVE_DEFAULT_APP_NAME +) -> bool: + app_status = serve.status().applications[app_name] + assert app_status.deployments[name].status == expected_status + return True + + +def get_num_alive_replicas( + deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME +) -> int: + """Get the replicas currently running for the given deployment.""" + + dep_id = DeploymentID(name=deployment_name, app_name=app_name) + actors = state_api.list_actors( + filters=[ + ("class_name", "=", dep_id.to_replica_actor_class_name()), + ("state", "=", "ALIVE"), + ] + ) + return len(actors) + + +def check_num_replicas_gte( + name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME +) -> int: + """Check if num replicas is >= target.""" + + assert get_num_alive_replicas(name, app_name) >= target + return True + + +def check_num_replicas_eq( + name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME +) -> int: + """Check if num replicas is == target.""" + + assert get_num_alive_replicas(name, app_name) == target + return True + + +def check_num_replicas_lte( + name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME +) -> int: + """Check if num replicas is <= target.""" + + assert get_num_alive_replicas(name, app_name) <= target + return True + + +def check_apps_running(apps: List): + status = serve.status() + + for app_name in apps: + assert status.applications[app_name].status == ApplicationStatus.RUNNING + + return True + + +def check_replica_counts( + controller: ActorHandle, + deployment_id: DeploymentID, + total: Optional[int] = None, + by_state: Optional[List[Tuple[ReplicaState, int, Callable]]] = None, +): + """Uses _dump_replica_states_for_testing to check replica counts. + + Args: + controller: A handle to the Serve controller. + deployment_id: The deployment to check replica counts for. + total: The total number of expected replicas for the deployment. + by_state: A list of tuples of the form + (replica state, number of replicas, filter function). + Used for more fine grained checks. + """ + replicas = ray.get( + controller._dump_replica_states_for_testing.remote(deployment_id) + ) + + if total is not None: + replica_counts = { + state: len(replicas.get([state])) + for state in ALL_REPLICA_STATES + if replicas.get([state]) + } + assert replicas.count() == total, replica_counts + + if by_state is not None: + for state, count, check in by_state: + assert isinstance(state, ReplicaState) + assert isinstance(count, int) and count >= 0 + if check: + filtered = {r for r in replicas.get(states=[state]) if check(r)} + curr_count = len(filtered) + else: + curr_count = replicas.count(states=[state]) + msg = f"Expected {count} for state {state} but got {curr_count}." + assert curr_count == count, msg + + return True + + +@ray.remote(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE, num_cpus=0) +class TelemetryStorage: + def __init__(self): + self.reports_received = 0 + self.current_report = dict() + + def store_report(self, report: Dict) -> None: + self.reports_received += 1 + self.current_report = report + + def get_report(self) -> Dict: + return self.current_report + + def get_reports_received(self) -> int: + return self.reports_received + + +@serve.deployment(ray_actor_options={"num_cpus": 0}) +class TelemetryReceiver: + def __init__(self): + self.storage = ray.get_actor(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE) + + async def __call__(self, request: Request) -> bool: + report = await request.json() + ray.get(self.storage.store_report.remote(report)) + return True + + +receiver_app = TelemetryReceiver.bind() + + +def start_telemetry_app(): + """Start a telemetry Serve app. + + Ray should be initialized before calling this method. + + NOTE: If you're running the TelemetryReceiver Serve app to check telemetry, + remember that the receiver itself is counted in the telemetry. E.g. if you + deploy a Serve app other than the receiver, the number of apps in the + cluster is 2- not 1– since the receiver is also running. + + Returns a handle to a TelemetryStorage actor. You can use this actor + to access the latest telemetry reports. + """ + + storage = TelemetryStorage.remote() + serve.run(receiver_app, name="telemetry", route_prefix=TELEMETRY_ROUTE_PREFIX) + return storage + + +def check_telemetry( + tag: ServeUsageTag, expected: Any, storage_actor_name: str = STORAGE_ACTOR_NAME +): + storage_handle = ray.get_actor(storage_actor_name, namespace=SERVE_NAMESPACE) + report = ray.get(storage_handle.get_report.remote()) + print(report["extra_usage_tags"]) + assert tag.get_value_from_report(report) == expected + return True + + +def ping_grpc_list_applications(channel, app_names, test_draining=False): + import pytest + + stub = serve_pb2_grpc.RayServeAPIServiceStub(channel) + request = serve_pb2.ListApplicationsRequest() + if test_draining: + with pytest.raises(grpc.RpcError) as exception_info: + _, _ = stub.ListApplications.with_call(request=request) + rpc_error = exception_info.value + assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE + assert rpc_error.details() == DRAINING_MESSAGE + else: + response, call = stub.ListApplications.with_call(request=request) + assert call.code() == grpc.StatusCode.OK + assert response.application_names == app_names + return True + + +def ping_grpc_healthz(channel, test_draining=False): + import pytest + + stub = serve_pb2_grpc.RayServeAPIServiceStub(channel) + request = serve_pb2.HealthzRequest() + if test_draining: + with pytest.raises(grpc.RpcError) as exception_info: + _, _ = stub.Healthz.with_call(request=request) + rpc_error = exception_info.value + assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE + assert rpc_error.details() == DRAINING_MESSAGE + else: + response, call = stub.Healthz.with_call(request=request) + assert call.code() == grpc.StatusCode.OK + assert response.message == "success" + + +def ping_grpc_call_method(channel, app_name, test_not_found=False): + import pytest + + stub = serve_pb2_grpc.UserDefinedServiceStub(channel) + request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") + metadata = (("application", app_name),) + if test_not_found: + with pytest.raises(grpc.RpcError) as exception_info: + _, _ = stub.__call__.with_call(request=request, metadata=metadata) + rpc_error = exception_info.value + assert rpc_error.code() == grpc.StatusCode.NOT_FOUND, rpc_error.code() + assert f"Application '{app_name}' not found." in rpc_error.details() + else: + response, call = stub.__call__.with_call(request=request, metadata=metadata) + assert call.code() == grpc.StatusCode.OK, call.code() + assert response.greeting == "Hello foo from bar", response.greeting + + +def ping_grpc_another_method(channel, app_name): + stub = serve_pb2_grpc.UserDefinedServiceStub(channel) + request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") + metadata = (("application", app_name),) + response = stub.Method1(request=request, metadata=metadata) + assert response.greeting == "Hello foo from method1" + + +def ping_grpc_model_multiplexing(channel, app_name): + stub = serve_pb2_grpc.UserDefinedServiceStub(channel) + request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") + multiplexed_model_id = "999" + metadata = ( + ("application", app_name), + ("multiplexed_model_id", multiplexed_model_id), + ) + response = stub.Method2(request=request, metadata=metadata) + assert ( + response.greeting + == f"Method2 called model, loading model: {multiplexed_model_id}" + ) + + +def ping_grpc_streaming(channel, app_name): + stub = serve_pb2_grpc.UserDefinedServiceStub(channel) + request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") + metadata = (("application", app_name),) + responses = stub.Streaming(request=request, metadata=metadata) + for idx, response in enumerate(responses): + assert response.greeting == f"{idx}: Hello foo from bar" + + +def ping_fruit_stand(channel, app_name): + stub = serve_pb2_grpc.FruitServiceStub(channel) + request = serve_pb2.FruitAmounts(orange=4, apple=8) + metadata = (("application", app_name),) + response = stub.FruitStand(request=request, metadata=metadata) + assert response.costs == 32 + + +@asynccontextmanager +async def send_signal_on_cancellation(signal_actor: ActorHandle): + cancelled = False + try: + yield + await asyncio.sleep(100) + except asyncio.CancelledError: + cancelled = True + # Clear the context var to avoid Ray recursively cancelling this method call. + ray._raylet.async_task_id.set(None) + await signal_actor.send.remote() + + if not cancelled: + raise RuntimeError( + "CancelledError wasn't raised during `send_signal_on_cancellation` block" + ) + + +class FakeGrpcContext: + def __init__(self): + self._auth_context = {"key": "value"} + self._invocation_metadata = [("key", "value")] + self._peer = "peer" + self._peer_identities = b"peer_identities" + self._peer_identity_key = "peer_identity_key" + self._code = None + self._details = None + self._trailing_metadata = [] + self._invocation_metadata = [] + + def auth_context(self): + return self._auth_context + + def code(self): + return self._code + + def details(self): + return self._details + + def peer(self): + return self._peer + + def peer_identities(self): + return self._peer_identities + + def peer_identity_key(self): + return self._peer_identity_key + + def trailing_metadata(self): + return self._trailing_metadata + + def set_code(self, code): + self._code = code + + def set_details(self, details): + self._details = details + + def set_trailing_metadata(self, trailing_metadata): + self._trailing_metadata = trailing_metadata + + def invocation_metadata(self): + return self._invocation_metadata + + +class FakeGauge: + def __init__(self, name: str = None, tag_keys: Tuple[str] = None): + self.name = name + self.values = dict() + + self.tags = tag_keys or () + self.default_tags = dict() + + def set_default_tags(self, tags: Dict[str, str]): + for key, tag in tags.items(): + assert key in self.tags + self.default_tags[key] = tag + + def set(self, value: Union[int, float], tags: Dict[str, str] = None): + merged_tags = self.default_tags.copy() + merged_tags.update(tags or {}) + assert set(merged_tags.keys()) == set(self.tags) + + d = self.values + for tag in self.tags[:-1]: + tag_value = merged_tags[tag] + if tag_value not in d: + d[tag_value] = dict() + d = d[tag_value] + + d[merged_tags[self.tags[-1]]] = value + + def get_value(self, tags: Dict[str, str]): + value = self.values + for tag in self.tags: + tag_value = tags[tag] + value = value.get(tag_value) + if value is None: + return + + return value + + +class FakeCounter: + def __init__(self, name: str = None, tag_keys: Tuple[str] = None): + self.name = name + self.counts = dict() + + self.tags = tag_keys or () + self.default_tags = dict() + + def set_default_tags(self, tags: Dict[str, str]): + for key, tag in tags.items(): + assert key in self.tags + self.default_tags[key] = tag + + def inc(self, value: Union[int, float] = 1.0, tags: Dict[str, str] = None): + merged_tags = self.default_tags.copy() + merged_tags.update(tags or {}) + assert set(merged_tags.keys()) == set(self.tags) + + d = self.counts + for tag in self.tags[:-1]: + tag_value = merged_tags[tag] + if tag_value not in d: + d[tag_value] = dict() + d = d[tag_value] + + key = merged_tags[self.tags[-1]] + d[key] = d.get(key, 0) + value + + def get_count(self, tags: Dict[str, str]) -> int: + value = self.counts + for tag in self.tags: + tag_value = tags[tag] + value = value.get(tag_value) + if value is None: + return + + return value + + def get_tags(self): + return self.tags + + +@ray.remote +def get_node_id(): + return ray.get_runtime_context().get_node_id() + + +def check_num_alive_nodes(target: int): + alive_nodes = [node for node in ray.nodes() if node["Alive"]] + assert len(alive_nodes) == target + return True + + +def get_deployment_details( + deployment_name: str, + app_name: str = SERVE_DEFAULT_APP_NAME, + _client: ServeControllerClient = None, +): + client = _client or _get_global_client() + details = client.get_serve_details() + return details["applications"][app_name]["deployments"][deployment_name] + + +@ray.remote +class Counter: + def __init__(self, target: int): + self.count = 0 + self.target = target + self.ready_event = asyncio.Event() + + def inc(self): + self.count += 1 + if self.count == self.target: + self.ready_event.set() + + async def wait(self): + await self.ready_event.wait() + + +def tlog(s: str, level: str = "INFO"): + """Convenient logging method for testing.""" + + now = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] + print(f"[{level}] {now} {s}") diff --git a/.venv/lib/python3.11/site-packages/ray/serve/_private/version.py b/.venv/lib/python3.11/site-packages/ray/serve/_private/version.py new file mode 100644 index 0000000000000000000000000000000000000000..08b53acba0f0cc127b6fffa9379704c3558c1afe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/serve/_private/version.py @@ -0,0 +1,200 @@ +import json +import logging +from copy import deepcopy +from typing import Any, Dict, List, Optional +from zlib import crc32 + +from ray._private.pydantic_compat import BaseModel +from ray.serve._private.config import DeploymentConfig +from ray.serve._private.utils import DeploymentOptionUpdateType, get_random_string +from ray.serve.config import AutoscalingConfig +from ray.serve.generated.serve_pb2 import DeploymentVersion as DeploymentVersionProto + +logger = logging.getLogger("ray.serve") + + +class DeploymentVersion: + def __init__( + self, + code_version: Optional[str], + deployment_config: DeploymentConfig, + ray_actor_options: Optional[Dict], + placement_group_bundles: Optional[List[Dict[str, float]]] = None, + placement_group_strategy: Optional[str] = None, + max_replicas_per_node: Optional[int] = None, + ): + if code_version is not None and not isinstance(code_version, str): + raise TypeError(f"code_version must be str, got {type(code_version)}.") + if code_version is None: + self.code_version = get_random_string() + else: + self.code_version = code_version + + # Options for this field may be mutated over time, so any logic that uses this + # should access this field directly. + self.deployment_config = deployment_config + self.ray_actor_options = ray_actor_options + self.placement_group_bundles = placement_group_bundles + self.placement_group_strategy = placement_group_strategy + self.max_replicas_per_node = max_replicas_per_node + self.compute_hashes() + + @classmethod + def from_deployment_version(cls, deployment_version, deployment_config): + version_copy = deepcopy(deployment_version) + version_copy.deployment_config = deployment_config + version_copy.compute_hashes() + return version_copy + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DeploymentVersion): + return False + return self._hash == other._hash + + def requires_actor_restart(self, new_version): + """Determines whether the new version requires actors of the current version to + be restarted. + """ + return ( + self.code_version != new_version.code_version + or self.ray_actor_options_hash != new_version.ray_actor_options_hash + or self.placement_group_options_hash + != new_version.placement_group_options_hash + or self.max_replicas_per_node != new_version.max_replicas_per_node + ) + + def requires_actor_reconfigure(self, new_version): + """Determines whether the new version requires calling reconfigure() on the + replica actor. + """ + return self.reconfigure_actor_hash != new_version.reconfigure_actor_hash + + def requires_long_poll_broadcast(self, new_version): + """Determines whether lightweightly updating an existing replica to the new + version requires broadcasting through long poll that the running replicas has + changed. + """ + return ( + self.deployment_config.max_ongoing_requests + != new_version.deployment_config.max_ongoing_requests + ) + + def compute_hashes(self): + # If these change, the controller will rolling upgrade existing replicas. + serialized_ray_actor_options = _serialize(self.ray_actor_options or {}) + self.ray_actor_options_hash = crc32(serialized_ray_actor_options) + combined_placement_group_options = {} + if self.placement_group_bundles is not None: + combined_placement_group_options["bundles"] = self.placement_group_bundles + if self.placement_group_strategy is not None: + combined_placement_group_options["strategy"] = self.placement_group_strategy + serialized_placement_group_options = _serialize( + combined_placement_group_options + ) + self.placement_group_options_hash = crc32(serialized_placement_group_options) + + # If this changes, DeploymentReplica.reconfigure() will call reconfigure on the + # actual replica actor + self.reconfigure_actor_hash = crc32( + self._get_serialized_options( + [DeploymentOptionUpdateType.NeedsActorReconfigure] + ) + ) + + # Used by __eq__ in deployment state to either reconfigure the replicas or + # stop and restart them + self._hash = crc32( + self.code_version.encode("utf-8") + + serialized_ray_actor_options + + serialized_placement_group_options + + str(self.max_replicas_per_node).encode("utf-8") + + self._get_serialized_options( + [ + DeploymentOptionUpdateType.NeedsReconfigure, + DeploymentOptionUpdateType.NeedsActorReconfigure, + ] + ) + ) + + def to_proto(self) -> bytes: + # TODO(simon): enable cross language user config + return DeploymentVersionProto( + code_version=self.code_version, + deployment_config=self.deployment_config.to_proto(), + ray_actor_options=json.dumps(self.ray_actor_options), + placement_group_bundles=json.dumps(self.placement_group_bundles) + if self.placement_group_bundles is not None + else "", + placement_group_strategy=self.placement_group_strategy + if self.placement_group_strategy is not None + else "", + max_replicas_per_node=self.max_replicas_per_node + if self.max_replicas_per_node is not None + else 0, + ) + + @classmethod + def from_proto(cls, proto: DeploymentVersionProto): + return DeploymentVersion( + proto.code_version, + DeploymentConfig.from_proto(proto.deployment_config), + json.loads(proto.ray_actor_options), + placement_group_bundles=( + json.loads(proto.placement_group_bundles) + if proto.placement_group_bundles + else None + ), + placement_group_version=( + proto.placement_group_version if proto.placement_group_version else None + ), + max_replicas_per_node=( + proto.max_replicas_per_node if proto.max_replicas_per_node else None + ), + ) + + def _get_serialized_options( + self, update_types: List[DeploymentOptionUpdateType] + ) -> bytes: + """Returns a serialized dictionary containing fields of a deployment config that + should prompt a deployment version update. + """ + reconfigure_dict = {} + # TODO(aguo): Once we only support pydantic 2, we can remove this if check. + # In pydantic 2.0, `__fields__` has been renamed to `model_fields`. + fields = ( + self.deployment_config.model_fields + if hasattr(self.deployment_config, "model_fields") + else self.deployment_config.__fields__ + ) + for option_name, field in fields.items(): + option_weight = field.field_info.extra.get("update_type") + if option_weight in update_types: + reconfigure_dict[option_name] = getattr( + self.deployment_config, option_name + ) + # If autoscaling config was changed, only broadcast to + # replicas if metrics_interval_s or look_back_period_s + # was changed, because the rest of the fields are only + # used in deployment state manager + if isinstance(reconfigure_dict[option_name], AutoscalingConfig): + reconfigure_dict[option_name] = reconfigure_dict[option_name].dict( + include={"metrics_interval_s", "look_back_period_s"} + ) + elif isinstance(reconfigure_dict[option_name], BaseModel): + reconfigure_dict[option_name] = reconfigure_dict[option_name].dict() + + if ( + isinstance(self.deployment_config.user_config, bytes) + and "user_config" in reconfigure_dict + ): + del reconfigure_dict["user_config"] + return self.deployment_config.user_config + _serialize(reconfigure_dict) + + return _serialize(reconfigure_dict) + + +def _serialize(json_object): + return str.encode(json.dumps(json_object, sort_keys=True))