diff --git a/.gitattributes b/.gitattributes index 2180227b7a0e6b7d31765bb94941cf79afa1d08b..19e8357d01e7258ab53b5ddbd60ae495f1031d94 100644 --- a/.gitattributes +++ b/.gitattributes @@ -157,3 +157,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/xgrammar/xgrammar_bindings.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/_raylet.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/ray/_raylet.so b/.venv/lib/python3.11/site-packages/ray/_raylet.so new file mode 100644 index 0000000000000000000000000000000000000000..0c81c8b23a01d1a7f035f41beb0fc6be9175a9e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_raylet.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5516ec5efa37efb034ca0fa6c8403331a430101356bdf9829cded6351510037 +size 35971224 diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/agent_manager_pb2_grpc.py b/.venv/lib/python3.11/site-packages/ray/core/generated/agent_manager_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..2daafffebfc817aefe8fcb96eaec25e65b3903e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/agent_manager_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/core_worker_pb2.py b/.venv/lib/python3.11/site-packages/ray/core/generated/core_worker_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..ae156278509e978685c62359f8c05e004ac0bd70 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/core_worker_pb2.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: src/ray/protobuf/core_worker.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import common_pb2 as src_dot_ray_dot_protobuf_dot_common__pb2 +from . import pubsub_pb2 as src_dot_ray_dot_protobuf_dot_pubsub__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"src/ray/protobuf/core_worker.proto\x12\x07ray.rpc\x1a\x1dsrc/ray/protobuf/common.proto\x1a\x1dsrc/ray/protobuf/pubsub.proto\"0\n\x0f\x41\x63tiveObjectIDs\x12\x1d\n\nobject_ids\x18\x01 \x03(\x0cR\tobjectIds\"\xfc\x05\n\x0b\x41\x63torHandle\x12\x19\n\x08\x61\x63tor_id\x18\x01 \x01(\x0cR\x07\x61\x63torId\x12\x19\n\x08owner_id\x18\x02 \x01(\x0cR\x07ownerId\x12\x35\n\rowner_address\x18\x03 \x01(\x0b\x32\x10.ray.rpc.AddressR\x0cownerAddress\x12&\n\x0f\x63reation_job_id\x18\x04 \x01(\x0cR\rcreationJobId\x12\x38\n\x0e\x61\x63tor_language\x18\x05 \x01(\x0e\x32\x11.ray.rpc.LanguageR\ractorLanguage\x12q\n\'actor_creation_task_function_descriptor\x18\x06 \x01(\x0b\x32\x1b.ray.rpc.FunctionDescriptorR#actorCreationTaskFunctionDescriptor\x12!\n\x0c\x61\x63tor_cursor\x18\x07 \x01(\x0cR\x0b\x61\x63torCursor\x12%\n\x0e\x65xtension_data\x18\x08 \x01(\x0cR\rextensionData\x12(\n\x10max_task_retries\x18\t \x01(\x03R\x0emaxTaskRetries\x12\x12\n\x04name\x18\n \x01(\tR\x04name\x12#\n\rray_namespace\x18\x0b \x01(\tR\x0crayNamespace\x12/\n\x14\x65xecute_out_of_order\x18\x0c \x01(\x08R\x11\x65xecuteOutOfOrder\x12*\n\x11max_pending_calls\x18\r \x01(\x05R\x0fmaxPendingCalls\x12,\n\x12\x65nable_task_events\x18\x0e \x01(\x08R\x10\x65nableTaskEvents\x12\x38\n\x06labels\x18\x0f \x03(\x0b\x32 .ray.rpc.ActorHandle.LabelsEntryR\x06labels\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\x93\x02\n\x0fPushTaskRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12.\n\ttask_spec\x18\x02 \x01(\x0b\x32\x11.ray.rpc.TaskSpecR\x08taskSpec\x12\'\n\x0fsequence_number\x18\x03 \x01(\x03R\x0esequenceNumber\x12\x33\n\x16\x63lient_processed_up_to\x18\x04 \x01(\x03R\x13\x63lientProcessedUpTo\x12\x44\n\x10resource_mapping\x18\x05 \x03(\x0b\x32\x19.ray.rpc.ResourceMapEntryR\x0fresourceMapping\"\x87\x05\n\rPushTaskReply\x12<\n\x0ereturn_objects\x18\x01 \x03(\x0b\x32\x15.ray.rpc.ReturnObjectR\rreturnObjects\x12K\n\x16\x64ynamic_return_objects\x18\x02 \x03(\x0b\x32\x15.ray.rpc.ReturnObjectR\x14\x64ynamicReturnObjects\x12%\n\x0eworker_exiting\x18\x03 \x01(\x08R\rworkerExiting\x12\x42\n\rborrowed_refs\x18\x04 \x03(\x0b\x32\x1d.ray.rpc.ObjectReferenceCountR\x0c\x62orrowedRefs\x12,\n\x12is_retryable_error\x18\x05 \x01(\x08R\x10isRetryableError\x12\x30\n\x14is_application_error\x18\x06 \x01(\x08R\x12isApplicationError\x12?\n\x1cwas_cancelled_before_running\x18\x07 \x01(\x08R\x19wasCancelledBeforeRunning\x12+\n\x0f\x61\x63tor_repr_name\x18\x08 \x01(\tH\x00R\ractorReprName\x88\x01\x01\x12\x30\n\x14task_execution_error\x18\t \x01(\tR\x12taskExecutionError\x12l\n\x1estreaming_generator_return_ids\x18\n \x03(\x0b\x32\'.ray.rpc.StreamingGeneratorReturnIdInfoR\x1bstreamingGeneratorReturnIdsB\x12\n\x10_actor_repr_name\"g\n%DirectActorCallArgWaitCompleteRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12\x10\n\x03tag\x18\x02 \x01(\x03R\x03tag\"%\n#DirectActorCallArgWaitCompleteReply\"]\n\x16GetObjectStatusRequest\x12&\n\x0fowner_worker_id\x18\x01 \x01(\x0cR\rownerWorkerId\x12\x1b\n\tobject_id\x18\x02 \x01(\x0cR\x08objectId\"\x85\x01\n\tRayObject\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x1a\n\x08metadata\x18\x02 \x01(\x0cR\x08metadata\x12H\n\x13nested_inlined_refs\x18\x03 \x03(\x0b\x32\x18.ray.rpc.ObjectReferenceR\x11nestedInlinedRefs\"\xfc\x01\n\x14GetObjectStatusReply\x12\x42\n\x06status\x18\x01 \x01(\x0e\x32*.ray.rpc.GetObjectStatusReply.ObjectStatusR\x06status\x12*\n\x06object\x18\x02 \x01(\x0b\x32\x12.ray.rpc.RayObjectR\x06object\x12\x19\n\x08node_ids\x18\x03 \x03(\x0cR\x07nodeIds\x12\x1f\n\x0bobject_size\x18\x04 \x01(\x04R\nobjectSize\"8\n\x0cObjectStatus\x12\x0b\n\x07\x43REATED\x10\x00\x12\x10\n\x0cOUT_OF_SCOPE\x10\x01\x12\t\n\x05\x46REED\x10\x02\"h\n\x1dWaitForActorRefDeletedRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12\x19\n\x08\x61\x63tor_id\x18\x02 \x01(\x0cR\x07\x61\x63torId\"\x1d\n\x1bWaitForActorRefDeletedReply\"\xc0\x01\n UpdateObjectLocationBatchRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12\x17\n\x07node_id\x18\x02 \x01(\x0cR\x06nodeId\x12U\n\x17object_location_updates\x18\x03 \x03(\x0b\x32\x1d.ray.rpc.ObjectLocationUpdateR\x15objectLocationUpdates\" \n\x1eUpdateObjectLocationBatchReply\"w\n\x1bObjectSpilledLocationUpdate\x12\x1f\n\x0bspilled_url\x18\x03 \x01(\tR\nspilledUrl\x12\x37\n\x18spilled_to_local_storage\x18\x04 \x01(\x08R\x15spilledToLocalStorage\"\xe6\x02\n\x14ObjectLocationUpdate\x12\x1b\n\tobject_id\x18\x01 \x01(\x0cR\x08objectId\x12^\n\x16plasma_location_update\x18\x02 \x01(\x0e\x32#.ray.rpc.ObjectPlasmaLocationUpdateH\x00R\x14plasmaLocationUpdate\x88\x01\x01\x12\x61\n\x17spilled_location_update\x18\x03 \x01(\x0b\x32$.ray.rpc.ObjectSpilledLocationUpdateH\x01R\x15spilledLocationUpdate\x88\x01\x01\x12&\n\x0cgenerator_id\x18\x04 \x01(\x0cH\x02R\x0bgeneratorId\x88\x01\x01\x42\x19\n\x17_plasma_location_updateB\x1a\n\x18_spilled_location_updateB\x0f\n\r_generator_id\"m\n\x1eGetObjectLocationsOwnerRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12\x1d\n\nobject_ids\x18\x02 \x03(\x0cR\tobjectIds\"|\n\x1cGetObjectLocationsOwnerReply\x12\\\n\x15object_location_infos\x18\x01 \x03(\x0b\x32(.ray.rpc.WorkerObjectLocationsPubMessageR\x13objectLocationInfos\"\x98\x01\n\x10KillActorRequest\x12*\n\x11intended_actor_id\x18\x01 \x01(\x0cR\x0fintendedActorId\x12\x1d\n\nforce_kill\x18\x02 \x01(\x08R\tforceKill\x12\x39\n\x0b\x64\x65\x61th_cause\x18\x03 \x01(\x0b\x32\x18.ray.rpc.ActorDeathCauseR\ndeathCause\"\x10\n\x0eKillActorReply\"\xa4\x01\n\x11\x43\x61ncelTaskRequest\x12(\n\x10intended_task_id\x18\x01 \x01(\x0cR\x0eintendedTaskId\x12\x1d\n\nforce_kill\x18\x02 \x01(\x08R\tforceKill\x12\x1c\n\trecursive\x18\x03 \x01(\x08R\trecursive\x12(\n\x10\x63\x61ller_worker_id\x18\x04 \x01(\x0cR\x0e\x63\x61llerWorkerId\"t\n\x0f\x43\x61ncelTaskReply\x12\x34\n\x16requested_task_running\x18\x01 \x01(\x08R\x14requestedTaskRunning\x12+\n\x11\x61ttempt_succeeded\x18\x02 \x01(\x08R\x10\x61ttemptSucceeded\"\x80\x01\n\x17RemoteCancelTaskRequest\x12(\n\x10remote_object_id\x18\x01 \x01(\x0cR\x0eremoteObjectId\x12\x1d\n\nforce_kill\x18\x02 \x01(\x08R\tforceKill\x12\x1c\n\trecursive\x18\x03 \x01(\x08R\trecursive\"\x17\n\x15RemoteCancelTaskReply\"\xca\x01\n\x19GetCoreWorkerStatsRequest\x12,\n\x12intended_worker_id\x18\x01 \x01(\x0cR\x10intendedWorkerId\x12.\n\x13include_memory_info\x18\x02 \x01(\x08R\x11includeMemoryInfo\x12*\n\x11include_task_info\x18\x03 \x01(\x08R\x0fincludeTaskInfo\x12\x19\n\x05limit\x18\x04 \x01(\x03H\x00R\x05limit\x88\x01\x01\x42\x08\n\x06_limit\"\xf9\x01\n\x17GetCoreWorkerStatsReply\x12\x44\n\x11\x63ore_worker_stats\x18\x01 \x01(\x0b\x32\x18.ray.rpc.CoreWorkerStatsR\x0f\x63oreWorkerStats\x12M\n\x17owned_task_info_entries\x18\x02 \x03(\x0b\x32\x16.ray.rpc.TaskInfoEntryR\x14ownedTaskInfoEntries\x12(\n\x10running_task_ids\x18\x03 \x03(\x0cR\x0erunningTaskIds\x12\x1f\n\x0btasks_total\x18\x04 \x01(\x03R\ntasksTotal\"E\n\x0eLocalGCRequest\x12\x33\n\x16triggered_by_global_gc\x18\x01 \x01(\x08R\x13triggeredByGlobalGc\"\x0e\n\x0cLocalGCReply\"7\n\x18PlasmaObjectReadyRequest\x12\x1b\n\tobject_id\x18\x01 \x01(\x0cR\x08objectId\"\x18\n\x16PlasmaObjectReadyReply\"T\n\x14\x44\x65leteObjectsRequest\x12\x1d\n\nobject_ids\x18\x01 \x03(\x0cR\tobjectIds\x12\x1d\n\nlocal_only\x18\x02 \x01(\x08R\tlocalOnly\"\x14\n\x12\x44\x65leteObjectsReply\"\xa6\x01\n\x13SpillObjectsRequest\x12I\n\x14object_refs_to_spill\x18\x01 \x03(\x0b\x32\x18.ray.rpc.ObjectReferenceR\x11objectRefsToSpill\x12\x44\n\x0e\x64\x65lete_request\x18\x02 \x01(\x0b\x32\x1d.ray.rpc.DeleteObjectsRequestR\rdeleteRequest\"C\n\x11SpillObjectsReply\x12.\n\x13spilled_objects_url\x18\x01 \x03(\tR\x11spilledObjectsUrl\"\x81\x01\n\x1cRestoreSpilledObjectsRequest\x12.\n\x13spilled_objects_url\x18\x01 \x03(\tR\x11spilledObjectsUrl\x12\x31\n\x15object_ids_to_restore\x18\x02 \x03(\x0cR\x12objectIdsToRestore\"N\n\x1aRestoreSpilledObjectsReply\x12\x30\n\x14\x62ytes_restored_total\x18\x01 \x01(\x03R\x12\x62ytesRestoredTotal\"M\n\x1b\x44\x65leteSpilledObjectsRequest\x12.\n\x13spilled_objects_url\x18\x01 \x03(\tR\x11spilledObjectsUrl\"\x1b\n\x19\x44\x65leteSpilledObjectsReply\",\n\x0b\x45xitRequest\x12\x1d\n\nforce_exit\x18\x01 \x01(\x08R\tforceExit\"%\n\tExitReply\x12\x18\n\x07success\x18\x01 \x01(\x08R\x07success\"\xe4\x01\n\x18\x41ssignObjectOwnerRequest\x12\x1b\n\tobject_id\x18\x01 \x01(\x0cR\x08objectId\x12\x1f\n\x0bobject_size\x18\x02 \x01(\x04R\nobjectSize\x12\x30\n\x14\x63ontained_object_ids\x18\x03 \x03(\x0cR\x12\x63ontainedObjectIds\x12;\n\x10\x62orrower_address\x18\x04 \x01(\x0b\x32\x10.ray.rpc.AddressR\x0f\x62orrowerAddress\x12\x1b\n\tcall_site\x18\x05 \x01(\tR\x08\x63\x61llSite\"\x18\n\x16\x41ssignObjectOwnerReply\"\x1f\n\x1dRayletNotifyGCSRestartRequest\"\x1d\n\x1bRayletNotifyGCSRestartReply\"\x18\n\x16NumPendingTasksRequest\"B\n\x14NumPendingTasksReply\x12*\n\x11num_pending_tasks\x18\x01 \x01(\x03R\x0fnumPendingTasks\"\x8c\x02\n!ReportGeneratorItemReturnsRequest\x12K\n\x16\x64ynamic_return_objects\x18\x01 \x03(\x0b\x32\x15.ray.rpc.ReturnObjectR\x14\x64ynamicReturnObjects\x12\x31\n\x0bworker_addr\x18\x02 \x01(\x0b\x32\x10.ray.rpc.AddressR\nworkerAddr\x12\x1d\n\nitem_index\x18\x03 \x01(\x03R\titemIndex\x12!\n\x0cgenerator_id\x18\x05 \x01(\x0cR\x0bgeneratorId\x12%\n\x0e\x61ttempt_number\x18\x06 \x01(\x04R\rattemptNumber\"\\\n\x1fReportGeneratorItemReturnsReply\x12\x39\n\x19total_num_object_consumed\x18\x01 \x01(\x03R\x16totalNumObjectConsumed\"\x99\x01\n\"RegisterMutableObjectReaderRequest\x12(\n\x10writer_object_id\x18\x01 \x01(\x0cR\x0ewriterObjectId\x12\x1f\n\x0bnum_readers\x18\x02 \x01(\x03R\nnumReaders\x12(\n\x10reader_object_id\x18\x03 \x01(\x0cR\x0ereaderObjectId\"\"\n RegisterMutableObjectReaderReply*4\n\x1aObjectPlasmaLocationUpdate\x12\t\n\x05\x41\x44\x44\x45\x44\x10\x00\x12\x0b\n\x07REMOVED\x10\x01\x32\xf7\x10\n\x11\x43oreWorkerService\x12\x66\n\x16RayletNotifyGCSRestart\x12&.ray.rpc.RayletNotifyGCSRestartRequest\x1a$.ray.rpc.RayletNotifyGCSRestartReply\x12<\n\x08PushTask\x12\x18.ray.rpc.PushTaskRequest\x1a\x16.ray.rpc.PushTaskReply\x12~\n\x1e\x44irectActorCallArgWaitComplete\x12..ray.rpc.DirectActorCallArgWaitCompleteRequest\x1a,.ray.rpc.DirectActorCallArgWaitCompleteReply\x12Q\n\x0fGetObjectStatus\x12\x1f.ray.rpc.GetObjectStatusRequest\x1a\x1d.ray.rpc.GetObjectStatusReply\x12\x66\n\x16WaitForActorRefDeleted\x12&.ray.rpc.WaitForActorRefDeletedRequest\x1a$.ray.rpc.WaitForActorRefDeletedReply\x12W\n\x11PubsubLongPolling\x12!.ray.rpc.PubsubLongPollingRequest\x1a\x1f.ray.rpc.PubsubLongPollingReply\x12r\n\x1aReportGeneratorItemReturns\x12*.ray.rpc.ReportGeneratorItemReturnsRequest\x1a(.ray.rpc.ReportGeneratorItemReturnsReply\x12Z\n\x12PubsubCommandBatch\x12\".ray.rpc.PubsubCommandBatchRequest\x1a .ray.rpc.PubsubCommandBatchReply\x12o\n\x19UpdateObjectLocationBatch\x12).ray.rpc.UpdateObjectLocationBatchRequest\x1a\'.ray.rpc.UpdateObjectLocationBatchReply\x12i\n\x17GetObjectLocationsOwner\x12\'.ray.rpc.GetObjectLocationsOwnerRequest\x1a%.ray.rpc.GetObjectLocationsOwnerReply\x12?\n\tKillActor\x12\x19.ray.rpc.KillActorRequest\x1a\x17.ray.rpc.KillActorReply\x12\x42\n\nCancelTask\x12\x1a.ray.rpc.CancelTaskRequest\x1a\x18.ray.rpc.CancelTaskReply\x12T\n\x10RemoteCancelTask\x12 .ray.rpc.RemoteCancelTaskRequest\x1a\x1e.ray.rpc.RemoteCancelTaskReply\x12Z\n\x12GetCoreWorkerStats\x12\".ray.rpc.GetCoreWorkerStatsRequest\x1a .ray.rpc.GetCoreWorkerStatsReply\x12\x39\n\x07LocalGC\x12\x17.ray.rpc.LocalGCRequest\x1a\x15.ray.rpc.LocalGCReply\x12K\n\rDeleteObjects\x12\x1d.ray.rpc.DeleteObjectsRequest\x1a\x1b.ray.rpc.DeleteObjectsReply\x12H\n\x0cSpillObjects\x12\x1c.ray.rpc.SpillObjectsRequest\x1a\x1a.ray.rpc.SpillObjectsReply\x12\x63\n\x15RestoreSpilledObjects\x12%.ray.rpc.RestoreSpilledObjectsRequest\x1a#.ray.rpc.RestoreSpilledObjectsReply\x12`\n\x14\x44\x65leteSpilledObjects\x12$.ray.rpc.DeleteSpilledObjectsRequest\x1a\".ray.rpc.DeleteSpilledObjectsReply\x12W\n\x11PlasmaObjectReady\x12!.ray.rpc.PlasmaObjectReadyRequest\x1a\x1f.ray.rpc.PlasmaObjectReadyReply\x12\x30\n\x04\x45xit\x12\x14.ray.rpc.ExitRequest\x1a\x12.ray.rpc.ExitReply\x12W\n\x11\x41ssignObjectOwner\x12!.ray.rpc.AssignObjectOwnerRequest\x1a\x1f.ray.rpc.AssignObjectOwnerReply\x12Q\n\x0fNumPendingTasks\x12\x1f.ray.rpc.NumPendingTasksRequest\x1a\x1d.ray.rpc.NumPendingTasksReply\x12u\n\x1bRegisterMutableObjectReader\x12+.ray.rpc.RegisterMutableObjectReaderRequest\x1a).ray.rpc.RegisterMutableObjectReaderReplyb\x06proto3') + +_OBJECTPLASMALOCATIONUPDATE = DESCRIPTOR.enum_types_by_name['ObjectPlasmaLocationUpdate'] +ObjectPlasmaLocationUpdate = enum_type_wrapper.EnumTypeWrapper(_OBJECTPLASMALOCATIONUPDATE) +ADDED = 0 +REMOVED = 1 + + +_ACTIVEOBJECTIDS = DESCRIPTOR.message_types_by_name['ActiveObjectIDs'] +_ACTORHANDLE = DESCRIPTOR.message_types_by_name['ActorHandle'] +_ACTORHANDLE_LABELSENTRY = _ACTORHANDLE.nested_types_by_name['LabelsEntry'] +_PUSHTASKREQUEST = DESCRIPTOR.message_types_by_name['PushTaskRequest'] +_PUSHTASKREPLY = DESCRIPTOR.message_types_by_name['PushTaskReply'] +_DIRECTACTORCALLARGWAITCOMPLETEREQUEST = DESCRIPTOR.message_types_by_name['DirectActorCallArgWaitCompleteRequest'] +_DIRECTACTORCALLARGWAITCOMPLETEREPLY = DESCRIPTOR.message_types_by_name['DirectActorCallArgWaitCompleteReply'] +_GETOBJECTSTATUSREQUEST = DESCRIPTOR.message_types_by_name['GetObjectStatusRequest'] +_RAYOBJECT = DESCRIPTOR.message_types_by_name['RayObject'] +_GETOBJECTSTATUSREPLY = DESCRIPTOR.message_types_by_name['GetObjectStatusReply'] +_WAITFORACTORREFDELETEDREQUEST = DESCRIPTOR.message_types_by_name['WaitForActorRefDeletedRequest'] +_WAITFORACTORREFDELETEDREPLY = DESCRIPTOR.message_types_by_name['WaitForActorRefDeletedReply'] +_UPDATEOBJECTLOCATIONBATCHREQUEST = DESCRIPTOR.message_types_by_name['UpdateObjectLocationBatchRequest'] +_UPDATEOBJECTLOCATIONBATCHREPLY = DESCRIPTOR.message_types_by_name['UpdateObjectLocationBatchReply'] +_OBJECTSPILLEDLOCATIONUPDATE = DESCRIPTOR.message_types_by_name['ObjectSpilledLocationUpdate'] +_OBJECTLOCATIONUPDATE = DESCRIPTOR.message_types_by_name['ObjectLocationUpdate'] +_GETOBJECTLOCATIONSOWNERREQUEST = DESCRIPTOR.message_types_by_name['GetObjectLocationsOwnerRequest'] +_GETOBJECTLOCATIONSOWNERREPLY = DESCRIPTOR.message_types_by_name['GetObjectLocationsOwnerReply'] +_KILLACTORREQUEST = DESCRIPTOR.message_types_by_name['KillActorRequest'] +_KILLACTORREPLY = DESCRIPTOR.message_types_by_name['KillActorReply'] +_CANCELTASKREQUEST = DESCRIPTOR.message_types_by_name['CancelTaskRequest'] +_CANCELTASKREPLY = DESCRIPTOR.message_types_by_name['CancelTaskReply'] +_REMOTECANCELTASKREQUEST = DESCRIPTOR.message_types_by_name['RemoteCancelTaskRequest'] +_REMOTECANCELTASKREPLY = DESCRIPTOR.message_types_by_name['RemoteCancelTaskReply'] +_GETCOREWORKERSTATSREQUEST = DESCRIPTOR.message_types_by_name['GetCoreWorkerStatsRequest'] +_GETCOREWORKERSTATSREPLY = DESCRIPTOR.message_types_by_name['GetCoreWorkerStatsReply'] +_LOCALGCREQUEST = DESCRIPTOR.message_types_by_name['LocalGCRequest'] +_LOCALGCREPLY = DESCRIPTOR.message_types_by_name['LocalGCReply'] +_PLASMAOBJECTREADYREQUEST = DESCRIPTOR.message_types_by_name['PlasmaObjectReadyRequest'] +_PLASMAOBJECTREADYREPLY = DESCRIPTOR.message_types_by_name['PlasmaObjectReadyReply'] +_DELETEOBJECTSREQUEST = DESCRIPTOR.message_types_by_name['DeleteObjectsRequest'] +_DELETEOBJECTSREPLY = DESCRIPTOR.message_types_by_name['DeleteObjectsReply'] +_SPILLOBJECTSREQUEST = DESCRIPTOR.message_types_by_name['SpillObjectsRequest'] +_SPILLOBJECTSREPLY = DESCRIPTOR.message_types_by_name['SpillObjectsReply'] +_RESTORESPILLEDOBJECTSREQUEST = DESCRIPTOR.message_types_by_name['RestoreSpilledObjectsRequest'] +_RESTORESPILLEDOBJECTSREPLY = DESCRIPTOR.message_types_by_name['RestoreSpilledObjectsReply'] +_DELETESPILLEDOBJECTSREQUEST = DESCRIPTOR.message_types_by_name['DeleteSpilledObjectsRequest'] +_DELETESPILLEDOBJECTSREPLY = DESCRIPTOR.message_types_by_name['DeleteSpilledObjectsReply'] +_EXITREQUEST = DESCRIPTOR.message_types_by_name['ExitRequest'] +_EXITREPLY = DESCRIPTOR.message_types_by_name['ExitReply'] +_ASSIGNOBJECTOWNERREQUEST = DESCRIPTOR.message_types_by_name['AssignObjectOwnerRequest'] +_ASSIGNOBJECTOWNERREPLY = DESCRIPTOR.message_types_by_name['AssignObjectOwnerReply'] +_RAYLETNOTIFYGCSRESTARTREQUEST = DESCRIPTOR.message_types_by_name['RayletNotifyGCSRestartRequest'] +_RAYLETNOTIFYGCSRESTARTREPLY = DESCRIPTOR.message_types_by_name['RayletNotifyGCSRestartReply'] +_NUMPENDINGTASKSREQUEST = DESCRIPTOR.message_types_by_name['NumPendingTasksRequest'] +_NUMPENDINGTASKSREPLY = DESCRIPTOR.message_types_by_name['NumPendingTasksReply'] +_REPORTGENERATORITEMRETURNSREQUEST = DESCRIPTOR.message_types_by_name['ReportGeneratorItemReturnsRequest'] +_REPORTGENERATORITEMRETURNSREPLY = DESCRIPTOR.message_types_by_name['ReportGeneratorItemReturnsReply'] +_REGISTERMUTABLEOBJECTREADERREQUEST = DESCRIPTOR.message_types_by_name['RegisterMutableObjectReaderRequest'] +_REGISTERMUTABLEOBJECTREADERREPLY = DESCRIPTOR.message_types_by_name['RegisterMutableObjectReaderReply'] +_GETOBJECTSTATUSREPLY_OBJECTSTATUS = _GETOBJECTSTATUSREPLY.enum_types_by_name['ObjectStatus'] +ActiveObjectIDs = _reflection.GeneratedProtocolMessageType('ActiveObjectIDs', (_message.Message,), { + 'DESCRIPTOR' : _ACTIVEOBJECTIDS, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ActiveObjectIDs) + }) +_sym_db.RegisterMessage(ActiveObjectIDs) + +ActorHandle = _reflection.GeneratedProtocolMessageType('ActorHandle', (_message.Message,), { + + 'LabelsEntry' : _reflection.GeneratedProtocolMessageType('LabelsEntry', (_message.Message,), { + 'DESCRIPTOR' : _ACTORHANDLE_LABELSENTRY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ActorHandle.LabelsEntry) + }) + , + 'DESCRIPTOR' : _ACTORHANDLE, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ActorHandle) + }) +_sym_db.RegisterMessage(ActorHandle) +_sym_db.RegisterMessage(ActorHandle.LabelsEntry) + +PushTaskRequest = _reflection.GeneratedProtocolMessageType('PushTaskRequest', (_message.Message,), { + 'DESCRIPTOR' : _PUSHTASKREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PushTaskRequest) + }) +_sym_db.RegisterMessage(PushTaskRequest) + +PushTaskReply = _reflection.GeneratedProtocolMessageType('PushTaskReply', (_message.Message,), { + 'DESCRIPTOR' : _PUSHTASKREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PushTaskReply) + }) +_sym_db.RegisterMessage(PushTaskReply) + +DirectActorCallArgWaitCompleteRequest = _reflection.GeneratedProtocolMessageType('DirectActorCallArgWaitCompleteRequest', (_message.Message,), { + 'DESCRIPTOR' : _DIRECTACTORCALLARGWAITCOMPLETEREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DirectActorCallArgWaitCompleteRequest) + }) +_sym_db.RegisterMessage(DirectActorCallArgWaitCompleteRequest) + +DirectActorCallArgWaitCompleteReply = _reflection.GeneratedProtocolMessageType('DirectActorCallArgWaitCompleteReply', (_message.Message,), { + 'DESCRIPTOR' : _DIRECTACTORCALLARGWAITCOMPLETEREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DirectActorCallArgWaitCompleteReply) + }) +_sym_db.RegisterMessage(DirectActorCallArgWaitCompleteReply) + +GetObjectStatusRequest = _reflection.GeneratedProtocolMessageType('GetObjectStatusRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETOBJECTSTATUSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetObjectStatusRequest) + }) +_sym_db.RegisterMessage(GetObjectStatusRequest) + +RayObject = _reflection.GeneratedProtocolMessageType('RayObject', (_message.Message,), { + 'DESCRIPTOR' : _RAYOBJECT, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RayObject) + }) +_sym_db.RegisterMessage(RayObject) + +GetObjectStatusReply = _reflection.GeneratedProtocolMessageType('GetObjectStatusReply', (_message.Message,), { + 'DESCRIPTOR' : _GETOBJECTSTATUSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetObjectStatusReply) + }) +_sym_db.RegisterMessage(GetObjectStatusReply) + +WaitForActorRefDeletedRequest = _reflection.GeneratedProtocolMessageType('WaitForActorRefDeletedRequest', (_message.Message,), { + 'DESCRIPTOR' : _WAITFORACTORREFDELETEDREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.WaitForActorRefDeletedRequest) + }) +_sym_db.RegisterMessage(WaitForActorRefDeletedRequest) + +WaitForActorRefDeletedReply = _reflection.GeneratedProtocolMessageType('WaitForActorRefDeletedReply', (_message.Message,), { + 'DESCRIPTOR' : _WAITFORACTORREFDELETEDREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.WaitForActorRefDeletedReply) + }) +_sym_db.RegisterMessage(WaitForActorRefDeletedReply) + +UpdateObjectLocationBatchRequest = _reflection.GeneratedProtocolMessageType('UpdateObjectLocationBatchRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEOBJECTLOCATIONBATCHREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.UpdateObjectLocationBatchRequest) + }) +_sym_db.RegisterMessage(UpdateObjectLocationBatchRequest) + +UpdateObjectLocationBatchReply = _reflection.GeneratedProtocolMessageType('UpdateObjectLocationBatchReply', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEOBJECTLOCATIONBATCHREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.UpdateObjectLocationBatchReply) + }) +_sym_db.RegisterMessage(UpdateObjectLocationBatchReply) + +ObjectSpilledLocationUpdate = _reflection.GeneratedProtocolMessageType('ObjectSpilledLocationUpdate', (_message.Message,), { + 'DESCRIPTOR' : _OBJECTSPILLEDLOCATIONUPDATE, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ObjectSpilledLocationUpdate) + }) +_sym_db.RegisterMessage(ObjectSpilledLocationUpdate) + +ObjectLocationUpdate = _reflection.GeneratedProtocolMessageType('ObjectLocationUpdate', (_message.Message,), { + 'DESCRIPTOR' : _OBJECTLOCATIONUPDATE, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ObjectLocationUpdate) + }) +_sym_db.RegisterMessage(ObjectLocationUpdate) + +GetObjectLocationsOwnerRequest = _reflection.GeneratedProtocolMessageType('GetObjectLocationsOwnerRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETOBJECTLOCATIONSOWNERREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetObjectLocationsOwnerRequest) + }) +_sym_db.RegisterMessage(GetObjectLocationsOwnerRequest) + +GetObjectLocationsOwnerReply = _reflection.GeneratedProtocolMessageType('GetObjectLocationsOwnerReply', (_message.Message,), { + 'DESCRIPTOR' : _GETOBJECTLOCATIONSOWNERREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetObjectLocationsOwnerReply) + }) +_sym_db.RegisterMessage(GetObjectLocationsOwnerReply) + +KillActorRequest = _reflection.GeneratedProtocolMessageType('KillActorRequest', (_message.Message,), { + 'DESCRIPTOR' : _KILLACTORREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.KillActorRequest) + }) +_sym_db.RegisterMessage(KillActorRequest) + +KillActorReply = _reflection.GeneratedProtocolMessageType('KillActorReply', (_message.Message,), { + 'DESCRIPTOR' : _KILLACTORREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.KillActorReply) + }) +_sym_db.RegisterMessage(KillActorReply) + +CancelTaskRequest = _reflection.GeneratedProtocolMessageType('CancelTaskRequest', (_message.Message,), { + 'DESCRIPTOR' : _CANCELTASKREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.CancelTaskRequest) + }) +_sym_db.RegisterMessage(CancelTaskRequest) + +CancelTaskReply = _reflection.GeneratedProtocolMessageType('CancelTaskReply', (_message.Message,), { + 'DESCRIPTOR' : _CANCELTASKREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.CancelTaskReply) + }) +_sym_db.RegisterMessage(CancelTaskReply) + +RemoteCancelTaskRequest = _reflection.GeneratedProtocolMessageType('RemoteCancelTaskRequest', (_message.Message,), { + 'DESCRIPTOR' : _REMOTECANCELTASKREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RemoteCancelTaskRequest) + }) +_sym_db.RegisterMessage(RemoteCancelTaskRequest) + +RemoteCancelTaskReply = _reflection.GeneratedProtocolMessageType('RemoteCancelTaskReply', (_message.Message,), { + 'DESCRIPTOR' : _REMOTECANCELTASKREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RemoteCancelTaskReply) + }) +_sym_db.RegisterMessage(RemoteCancelTaskReply) + +GetCoreWorkerStatsRequest = _reflection.GeneratedProtocolMessageType('GetCoreWorkerStatsRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETCOREWORKERSTATSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetCoreWorkerStatsRequest) + }) +_sym_db.RegisterMessage(GetCoreWorkerStatsRequest) + +GetCoreWorkerStatsReply = _reflection.GeneratedProtocolMessageType('GetCoreWorkerStatsReply', (_message.Message,), { + 'DESCRIPTOR' : _GETCOREWORKERSTATSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetCoreWorkerStatsReply) + }) +_sym_db.RegisterMessage(GetCoreWorkerStatsReply) + +LocalGCRequest = _reflection.GeneratedProtocolMessageType('LocalGCRequest', (_message.Message,), { + 'DESCRIPTOR' : _LOCALGCREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.LocalGCRequest) + }) +_sym_db.RegisterMessage(LocalGCRequest) + +LocalGCReply = _reflection.GeneratedProtocolMessageType('LocalGCReply', (_message.Message,), { + 'DESCRIPTOR' : _LOCALGCREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.LocalGCReply) + }) +_sym_db.RegisterMessage(LocalGCReply) + +PlasmaObjectReadyRequest = _reflection.GeneratedProtocolMessageType('PlasmaObjectReadyRequest', (_message.Message,), { + 'DESCRIPTOR' : _PLASMAOBJECTREADYREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PlasmaObjectReadyRequest) + }) +_sym_db.RegisterMessage(PlasmaObjectReadyRequest) + +PlasmaObjectReadyReply = _reflection.GeneratedProtocolMessageType('PlasmaObjectReadyReply', (_message.Message,), { + 'DESCRIPTOR' : _PLASMAOBJECTREADYREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PlasmaObjectReadyReply) + }) +_sym_db.RegisterMessage(PlasmaObjectReadyReply) + +DeleteObjectsRequest = _reflection.GeneratedProtocolMessageType('DeleteObjectsRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEOBJECTSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DeleteObjectsRequest) + }) +_sym_db.RegisterMessage(DeleteObjectsRequest) + +DeleteObjectsReply = _reflection.GeneratedProtocolMessageType('DeleteObjectsReply', (_message.Message,), { + 'DESCRIPTOR' : _DELETEOBJECTSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DeleteObjectsReply) + }) +_sym_db.RegisterMessage(DeleteObjectsReply) + +SpillObjectsRequest = _reflection.GeneratedProtocolMessageType('SpillObjectsRequest', (_message.Message,), { + 'DESCRIPTOR' : _SPILLOBJECTSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.SpillObjectsRequest) + }) +_sym_db.RegisterMessage(SpillObjectsRequest) + +SpillObjectsReply = _reflection.GeneratedProtocolMessageType('SpillObjectsReply', (_message.Message,), { + 'DESCRIPTOR' : _SPILLOBJECTSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.SpillObjectsReply) + }) +_sym_db.RegisterMessage(SpillObjectsReply) + +RestoreSpilledObjectsRequest = _reflection.GeneratedProtocolMessageType('RestoreSpilledObjectsRequest', (_message.Message,), { + 'DESCRIPTOR' : _RESTORESPILLEDOBJECTSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RestoreSpilledObjectsRequest) + }) +_sym_db.RegisterMessage(RestoreSpilledObjectsRequest) + +RestoreSpilledObjectsReply = _reflection.GeneratedProtocolMessageType('RestoreSpilledObjectsReply', (_message.Message,), { + 'DESCRIPTOR' : _RESTORESPILLEDOBJECTSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RestoreSpilledObjectsReply) + }) +_sym_db.RegisterMessage(RestoreSpilledObjectsReply) + +DeleteSpilledObjectsRequest = _reflection.GeneratedProtocolMessageType('DeleteSpilledObjectsRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETESPILLEDOBJECTSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DeleteSpilledObjectsRequest) + }) +_sym_db.RegisterMessage(DeleteSpilledObjectsRequest) + +DeleteSpilledObjectsReply = _reflection.GeneratedProtocolMessageType('DeleteSpilledObjectsReply', (_message.Message,), { + 'DESCRIPTOR' : _DELETESPILLEDOBJECTSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.DeleteSpilledObjectsReply) + }) +_sym_db.RegisterMessage(DeleteSpilledObjectsReply) + +ExitRequest = _reflection.GeneratedProtocolMessageType('ExitRequest', (_message.Message,), { + 'DESCRIPTOR' : _EXITREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ExitRequest) + }) +_sym_db.RegisterMessage(ExitRequest) + +ExitReply = _reflection.GeneratedProtocolMessageType('ExitReply', (_message.Message,), { + 'DESCRIPTOR' : _EXITREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ExitReply) + }) +_sym_db.RegisterMessage(ExitReply) + +AssignObjectOwnerRequest = _reflection.GeneratedProtocolMessageType('AssignObjectOwnerRequest', (_message.Message,), { + 'DESCRIPTOR' : _ASSIGNOBJECTOWNERREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.AssignObjectOwnerRequest) + }) +_sym_db.RegisterMessage(AssignObjectOwnerRequest) + +AssignObjectOwnerReply = _reflection.GeneratedProtocolMessageType('AssignObjectOwnerReply', (_message.Message,), { + 'DESCRIPTOR' : _ASSIGNOBJECTOWNERREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.AssignObjectOwnerReply) + }) +_sym_db.RegisterMessage(AssignObjectOwnerReply) + +RayletNotifyGCSRestartRequest = _reflection.GeneratedProtocolMessageType('RayletNotifyGCSRestartRequest', (_message.Message,), { + 'DESCRIPTOR' : _RAYLETNOTIFYGCSRESTARTREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RayletNotifyGCSRestartRequest) + }) +_sym_db.RegisterMessage(RayletNotifyGCSRestartRequest) + +RayletNotifyGCSRestartReply = _reflection.GeneratedProtocolMessageType('RayletNotifyGCSRestartReply', (_message.Message,), { + 'DESCRIPTOR' : _RAYLETNOTIFYGCSRESTARTREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RayletNotifyGCSRestartReply) + }) +_sym_db.RegisterMessage(RayletNotifyGCSRestartReply) + +NumPendingTasksRequest = _reflection.GeneratedProtocolMessageType('NumPendingTasksRequest', (_message.Message,), { + 'DESCRIPTOR' : _NUMPENDINGTASKSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.NumPendingTasksRequest) + }) +_sym_db.RegisterMessage(NumPendingTasksRequest) + +NumPendingTasksReply = _reflection.GeneratedProtocolMessageType('NumPendingTasksReply', (_message.Message,), { + 'DESCRIPTOR' : _NUMPENDINGTASKSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.NumPendingTasksReply) + }) +_sym_db.RegisterMessage(NumPendingTasksReply) + +ReportGeneratorItemReturnsRequest = _reflection.GeneratedProtocolMessageType('ReportGeneratorItemReturnsRequest', (_message.Message,), { + 'DESCRIPTOR' : _REPORTGENERATORITEMRETURNSREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ReportGeneratorItemReturnsRequest) + }) +_sym_db.RegisterMessage(ReportGeneratorItemReturnsRequest) + +ReportGeneratorItemReturnsReply = _reflection.GeneratedProtocolMessageType('ReportGeneratorItemReturnsReply', (_message.Message,), { + 'DESCRIPTOR' : _REPORTGENERATORITEMRETURNSREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ReportGeneratorItemReturnsReply) + }) +_sym_db.RegisterMessage(ReportGeneratorItemReturnsReply) + +RegisterMutableObjectReaderRequest = _reflection.GeneratedProtocolMessageType('RegisterMutableObjectReaderRequest', (_message.Message,), { + 'DESCRIPTOR' : _REGISTERMUTABLEOBJECTREADERREQUEST, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RegisterMutableObjectReaderRequest) + }) +_sym_db.RegisterMessage(RegisterMutableObjectReaderRequest) + +RegisterMutableObjectReaderReply = _reflection.GeneratedProtocolMessageType('RegisterMutableObjectReaderReply', (_message.Message,), { + 'DESCRIPTOR' : _REGISTERMUTABLEOBJECTREADERREPLY, + '__module__' : 'src.ray.protobuf.core_worker_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.RegisterMutableObjectReaderReply) + }) +_sym_db.RegisterMessage(RegisterMutableObjectReaderReply) + +_COREWORKERSERVICE = DESCRIPTOR.services_by_name['CoreWorkerService'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _ACTORHANDLE_LABELSENTRY._options = None + _ACTORHANDLE_LABELSENTRY._serialized_options = b'8\001' + _OBJECTPLASMALOCATIONUPDATE._serialized_start=6533 + _OBJECTPLASMALOCATIONUPDATE._serialized_end=6585 + _ACTIVEOBJECTIDS._serialized_start=109 + _ACTIVEOBJECTIDS._serialized_end=157 + _ACTORHANDLE._serialized_start=160 + _ACTORHANDLE._serialized_end=924 + _ACTORHANDLE_LABELSENTRY._serialized_start=867 + _ACTORHANDLE_LABELSENTRY._serialized_end=924 + _PUSHTASKREQUEST._serialized_start=927 + _PUSHTASKREQUEST._serialized_end=1202 + _PUSHTASKREPLY._serialized_start=1205 + _PUSHTASKREPLY._serialized_end=1852 + _DIRECTACTORCALLARGWAITCOMPLETEREQUEST._serialized_start=1854 + _DIRECTACTORCALLARGWAITCOMPLETEREQUEST._serialized_end=1957 + _DIRECTACTORCALLARGWAITCOMPLETEREPLY._serialized_start=1959 + _DIRECTACTORCALLARGWAITCOMPLETEREPLY._serialized_end=1996 + _GETOBJECTSTATUSREQUEST._serialized_start=1998 + _GETOBJECTSTATUSREQUEST._serialized_end=2091 + _RAYOBJECT._serialized_start=2094 + _RAYOBJECT._serialized_end=2227 + _GETOBJECTSTATUSREPLY._serialized_start=2230 + _GETOBJECTSTATUSREPLY._serialized_end=2482 + _GETOBJECTSTATUSREPLY_OBJECTSTATUS._serialized_start=2426 + _GETOBJECTSTATUSREPLY_OBJECTSTATUS._serialized_end=2482 + _WAITFORACTORREFDELETEDREQUEST._serialized_start=2484 + _WAITFORACTORREFDELETEDREQUEST._serialized_end=2588 + _WAITFORACTORREFDELETEDREPLY._serialized_start=2590 + _WAITFORACTORREFDELETEDREPLY._serialized_end=2619 + _UPDATEOBJECTLOCATIONBATCHREQUEST._serialized_start=2622 + _UPDATEOBJECTLOCATIONBATCHREQUEST._serialized_end=2814 + _UPDATEOBJECTLOCATIONBATCHREPLY._serialized_start=2816 + _UPDATEOBJECTLOCATIONBATCHREPLY._serialized_end=2848 + _OBJECTSPILLEDLOCATIONUPDATE._serialized_start=2850 + _OBJECTSPILLEDLOCATIONUPDATE._serialized_end=2969 + _OBJECTLOCATIONUPDATE._serialized_start=2972 + _OBJECTLOCATIONUPDATE._serialized_end=3330 + _GETOBJECTLOCATIONSOWNERREQUEST._serialized_start=3332 + _GETOBJECTLOCATIONSOWNERREQUEST._serialized_end=3441 + _GETOBJECTLOCATIONSOWNERREPLY._serialized_start=3443 + _GETOBJECTLOCATIONSOWNERREPLY._serialized_end=3567 + _KILLACTORREQUEST._serialized_start=3570 + _KILLACTORREQUEST._serialized_end=3722 + _KILLACTORREPLY._serialized_start=3724 + _KILLACTORREPLY._serialized_end=3740 + _CANCELTASKREQUEST._serialized_start=3743 + _CANCELTASKREQUEST._serialized_end=3907 + _CANCELTASKREPLY._serialized_start=3909 + _CANCELTASKREPLY._serialized_end=4025 + _REMOTECANCELTASKREQUEST._serialized_start=4028 + _REMOTECANCELTASKREQUEST._serialized_end=4156 + _REMOTECANCELTASKREPLY._serialized_start=4158 + _REMOTECANCELTASKREPLY._serialized_end=4181 + _GETCOREWORKERSTATSREQUEST._serialized_start=4184 + _GETCOREWORKERSTATSREQUEST._serialized_end=4386 + _GETCOREWORKERSTATSREPLY._serialized_start=4389 + _GETCOREWORKERSTATSREPLY._serialized_end=4638 + _LOCALGCREQUEST._serialized_start=4640 + _LOCALGCREQUEST._serialized_end=4709 + _LOCALGCREPLY._serialized_start=4711 + _LOCALGCREPLY._serialized_end=4725 + _PLASMAOBJECTREADYREQUEST._serialized_start=4727 + _PLASMAOBJECTREADYREQUEST._serialized_end=4782 + _PLASMAOBJECTREADYREPLY._serialized_start=4784 + _PLASMAOBJECTREADYREPLY._serialized_end=4808 + _DELETEOBJECTSREQUEST._serialized_start=4810 + _DELETEOBJECTSREQUEST._serialized_end=4894 + _DELETEOBJECTSREPLY._serialized_start=4896 + _DELETEOBJECTSREPLY._serialized_end=4916 + _SPILLOBJECTSREQUEST._serialized_start=4919 + _SPILLOBJECTSREQUEST._serialized_end=5085 + _SPILLOBJECTSREPLY._serialized_start=5087 + _SPILLOBJECTSREPLY._serialized_end=5154 + _RESTORESPILLEDOBJECTSREQUEST._serialized_start=5157 + _RESTORESPILLEDOBJECTSREQUEST._serialized_end=5286 + _RESTORESPILLEDOBJECTSREPLY._serialized_start=5288 + _RESTORESPILLEDOBJECTSREPLY._serialized_end=5366 + _DELETESPILLEDOBJECTSREQUEST._serialized_start=5368 + _DELETESPILLEDOBJECTSREQUEST._serialized_end=5445 + _DELETESPILLEDOBJECTSREPLY._serialized_start=5447 + _DELETESPILLEDOBJECTSREPLY._serialized_end=5474 + _EXITREQUEST._serialized_start=5476 + _EXITREQUEST._serialized_end=5520 + _EXITREPLY._serialized_start=5522 + _EXITREPLY._serialized_end=5559 + _ASSIGNOBJECTOWNERREQUEST._serialized_start=5562 + _ASSIGNOBJECTOWNERREQUEST._serialized_end=5790 + _ASSIGNOBJECTOWNERREPLY._serialized_start=5792 + _ASSIGNOBJECTOWNERREPLY._serialized_end=5816 + _RAYLETNOTIFYGCSRESTARTREQUEST._serialized_start=5818 + _RAYLETNOTIFYGCSRESTARTREQUEST._serialized_end=5849 + _RAYLETNOTIFYGCSRESTARTREPLY._serialized_start=5851 + _RAYLETNOTIFYGCSRESTARTREPLY._serialized_end=5880 + _NUMPENDINGTASKSREQUEST._serialized_start=5882 + _NUMPENDINGTASKSREQUEST._serialized_end=5906 + _NUMPENDINGTASKSREPLY._serialized_start=5908 + _NUMPENDINGTASKSREPLY._serialized_end=5974 + _REPORTGENERATORITEMRETURNSREQUEST._serialized_start=5977 + _REPORTGENERATORITEMRETURNSREQUEST._serialized_end=6245 + _REPORTGENERATORITEMRETURNSREPLY._serialized_start=6247 + _REPORTGENERATORITEMRETURNSREPLY._serialized_end=6339 + _REGISTERMUTABLEOBJECTREADERREQUEST._serialized_start=6342 + _REGISTERMUTABLEOBJECTREADERREQUEST._serialized_end=6495 + _REGISTERMUTABLEOBJECTREADERREPLY._serialized_start=6497 + _REGISTERMUTABLEOBJECTREADERREPLY._serialized_end=6531 + _COREWORKERSERVICE._serialized_start=6588 + _COREWORKERSERVICE._serialized_end=8755 +# @@protoc_insertion_point(module_scope) diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/export_actor_data_pb2_grpc.py b/.venv/lib/python3.11/site-packages/ray/core/generated/export_actor_data_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..2daafffebfc817aefe8fcb96eaec25e65b3903e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/export_actor_data_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/reporter_pb2_grpc.py b/.venv/lib/python3.11/site-packages/ray/core/generated/reporter_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..8017fe5867c243e74438eec02fa74f06ccf5323e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/reporter_pb2_grpc.py @@ -0,0 +1,259 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import reporter_pb2 as src_dot_ray_dot_protobuf_dot_reporter__pb2 + + +class ReporterServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ReportOCMetrics = channel.unary_unary( + '/ray.rpc.ReporterService/ReportOCMetrics', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsReply.FromString, + ) + self.GetTraceback = channel.unary_unary( + '/ray.rpc.ReporterService/GetTraceback', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackReply.FromString, + ) + self.CpuProfiling = channel.unary_unary( + '/ray.rpc.ReporterService/CpuProfiling', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingReply.FromString, + ) + self.MemoryProfiling = channel.unary_unary( + '/ray.rpc.ReporterService/MemoryProfiling', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingReply.FromString, + ) + + +class ReporterServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def ReportOCMetrics(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTraceback(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CpuProfiling(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def MemoryProfiling(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ReporterServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'ReportOCMetrics': grpc.unary_unary_rpc_method_handler( + servicer.ReportOCMetrics, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsReply.SerializeToString, + ), + 'GetTraceback': grpc.unary_unary_rpc_method_handler( + servicer.GetTraceback, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackReply.SerializeToString, + ), + 'CpuProfiling': grpc.unary_unary_rpc_method_handler( + servicer.CpuProfiling, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingReply.SerializeToString, + ), + 'MemoryProfiling': grpc.unary_unary_rpc_method_handler( + servicer.MemoryProfiling, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ray.rpc.ReporterService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ReporterService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def ReportOCMetrics(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.ReporterService/ReportOCMetrics', + src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.ReportOCMetricsReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetTraceback(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.ReporterService/GetTraceback', + src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.GetTracebackReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CpuProfiling(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.ReporterService/CpuProfiling', + src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.CpuProfilingReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def MemoryProfiling(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.ReporterService/MemoryProfiling', + src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.MemoryProfilingReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + +class LogServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ListLogs = channel.unary_unary( + '/ray.rpc.LogService/ListLogs', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsReply.FromString, + ) + self.StreamLog = channel.unary_stream( + '/ray.rpc.LogService/StreamLog', + request_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogRequest.SerializeToString, + response_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogReply.FromString, + ) + + +class LogServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def ListLogs(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamLog(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LogServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'ListLogs': grpc.unary_unary_rpc_method_handler( + servicer.ListLogs, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsReply.SerializeToString, + ), + 'StreamLog': grpc.unary_stream_rpc_method_handler( + servicer.StreamLog, + request_deserializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogRequest.FromString, + response_serializer=src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ray.rpc.LogService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class LogService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def ListLogs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.LogService/ListLogs', + src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.ListLogsReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def StreamLog(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/ray.rpc.LogService/StreamLog', + src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogRequest.SerializeToString, + src_dot_ray_dot_protobuf_dot_reporter__pb2.StreamLogReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2.py b/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..f114ec18a77a3652ed55bb9d980e6f1fccbe919d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: src/ray/protobuf/usage.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csrc/ray/protobuf/usage.proto\x12\tray.usage*\x9b\x14\n\x06TagKey\x12\n\n\x06_TEST1\x10\x00\x12\n\n\x06_TEST2\x10\x01\x12\x13\n\x0fRLLIB_FRAMEWORK\x10\x02\x12\x13\n\x0fRLLIB_ALGORITHM\x10\x03\x12\x15\n\x11RLLIB_NUM_WORKERS\x10\x04\x12\x15\n\x11SERVE_API_VERSION\x10\x05\x12\x19\n\x15SERVE_NUM_DEPLOYMENTS\x10\x06\x12\x0f\n\x0bGCS_STORAGE\x10\x07\x12\x1d\n\x19SERVE_NUM_GPU_DEPLOYMENTS\x10\x08\x12\x16\n\x12SERVE_FASTAPI_USED\x10\t\x12\x19\n\x15SERVE_DAG_DRIVER_USED\x10\n\x12\x1b\n\x17SERVE_HTTP_ADAPTER_USED\x10\x0b\x12\x1b\n\x17SERVE_GRPC_INGRESS_USED\x10\x0c\x12\x1a\n\x16SERVE_REST_API_VERSION\x10\r\x12\x12\n\x0eSERVE_NUM_APPS\x10\x0e\x12*\n&SERVE_NUM_REPLICAS_LIGHTWEIGHT_UPDATED\x10\x0f\x12)\n%SERVE_USER_CONFIG_LIGHTWEIGHT_UPDATED\x10\x10\x12\x30\n,SERVE_AUTOSCALING_CONFIG_LIGHTWEIGHT_UPDATED\x10\x11\x12#\n\x1fSERVE_RAY_SERVE_HANDLE_API_USED\x10\x12\x12(\n$SERVE_RAY_SERVE_SYNC_HANDLE_API_USED\x10\x13\x12$\n SERVE_DEPLOYMENT_HANDLE_API_USED\x10\x14\x12\x32\n.SERVE_DEPLOYMENT_HANDLE_TO_OBJECT_REF_API_USED\x10\x15\x12\x1e\n\x1aSERVE_MULTIPLEXED_API_USED\x10\x16\x12\x19\n\x15SERVE_HTTP_PROXY_USED\x10\x17\x12\x19\n\x15SERVE_GRPC_PROXY_USED\x10\x18\x12\x19\n\x15SERVE_STATUS_API_USED\x10\x19\x12!\n\x1dSERVE_GET_APP_HANDLE_API_USED\x10\x1a\x12(\n$SERVE_GET_DEPLOYMENT_HANDLE_API_USED\x10\x1b\x12(\n$SERVE_APP_CONTAINER_RUNTIME_ENV_USED\x10\x1c\x12/\n+SERVE_DEPLOYMENT_CONTAINER_RUNTIME_ENV_USED\x10\x1d\x12\x1e\n\x1aSERVE_NUM_NODE_COMPACTIONS\x10\x1e\x12 \n\x1cSERVE_AUTO_NUM_REPLICAS_USED\x10\x1f\x12\x1e\n\x1a\x43ORE_STATE_API_LIST_ACTORS\x10\x64\x12\x1d\n\x19\x43ORE_STATE_API_LIST_TASKS\x10\x65\x12\x1c\n\x18\x43ORE_STATE_API_LIST_JOBS\x10\x66\x12\x1d\n\x19\x43ORE_STATE_API_LIST_NODES\x10g\x12(\n$CORE_STATE_API_LIST_PLACEMENT_GROUPS\x10h\x12\x1f\n\x1b\x43ORE_STATE_API_LIST_WORKERS\x10i\x12\x1f\n\x1b\x43ORE_STATE_API_LIST_OBJECTS\x10j\x12$\n CORE_STATE_API_LIST_RUNTIME_ENVS\x10k\x12&\n\"CORE_STATE_API_LIST_CLUSTER_EVENTS\x10l\x12\x1c\n\x18\x43ORE_STATE_API_LIST_LOGS\x10m\x12\x1a\n\x16\x43ORE_STATE_API_GET_LOG\x10n\x12\"\n\x1e\x43ORE_STATE_API_SUMMARIZE_TASKS\x10o\x12#\n\x1f\x43ORE_STATE_API_SUMMARIZE_ACTORS\x10p\x12$\n CORE_STATE_API_SUMMARIZE_OBJECTS\x10q\x12\x13\n\x0e\x44\x41SHBOARD_USED\x10\xc8\x01\x12)\n$DASHBOARD_METRICS_PROMETHEUS_ENABLED\x10\xc9\x01\x12&\n!DASHBOARD_METRICS_GRAFANA_ENABLED\x10\xca\x01\x12\x13\n\x0ePG_NUM_CREATED\x10\xac\x02\x12\x16\n\x11\x41\x43TOR_NUM_CREATED\x10\xad\x02\x12\x1e\n\x19WORKER_CRASH_SYSTEM_ERROR\x10\xae\x02\x12\x15\n\x10WORKER_CRASH_OOM\x10\xaf\x02\x12\x19\n\x14RAY_GET_TIMEOUT_ZERO\x10\xb0\x02\x12\x1d\n\x18NUM_ACTOR_CREATION_TASKS\x10\xb1\x02\x12\x14\n\x0fNUM_ACTOR_TASKS\x10\xb2\x02\x12\x15\n\x10NUM_NORMAL_TASKS\x10\xb3\x02\x12\x10\n\x0bNUM_DRIVERS\x10\xb4\x02\x12\"\n\x1d\x45XPERIMENTAL_STATE_API_IMPORT\x10\xb5\x02\x12\x17\n\x12\x41UTOSCALER_VERSION\x10\xb6\x02\x12\x15\n\x10\x44\x41TA_LOGICAL_OPS\x10\x90\x03\x12\x10\n\x0b\x41IR_TRAINER\x10\xf4\x03\x12\x12\n\rTUNE_SEARCHER\x10\xf5\x03\x12\x13\n\x0eTUNE_SCHEDULER\x10\xf6\x03\x12\x11\n\x0c\x41IR_ENV_VARS\x10\xf7\x03\x12%\n AIR_SETUP_WANDB_INTEGRATION_USED\x10\xf8\x03\x12&\n!AIR_SETUP_MLFLOW_INTEGRATION_USED\x10\xf9\x03\x12\x12\n\rAIR_CALLBACKS\x10\xfa\x03\x12\x1e\n\x19\x41IR_STORAGE_CONFIGURATION\x10\xfb\x03\x12\x13\n\x0e\x41IR_ENTRYPOINT\x10\xfc\x03\x12\x1b\n\x16TRAIN_TORCH_GET_DEVICE\x10\xfd\x03\x12\x1e\n\x19TRAIN_TORCH_PREPARE_MODEL\x10\xfe\x03\x12#\n\x1eTRAIN_TORCH_PREPARE_DATALOADER\x10\xff\x03\x12$\n\x1fTRAIN_LIGHTNING_PREPARE_TRAINER\x10\x80\x04\x12+\n&TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK\x10\x81\x04\x12#\n\x1eTRAIN_LIGHTNING_RAYDDPSTRATEGY\x10\x82\x04\x12$\n\x1fTRAIN_LIGHTNING_RAYFSDPSTRATEGY\x10\x83\x04\x12)\n$TRAIN_LIGHTNING_RAYDEEPSPEEDSTRATEGY\x10\x84\x04\x12,\n\'TRAIN_LIGHTNING_RAYLIGHTNINGENVIRONMENT\x10\x85\x04\x12\'\n\"TRAIN_TRANSFORMERS_PREPARE_TRAINER\x10\x86\x04\x12.\n)TRAIN_TRANSFORMERS_RAYTRAINREPORTCALLBACK\x10\x87\x04\x12\x1c\n\x17TRAIN_TORCH_GET_DEVICES\x10\x88\x04\x62\x06proto3') + +_TAGKEY = DESCRIPTOR.enum_types_by_name['TagKey'] +TagKey = enum_type_wrapper.EnumTypeWrapper(_TAGKEY) +_TEST1 = 0 +_TEST2 = 1 +RLLIB_FRAMEWORK = 2 +RLLIB_ALGORITHM = 3 +RLLIB_NUM_WORKERS = 4 +SERVE_API_VERSION = 5 +SERVE_NUM_DEPLOYMENTS = 6 +GCS_STORAGE = 7 +SERVE_NUM_GPU_DEPLOYMENTS = 8 +SERVE_FASTAPI_USED = 9 +SERVE_DAG_DRIVER_USED = 10 +SERVE_HTTP_ADAPTER_USED = 11 +SERVE_GRPC_INGRESS_USED = 12 +SERVE_REST_API_VERSION = 13 +SERVE_NUM_APPS = 14 +SERVE_NUM_REPLICAS_LIGHTWEIGHT_UPDATED = 15 +SERVE_USER_CONFIG_LIGHTWEIGHT_UPDATED = 16 +SERVE_AUTOSCALING_CONFIG_LIGHTWEIGHT_UPDATED = 17 +SERVE_RAY_SERVE_HANDLE_API_USED = 18 +SERVE_RAY_SERVE_SYNC_HANDLE_API_USED = 19 +SERVE_DEPLOYMENT_HANDLE_API_USED = 20 +SERVE_DEPLOYMENT_HANDLE_TO_OBJECT_REF_API_USED = 21 +SERVE_MULTIPLEXED_API_USED = 22 +SERVE_HTTP_PROXY_USED = 23 +SERVE_GRPC_PROXY_USED = 24 +SERVE_STATUS_API_USED = 25 +SERVE_GET_APP_HANDLE_API_USED = 26 +SERVE_GET_DEPLOYMENT_HANDLE_API_USED = 27 +SERVE_APP_CONTAINER_RUNTIME_ENV_USED = 28 +SERVE_DEPLOYMENT_CONTAINER_RUNTIME_ENV_USED = 29 +SERVE_NUM_NODE_COMPACTIONS = 30 +SERVE_AUTO_NUM_REPLICAS_USED = 31 +CORE_STATE_API_LIST_ACTORS = 100 +CORE_STATE_API_LIST_TASKS = 101 +CORE_STATE_API_LIST_JOBS = 102 +CORE_STATE_API_LIST_NODES = 103 +CORE_STATE_API_LIST_PLACEMENT_GROUPS = 104 +CORE_STATE_API_LIST_WORKERS = 105 +CORE_STATE_API_LIST_OBJECTS = 106 +CORE_STATE_API_LIST_RUNTIME_ENVS = 107 +CORE_STATE_API_LIST_CLUSTER_EVENTS = 108 +CORE_STATE_API_LIST_LOGS = 109 +CORE_STATE_API_GET_LOG = 110 +CORE_STATE_API_SUMMARIZE_TASKS = 111 +CORE_STATE_API_SUMMARIZE_ACTORS = 112 +CORE_STATE_API_SUMMARIZE_OBJECTS = 113 +DASHBOARD_USED = 200 +DASHBOARD_METRICS_PROMETHEUS_ENABLED = 201 +DASHBOARD_METRICS_GRAFANA_ENABLED = 202 +PG_NUM_CREATED = 300 +ACTOR_NUM_CREATED = 301 +WORKER_CRASH_SYSTEM_ERROR = 302 +WORKER_CRASH_OOM = 303 +RAY_GET_TIMEOUT_ZERO = 304 +NUM_ACTOR_CREATION_TASKS = 305 +NUM_ACTOR_TASKS = 306 +NUM_NORMAL_TASKS = 307 +NUM_DRIVERS = 308 +EXPERIMENTAL_STATE_API_IMPORT = 309 +AUTOSCALER_VERSION = 310 +DATA_LOGICAL_OPS = 400 +AIR_TRAINER = 500 +TUNE_SEARCHER = 501 +TUNE_SCHEDULER = 502 +AIR_ENV_VARS = 503 +AIR_SETUP_WANDB_INTEGRATION_USED = 504 +AIR_SETUP_MLFLOW_INTEGRATION_USED = 505 +AIR_CALLBACKS = 506 +AIR_STORAGE_CONFIGURATION = 507 +AIR_ENTRYPOINT = 508 +TRAIN_TORCH_GET_DEVICE = 509 +TRAIN_TORCH_PREPARE_MODEL = 510 +TRAIN_TORCH_PREPARE_DATALOADER = 511 +TRAIN_LIGHTNING_PREPARE_TRAINER = 512 +TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK = 513 +TRAIN_LIGHTNING_RAYDDPSTRATEGY = 514 +TRAIN_LIGHTNING_RAYFSDPSTRATEGY = 515 +TRAIN_LIGHTNING_RAYDEEPSPEEDSTRATEGY = 516 +TRAIN_LIGHTNING_RAYLIGHTNINGENVIRONMENT = 517 +TRAIN_TRANSFORMERS_PREPARE_TRAINER = 518 +TRAIN_TRANSFORMERS_RAYTRAINREPORTCALLBACK = 519 +TRAIN_TORCH_GET_DEVICES = 520 + + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TAGKEY._serialized_start=44 + _TAGKEY._serialized_end=2631 +# @@protoc_insertion_point(module_scope) diff --git a/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2_grpc.py b/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..2daafffebfc817aefe8fcb96eaec25e65b3903e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/generated/usage_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/src/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7211a7a661381ab303c858242e59c5e2b47ec2b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/src/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/plasma/__init__.py b/.venv/lib/python3.11/site-packages/ray/core/src/plasma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/plasma/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/src/plasma/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a82f10832c3334e4c373f7042d5059bfd5db3227 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/src/plasma/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/ray/__init__.py b/.venv/lib/python3.11/site-packages/ray/core/src/ray/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/ray/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/src/ray/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bf328704c04015d68c9bddffd3a818964a54aa7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/src/ray/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/__init__.py b/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c967fb4a148f570ff1665ffdbc196849dad1879f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/__pycache__/install_and_start_prometheus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/__pycache__/install_and_start_prometheus.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6efb11746030a546b118fc55d484205b0fc86803 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/__pycache__/install_and_start_prometheus.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757d19d9f820c65274d31bb73be576804e7c8275 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b497494215e4905a2ee62aa5e3154a718248ac8e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/data_dashboard_panels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/data_dashboard_panels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9f1af18e341c65d71c888dae37d6448c3a14e75 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/data_dashboard_panels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/default_dashboard_panels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/default_dashboard_panels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35a0d557e5de29cf232a9f2a8ad1645a75e18d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/default_dashboard_panels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_dashboard_panels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_dashboard_panels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e714f588fa3077cf2356e39ab586396fc8e6d36b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_dashboard_panels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_deployment_dashboard_panels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_deployment_dashboard_panels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc629425fbe2d744ab0a0567f3abdfbda7f31371 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/__pycache__/serve_deployment_dashboard_panels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/data_grafana_dashboard_base.json b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/data_grafana_dashboard_base.json new file mode 100644 index 0000000000000000000000000000000000000000..961b87a128adaed185a9d99def7e71c10284e46c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/data_grafana_dashboard_base.json @@ -0,0 +1,147 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "iteration": 1667344411089, + "links": [], + "panels": [], + "refresh": false, + "schemaVersion": 27, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false + }, + "description": "Filter queries of a specific Prometheus type.", + "hide": 2, + "includeAll": false, + "multi": false, + "name": "datasource", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "allValue": ".+", + "current": { + "selected": false + }, + "datasource": "${datasource}", + "definition": "label_values(ray_data_allocated_bytes{{{global_filters}}}, SessionName)", + "description": "Filter queries to specific ray sessions.", + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": false, + "name": "SessionName", + "options": [], + "query": { + "query": "label_values(ray_data_allocated_bytes{{{global_filters}}}, SessionName)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 2, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".+", + "current": { + "selected": true, + "text": [ + "All" + ], + "value": [ + "$__all" + ] + }, + "datasource": "${datasource}", + "definition": "label_values(ray_data_allocated_bytes{{{global_filters}}}, dataset)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": true, + "name": "DatasetID", + "options": [], + "query": { + "query": "label_values(ray_data_allocated_bytes{{{global_filters}}}, dataset)", + "refId": "Prometheus-Dataset-Variable-Query" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "current": { + "selected": false + }, + "datasource": "${datasource}", + "definition": "label_values(ray_node_network_receive_speed{{{global_filters}}}, ray_io_cluster)", + "description": "Filter queries to specific Ray clusters for KubeRay. When ingesting metrics across multiple ray clusters, the ray_io_cluster label should be set per cluster. For KubeRay users, this is done automaticaly with Prometheus PodMonitor.", + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "Cluster", + "options": [], + "query": { + "query": "label_values(ray_node_network_receive_speed{{{global_filters}}}, ray_io_cluster)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 2, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "rayMeta": ["excludesSystemRoutes"], + "time": { + "from": "now-30m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Data Dashboard", + "uid": "rayDataDashboard", + "version": 1 + } diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/default_grafana_dashboard_base.json b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/default_grafana_dashboard_base.json new file mode 100644 index 0000000000000000000000000000000000000000..b26c0cdb9ccbd8f23066b3fe429cd68403105726 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/metrics/dashboards/default_grafana_dashboard_base.json @@ -0,0 +1,142 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "iteration": 1667344411089, + "links": [], + "panels": [], + "refresh": false, + "schemaVersion": 27, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false + }, + "description": "Filter queries of a specific Prometheus type.", + "hide": 2, + "includeAll": false, + "multi": false, + "name": "datasource", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "allValue": ".+", + "current": { + "selected": false + }, + "datasource": "${datasource}", + "definition": "label_values(ray_node_network_receive_speed{{{global_filters}}}, SessionName)", + "description": "Filter queries to specific ray sessions.", + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": false, + "name": "SessionName", + "options": [], + "query": { + "query": "label_values(ray_node_network_receive_speed{{{global_filters}}}, SessionName)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 2, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".+", + "current": { + "selected": true, + "text": ["All"], + "value": ["$__all"] + }, + "datasource": "${datasource}", + "definition": "label_values(ray_node_network_receive_speed{{SessionName=~\"$SessionName\",{global_filters}}}, instance)", + "description": null, + "error": null, + "hide": 0, + "includeAll": true, + "label": null, + "multi": true, + "name": "Instance", + "options": [], + "query": { + "query": "label_values(ray_node_network_receive_speed{{SessionName=~\"$SessionName\",{global_filters}}}, instance)", + "refId": "Prometheus-Instance-Variable-Query" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "current": { + "selected": false + }, + "datasource": "${datasource}", + "definition": "label_values(ray_node_network_receive_speed{{{global_filters}}}, ray_io_cluster)", + "description": "Filter queries to specific Ray clusters for KubeRay. When ingesting metrics across multiple ray clusters, the ray_io_cluster label should be set per cluster. For KubeRay users, this is done automaticaly with Prometheus PodMonitor.", + "error": null, + "hide": 0, + "includeAll": false, + "label": null, + "multi": false, + "name": "Cluster", + "options": [], + "query": { + "query": "label_values(ray_node_network_receive_speed{{{global_filters}}}, ray_io_cluster)", + "refId": "StandardVariableQuery" + }, + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 2, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-30m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Default Dashboard", + "uid": "rayDefaultDashboard", + "version": 4 +} diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/client/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd253024ce8d662c62b5e1357aac1f4968bbb751 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/__init__.py @@ -0,0 +1,302 @@ +import logging +import os +import threading +from typing import Any, Dict, List, Optional, Tuple + +import ray._private.ray_constants as ray_constants +from ray._private.client_mode_hook import ( + _explicitly_disable_client_mode, + _explicitly_enable_client_mode, +) +from ray._private.ray_logging import setup_logger +from ray.job_config import JobConfig +from ray.util.annotations import DeveloperAPI +from ray._private.utils import check_version_info + + +logger = logging.getLogger(__name__) + + +class _ClientContext: + def __init__(self): + from ray.util.client.api import _ClientAPI + + self.api = _ClientAPI() + self.client_worker = None + self._server = None + self._connected_with_init = False + self._inside_client_test = False + + def connect( + self, + conn_str: str, + job_config: JobConfig = None, + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + connection_retries: int = 3, + namespace: str = None, + *, + ignore_version: bool = False, + _credentials: Optional["grpc.ChannelCredentials"] = None, # noqa: F821 + ray_init_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Connect the Ray Client to a server. + + Args: + conn_str: Connection string, in the form "[host]:port" + job_config: The job config of the server. + secure: Whether to use a TLS secured gRPC channel + metadata: gRPC metadata to send on connect + connection_retries: number of connection attempts to make + ignore_version: whether to ignore Python or Ray version mismatches. + This should only be used for debugging purposes. + + Returns: + Dictionary of connection info, e.g., {"num_clients": 1}. + """ + # Delay imports until connect to avoid circular imports. + from ray.util.client.worker import Worker + + if self.client_worker is not None: + if self._connected_with_init: + return + raise Exception("ray.init() called, but ray client is already connected") + if not self._inside_client_test: + # If we're calling a client connect specifically and we're not + # currently in client mode, ensure we are. + _explicitly_enable_client_mode() + if namespace is not None: + job_config = job_config or JobConfig() + job_config.set_ray_namespace(namespace) + + logging_level = ray_constants.LOGGER_LEVEL + logging_format = ray_constants.LOGGER_FORMAT + + if ray_init_kwargs is None: + ray_init_kwargs = {} + + # NOTE(architkulkarni): env_hook is not supported with Ray Client. + ray_init_kwargs["_skip_env_hook"] = True + + if ray_init_kwargs.get("logging_level") is not None: + logging_level = ray_init_kwargs["logging_level"] + if ray_init_kwargs.get("logging_format") is not None: + logging_format = ray_init_kwargs["logging_format"] + + setup_logger(logging_level, logging_format) + + try: + self.client_worker = Worker( + conn_str, + secure=secure, + _credentials=_credentials, + metadata=metadata, + connection_retries=connection_retries, + ) + self.api.worker = self.client_worker + self.client_worker._server_init(job_config, ray_init_kwargs) + conn_info = self.client_worker.connection_info() + self._check_versions(conn_info, ignore_version) + self._register_serializers() + return conn_info + except Exception: + self.disconnect() + raise + + def _register_serializers(self): + """Register the custom serializer addons at the client side. + + The server side should have already registered the serializers via + regular worker's serialization_context mechanism. + """ + import ray.util.serialization_addons + from ray.util.serialization import StandaloneSerializationContext + + ctx = StandaloneSerializationContext() + ray.util.serialization_addons.apply(ctx) + + def _check_versions(self, conn_info: Dict[str, Any], ignore_version: bool) -> None: + # conn_info has "python_version" and "ray_version" so it can be used to compare. + ignore_version = ignore_version or ("RAY_IGNORE_VERSION_MISMATCH" in os.environ) + check_version_info( + conn_info, + "Ray Client", + raise_on_mismatch=not ignore_version, + python_version_match_level="minor", + ) + + def disconnect(self): + """Disconnect the Ray Client.""" + from ray.util.client.api import _ClientAPI + + if self.client_worker is not None: + self.client_worker.close() + self.api = _ClientAPI() + self.client_worker = None + + # remote can be called outside of a connection, which is why it + # exists on the same API layer as connect() itself. + def remote(self, *args, **kwargs): + """remote is the hook stub passed on to replace `ray.remote`. + + This sets up remote functions or actors, as the decorator, + but does not execute them. + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.api.remote(*args, **kwargs) + + def __getattr__(self, key: str): + if self.is_connected(): + return getattr(self.api, key) + elif key in ["is_initialized", "_internal_kv_initialized"]: + # Client is not connected, thus Ray is not considered initialized. + return lambda: False + else: + raise Exception( + "Ray Client is not connected. Please connect by calling `ray.init`." + ) + + def is_connected(self) -> bool: + if self.client_worker is None: + return False + return self.client_worker.is_connected() + + def init(self, *args, **kwargs): + if self._server is not None: + raise Exception("Trying to start two instances of ray via client") + import ray.util.client.server.server as ray_client_server + + server_handle, address_info = ray_client_server.init_and_serve( + "127.0.0.1:50051", *args, **kwargs + ) + self._server = server_handle.grpc_server + self.connect("127.0.0.1:50051") + self._connected_with_init = True + return address_info + + def shutdown(self, _exiting_interpreter=False): + self.disconnect() + import ray.util.client.server.server as ray_client_server + + if self._server is None: + return + ray_client_server.shutdown_with_server(self._server, _exiting_interpreter) + self._server = None + + +# All connected context will be put here +# This struct will be guarded by a lock for thread safety +_all_contexts = set() +_lock = threading.Lock() + +# This is the default context which is used when allow_multiple is not True +_default_context = _ClientContext() + + +@DeveloperAPI +class RayAPIStub: + """This class stands in as the replacement API for the `import ray` module. + + Much like the ray module, this mostly delegates the work to the + _client_worker. As parts of the ray API are covered, they are piped through + here or on the client worker API. + """ + + def __init__(self): + self._cxt = threading.local() + self._cxt.handler = _default_context + self._inside_client_test = False + + def get_context(self): + try: + return self._cxt.__getattribute__("handler") + except AttributeError: + self._cxt.handler = _default_context + return self._cxt.handler + + def set_context(self, cxt): + old_cxt = self.get_context() + if cxt is None: + self._cxt.handler = _ClientContext() + else: + self._cxt.handler = cxt + return old_cxt + + def is_default(self): + return self.get_context() == _default_context + + def connect(self, *args, **kw_args): + self.get_context()._inside_client_test = self._inside_client_test + conn = self.get_context().connect(*args, **kw_args) + global _lock, _all_contexts + with _lock: + _all_contexts.add(self._cxt.handler) + return conn + + def disconnect(self, *args, **kw_args): + global _lock, _all_contexts, _default_context + with _lock: + if _default_context == self.get_context(): + for cxt in _all_contexts: + cxt.disconnect(*args, **kw_args) + _all_contexts = set() + else: + self.get_context().disconnect(*args, **kw_args) + if self.get_context() in _all_contexts: + _all_contexts.remove(self.get_context()) + if len(_all_contexts) == 0: + _explicitly_disable_client_mode() + + def remote(self, *args, **kwargs): + return self.get_context().remote(*args, **kwargs) + + def __getattr__(self, name): + return self.get_context().__getattr__(name) + + def is_connected(self, *args, **kwargs): + return self.get_context().is_connected(*args, **kwargs) + + def init(self, *args, **kwargs): + ret = self.get_context().init(*args, **kwargs) + global _lock, _all_contexts + with _lock: + _all_contexts.add(self._cxt.handler) + return ret + + def shutdown(self, *args, **kwargs): + global _lock, _all_contexts + with _lock: + if _default_context == self.get_context(): + for cxt in _all_contexts: + cxt.shutdown(*args, **kwargs) + _all_contexts = set() + else: + self.get_context().shutdown(*args, **kwargs) + if self.get_context() in _all_contexts: + _all_contexts.remove(self.get_context()) + if len(_all_contexts) == 0: + _explicitly_disable_client_mode() + + +ray = RayAPIStub() + + +@DeveloperAPI +def num_connected_contexts(): + """Return the number of client connections active.""" + global _lock, _all_contexts + with _lock: + return len(_all_contexts) + + +# Someday we might add methods in this module so that someone who +# tries to `import ray_client as ray` -- as a module, instead of +# `from ray_client import ray` -- as the API stub +# still gets expected functionality. This is the way the ray package +# worked in the past. +# +# This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/ +# But until Python 3.6 is EOL, here we are. diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d50fc8f49227f5dec287e9d74357c8e9f6ea25a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659452440ba91caa427b5badf8ac565fda8a182d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_app.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_app.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10fe9db8fea97712baa9dd6a78d8c265e3807be4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_app.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_pickler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_pickler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be10db9d2b763ed86a9e8b5669ec10307bb5bc16 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/client_pickler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e38afd3de5b88b471c90e9e16f14ad9e5ec4185 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/dataclient.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/dataclient.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90defb8a53d6ba5c6422e7e0f5c374249c084b2a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/dataclient.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/logsclient.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/logsclient.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdb1a93696fc59aa94feaedad95daa650a7b4927 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/logsclient.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/options.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/options.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaea47530313ecb65277a90459fc2d0adbf32ae0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/options.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/ray_client_helpers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/ray_client_helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a542600e0b3694ec405aa7ef4a446766f194a967 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/ray_client_helpers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/runtime_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/runtime_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5fa61af3b3129139b34c009b1229a5247c96000 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/runtime_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..219a77cc9292defb7fbfd655bb14d7de510f3857 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/__pycache__/worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/client_pickler.py b/.venv/lib/python3.11/site-packages/ray/util/client/client_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..4971c0e11f96b2497601e88c265b31b8e0a542a5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/client_pickler.py @@ -0,0 +1,178 @@ +"""Implements the client side of the client/server pickling protocol. + +All ray client client/server data transfer happens through this pickling +protocol. The model is as follows: + + * All Client objects (eg ClientObjectRef) always live on the client and + are never represented in the server + * All Ray objects (eg, ray.ObjectRef) always live on the server and are + never returned to the client + * In order to translate between these two references, PickleStub tuples + are generated as persistent ids in the data blobs during the pickling + and unpickling of these objects. + +The PickleStubs have just enough information to find or generate their +associated partner object on either side. + +This also has the advantage of avoiding predefined pickle behavior for ray +objects, which may include ray internal reference counting. + +ClientPickler dumps things from the client into the appropriate stubs +ServerUnpickler loads stubs from the server into their client counterparts. +""" + +import io + +from typing import NamedTuple +from typing import Any +from typing import Dict +from typing import Optional + +import ray.cloudpickle as cloudpickle +from ray.util.client import RayAPIStub +from ray.util.client.common import ClientObjectRef +from ray.util.client.common import ClientActorHandle +from ray.util.client.common import ClientActorRef +from ray.util.client.common import ClientActorClass +from ray.util.client.common import ClientRemoteFunc +from ray.util.client.common import ClientRemoteMethod +from ray.util.client.common import OptionWrapper +from ray.util.client.common import InProgressSentinel +import ray.core.generated.ray_client_pb2 as ray_client_pb2 + +import pickle # noqa: F401 + + +# NOTE(barakmich): These PickleStubs are really close to +# the data for an execution, with no arguments. Combine the two? +class PickleStub( + NamedTuple( + "PickleStub", + [ + ("type", str), + ("client_id", str), + ("ref_id", bytes), + ("name", Optional[str]), + ("baseline_options", Optional[Dict]), + ], + ) +): + def __reduce__(self): + # PySpark's namedtuple monkey patch breaks compatibility with + # cloudpickle. Thus we revert this patch here if it exists. + return object.__reduce__(self) + + +class ClientPickler(cloudpickle.CloudPickler): + def __init__(self, client_id, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + + def persistent_id(self, obj): + if isinstance(obj, RayAPIStub): + return PickleStub( + type="Ray", + client_id=self.client_id, + ref_id=b"", + name=None, + baseline_options=None, + ) + elif isinstance(obj, ClientObjectRef): + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj.id, + name=None, + baseline_options=None, + ) + elif isinstance(obj, ClientActorHandle): + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id.id, + name=None, + baseline_options=None, + ) + elif isinstance(obj, ClientRemoteFunc): + if obj._ref is None: + obj._ensure_ref() + if type(obj._ref) is InProgressSentinel: + return PickleStub( + type="RemoteFuncSelfReference", + client_id=self.client_id, + ref_id=obj._client_side_ref.id, + name=None, + baseline_options=None, + ) + return PickleStub( + type="RemoteFunc", + client_id=self.client_id, + ref_id=obj._ref.id, + name=None, + baseline_options=obj._options, + ) + elif isinstance(obj, ClientActorClass): + if obj._ref is None: + obj._ensure_ref() + if type(obj._ref) is InProgressSentinel: + return PickleStub( + type="RemoteActorSelfReference", + client_id=self.client_id, + ref_id=obj._client_side_ref.id, + name=None, + baseline_options=None, + ) + return PickleStub( + type="RemoteActor", + client_id=self.client_id, + ref_id=obj._ref.id, + name=None, + baseline_options=obj._options, + ) + elif isinstance(obj, ClientRemoteMethod): + return PickleStub( + type="RemoteMethod", + client_id=self.client_id, + ref_id=obj._actor_handle.actor_ref.id, + name=obj._method_name, + baseline_options=None, + ) + elif isinstance(obj, OptionWrapper): + raise NotImplementedError("Sending a partial option is unimplemented") + return None + + +class ServerUnpickler(pickle.Unpickler): + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Object": + return ClientObjectRef(pid.ref_id) + elif pid.type == "Actor": + return ClientActorHandle(ClientActorRef(pid.ref_id)) + else: + raise NotImplementedError("Being passed back an unknown stub") + + +def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes: + with io.BytesIO() as file: + cp = ClientPickler(client_id, file, protocol=protocol) + cp.dump(obj) + return file.getvalue() + + +def loads_from_server( + data: bytes, *, fix_imports=True, encoding="ASCII", errors="strict" +) -> Any: + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ServerUnpickler( + file, fix_imports=fix_imports, encoding=encoding, errors=errors + ).load() + + +def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg: + out = ray_client_pb2.Arg() + out.local = ray_client_pb2.Arg.Locality.INTERNED + out.data = dumps_from_client(val, client_id) + return out diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/logsclient.py b/.venv/lib/python3.11/site-packages/ray/util/client/logsclient.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d9a6af992855cf44061c50690a7d5a3575920d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/logsclient.py @@ -0,0 +1,136 @@ +"""This file implements a threaded stream controller to return logs back from +the ray clientserver. +""" +import sys +import logging +import queue +import threading +import time +import grpc + +from typing import TYPE_CHECKING + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.util.client.worker import Worker + +logger = logging.getLogger(__name__) +# TODO(barakmich): Running a logger in a logger causes loopback. +# The client logger need its own root -- possibly this one. +# For the moment, let's just not propogate beyond this point. +logger.propagate = False + + +class LogstreamClient: + def __init__(self, client_worker: "Worker", metadata: list): + """Initializes a thread-safe log stream over a Ray Client gRPC channel. + + Args: + client_worker: The Ray Client worker that manages this client + metadata: metadata to pass to gRPC requests + """ + self.client_worker = client_worker + self._metadata = metadata + self.request_queue = queue.Queue() + self.log_thread = self._start_logthread() + self.log_thread.start() + self.last_req = None + + def _start_logthread(self) -> threading.Thread: + return threading.Thread(target=self._log_main, args=(), daemon=True) + + def _log_main(self) -> None: + reconnecting = False + while not self.client_worker._in_shutdown: + if reconnecting: + # Refresh queue and retry last request + self.request_queue = queue.Queue() + if self.last_req: + self.request_queue.put(self.last_req) + stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.client_worker.channel) + try: + log_stream = stub.Logstream( + iter(self.request_queue.get, None), metadata=self._metadata + ) + except ValueError: + # Trying to use the stub on a cancelled channel will raise + # ValueError. This should only happen when the data client + # is attempting to reset the connection -- sleep and try + # again. + time.sleep(0.5) + continue + try: + for record in log_stream: + if record.level < 0: + self.stdstream(level=record.level, msg=record.msg) + self.log(level=record.level, msg=record.msg) + return + except grpc.RpcError as e: + reconnecting = self._process_rpc_error(e) + if not reconnecting: + return + + def _process_rpc_error(self, e: grpc.RpcError) -> bool: + """ + Processes RPC errors that occur while reading from data stream. + Returns True if the error can be recovered from, False otherwise. + """ + if self.client_worker._can_reconnect(e): + if log_once("lost_reconnect_logs"): + logger.warning( + "Log channel is reconnecting. Logs produced while " + "the connection was down can be found on the head " + "node of the cluster in " + "`ray_client_server_[port].out`" + ) + logger.debug("Log channel dropped, retrying.") + time.sleep(0.5) + return True + logger.debug("Shutting down log channel.") + if not self.client_worker._in_shutdown: + logger.exception("Unexpected exception:") + return False + + def log(self, level: int, msg: str): + """Log the message from the log stream. + By default, calls logger.log but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + logger.log(level=level, msg=msg) + + def stdstream(self, level: int, msg: str): + """Log the stdout/stderr entry from the log stream. + By default, calls print but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + print_file = sys.stderr if level == -2 else sys.stdout + print(msg, file=print_file, end="") + + def set_logstream_level(self, level: int): + logger.setLevel(level) + req = ray_client_pb2.LogSettingsRequest() + req.enabled = True + req.loglevel = level + self.request_queue.put(req) + self.last_req = req + + def close(self) -> None: + self.request_queue.put(None) + if self.log_thread is not None: + self.log_thread.join() + + def disable_logs(self) -> None: + req = ray_client_pb2.LogSettingsRequest() + req.enabled = False + self.request_queue.put(req) + self.last_req = req diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37c7767bb0cf1430118b672b9ace20dac6eded17 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/__init__.py @@ -0,0 +1 @@ +from ray.util.client.server.server import serve # noqa diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__main__.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..b582d767cd21e8778ac909d7a0ab8eb8c4f53450 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/__main__.py @@ -0,0 +1,4 @@ +if __name__ == "__main__": + from ray.util.client.server.server import main + + main() diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6867aa53cd2a65da6676064df9e6614b1fd935e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__main__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47b0daece83f52741022e5ad6ffcdc69c5905d96 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/__main__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/dataservicer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/dataservicer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..311a410fc8e6d529340af8401236dcb58128f4bb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/dataservicer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/logservicer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/logservicer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ee5cd8a5a7bbae87056da2925502acb07319409 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/logservicer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/proxier.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/proxier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18cf6ded822185711845ee084c56e90db0964e09 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/proxier.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35a0881999de0594f2e6beffb9a610dbe94a5bb5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_pickler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_pickler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1bede8adf407245e1fa796e5c00018c89090ba Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_pickler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_stubs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ee13b94109d45e3b90297cf3430bcc624ca755 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/client/server/__pycache__/server_stubs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/dataservicer.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/dataservicer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce816856e4df366b6cf7fc4544d9d93ddf7d28a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/dataservicer.py @@ -0,0 +1,416 @@ +from collections import defaultdict +from ray.util.client.server.server_pickler import loads_from_client +import ray +import logging +import grpc +from queue import Queue +import sys + +from typing import Any, Dict, Iterator, TYPE_CHECKING, Union +from threading import Event, Lock, Thread +import time + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray.util.client.common import ( + CLIENT_SERVER_MAX_THREADS, + _propagate_error_in_context, + OrderedResponseCache, +) +from ray.util.debug import log_once +from ray._private.client_mode_hook import disable_client_hook + +if TYPE_CHECKING: + from ray.util.client.server.server import RayletServicer + +logger = logging.getLogger(__name__) + +QUEUE_JOIN_SECONDS = 10 + + +def _get_reconnecting_from_context(context: Any) -> bool: + """ + Get `reconnecting` from gRPC metadata, or False if missing. + """ + metadata = {k: v for k, v in context.invocation_metadata()} + val = metadata.get("reconnecting") + if val is None or val not in ("True", "False"): + logger.error( + f'Client connecting with invalid value for "reconnecting": {val}, ' + "This may be because you have a mismatched client and server " + "version." + ) + return False + return val == "True" + + +def _should_cache(req: ray_client_pb2.DataRequest) -> bool: + """ + Returns True if the response should to the given request should be cached, + false otherwise. At the moment the only requests we do not cache are: + - asynchronous gets: These arrive out of order. Skipping caching here + is fine, since repeating an async get is idempotent + - acks: Repeating acks is idempotent + - clean up requests: Also idempotent, and client has likely already + wrapped up the data connection by this point. + - puts: We should only cache when we receive the final chunk, since + any earlier chunks won't generate a response + - tasks: We should only cache when we receive the final chunk, + since any earlier chunks won't generate a response + """ + req_type = req.WhichOneof("type") + if req_type == "get" and req.get.asynchronous: + return False + if req_type == "put": + return req.put.chunk_id == req.put.total_chunks - 1 + if req_type == "task": + return req.task.chunk_id == req.task.total_chunks - 1 + return req_type not in ("acknowledge", "connection_cleanup") + + +def fill_queue( + grpc_input_generator: Iterator[ray_client_pb2.DataRequest], + output_queue: "Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]", # noqa: E501 +) -> None: + """ + Pushes incoming requests to a shared output_queue. + """ + try: + for req in grpc_input_generator: + output_queue.put(req) + except grpc.RpcError as e: + logger.debug( + "closing dataservicer reader thread " + f"grpc error reading request_iterator: {e}" + ) + finally: + # Set the sentinel value for the output_queue + output_queue.put(None) + + +class ChunkCollector: + """ + Helper class for collecting chunks from PutObject or ClientTask messages + """ + + def __init__(self): + self.curr_req_id = None + self.last_seen_chunk_id = -1 + self.data = bytearray() + + def add_chunk( + self, + req: ray_client_pb2.DataRequest, + chunk: Union[ray_client_pb2.PutRequest, ray_client_pb2.ClientTask], + ): + if self.curr_req_id is not None and self.curr_req_id != req.req_id: + raise RuntimeError( + "Expected to receive a chunk from request with id " + f"{self.curr_req_id}, but found {req.req_id} instead." + ) + self.curr_req_id = req.req_id + next_chunk = self.last_seen_chunk_id + 1 + if chunk.chunk_id < next_chunk: + # Repeated chunk, ignore + return + if chunk.chunk_id > next_chunk: + raise RuntimeError( + f"A chunk {chunk.chunk_id} of request {req.req_id} was " + "received out of order." + ) + elif chunk.chunk_id == self.last_seen_chunk_id + 1: + self.data.extend(chunk.data) + self.last_seen_chunk_id = chunk.chunk_id + return chunk.chunk_id + 1 == chunk.total_chunks + + def reset(self): + self.curr_req_id = None + self.last_seen_chunk_id = -1 + self.data = bytearray() + + +class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): + def __init__(self, basic_service: "RayletServicer"): + self.basic_service = basic_service + self.clients_lock = Lock() + self.num_clients = 0 # guarded by self.clients_lock + # dictionary mapping client_id's to the last time they connected + self.client_last_seen: Dict[str, float] = {} + # dictionary mapping client_id's to their reconnect grace periods + self.reconnect_grace_periods: Dict[str, float] = {} + # dictionary mapping client_id's to their response cache + self.response_caches: Dict[str, OrderedResponseCache] = defaultdict( + OrderedResponseCache + ) + # stopped event, useful for signals that the server is shut down + self.stopped = Event() + # Helper for collecting chunks from PutObject calls. Assumes that + # that put requests from different objects aren't interleaved. + self.put_request_chunk_collector = ChunkCollector() + # Helper for collecting chunks from ClientTask calls. Assumes that + # schedule requests from different remote calls aren't interleaved. + self.client_task_chunk_collector = ChunkCollector() + + def Datapath(self, request_iterator, context): + start_time = time.time() + # set to True if client shuts down gracefully + cleanup_requested = False + metadata = {k: v for k, v in context.invocation_metadata()} + client_id = metadata.get("client_id") + if client_id is None: + logger.error("Client connecting with no client_id") + return + logger.debug(f"New data connection from client {client_id}: ") + accepted_connection = self._init(client_id, context, start_time) + response_cache = self.response_caches[client_id] + # Set to False if client requests a reconnect grace period of 0 + reconnect_enabled = True + if not accepted_connection: + return + try: + request_queue = Queue() + queue_filler_thread = Thread( + target=fill_queue, daemon=True, args=(request_iterator, request_queue) + ) + queue_filler_thread.start() + """For non `async get` requests, this loop yields immediately + For `async get` requests, this loop: + 1) does not yield, it just continues + 2) When the result is ready, it yields + """ + for req in iter(request_queue.get, None): + if isinstance(req, ray_client_pb2.DataResponse): + # Early shortcut if this is the result of an async get. + yield req + continue + + assert isinstance(req, ray_client_pb2.DataRequest) + if _should_cache(req) and reconnect_enabled: + cached_resp = response_cache.check_cache(req.req_id) + if isinstance(cached_resp, Exception): + # Cache state is invalid, raise exception + raise cached_resp + if cached_resp is not None: + yield cached_resp + continue + + resp = None + req_type = req.WhichOneof("type") + if req_type == "init": + resp_init = self.basic_service.Init(req.init) + resp = ray_client_pb2.DataResponse( + init=resp_init, + ) + with self.clients_lock: + self.reconnect_grace_periods[ + client_id + ] = req.init.reconnect_grace_period + if req.init.reconnect_grace_period == 0: + reconnect_enabled = False + + elif req_type == "get": + if req.get.asynchronous: + get_resp = self.basic_service._async_get_object( + req.get, client_id, req.req_id, request_queue + ) + if get_resp is None: + # Skip sending a response for this request and + # continue to the next requst. The response for + # this request will be sent when the object is + # ready. + continue + else: + get_resp = self.basic_service._get_object(req.get, client_id) + resp = ray_client_pb2.DataResponse(get=get_resp) + elif req_type == "put": + if not self.put_request_chunk_collector.add_chunk(req, req.put): + # Put request still in progress + continue + put_resp = self.basic_service._put_object( + self.put_request_chunk_collector.data, + req.put.client_ref_id, + client_id, + req.put.owner_id, + ) + self.put_request_chunk_collector.reset() + resp = ray_client_pb2.DataResponse(put=put_resp) + elif req_type == "release": + released = [] + for rel_id in req.release.ids: + rel = self.basic_service.release(client_id, rel_id) + released.append(rel) + resp = ray_client_pb2.DataResponse( + release=ray_client_pb2.ReleaseResponse(ok=released) + ) + elif req_type == "connection_info": + resp = ray_client_pb2.DataResponse( + connection_info=self._build_connection_response() + ) + elif req_type == "prep_runtime_env": + with self.clients_lock: + resp_prep = self.basic_service.PrepRuntimeEnv( + req.prep_runtime_env + ) + resp = ray_client_pb2.DataResponse(prep_runtime_env=resp_prep) + elif req_type == "connection_cleanup": + cleanup_requested = True + cleanup_resp = ray_client_pb2.ConnectionCleanupResponse() + resp = ray_client_pb2.DataResponse(connection_cleanup=cleanup_resp) + elif req_type == "acknowledge": + # Clean up acknowledged cache entries + response_cache.cleanup(req.acknowledge.req_id) + continue + elif req_type == "task": + with self.clients_lock: + task = req.task + if not self.client_task_chunk_collector.add_chunk(req, task): + # Not all serialized arguments have arrived + continue + arglist, kwargs = loads_from_client( + self.client_task_chunk_collector.data, self.basic_service + ) + self.client_task_chunk_collector.reset() + resp_ticket = self.basic_service.Schedule( + req.task, arglist, kwargs, context + ) + resp = ray_client_pb2.DataResponse(task_ticket=resp_ticket) + del arglist + del kwargs + elif req_type == "terminate": + with self.clients_lock: + response = self.basic_service.Terminate(req.terminate, context) + resp = ray_client_pb2.DataResponse(terminate=response) + elif req_type == "list_named_actors": + with self.clients_lock: + response = self.basic_service.ListNamedActors( + req.list_named_actors + ) + resp = ray_client_pb2.DataResponse(list_named_actors=response) + else: + raise Exception( + f"Unreachable code: Request type " + f"{req_type} not handled in Datapath" + ) + resp.req_id = req.req_id + if _should_cache(req) and reconnect_enabled: + response_cache.update_cache(req.req_id, resp) + yield resp + except Exception as e: + logger.exception("Error in data channel:") + recoverable = _propagate_error_in_context(e, context) + invalid_cache = response_cache.invalidate(e) + if not recoverable or invalid_cache: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + # Connection isn't recoverable, skip cleanup + cleanup_requested = True + finally: + logger.debug(f"Stream is broken with client {client_id}") + queue_filler_thread.join(QUEUE_JOIN_SECONDS) + if queue_filler_thread.is_alive(): + logger.error( + "Queue filler thread failed to join before timeout: {}".format( + QUEUE_JOIN_SECONDS + ) + ) + cleanup_delay = self.reconnect_grace_periods.get(client_id) + if not cleanup_requested and cleanup_delay is not None: + logger.debug( + "Cleanup wasn't requested, delaying cleanup by" + f"{cleanup_delay} seconds." + ) + # Delay cleanup, since client may attempt a reconnect + # Wait on the "stopped" event in case the grpc server is + # stopped and we can clean up earlier. + self.stopped.wait(timeout=cleanup_delay) + else: + logger.debug("Cleanup was requested, cleaning up immediately.") + with self.clients_lock: + if client_id not in self.client_last_seen: + logger.debug("Connection already cleaned up.") + # Some other connection has already cleaned up this + # this client's session. This can happen if the client + # reconnects and then gracefully shut's down immediately. + return + last_seen = self.client_last_seen[client_id] + if last_seen > start_time: + # The client successfully reconnected and updated + # last seen some time during the grace period + logger.debug("Client reconnected, skipping cleanup") + return + # Either the client shut down gracefully, or the client + # failed to reconnect within the grace period. Clean up + # the connection. + self.basic_service.release_all(client_id) + del self.client_last_seen[client_id] + if client_id in self.reconnect_grace_periods: + del self.reconnect_grace_periods[client_id] + if client_id in self.response_caches: + del self.response_caches[client_id] + self.num_clients -= 1 + logger.debug( + f"Removed client {client_id}, " f"remaining={self.num_clients}" + ) + + # It's important to keep the Ray shutdown + # within this locked context or else Ray could hang. + # NOTE: it is strange to start ray in server.py but shut it + # down here. Consider consolidating ray lifetime management. + with disable_client_hook(): + if self.num_clients == 0: + logger.debug("Shutting down ray.") + ray.shutdown() + + def _init(self, client_id: str, context: Any, start_time: float): + """ + Checks if resources allow for another client. + Returns a boolean indicating if initialization was successful. + """ + with self.clients_lock: + reconnecting = _get_reconnecting_from_context(context) + threshold = int(CLIENT_SERVER_MAX_THREADS / 2) + if self.num_clients >= threshold: + logger.warning( + f"[Data Servicer]: Num clients {self.num_clients} " + f"has reached the threshold {threshold}. " + f"Rejecting client: {client_id}. " + ) + if log_once("client_threshold"): + logger.warning( + "You can configure the client connection " + "threshold by setting the " + "RAY_CLIENT_SERVER_MAX_THREADS env var " + f"(currently set to {CLIENT_SERVER_MAX_THREADS})." + ) + context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) + return False + if reconnecting and client_id not in self.client_last_seen: + # Client took too long to reconnect, session has been + # cleaned up. + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details( + "Attempted to reconnect to a session that has already " + "been cleaned up." + ) + return False + if client_id in self.client_last_seen: + logger.debug(f"Client {client_id} has reconnected.") + else: + self.num_clients += 1 + logger.debug( + f"Accepted data connection from {client_id}. " + f"Total clients: {self.num_clients}" + ) + self.client_last_seen[client_id] = start_time + return True + + def _build_connection_response(self): + with self.clients_lock: + cur_num_clients = self.num_clients + return ray_client_pb2.ConnectionInfoResponse( + num_clients=cur_num_clients, + python_version="{}.{}.{}".format( + sys.version_info[0], sys.version_info[1], sys.version_info[2] + ), + ray_version=ray.__version__, + ray_commit=ray.__commit__, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/logservicer.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/logservicer.py new file mode 100644 index 0000000000000000000000000000000000000000..764e6c82c65347b17ef9f76735f1d33636e27de9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/logservicer.py @@ -0,0 +1,125 @@ +"""This file responds to log stream requests and forwards logs +with its handler. +""" +import io +import logging +import queue +import threading +import uuid + +import grpc + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray._private.ray_logging import global_worker_stdstream_dispatcher +from ray._private.worker import print_worker_logs +from ray.util.client.common import CLIENT_SERVER_MAX_THREADS + +logger = logging.getLogger(__name__) + + +class LogstreamHandler(logging.Handler): + def __init__(self, queue, level): + super().__init__() + self.queue = queue + self.level = level + + def emit(self, record: logging.LogRecord): + logdata = ray_client_pb2.LogData() + logdata.msg = record.getMessage() + logdata.level = record.levelno + logdata.name = record.name + self.queue.put(logdata) + + +class StdStreamHandler: + def __init__(self, queue): + self.queue = queue + self.id = str(uuid.uuid4()) + + def handle(self, data): + logdata = ray_client_pb2.LogData() + logdata.level = -2 if data["is_err"] else -1 + logdata.name = "stderr" if data["is_err"] else "stdout" + with io.StringIO() as file: + print_worker_logs(data, file) + logdata.msg = file.getvalue() + self.queue.put(logdata) + + def register_global(self): + global_worker_stdstream_dispatcher.add_handler(self.id, self.handle) + + def unregister_global(self): + global_worker_stdstream_dispatcher.remove_handler(self.id) + + +def log_status_change_thread(log_queue, request_iterator): + std_handler = StdStreamHandler(log_queue) + current_handler = None + root_logger = logging.getLogger("ray") + default_level = root_logger.getEffectiveLevel() + try: + for req in request_iterator: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + if not req.enabled: + current_handler = None + continue + current_handler = LogstreamHandler(log_queue, req.loglevel) + std_handler.register_global() + root_logger.addHandler(current_handler) + root_logger.setLevel(req.loglevel) + except grpc.RpcError as e: + logger.debug(f"closing log thread " f"grpc error reading request_iterator: {e}") + finally: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + log_queue.put(None) + + +class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer): + def __init__(self): + super().__init__() + self.num_clients = 0 + self.client_lock = threading.Lock() + + def Logstream(self, request_iterator, context): + initialized = False + with self.client_lock: + threshold = CLIENT_SERVER_MAX_THREADS / 2 + if self.num_clients + 1 >= threshold: + context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) + logger.warning( + f"Logstream: Num clients {self.num_clients} has reached " + f"the threshold {threshold}. Rejecting new connection." + ) + return + self.num_clients += 1 + initialized = True + logger.info( + "New logs connection established. " f"Total clients: {self.num_clients}" + ) + log_queue = queue.Queue() + thread = threading.Thread( + target=log_status_change_thread, + args=(log_queue, request_iterator), + daemon=True, + ) + thread.start() + try: + queue_iter = iter(log_queue.get, None) + for record in queue_iter: + if record is None: + break + yield record + except grpc.RpcError as e: + logger.debug(f"Closing log channel: {e}") + finally: + thread.join() + with self.client_lock: + if initialized: + self.num_clients -= 1 diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/proxier.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/proxier.py new file mode 100644 index 0000000000000000000000000000000000000000..6f350ebb27a7750ecbeae22cccb3a0ff27ed53c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/proxier.py @@ -0,0 +1,871 @@ +import atexit +import json +import logging +import socket +import sys +import time +import traceback +from concurrent import futures +from dataclasses import dataclass +from itertools import chain +import urllib +from threading import Event, Lock, RLock, Thread +from typing import Callable, Dict, List, Optional, Tuple + +import grpc + +# Import psutil after ray so the packaged version is used. +import psutil + +import ray +import ray.core.generated.agent_manager_pb2 as agent_manager_pb2 +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 +from ray._private.client_mode_hook import disable_client_hook +from ray._raylet import GcsClient +from ray._private.parameter import RayParams +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.services import ProcessInfo, start_ray_client_server +from ray._private.tls_utils import add_port_to_grpc_server +from ray._private.utils import detect_fate_sharing_support +from ray.cloudpickle.compat import pickle +from ray.job_config import JobConfig +from ray.util.client.common import ( + CLIENT_SERVER_MAX_THREADS, + GRPC_OPTIONS, + ClientServerHandle, + _get_client_id_from_context, + _propagate_error_in_context, +) +from ray.util.client.server.dataservicer import _get_reconnecting_from_context + +logger = logging.getLogger(__name__) + +CHECK_PROCESS_INTERVAL_S = 30 + +MIN_SPECIFIC_SERVER_PORT = 23000 +MAX_SPECIFIC_SERVER_PORT = 24000 + +CHECK_CHANNEL_TIMEOUT_S = 30 + +LOGSTREAM_RETRIES = 5 +LOGSTREAM_RETRY_INTERVAL_SEC = 2 + + +@dataclass +class SpecificServer: + port: int + process_handle_future: futures.Future + channel: "grpc._channel.Channel" + + def is_ready(self) -> bool: + """Check if the server is ready or not (doesn't block).""" + return self.process_handle_future.done() + + def wait_ready(self, timeout: Optional[float] = None) -> None: + """ + Wait for the server to actually start up. + """ + res = self.process_handle_future.result(timeout=timeout) + if res is None: + # This is only set to none when server creation specifically fails. + raise RuntimeError("Server startup failed.") + + def poll(self) -> Optional[int]: + """Check if the process has exited.""" + try: + proc = self.process_handle_future.result(timeout=0.1) + if proc is not None: + return proc.process.poll() + except futures.TimeoutError: + return + + def kill(self) -> None: + """Try to send a KILL signal to the process.""" + try: + proc = self.process_handle_future.result(timeout=0.1) + if proc is not None: + proc.process.kill() + except futures.TimeoutError: + # Server has not been started yet. + pass + + def set_result(self, proc: Optional[ProcessInfo]) -> None: + """Set the result of the internal future if it is currently unset.""" + if not self.is_ready(): + self.process_handle_future.set_result(proc) + + +def _match_running_client_server(command: List[str]) -> bool: + """ + Detects if the main process in the given command is the RayClient Server. + This works by ensuring that the the first three arguments are similar to: + -m ray.util.client.server + """ + flattened = " ".join(command) + rejoined = flattened.split() + if len(rejoined) < 3: + return False + return rejoined[1:3] == ["-m", "ray.util.client.server"] + + +class ProxyManager: + def __init__( + self, + address: Optional[str], + runtime_env_agent_address: str, + *, + session_dir: Optional[str] = None, + redis_username: Optional[str] = None, + redis_password: Optional[str] = None, + runtime_env_agent_port: int = 0, + ): + self.servers: Dict[str, SpecificServer] = dict() + self.server_lock = RLock() + self._address = address + self._redis_username = redis_username + self._redis_password = redis_password + self._free_ports: List[int] = list( + range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT) + ) + + self._runtime_env_agent_address = runtime_env_agent_address + + self._check_thread = Thread(target=self._check_processes, daemon=True) + self._check_thread.start() + + self.fate_share = bool(detect_fate_sharing_support()) + self._node: Optional[ray._private.node.Node] = None + atexit.register(self._cleanup) + + def _get_unused_port(self) -> int: + """ + Search for a port in _free_ports that is unused. + """ + with self.server_lock: + num_ports = len(self._free_ports) + for _ in range(num_ports): + port = self._free_ports.pop(0) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.bind(("", port)) + except OSError: + self._free_ports.append(port) + continue + finally: + s.close() + return port + raise RuntimeError("Unable to succeed in selecting a random port.") + + @property + def address(self) -> str: + """ + Returns the provided Ray bootstrap address, or creates a new cluster. + """ + if self._address: + return self._address + # Start a new, locally scoped cluster. + connection_tuple = ray.init() + self._address = connection_tuple["address"] + self._session_dir = connection_tuple["session_dir"] + return self._address + + @property + def node(self) -> ray._private.node.Node: + """Gets a 'ray.Node' object for this node (the head node). + If it does not already exist, one is created using the bootstrap + address. + """ + if self._node: + return self._node + ray_params = RayParams(gcs_address=self.address) + + self._node = ray._private.node.Node( + ray_params, + head=False, + shutdown_at_exit=False, + spawn_reaper=False, + connect_only=True, + ) + + return self._node + + def create_specific_server(self, client_id: str) -> SpecificServer: + """ + Create, but not start a SpecificServer for a given client. This + method must be called once per client. + """ + with self.server_lock: + assert ( + self.servers.get(client_id) is None + ), f"Server already created for Client: {client_id}" + port = self._get_unused_port() + server = SpecificServer( + port=port, + process_handle_future=futures.Future(), + channel=ray._private.utils.init_grpc_channel( + f"127.0.0.1:{port}", options=GRPC_OPTIONS + ), + ) + self.servers[client_id] = server + return server + + def _create_runtime_env( + self, + serialized_runtime_env: str, + runtime_env_config: str, + specific_server: SpecificServer, + ): + """Increase the runtime_env reference by sending an RPC to the agent. + + Includes retry logic to handle the case when the agent is + temporarily unreachable (e.g., hasn't been started up yet). + """ + logger.info( + f"Increasing runtime env reference for " + f"ray_client_server_{specific_server.port}." + f"Serialized runtime env is {serialized_runtime_env}." + ) + + assert ( + len(self._runtime_env_agent_address) > 0 + ), "runtime_env_agent_address not set" + + create_env_request = runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest( + serialized_runtime_env=serialized_runtime_env, + runtime_env_config=runtime_env_config, + job_id=f"ray_client_server_{specific_server.port}".encode("utf-8"), + source_process="client_server", + ) + + retries = 0 + max_retries = 5 + wait_time_s = 0.5 + last_exception = None + while retries <= max_retries: + try: + url = urllib.parse.urljoin( + self._runtime_env_agent_address, "/get_or_create_runtime_env" + ) + data = create_env_request.SerializeToString() + req = urllib.request.Request(url, data=data, method="POST") + req.add_header("Content-Type", "application/octet-stream") + response = urllib.request.urlopen(req, timeout=None) + response_data = response.read() + r = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply() + r.ParseFromString(response_data) + + if r.status == agent_manager_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK: + return r.serialized_runtime_env_context + elif ( + r.status == agent_manager_pb2.AgentRpcStatus.AGENT_RPC_STATUS_FAILED + ): + raise RuntimeError( + "Failed to create runtime_env for Ray client " + f"server, it is caused by:\n{r.error_message}" + ) + else: + assert False, f"Unknown status: {r.status}." + except urllib.error.URLError as e: + last_exception = e + logger.warning( + f"GetOrCreateRuntimeEnv request failed: {e}. " + f"Retrying after {wait_time_s}s. " + f"{max_retries-retries} retries remaining." + ) + + # Exponential backoff. + time.sleep(wait_time_s) + retries += 1 + wait_time_s *= 2 + + raise TimeoutError( + f"GetOrCreateRuntimeEnv request failed after {max_retries} attempts." + f" Last exception: {last_exception}" + ) + + def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: + """ + Start up a RayClient Server for an incoming client to + communicate with. Returns whether creation was successful. + """ + specific_server = self._get_server_for_client(client_id) + assert specific_server, f"Server has not been created for: {client_id}" + + output, error = self.node.get_log_file_handles( + f"ray_client_server_{specific_server.port}", unique=True + ) + + serialized_runtime_env = job_config._get_serialized_runtime_env() + runtime_env_config = job_config._get_proto_runtime_env_config() + if not serialized_runtime_env or serialized_runtime_env == "{}": + # TODO(edoakes): can we just remove this case and always send it + # to the agent? + serialized_runtime_env_context = RuntimeEnvContext().serialize() + else: + serialized_runtime_env_context = self._create_runtime_env( + serialized_runtime_env=serialized_runtime_env, + runtime_env_config=runtime_env_config, + specific_server=specific_server, + ) + + proc = start_ray_client_server( + self.address, + self.node.node_ip_address, + specific_server.port, + stdout_file=output, + stderr_file=error, + fate_share=self.fate_share, + server_type="specific-server", + serialized_runtime_env_context=serialized_runtime_env_context, + redis_username=self._redis_username, + redis_password=self._redis_password, + ) + + # Wait for the process being run transitions from the shim process + # to the actual RayClient Server. + pid = proc.process.pid + if sys.platform != "win32": + psutil_proc = psutil.Process(pid) + else: + psutil_proc = None + # Don't use `psutil` on Win32 + while psutil_proc is not None: + if proc.process.poll() is not None: + logger.error(f"SpecificServer startup failed for client: {client_id}") + break + cmd = psutil_proc.cmdline() + if _match_running_client_server(cmd): + break + logger.debug("Waiting for Process to reach the actual client server.") + time.sleep(0.5) + specific_server.set_result(proc) + logger.info( + f"SpecificServer started on port: {specific_server.port} " + f"with PID: {pid} for client: {client_id}" + ) + return proc.process.poll() is None + + def _get_server_for_client(self, client_id: str) -> Optional[SpecificServer]: + with self.server_lock: + client = self.servers.get(client_id) + if client is None: + logger.error(f"Unable to find channel for client: {client_id}") + return client + + def has_channel(self, client_id: str) -> bool: + server = self._get_server_for_client(client_id) + if server is None: + return False + + return server.is_ready() + + def get_channel( + self, + client_id: str, + ) -> Optional["grpc._channel.Channel"]: + """ + Find the gRPC Channel for the given client_id. This will block until + the server process has started. + """ + server = self._get_server_for_client(client_id) + if server is None: + return None + # Wait for the SpecificServer to become ready. + server.wait_ready() + try: + grpc.channel_ready_future(server.channel).result( + timeout=CHECK_CHANNEL_TIMEOUT_S + ) + return server.channel + except grpc.FutureTimeoutError: + logger.exception(f"Timeout waiting for channel for {client_id}") + return None + + def _check_processes(self): + """ + Keeps the internal servers dictionary up-to-date with running servers. + """ + while True: + with self.server_lock: + for client_id, specific_server in list(self.servers.items()): + if specific_server.poll() is not None: + logger.info( + f"Specific server {client_id} is no longer running" + f", freeing its port {specific_server.port}" + ) + del self.servers[client_id] + # Port is available to use again. + self._free_ports.append(specific_server.port) + + time.sleep(CHECK_PROCESS_INTERVAL_S) + + def _cleanup(self) -> None: + """ + Forcibly kill all spawned RayClient Servers. This ensures cleanup + for platforms where fate sharing is not supported. + """ + for server in self.servers.values(): + server.kill() + + +class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer): + def __init__(self, ray_connect_handler: Callable, proxy_manager: ProxyManager): + self.proxy_manager = proxy_manager + self.ray_connect_handler = ray_connect_handler + + def _call_inner_function( + self, request, context, method: str + ) -> Optional[ray_client_pb2_grpc.RayletDriverStub]: + client_id = _get_client_id_from_context(context) + chan = self.proxy_manager.get_channel(client_id) + if not chan: + logger.error(f"Channel for Client: {client_id} not found!") + context.set_code(grpc.StatusCode.NOT_FOUND) + return None + + stub = ray_client_pb2_grpc.RayletDriverStub(chan) + try: + metadata = [("client_id", client_id)] + if context: + metadata = context.invocation_metadata() + return getattr(stub, method)(request, metadata=metadata) + except Exception as e: + # Error while proxying -- propagate the error's context to user + logger.exception(f"Proxying call to {method} failed!") + _propagate_error_in_context(e, context) + + def _has_channel_for_request(self, context): + client_id = _get_client_id_from_context(context) + return self.proxy_manager.has_channel(client_id) + + def Init(self, request, context=None) -> ray_client_pb2.InitResponse: + return self._call_inner_function(request, context, "Init") + + def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: + """Proxies internal_kv.put. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "KVPut") + + with disable_client_hook(): + already_exists = ray.experimental.internal_kv._internal_kv_put( + request.key, request.value, overwrite=request.overwrite + ) + return ray_client_pb2.KVPutResponse(already_exists=already_exists) + + def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse: + """Proxies internal_kv.get. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "KVGet") + + with disable_client_hook(): + value = ray.experimental.internal_kv._internal_kv_get(request.key) + return ray_client_pb2.KVGetResponse(value=value) + + def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse: + """Proxies internal_kv.delete. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "KVDel") + + with disable_client_hook(): + ray.experimental.internal_kv._internal_kv_del(request.key) + return ray_client_pb2.KVDelResponse() + + def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse: + """Proxies internal_kv.list. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "KVList") + + with disable_client_hook(): + keys = ray.experimental.internal_kv._internal_kv_list(request.prefix) + return ray_client_pb2.KVListResponse(keys=keys) + + def KVExists(self, request, context=None) -> ray_client_pb2.KVExistsResponse: + """Proxies internal_kv.exists. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "KVExists") + + with disable_client_hook(): + exists = ray.experimental.internal_kv._internal_kv_exists(request.key) + return ray_client_pb2.KVExistsResponse(exists=exists) + + def PinRuntimeEnvURI( + self, request, context=None + ) -> ray_client_pb2.ClientPinRuntimeEnvURIResponse: + """Proxies internal_kv.pin_runtime_env_uri. + + This is used by the working_dir code to upload to the GCS before + ray.init is called. In that case (if we don't have a server yet) + we directly make the internal KV call from the proxier. + + Otherwise, we proxy the call to the downstream server as usual. + """ + if self._has_channel_for_request(context): + return self._call_inner_function(request, context, "PinRuntimeEnvURI") + + with disable_client_hook(): + ray.experimental.internal_kv._pin_runtime_env_uri( + request.uri, expiration_s=request.expiration_s + ) + return ray_client_pb2.ClientPinRuntimeEnvURIResponse() + + def ListNamedActors( + self, request, context=None + ) -> ray_client_pb2.ClientListNamedActorsResponse: + return self._call_inner_function(request, context, "ListNamedActors") + + def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: + + # NOTE: We need to respond to the PING request here to allow the client + # to continue with connecting. + if request.type == ray_client_pb2.ClusterInfoType.PING: + resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({})) + return resp + return self._call_inner_function(request, context, "ClusterInfo") + + def Terminate(self, req, context=None): + return self._call_inner_function(req, context, "Terminate") + + def GetObject(self, request, context=None): + try: + yield from self._call_inner_function(request, context, "GetObject") + except Exception as e: + # Error while iterating over response from GetObject stream + logger.exception("Proxying call to GetObject failed!") + _propagate_error_in_context(e, context) + + def PutObject( + self, request: ray_client_pb2.PutRequest, context=None + ) -> ray_client_pb2.PutResponse: + return self._call_inner_function(request, context, "PutObject") + + def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: + return self._call_inner_function(request, context, "WaitObject") + + def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: + return self._call_inner_function(task, context, "Schedule") + + +def ray_client_server_env_prep(job_config: JobConfig) -> JobConfig: + return job_config + + +def prepare_runtime_init_req( + init_request: ray_client_pb2.DataRequest, +) -> Tuple[ray_client_pb2.DataRequest, JobConfig]: + """ + Extract JobConfig and possibly mutate InitRequest before it is passed to + the specific RayClient Server. + """ + init_type = init_request.WhichOneof("type") + assert init_type == "init", ( + "Received initial message of type " f"{init_type}, not 'init'." + ) + req = init_request.init + job_config = JobConfig() + if req.job_config: + job_config = pickle.loads(req.job_config) + new_job_config = ray_client_server_env_prep(job_config) + modified_init_req = ray_client_pb2.InitRequest( + job_config=pickle.dumps(new_job_config), + ray_init_kwargs=init_request.init.ray_init_kwargs, + reconnect_grace_period=init_request.init.reconnect_grace_period, + ) + + init_request.init.CopyFrom(modified_init_req) + return (init_request, new_job_config) + + +class RequestIteratorProxy: + def __init__(self, request_iterator): + self.request_iterator = request_iterator + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.request_iterator) + except grpc.RpcError as e: + # To stop proxying already CANCLLED request stream gracefully, + # we only translate the exact grpc.RpcError to StopIteration, + # not its subsclasses. ex: grpc._Rendezvous + # https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354 + # This fixes the https://github.com/ray-project/ray/issues/23865 + if type(e) is not grpc.RpcError: + raise e # re-raise other grpc exceptions + logger.exception( + "Stop iterating cancelled request stream with the following exception:" + ) + raise StopIteration + + +class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): + def __init__(self, proxy_manager: ProxyManager): + self.num_clients = 0 + # dictionary mapping client_id's to the last time they connected + self.clients_last_seen: Dict[str, float] = {} + self.reconnect_grace_periods: Dict[str, float] = {} + self.clients_lock = Lock() + self.proxy_manager = proxy_manager + self.stopped = Event() + + def modify_connection_info_resp( + self, init_resp: ray_client_pb2.DataResponse + ) -> ray_client_pb2.DataResponse: + """ + Modify the `num_clients` returned the ConnectionInfoResponse because + individual SpecificServers only have **one** client. + """ + init_type = init_resp.WhichOneof("type") + if init_type != "connection_info": + return init_resp + modified_resp = ray_client_pb2.DataResponse() + modified_resp.CopyFrom(init_resp) + with self.clients_lock: + modified_resp.connection_info.num_clients = self.num_clients + return modified_resp + + def Datapath(self, request_iterator, context): + request_iterator = RequestIteratorProxy(request_iterator) + cleanup_requested = False + start_time = time.time() + client_id = _get_client_id_from_context(context) + if client_id == "": + return + reconnecting = _get_reconnecting_from_context(context) + + if reconnecting: + with self.clients_lock: + if client_id not in self.clients_last_seen: + # Client took too long to reconnect, session has already + # been cleaned up + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details( + "Attempted to reconnect a session that has already " + "been cleaned up" + ) + return + self.clients_last_seen[client_id] = start_time + server = self.proxy_manager._get_server_for_client(client_id) + channel = self.proxy_manager.get_channel(client_id) + # iterator doesn't need modification on reconnect + new_iter = request_iterator + else: + # Create Placeholder *before* reading the first request. + server = self.proxy_manager.create_specific_server(client_id) + with self.clients_lock: + self.clients_last_seen[client_id] = start_time + self.num_clients += 1 + + try: + if not reconnecting: + logger.info(f"New data connection from client {client_id}: ") + init_req = next(request_iterator) + with self.clients_lock: + self.reconnect_grace_periods[ + client_id + ] = init_req.init.reconnect_grace_period + try: + modified_init_req, job_config = prepare_runtime_init_req(init_req) + if not self.proxy_manager.start_specific_server( + client_id, job_config + ): + logger.error( + f"Server startup failed for client: {client_id}, " + f"using JobConfig: {job_config}!" + ) + raise RuntimeError( + "Starting Ray client server failed. See " + f"ray_client_server_{server.port}.err for " + "detailed logs." + ) + channel = self.proxy_manager.get_channel(client_id) + if channel is None: + logger.error(f"Channel not found for {client_id}") + raise RuntimeError( + "Proxy failed to Connect to backend! Check " + "`ray_client_server.err` and " + f"`ray_client_server_{server.port}.err` on the " + "head node of the cluster for the relevant logs. " + "By default these are located at " + "/tmp/ray/session_latest/logs." + ) + except Exception: + init_resp = ray_client_pb2.DataResponse( + init=ray_client_pb2.InitResponse( + ok=False, msg=traceback.format_exc() + ) + ) + init_resp.req_id = init_req.req_id + yield init_resp + return None + + new_iter = chain([modified_init_req], request_iterator) + + stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) + metadata = [("client_id", client_id), ("reconnecting", str(reconnecting))] + resp_stream = stub.Datapath(new_iter, metadata=metadata) + for resp in resp_stream: + resp_type = resp.WhichOneof("type") + if resp_type == "connection_cleanup": + # Specific server is skipping cleanup, proxier should too + cleanup_requested = True + yield self.modify_connection_info_resp(resp) + except Exception as e: + logger.exception("Proxying Datapath failed!") + # Propogate error through context + recoverable = _propagate_error_in_context(e, context) + if not recoverable: + # Client shouldn't attempt to recover, clean up connection + cleanup_requested = True + finally: + cleanup_delay = self.reconnect_grace_periods.get(client_id) + if not cleanup_requested and cleanup_delay is not None: + # Delay cleanup, since client may attempt a reconnect + # Wait on stopped event in case the server closes and we + # can clean up earlier + self.stopped.wait(timeout=cleanup_delay) + with self.clients_lock: + if client_id not in self.clients_last_seen: + logger.info(f"{client_id} not found. Skipping clean up.") + # Connection has already been cleaned up + return + last_seen = self.clients_last_seen[client_id] + logger.info( + f"{client_id} last started stream at {last_seen}. Current " + f"stream started at {start_time}." + ) + if last_seen > start_time: + logger.info("Client reconnected. Skipping cleanup.") + # Client has reconnected, don't clean up + return + logger.debug(f"Client detached: {client_id}") + self.num_clients -= 1 + del self.clients_last_seen[client_id] + if client_id in self.reconnect_grace_periods: + del self.reconnect_grace_periods[client_id] + server.set_result(None) + + +class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer): + def __init__(self, proxy_manager: ProxyManager): + super().__init__() + self.proxy_manager = proxy_manager + + def Logstream(self, request_iterator, context): + request_iterator = RequestIteratorProxy(request_iterator) + client_id = _get_client_id_from_context(context) + if client_id == "": + return + logger.debug(f"New logstream connection from client {client_id}: ") + + channel = None + # We need to retry a few times because the LogClient *may* connect + # Before the DataClient has finished connecting. + for i in range(LOGSTREAM_RETRIES): + channel = self.proxy_manager.get_channel(client_id) + + if channel is not None: + break + logger.warning(f"Retrying Logstream connection. {i+1} attempts failed.") + time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC) + + if channel is None: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details( + "Logstream proxy failed to connect. Channel for client " + f"{client_id} not found." + ) + return None + + stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel) + + resp_stream = stub.Logstream( + request_iterator, metadata=[("client_id", client_id)] + ) + try: + for resp in resp_stream: + yield resp + except Exception: + logger.exception("Proxying Logstream failed!") + + +def serve_proxier( + connection_str: str, + address: Optional[str], + *, + redis_username: Optional[str] = None, + redis_password: Optional[str] = None, + session_dir: Optional[str] = None, + runtime_env_agent_address: Optional[str] = None, +): + # Initialize internal KV to be used to upload and download working_dir + # before calling ray.init within the RayletServicers. + # NOTE(edoakes): redis_address and redis_password should only be None in + # tests. + if address is not None: + gcs_cli = GcsClient(address=address) + ray.experimental.internal_kv._initialize_internal_kv(gcs_cli) + + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS), + options=GRPC_OPTIONS, + ) + proxy_manager = ProxyManager( + address, + session_dir=session_dir, + redis_username=redis_username, + redis_password=redis_password, + runtime_env_agent_address=runtime_env_agent_address, + ) + task_servicer = RayletServicerProxy(None, proxy_manager) + data_servicer = DataServicerProxy(proxy_manager) + logs_servicer = LogstreamServicerProxy(proxy_manager) + ray_client_pb2_grpc.add_RayletDriverServicer_to_server(task_servicer, server) + ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(data_servicer, server) + ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(logs_servicer, server) + add_port_to_grpc_server(server, connection_str) + server.start() + return ClientServerHandle( + task_servicer=task_servicer, + data_servicer=data_servicer, + logs_servicer=logs_servicer, + grpc_server=server, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/server.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f7939fa3fc694ef7633c8ccedb3c797c758118 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/server.py @@ -0,0 +1,948 @@ +import base64 +import functools +import gc +import inspect +import json +import logging +import math +import pickle +import queue +import threading +import time +from collections import defaultdict +from concurrent import futures +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import grpc + +import ray +import ray._private.state +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray import cloudpickle +from ray._private import ray_constants +from ray._private.client_mode_hook import disable_client_hook +from ray._raylet import GcsClient +from ray._private.ray_constants import env_integer +from ray._private.ray_logging import setup_logger +from ray._private.services import canonicalize_bootstrap_address_or_die +from ray._private.tls_utils import add_port_to_grpc_server +from ray.job_config import JobConfig +from ray.util.client.common import ( + CLIENT_SERVER_MAX_THREADS, + GRPC_OPTIONS, + OBJECT_TRANSFER_CHUNK_SIZE, + ClientServerHandle, + ResponseCache, +) +from ray.util.client.server.dataservicer import DataServicer +from ray.util.client.server.logservicer import LogstreamServicer +from ray.util.client.server.proxier import serve_proxier +from ray.util.client.server.server_pickler import dumps_from_server, loads_from_client +from ray.util.client.server.server_stubs import current_server + +logger = logging.getLogger(__name__) + +TIMEOUT_FOR_SPECIFIC_SERVER_S = env_integer("TIMEOUT_FOR_SPECIFIC_SERVER_S", 30) + + +def _use_response_cache(func): + """ + Decorator for gRPC stubs. Before calling the real stubs, checks if there's + an existing entry in the caches. If there is, then return the cached + entry. Otherwise, call the real function and use the real cache + """ + + @functools.wraps(func) + def wrapper(self, request, context): + metadata = dict(context.invocation_metadata()) + expected_ids = ("client_id", "thread_id", "req_id") + if any(i not in metadata for i in expected_ids): + # Missing IDs, skip caching and call underlying stub directly + return func(self, request, context) + + # Get relevant IDs to check cache + client_id = metadata["client_id"] + thread_id = metadata["thread_id"] + req_id = int(metadata["req_id"]) + + # Check if response already cached + response_cache = self.response_caches[client_id] + cached_entry = response_cache.check_cache(thread_id, req_id) + if cached_entry is not None: + if isinstance(cached_entry, Exception): + # Original call errored, propogate error + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details(str(cached_entry)) + raise cached_entry + return cached_entry + + try: + # Response wasn't cached, call underlying stub and cache result + resp = func(self, request, context) + except Exception as e: + # Unexpected error in underlying stub -- update cache and + # propagate to user through context + response_cache.update_cache(thread_id, req_id, e) + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details(str(e)) + raise + response_cache.update_cache(thread_id, req_id, resp) + return resp + + return wrapper + + +class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): + def __init__(self, ray_connect_handler: Callable): + """Construct a raylet service + + Args: + ray_connect_handler: Function to connect to ray cluster + """ + # Stores client_id -> (ref_id -> ObjectRef) + self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(dict) + # Stores client_id -> (client_ref_id -> ref_id (in self.object_refs)) + self.client_side_ref_map: Dict[str, Dict[bytes, bytes]] = defaultdict(dict) + self.function_refs = {} + self.actor_refs: Dict[bytes, ray.ActorHandle] = {} + self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set) + self.registered_actor_classes = {} + self.named_actors = set() + self.state_lock = threading.Lock() + self.ray_connect_handler = ray_connect_handler + self.response_caches: Dict[str, ResponseCache] = defaultdict(ResponseCache) + + def Init( + self, request: ray_client_pb2.InitRequest, context=None + ) -> ray_client_pb2.InitResponse: + if request.job_config: + job_config = pickle.loads(request.job_config) + job_config._client_job = True + else: + job_config = None + current_job_config = None + with disable_client_hook(): + if ray.is_initialized(): + worker = ray._private.worker.global_worker + current_job_config = worker.core_worker.get_job_config() + else: + extra_kwargs = json.loads(request.ray_init_kwargs or "{}") + try: + self.ray_connect_handler(job_config, **extra_kwargs) + except Exception as e: + logger.exception("Running Ray Init failed:") + return ray_client_pb2.InitResponse( + ok=False, + msg="Call to `ray.init()` on the server " f"failed with: {e}", + ) + if job_config is None: + return ray_client_pb2.InitResponse(ok=True) + + # NOTE(edoakes): this code should not be necessary anymore because we + # only allow a single client/job per server. There is an existing test + # that tests the behavior of multiple clients with the same job config + # connecting to one server (test_client_init.py::test_num_clients), + # so I'm leaving it here for now. + job_config = job_config._get_proto_job_config() + # If the server has been initialized, we need to compare whether the + # runtime env is compatible. + if current_job_config: + job_uris = set(job_config.runtime_env_info.uris.working_dir_uri) + job_uris.update(job_config.runtime_env_info.uris.py_modules_uris) + current_job_uris = set( + current_job_config.runtime_env_info.uris.working_dir_uri + ) + current_job_uris.update( + current_job_config.runtime_env_info.uris.py_modules_uris + ) + if job_uris != current_job_uris and len(job_uris) > 0: + return ray_client_pb2.InitResponse( + ok=False, + msg="Runtime environment doesn't match " + f"request one {job_config.runtime_env_info.uris} " + f"current one {current_job_config.runtime_env_info.uris}", + ) + return ray_client_pb2.InitResponse(ok=True) + + @_use_response_cache + def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: + try: + with disable_client_hook(): + already_exists = ray.experimental.internal_kv._internal_kv_put( + request.key, + request.value, + overwrite=request.overwrite, + namespace=request.namespace, + ) + except Exception as e: + return_exception_in_context(e, context) + already_exists = False + return ray_client_pb2.KVPutResponse(already_exists=already_exists) + + def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse: + try: + with disable_client_hook(): + value = ray.experimental.internal_kv._internal_kv_get( + request.key, namespace=request.namespace + ) + except Exception as e: + return_exception_in_context(e, context) + value = b"" + return ray_client_pb2.KVGetResponse(value=value) + + @_use_response_cache + def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse: + try: + with disable_client_hook(): + deleted_num = ray.experimental.internal_kv._internal_kv_del( + request.key, + del_by_prefix=request.del_by_prefix, + namespace=request.namespace, + ) + except Exception as e: + return_exception_in_context(e, context) + deleted_num = 0 + return ray_client_pb2.KVDelResponse(deleted_num=deleted_num) + + def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse: + try: + with disable_client_hook(): + keys = ray.experimental.internal_kv._internal_kv_list( + request.prefix, namespace=request.namespace + ) + except Exception as e: + return_exception_in_context(e, context) + keys = [] + return ray_client_pb2.KVListResponse(keys=keys) + + def KVExists(self, request, context=None) -> ray_client_pb2.KVExistsResponse: + try: + with disable_client_hook(): + exists = ray.experimental.internal_kv._internal_kv_exists( + request.key, namespace=request.namespace + ) + except Exception as e: + return_exception_in_context(e, context) + exists = False + return ray_client_pb2.KVExistsResponse(exists=exists) + + def ListNamedActors( + self, request, context=None + ) -> ray_client_pb2.ClientListNamedActorsResponse: + with disable_client_hook(): + actors = ray.util.list_named_actors(all_namespaces=request.all_namespaces) + + return ray_client_pb2.ClientListNamedActorsResponse( + actors_json=json.dumps(actors) + ) + + def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: + resp = ray_client_pb2.ClusterInfoResponse() + resp.type = request.type + if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES: + with disable_client_hook(): + resources = ray.cluster_resources() + # Normalize resources into floats + # (the function may return values that are ints) + float_resources = {k: float(v) for k, v in resources.items()} + resp.resource_table.CopyFrom( + ray_client_pb2.ClusterInfoResponse.ResourceTable(table=float_resources) + ) + elif request.type == ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES: + with disable_client_hook(): + resources = ray.available_resources() + # Normalize resources into floats + # (the function may return values that are ints) + float_resources = {k: float(v) for k, v in resources.items()} + resp.resource_table.CopyFrom( + ray_client_pb2.ClusterInfoResponse.ResourceTable(table=float_resources) + ) + elif request.type == ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT: + ctx = ray_client_pb2.ClusterInfoResponse.RuntimeContext() + with disable_client_hook(): + rtc = ray.get_runtime_context() + ctx.job_id = ray._private.utils.hex_to_binary(rtc.get_job_id()) + ctx.node_id = ray._private.utils.hex_to_binary(rtc.get_node_id()) + ctx.namespace = rtc.namespace + ctx.capture_client_tasks = ( + rtc.should_capture_child_tasks_in_placement_group + ) + ctx.gcs_address = rtc.gcs_address + ctx.runtime_env = rtc.get_runtime_env_string() + resp.runtime_context.CopyFrom(ctx) + else: + with disable_client_hook(): + resp.json = self._return_debug_cluster_info(request, context) + return resp + + def _return_debug_cluster_info(self, request, context=None) -> str: + """Handle ClusterInfo requests that only return a json blob.""" + data = None + if request.type == ray_client_pb2.ClusterInfoType.NODES: + data = ray.nodes() + elif request.type == ray_client_pb2.ClusterInfoType.IS_INITIALIZED: + data = ray.is_initialized() + elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE: + data = ray.timeline() + elif request.type == ray_client_pb2.ClusterInfoType.PING: + data = {} + elif request.type == ray_client_pb2.ClusterInfoType.DASHBOARD_URL: + data = {"dashboard_url": ray._private.worker.get_dashboard_url()} + else: + raise TypeError("Unsupported cluster info type") + return json.dumps(data) + + def release(self, client_id: str, id: bytes) -> bool: + with self.state_lock: + if client_id in self.object_refs: + if id in self.object_refs[client_id]: + logger.debug(f"Releasing object {id.hex()} for {client_id}") + del self.object_refs[client_id][id] + return True + + if client_id in self.actor_owners: + if id in self.actor_owners[client_id]: + logger.debug(f"Releasing actor {id.hex()} for {client_id}") + self.actor_owners[client_id].remove(id) + if self._can_remove_actor_ref(id): + logger.debug(f"Deleting reference to actor {id.hex()}") + del self.actor_refs[id] + return True + + return False + + def release_all(self, client_id): + with self.state_lock: + self._release_objects(client_id) + self._release_actors(client_id) + # NOTE: Try to actually dereference the object and actor refs. + # Otherwise dereferencing will happen later, which may run concurrently + # with ray.shutdown() and will crash the process. The crash is a bug + # that should be fixed eventually. + gc.collect() + + def _can_remove_actor_ref(self, actor_id_bytes): + no_owner = not any( + actor_id_bytes in actor_list for actor_list in self.actor_owners.values() + ) + return no_owner and actor_id_bytes not in self.named_actors + + def _release_objects(self, client_id): + if client_id not in self.object_refs: + logger.debug(f"Releasing client with no references: {client_id}") + return + count = len(self.object_refs[client_id]) + del self.object_refs[client_id] + if client_id in self.client_side_ref_map: + del self.client_side_ref_map[client_id] + if client_id in self.response_caches: + del self.response_caches[client_id] + logger.debug(f"Released all {count} objects for client {client_id}") + + def _release_actors(self, client_id): + if client_id not in self.actor_owners: + logger.debug(f"Releasing client with no actors: {client_id}") + return + + count = 0 + actors_to_remove = self.actor_owners.pop(client_id) + for id_bytes in actors_to_remove: + count += 1 + if self._can_remove_actor_ref(id_bytes): + logger.debug(f"Deleting reference to actor {id_bytes.hex()}") + del self.actor_refs[id_bytes] + + logger.debug(f"Released all {count} actors for client: {client_id}") + + @_use_response_cache + def Terminate(self, req, context=None): + if req.WhichOneof("terminate_type") == "task_object": + try: + object_ref = self.object_refs[req.client_id][req.task_object.id] + with disable_client_hook(): + ray.cancel( + object_ref, + force=req.task_object.force, + recursive=req.task_object.recursive, + ) + except Exception as e: + return_exception_in_context(e, context) + elif req.WhichOneof("terminate_type") == "actor": + try: + actor_ref = self.actor_refs[req.actor.id] + with disable_client_hook(): + ray.kill(actor_ref, no_restart=req.actor.no_restart) + except Exception as e: + return_exception_in_context(e, context) + else: + raise RuntimeError( + "Client requested termination without providing a valid " + "terminate_type" + ) + return ray_client_pb2.TerminateResponse(ok=True) + + def _async_get_object( + self, + request: ray_client_pb2.GetRequest, + client_id: str, + req_id: int, + result_queue: queue.Queue, + context=None, + ) -> Optional[ray_client_pb2.GetResponse]: + """Attempts to schedule a callback to push the GetResponse to the + main loop when the desired object is ready. If there is some failure + in scheduling, a GetResponse will be immediately returned. + """ + if len(request.ids) != 1: + raise ValueError( + "Async get() must have exactly 1 Object ID. " f"Actual: {request}" + ) + rid = request.ids[0] + ref = self.object_refs[client_id].get(rid, None) + if not ref: + return ray_client_pb2.GetResponse( + valid=False, + error=cloudpickle.dumps( + ValueError( + f"ClientObjectRef with id {rid} not found for " + f"client {client_id}" + ) + ), + ) + try: + logger.debug("async get: %s" % ref) + with disable_client_hook(): + + def send_get_response(result: Any) -> None: + """Pushes GetResponses to the main DataPath loop to send + to the client. This is called when the object is ready + on the server side.""" + try: + serialized = dumps_from_server(result, client_id, self) + total_size = len(serialized) + assert total_size > 0, "Serialized object cannot be zero bytes" + total_chunks = math.ceil( + total_size / OBJECT_TRANSFER_CHUNK_SIZE + ) + for chunk_id in range(request.start_chunk_id, total_chunks): + start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE + end = min( + total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE + ) + get_resp = ray_client_pb2.GetResponse( + valid=True, + data=serialized[start:end], + chunk_id=chunk_id, + total_chunks=total_chunks, + total_size=total_size, + ) + chunk_resp = ray_client_pb2.DataResponse( + get=get_resp, req_id=req_id + ) + result_queue.put(chunk_resp) + except Exception as exc: + get_resp = ray_client_pb2.GetResponse( + valid=False, error=cloudpickle.dumps(exc) + ) + resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id) + result_queue.put(resp) + + ref._on_completed(send_get_response) + return None + + except Exception as e: + return ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e)) + + def GetObject(self, request: ray_client_pb2.GetRequest, context): + metadata = dict(context.invocation_metadata()) + client_id = metadata.get("client_id") + if client_id is None: + yield ray_client_pb2.GetResponse( + valid=False, + error=cloudpickle.dumps( + ValueError("client_id is not specified in request metadata") + ), + ) + else: + yield from self._get_object(request, client_id) + + def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str): + objectrefs = [] + for rid in request.ids: + ref = self.object_refs[client_id].get(rid, None) + if ref: + objectrefs.append(ref) + else: + yield ray_client_pb2.GetResponse( + valid=False, + error=cloudpickle.dumps( + ValueError( + f"ClientObjectRef {rid} is not found for client " + f"{client_id}" + ) + ), + ) + return + try: + logger.debug("get: %s" % objectrefs) + with disable_client_hook(): + items = ray.get(objectrefs, timeout=request.timeout) + except Exception as e: + yield ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e)) + return + serialized = dumps_from_server(items, client_id, self) + total_size = len(serialized) + assert total_size > 0, "Serialized object cannot be zero bytes" + total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) + for chunk_id in range(request.start_chunk_id, total_chunks): + start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE + end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) + yield ray_client_pb2.GetResponse( + valid=True, + data=serialized[start:end], + chunk_id=chunk_id, + total_chunks=total_chunks, + total_size=total_size, + ) + + def PutObject( + self, request: ray_client_pb2.PutRequest, context=None + ) -> ray_client_pb2.PutResponse: + """gRPC entrypoint for unary PutObject""" + return self._put_object( + request.data, request.client_ref_id, "", request.owner_id, context + ) + + def _put_object( + self, + data: Union[bytes, bytearray], + client_ref_id: bytes, + client_id: str, + owner_id: bytes, + context=None, + ): + """Put an object in the cluster with ray.put() via gRPC. + + Args: + data: Pickled data. Can either be bytearray if this is called + from the dataservicer, or bytes if called from PutObject. + client_ref_id: The id associated with this object on the client. + client_id: The client who owns this data, for tracking when to + delete this reference. + owner_id: The owner id of the object. + context: gRPC context. + """ + try: + obj = loads_from_client(data, self) + + if owner_id: + owner = self.actor_refs[owner_id] + else: + owner = None + with disable_client_hook(): + objectref = ray.put(obj, _owner=owner) + except Exception as e: + logger.exception("Put failed:") + return ray_client_pb2.PutResponse( + id=b"", valid=False, error=cloudpickle.dumps(e) + ) + + self.object_refs[client_id][objectref.binary()] = objectref + if len(client_ref_id) > 0: + self.client_side_ref_map[client_id][client_ref_id] = objectref.binary() + logger.debug("put: %s" % objectref) + return ray_client_pb2.PutResponse(id=objectref.binary(), valid=True) + + def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: + object_refs = [] + for rid in request.object_ids: + if rid not in self.object_refs[request.client_id]: + raise Exception( + "Asking for a ref not associated with this client: %s" % str(rid) + ) + object_refs.append(self.object_refs[request.client_id][rid]) + num_returns = request.num_returns + timeout = request.timeout + try: + with disable_client_hook(): + ready_object_refs, remaining_object_refs = ray.wait( + object_refs, + num_returns=num_returns, + timeout=timeout if timeout != -1 else None, + ) + except Exception as e: + # TODO(ameer): improve exception messages. + logger.error(f"Exception {e}") + return ray_client_pb2.WaitResponse(valid=False) + logger.debug( + "wait: %s %s" % (str(ready_object_refs), str(remaining_object_refs)) + ) + ready_object_ids = [ + ready_object_ref.binary() for ready_object_ref in ready_object_refs + ] + remaining_object_ids = [ + remaining_object_ref.binary() + for remaining_object_ref in remaining_object_refs + ] + return ray_client_pb2.WaitResponse( + valid=True, + ready_object_ids=ready_object_ids, + remaining_object_ids=remaining_object_ids, + ) + + def Schedule( + self, + task: ray_client_pb2.ClientTask, + arglist: List[Any], + kwargs: Dict[str, Any], + context=None, + ) -> ray_client_pb2.ClientTaskTicket: + logger.debug( + "schedule: %s %s" + % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) + ) + try: + with disable_client_hook(): + if task.type == ray_client_pb2.ClientTask.FUNCTION: + result = self._schedule_function(task, arglist, kwargs, context) + elif task.type == ray_client_pb2.ClientTask.ACTOR: + result = self._schedule_actor(task, arglist, kwargs, context) + elif task.type == ray_client_pb2.ClientTask.METHOD: + result = self._schedule_method(task, arglist, kwargs, context) + elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR: + result = self._schedule_named_actor(task, context) + else: + raise NotImplementedError( + "Unimplemented Schedule task type: %s" + % ray_client_pb2.ClientTask.RemoteExecType.Name(task.type) + ) + result.valid = True + return result + except Exception as e: + logger.debug("Caught schedule exception", exc_info=True) + return ray_client_pb2.ClientTaskTicket( + valid=False, error=cloudpickle.dumps(e) + ) + + def _schedule_method( + self, + task: ray_client_pb2.ClientTask, + arglist: List[Any], + kwargs: Dict[str, Any], + context=None, + ) -> ray_client_pb2.ClientTaskTicket: + actor_handle = self.actor_refs.get(task.payload_id) + if actor_handle is None: + raise Exception("Can't run an actor the server doesn't have a handle for") + method = getattr(actor_handle, task.name) + opts = decode_options(task.options) + if opts is not None: + method = method.options(**opts) + output = method.remote(*arglist, **kwargs) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) + + def _schedule_actor( + self, + task: ray_client_pb2.ClientTask, + arglist: List[Any], + kwargs: Dict[str, Any], + context=None, + ) -> ray_client_pb2.ClientTaskTicket: + remote_class = self.lookup_or_register_actor( + task.payload_id, task.client_id, decode_options(task.baseline_options) + ) + opts = decode_options(task.options) + if opts is not None: + remote_class = remote_class.options(**opts) + with current_server(self): + actor = remote_class.remote(*arglist, **kwargs) + self.actor_refs[actor._actor_id.binary()] = actor + self.actor_owners[task.client_id].add(actor._actor_id.binary()) + return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()]) + + def _schedule_function( + self, + task: ray_client_pb2.ClientTask, + arglist: List[Any], + kwargs: Dict[str, Any], + context=None, + ) -> ray_client_pb2.ClientTaskTicket: + remote_func = self.lookup_or_register_func( + task.payload_id, task.client_id, decode_options(task.baseline_options) + ) + opts = decode_options(task.options) + if opts is not None: + remote_func = remote_func.options(**opts) + with current_server(self): + output = remote_func.remote(*arglist, **kwargs) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) + + def _schedule_named_actor( + self, task: ray_client_pb2.ClientTask, context=None + ) -> ray_client_pb2.ClientTaskTicket: + assert len(task.payload_id) == 0 + # Convert empty string back to None. + actor = ray.get_actor(task.name, task.namespace or None) + bin_actor_id = actor._actor_id.binary() + if bin_actor_id not in self.actor_refs: + self.actor_refs[bin_actor_id] = actor + self.actor_owners[task.client_id].add(bin_actor_id) + self.named_actors.add(bin_actor_id) + return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()]) + + def lookup_or_register_func( + self, id: bytes, client_id: str, options: Optional[Dict] + ) -> ray.remote_function.RemoteFunction: + with disable_client_hook(): + if id not in self.function_refs: + funcref = self.object_refs[client_id][id] + func = ray.get(funcref) + if not inspect.isfunction(func): + raise Exception( + "Attempting to register function that isn't a function." + ) + if options is None or len(options) == 0: + self.function_refs[id] = ray.remote(func) + else: + self.function_refs[id] = ray.remote(**options)(func) + return self.function_refs[id] + + def lookup_or_register_actor( + self, id: bytes, client_id: str, options: Optional[Dict] + ): + with disable_client_hook(): + if id not in self.registered_actor_classes: + actor_class_ref = self.object_refs[client_id][id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that isn't a class.") + if options is None or len(options) == 0: + reg_class = ray.remote(actor_class) + else: + reg_class = ray.remote(**options)(actor_class) + self.registered_actor_classes[id] = reg_class + + return self.registered_actor_classes[id] + + def unify_and_track_outputs(self, output, client_id): + if output is None: + outputs = [] + elif isinstance(output, list): + outputs = output + else: + outputs = [output] + for out in outputs: + if out.binary() in self.object_refs[client_id]: + logger.warning(f"Already saw object_ref {out}") + self.object_refs[client_id][out.binary()] = out + return [out.binary() for out in outputs] + + +def return_exception_in_context(err, context): + if context is not None: + context.set_details(encode_exception(err)) + # Note: https://grpc.github.io/grpc/core/md_doc_statuscodes.html + # ABORTED used here since it should never be generated by the + # grpc lib -- this way we know the error was generated by ray logic + context.set_code(grpc.StatusCode.ABORTED) + + +def encode_exception(exception) -> str: + data = cloudpickle.dumps(exception) + return base64.standard_b64encode(data).decode() + + +def decode_options(options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]: + if not options.pickled_options: + return None + opts = pickle.loads(options.pickled_options) + assert isinstance(opts, dict) + + return opts + + +def serve(connection_str, ray_connect_handler=None): + def default_connect_handler( + job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any] + ): + with disable_client_hook(): + if not ray.is_initialized(): + return ray.init(job_config=job_config, **ray_init_kwargs) + + ray_connect_handler = ray_connect_handler or default_connect_handler + server = grpc.server( + futures.ThreadPoolExecutor( + max_workers=CLIENT_SERVER_MAX_THREADS, + thread_name_prefix="ray_client_server", + ), + options=GRPC_OPTIONS, + ) + task_servicer = RayletServicer(ray_connect_handler) + data_servicer = DataServicer(task_servicer) + logs_servicer = LogstreamServicer() + ray_client_pb2_grpc.add_RayletDriverServicer_to_server(task_servicer, server) + ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(data_servicer, server) + ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(logs_servicer, server) + add_port_to_grpc_server(server, connection_str) + current_handle = ClientServerHandle( + task_servicer=task_servicer, + data_servicer=data_servicer, + logs_servicer=logs_servicer, + grpc_server=server, + ) + server.start() + return current_handle + + +def init_and_serve(connection_str, *args, **kwargs): + with disable_client_hook(): + # Disable client mode inside the worker's environment + info = ray.init(*args, **kwargs) + + def ray_connect_handler(job_config=None, **ray_init_kwargs): + # Ray client will disconnect from ray when + # num_clients == 0. + if ray.is_initialized(): + return info + else: + return ray.init(job_config=job_config, *args, **kwargs) + + server_handle = serve(connection_str, ray_connect_handler=ray_connect_handler) + return (server_handle, info) + + +def shutdown_with_server(server, _exiting_interpreter=False): + server.stop(1) + with disable_client_hook(): + ray.shutdown(_exiting_interpreter) + + +def create_ray_handler(address, redis_password, redis_username=None): + def ray_connect_handler(job_config: JobConfig = None, **ray_init_kwargs): + if address: + if redis_password: + ray.init( + address=address, + _redis_username=redis_username, + _redis_password=redis_password, + job_config=job_config, + **ray_init_kwargs, + ) + else: + ray.init(address=address, job_config=job_config, **ray_init_kwargs) + else: + ray.init(job_config=job_config, **ray_init_kwargs) + + return ray_connect_handler + + +def try_create_gcs_client(address: Optional[str]) -> Optional[GcsClient]: + """ + Try to create a gcs client based on the the command line args or by + autodetecting a running Ray cluster. + """ + address = canonicalize_bootstrap_address_or_die(address) + return GcsClient(address=address) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host IP to bind to" + ) + parser.add_argument("-p", "--port", type=int, default=10001, help="Port to bind to") + parser.add_argument( + "--mode", + type=str, + choices=["proxy", "legacy", "specific-server"], + default="proxy", + ) + parser.add_argument( + "--address", required=False, type=str, help="Address to use to connect to Ray" + ) + parser.add_argument( + "--redis-username", + required=False, + type=str, + help="username for connecting to Redis", + ) + parser.add_argument( + "--redis-password", + required=False, + type=str, + help="Password for connecting to Redis", + ) + parser.add_argument( + "--runtime-env-agent-address", + required=False, + type=str, + default=None, + help="The port to use for connecting to the runtime_env_agent.", + ) + args, _ = parser.parse_known_args() + setup_logger(ray_constants.LOGGER_LEVEL, ray_constants.LOGGER_FORMAT) + + ray_connect_handler = create_ray_handler( + args.address, args.redis_password, args.redis_username + ) + + hostport = "%s:%d" % (args.host, args.port) + logger.info(f"Starting Ray Client server on {hostport}, args {args}") + if args.mode == "proxy": + server = serve_proxier( + hostport, + args.address, + redis_username=args.redis_username, + redis_password=args.redis_password, + runtime_env_agent_address=args.runtime_env_agent_address, + ) + else: + server = serve(hostport, ray_connect_handler) + + try: + idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S + while True: + health_report = { + "time": time.time(), + } + + try: + if not ray.experimental.internal_kv._internal_kv_initialized(): + gcs_client = try_create_gcs_client(args.address) + ray.experimental.internal_kv._initialize_internal_kv(gcs_client) + ray.experimental.internal_kv._internal_kv_put( + "ray_client_server", + json.dumps(health_report), + namespace=ray_constants.KV_NAMESPACE_HEALTHCHECK, + ) + except Exception as e: + logger.error( + f"[{args.mode}] Failed to put health check " f"on {args.address}" + ) + logger.exception(e) + + time.sleep(1) + if args.mode == "specific-server": + if server.data_servicer.num_clients > 0: + idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S + else: + idle_checks_remaining -= 1 + if idle_checks_remaining == 0: + raise KeyboardInterrupt() + if ( + idle_checks_remaining % 5 == 0 + and idle_checks_remaining != TIMEOUT_FOR_SPECIFIC_SERVER_S + ): + logger.info(f"{idle_checks_remaining} idle checks before shutdown.") + + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/server_pickler.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/server_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d91f400baa159649dda4088a9f6eba970ad257 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/server_pickler.py @@ -0,0 +1,124 @@ +"""Implements the client side of the client/server pickling protocol. + +These picklers are aware of the server internals and can find the +references held for the client within the server. + +More discussion about the client/server pickling protocol can be found in: + + ray/util/client/client_pickler.py + +ServerPickler dumps ray objects from the server into the appropriate stubs. +ClientUnpickler loads stubs from the client and finds their associated handle +in the server instance. +""" +import io +import ray + +from typing import Any +from typing import TYPE_CHECKING + +from ray._private.client_mode_hook import disable_client_hook +import ray.cloudpickle as cloudpickle +from ray.util.client.client_pickler import PickleStub +from ray.util.client.server.server_stubs import ClientReferenceActor +from ray.util.client.server.server_stubs import ClientReferenceFunction + +if TYPE_CHECKING: + from ray.util.client.server.server import RayletServicer + +import pickle # noqa: F401 + + +class ServerPickler(cloudpickle.CloudPickler): + def __init__(self, client_id: str, server: "RayletServicer", *args, **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + self.server = server + + def persistent_id(self, obj): + if isinstance(obj, ray.ObjectRef): + obj_id = obj.binary() + if obj_id not in self.server.object_refs[self.client_id]: + # We're passing back a reference, probably inside a reference. + # Let's hold onto it. + self.server.object_refs[self.client_id][obj_id] = obj + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj_id, + name=None, + baseline_options=None, + ) + elif isinstance(obj, ray.actor.ActorHandle): + actor_id = obj._actor_id.binary() + if actor_id not in self.server.actor_refs: + # We're passing back a handle, probably inside a reference. + self.server.actor_refs[actor_id] = obj + if actor_id not in self.server.actor_owners[self.client_id]: + self.server.actor_owners[self.client_id].add(actor_id) + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id.binary(), + name=None, + baseline_options=None, + ) + return None + + +class ClientUnpickler(pickle.Unpickler): + def __init__(self, server, *args, **kwargs): + super().__init__(*args, **kwargs) + self.server = server + + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Ray": + return ray + elif pid.type == "Object": + return self.server.object_refs[pid.client_id][pid.ref_id] + elif pid.type == "Actor": + return self.server.actor_refs[pid.ref_id] + elif pid.type == "RemoteFuncSelfReference": + return ClientReferenceFunction(pid.client_id, pid.ref_id) + elif pid.type == "RemoteFunc": + return self.server.lookup_or_register_func( + pid.ref_id, pid.client_id, pid.baseline_options + ) + elif pid.type == "RemoteActorSelfReference": + return ClientReferenceActor(pid.client_id, pid.ref_id) + elif pid.type == "RemoteActor": + return self.server.lookup_or_register_actor( + pid.ref_id, pid.client_id, pid.baseline_options + ) + elif pid.type == "RemoteMethod": + actor = self.server.actor_refs[pid.ref_id] + return getattr(actor, pid.name) + else: + raise NotImplementedError("Uncovered client data type") + + +def dumps_from_server( + obj: Any, client_id: str, server_instance: "RayletServicer", protocol=None +) -> bytes: + with io.BytesIO() as file: + sp = ServerPickler(client_id, server_instance, file, protocol=protocol) + sp.dump(obj) + return file.getvalue() + + +def loads_from_client( + data: bytes, + server_instance: "RayletServicer", + *, + fix_imports=True, + encoding="ASCII", + errors="strict" +) -> Any: + with disable_client_hook(): + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ClientUnpickler( + server_instance, file, fix_imports=fix_imports, encoding=encoding + ).load() diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/server/server_stubs.py b/.venv/lib/python3.11/site-packages/ray/util/client/server/server_stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..e19cbb3134a418bfdd8e1e33a04aaaac39f76654 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/server/server_stubs.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from abc import ABC +from abc import abstractmethod + +_current_server = None + + +@contextmanager +def current_server(r): + global _current_server + remote = _current_server + _current_server = r + try: + yield + finally: + _current_server = remote + + +class ClientReferenceSentinel(ABC): + def __init__(self, client_id, id): + self.client_id = client_id + self.id = id + + def __reduce__(self): + remote_obj = self.get_remote_obj() + if remote_obj is None: + return (self.__class__, (self.client_id, self.id)) + return (identity, (remote_obj,)) + + @abstractmethod + def get_remote_obj(self): + pass + + def get_real_ref_from_server(self): + global _current_server + if _current_server is None: + return None + client_map = _current_server.client_side_ref_map.get(self.client_id, None) + if client_map is None: + return None + return client_map.get(self.id, None) + + +class ClientReferenceActor(ClientReferenceSentinel): + def get_remote_obj(self): + global _current_server + real_ref_id = self.get_real_ref_from_server() + if real_ref_id is None: + return None + return _current_server.lookup_or_register_actor( + real_ref_id, self.client_id, None + ) + + +class ClientReferenceFunction(ClientReferenceSentinel): + def get_remote_obj(self): + global _current_server + real_ref_id = self.get_real_ref_from_server() + if real_ref_id is None: + return None + return _current_server.lookup_or_register_func( + real_ref_id, self.client_id, None + ) + + +def identity(x): + return x diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/collective/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a76e27b8e3e61229d502036a8c46e0bf5b8e150 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/__init__.py @@ -0,0 +1,51 @@ +from ray.util.collective.collective import ( + nccl_available, + gloo_available, + is_group_initialized, + init_collective_group, + destroy_collective_group, + create_collective_group, + get_rank, + get_collective_group_size, + allreduce, + allreduce_multigpu, + barrier, + reduce, + reduce_multigpu, + broadcast, + broadcast_multigpu, + allgather, + allgather_multigpu, + reducescatter, + reducescatter_multigpu, + send, + send_multigpu, + recv, + recv_multigpu, +) + +__all__ = [ + "nccl_available", + "gloo_available", + "is_group_initialized", + "init_collective_group", + "destroy_collective_group", + "create_collective_group", + "get_rank", + "get_collective_group_size", + "allreduce", + "allreduce_multigpu", + "barrier", + "reduce", + "reduce_multigpu", + "broadcast", + "broadcast_multigpu", + "allgather", + "allgather_multigpu", + "reducescatter", + "reducescatter_multigpu", + "send", + "send_multigpu", + "recv", + "recv_multigpu", +] diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective.py new file mode 100644 index 0000000000000000000000000000000000000000..9399cdb88c016d66ea0c0ffde1fff1540244e71b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective.py @@ -0,0 +1,789 @@ +"""APIs exposed under the namespace ray.util.collective.""" +import logging +import os +from typing import List + +import numpy as np + +import ray +from ray.util.collective import types + +_NCCL_AVAILABLE = True +_GLOO_AVAILABLE = True + +logger = logging.getLogger(__name__) + +try: + from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup +except ImportError: + _NCCL_AVAILABLE = False + logger.warning( + "NCCL seems unavailable. Please install Cupy " + "following the guide at: " + "https://docs.cupy.dev/en/stable/install.html." + ) + +try: + from ray.util.collective.collective_group.gloo_collective_group import GLOOGroup +except ImportError: + _GLOO_AVAILABLE = False + + +def nccl_available(): + return _NCCL_AVAILABLE + + +def gloo_available(): + return _GLOO_AVAILABLE + + +class GroupManager(object): + """Use this class to manage the collective groups we created so far. + + Each process will have an instance of `GroupManager`. Each process + could belong to multiple collective groups. The membership information + and other metadata are stored in the global `_group_mgr` object. + """ + + def __init__(self): + self._name_group_map = {} + self._group_name_map = {} + + def create_collective_group(self, backend, world_size, rank, group_name): + """The entry to create new collective groups in the manager. + + Put the registration and the group information into the manager + metadata as well. + """ + backend = types.Backend(backend) + if backend == types.Backend.MPI: + raise RuntimeError("Ray does not support MPI.") + elif backend == types.Backend.GLOO: + logger.debug("Creating GLOO group: '{}'...".format(group_name)) + g = GLOOGroup( + world_size, + rank, + group_name, + store_type="ray_internal_kv", + device_type="tcp", + ) + self._name_group_map[group_name] = g + self._group_name_map[g] = group_name + elif backend == types.Backend.NCCL: + logger.debug("Creating NCCL group: '{}'...".format(group_name)) + g = NCCLGroup(world_size, rank, group_name) + self._name_group_map[group_name] = g + self._group_name_map[g] = group_name + return self._name_group_map[group_name] + + def is_group_exist(self, group_name): + return group_name in self._name_group_map + + def get_group_by_name(self, group_name): + """Get the collective group handle by its name.""" + if not self.is_group_exist(group_name): + logger.warning("The group '{}' is not initialized.".format(group_name)) + return None + return self._name_group_map[group_name] + + def destroy_collective_group(self, group_name): + """Group destructor.""" + if not self.is_group_exist(group_name): + logger.warning("The group '{}' does not exist.".format(group_name)) + return + + # release the collective group resource + g = self._name_group_map[group_name] + # clean up the dicts + del self._group_name_map[g] + del self._name_group_map[group_name] + # Release the communicator resources + g.destroy_group() + + # Release the detached actors spawned by `create_collective_group()` + name = "info_" + group_name + try: + store = ray.get_actor(name) + ray.kill(store) + except ValueError: + pass + + +_group_mgr = GroupManager() + + +def is_group_initialized(group_name): + """Check if the group is initialized in this process by the group name.""" + return _group_mgr.is_group_exist(group_name) + + +def init_collective_group( + world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default" +): + """Initialize a collective group inside an actor process. + + Args: + world_size: the total number of processes in the group. + rank: the rank of the current process. + backend: the CCL backend to use, NCCL or GLOO. + group_name: the name of the collective group. + + Returns: + None + """ + _check_inside_actor() + backend = types.Backend(backend) + _check_backend_availability(backend) + global _group_mgr + # TODO(Hao): implement a group auto-counter. + if not group_name: + raise ValueError("group_name '{}' needs to be a string.".format(group_name)) + + if _group_mgr.is_group_exist(group_name): + raise RuntimeError("Trying to initialize a group twice.") + + assert world_size > 0 + assert rank >= 0 + assert rank < world_size + _group_mgr.create_collective_group(backend, world_size, rank, group_name) + + +def create_collective_group( + actors, + world_size: int, + ranks: List[int], + backend=types.Backend.NCCL, + group_name: str = "default", +): + """Declare a list of actors as a collective group. + + Note: This function should be called in a driver process. + + Args: + actors: a list of actors to be set in a collective group. + world_size: the total number of processes in the group. + ranks (List[int]): the rank of each actor. + backend: the CCL backend to use, NCCL or GLOO. + group_name: the name of the collective group. + + Returns: + None + """ + backend = types.Backend(backend) + _check_backend_availability(backend) + + name = "info_" + group_name + try: + ray.get_actor(name) + raise RuntimeError("Trying to initialize a group twice.") + except ValueError: + pass + + if len(ranks) != len(actors): + raise RuntimeError( + "Each actor should correspond to one rank. Got '{}' " + "ranks but '{}' actors".format(len(ranks), len(actors)) + ) + + if set(ranks) != set(range(len(ranks))): + raise RuntimeError( + "Ranks must be a permutation from 0 to '{}'. Got '{}'.".format( + len(ranks), "".join([str(r) for r in ranks]) + ) + ) + + if world_size <= 0: + raise RuntimeError( + "World size must be greater than zero. Got '{}'.".format(world_size) + ) + if not all(ranks) >= 0: + raise RuntimeError("Ranks must be non-negative.") + if not all(ranks) < world_size: + raise RuntimeError("Ranks cannot be greater than world_size.") + + # avoid a circular dependency + from ray.util.collective.util import Info + + # store the information into a NamedActor that can be accessed later. + name = "info_" + group_name + actors_id = [a._ray_actor_id for a in actors] + # TODO (Dacheng): how do we recycle this name actor? + info = Info.options(name=name, lifetime="detached").remote() + ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)]) + + +# TODO (we need a declarative destroy() API here.) +def destroy_collective_group(group_name: str = "default") -> None: + """Destroy a collective group given its group name.""" + _check_inside_actor() + global _group_mgr + _group_mgr.destroy_collective_group(group_name) + + +def get_rank(group_name: str = "default") -> int: + """Return the rank of this process in the given group. + + Args: + group_name: the name of the group to query + + Returns: + the rank of this process in the named group, + -1 if the group does not exist or the process does + not belong to the group. + """ + _check_inside_actor() + if not is_group_initialized(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.rank + + +def get_collective_group_size(group_name: str = "default") -> int: + """Return the size of the collective group with the given name. + + Args: + group_name: the name of the group to query + + Returns: + The world size of the collective group, -1 if the group does + not exist or the process does not belong to the group. + """ + _check_inside_actor() + if not is_group_initialized(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.world_size + + +def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): + """Collective allreduce the tensor across the group. + + Args: + tensor: the tensor to be all-reduced on this process. + group_name: the collective group name to perform allreduce. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + opts = types.AllReduceOptions + opts.reduceOp = op + g.allreduce([tensor], opts) + + +def allreduce_multigpu( + tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM +): + """Collective allreduce a list of tensors across the group. + + Args: + tensor_list (List[tensor]): list of tensors to be allreduced, + each on a GPU. + group_name: the collective group name to perform allreduce. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + opts = types.AllReduceOptions + opts.reduceOp = op + g.allreduce(tensor_list, opts) + + +def barrier(group_name: str = "default"): + """Barrier all processes in the collective group. + + Args: + group_name: the name of the group to barrier. + + Returns: + None + """ + g = _check_and_get_group(group_name) + g.barrier() + + +def reduce( + tensor, dst_rank: int = 0, group_name: str = "default", op=types.ReduceOp.SUM +): + """Reduce the tensor across the group to the destination rank. + + Args: + tensor: the tensor to be reduced on this process. + dst_rank: the rank of the destination process. + group_name: the collective group name to perform reduce. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check dst rank + _check_rank_valid(g, dst_rank) + opts = types.ReduceOptions() + opts.reduceOp = op + opts.root_rank = dst_rank + opts.root_tensor = 0 + g.reduce([tensor], opts) + + +def reduce_multigpu( + tensor_list: list, + dst_rank: int = 0, + dst_tensor: int = 0, + group_name: str = "default", + op=types.ReduceOp.SUM, +): + """Reduce the tensor across the group to the destination rank + and destination tensor. + + Args: + tensor_list: the list of tensors to be reduced on this process; + each tensor located on a GPU. + dst_rank: the rank of the destination process. + dst_tensor: the index of GPU at the destination. + group_name: the collective group name to perform reduce. + op: The reduce operation. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + + # check dst rank + _check_rank_valid(g, dst_rank) + _check_root_tensor_valid(len(tensor_list), dst_tensor) + opts = types.ReduceOptions() + opts.reduceOp = op + opts.root_rank = dst_rank + opts.root_tensor = dst_tensor + g.reduce(tensor_list, opts) + + +def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): + """Broadcast the tensor from a source process to all others. + + Args: + tensor: the tensor to be broadcasted (src) or received (destination). + src_rank: the rank of the source process. + group_name: the collective group name to perform broadcast. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check src rank + _check_rank_valid(g, src_rank) + opts = types.BroadcastOptions() + opts.root_rank = src_rank + opts.root_tensor = 0 + g.broadcast([tensor], opts) + + +def broadcast_multigpu( + tensor_list, src_rank: int = 0, src_tensor: int = 0, group_name: str = "default" +): + """Broadcast the tensor from a source GPU to all other GPUs. + + Args: + tensor_list: the tensors to broadcast (src) or receive (dst). + src_rank: the rank of the source process. + src_tensor: the index of the source GPU on the source process. + group_name: the collective group name to perform broadcast. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + + # check src rank + _check_rank_valid(g, src_rank) + _check_root_tensor_valid(len(tensor_list), src_tensor) + opts = types.BroadcastOptions() + opts.root_rank = src_rank + opts.root_tensor = src_tensor + g.broadcast(tensor_list, opts) + + +def allgather(tensor_list: list, tensor, group_name: str = "default"): + """Allgather tensors from each process of the group into a list. + + Args: + tensor_list: the results, stored as a list of tensors. + tensor: the tensor (to be gathered) in the current process + group_name: the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + # Typically CLL lib requires len(tensor_list) >= world_size; + # Here we make it more strict: len(tensor_list) == world_size. + raise RuntimeError( + "The length of the tensor list operands to allgather " + "must be equal to world_size." + ) + opts = types.AllGatherOptions() + g.allgather([tensor_list], [tensor], opts) + + +def allgather_multigpu( + output_tensor_lists: list, input_tensor_list: list, group_name: str = "default" +): + """Allgather tensors from each gpus of the group into lists. + + Args: + output_tensor_lists (List[List[tensor]]): gathered results, with shape + must be num_gpus * world_size * shape(tensor). + input_tensor_list: (List[tensor]): a list of tensors, with shape + num_gpus * shape(tensor). + group_name: the name of the collective group. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_lists_input(output_tensor_lists) + _check_tensor_list_input(input_tensor_list) + g = _check_and_get_group(group_name) + opts = types.AllGatherOptions() + g.allgather(output_tensor_lists, input_tensor_list, opts) + + +def reducescatter( + tensor, tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM +): + """Reducescatter a list of tensors across the group. + + Reduce the list of the tensors across each process in the group, then + scatter the reduced list of tensors -- one tensor for each process. + + Args: + tensor: the resulted tensor on this process. + tensor_list: The list of tensors to be reduced and scattered. + group_name: the name of the collective group. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + raise RuntimeError( + "The length of the tensor list operands to reducescatter " + "must not be equal to world_size." + ) + opts = types.ReduceScatterOptions() + opts.reduceOp = op + g.reducescatter([tensor], [tensor_list], opts) + + +def reducescatter_multigpu( + output_tensor_list, + input_tensor_lists, + group_name: str = "default", + op=types.ReduceOp.SUM, +): + """Reducescatter a list of tensors across all GPUs. + + Args: + output_tensor_list: the resulted list of tensors, with + shape: num_gpus * shape(tensor). + input_tensor_lists: the original tensors, with shape: + num_gpus * world_size * shape(tensor). + group_name: the name of the collective group. + op: The reduce operation. + + Returns: + None. + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_lists_input(input_tensor_lists) + _check_tensor_list_input(output_tensor_list) + g = _check_and_get_group(group_name) + opts = types.ReduceScatterOptions() + opts.reduceOp = op + g.reducescatter(output_tensor_list, input_tensor_lists, opts) + + +def send(tensor, dst_rank: int, group_name: str = "default"): + """Send a tensor to a remote process synchronously. + + Args: + tensor: the tensor to send. + dst_rank: the rank of the destination process. + group_name: the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, dst_rank) + if dst_rank == g.rank: + raise RuntimeError("The destination rank '{}' is self.".format(dst_rank)) + opts = types.SendOptions() + opts.dst_rank = dst_rank + g.send([tensor], opts) + + +def send_multigpu( + tensor, + dst_rank: int, + dst_gpu_index: int, + group_name: str = "default", + n_elements: int = 0, +): + """Send a tensor to a remote GPU synchronously. + + The function asssume each process owns >1 GPUs, and the sender + process and receiver process has equal nubmer of GPUs. + + Args: + tensor: the tensor to send, located on a GPU. + dst_rank: the rank of the destination process. + dst_gpu_index: the destination gpu index. + group_name: the name of the collective group. + n_elements: if specified, send the next n elements + from the starting address of tensor. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("send_multigpu call requires NCCL.") + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, dst_rank) + if dst_rank == g.rank: + raise RuntimeError( + "The dst_rank '{}' is self. Considering " + "doing GPU to GPU memcpy instead?".format(dst_rank) + ) + if n_elements < 0: + raise RuntimeError("The n_elements '{}' should >= 0.".format(n_elements)) + opts = types.SendOptions() + opts.dst_rank = dst_rank + opts.dst_gpu_index = dst_gpu_index + opts.n_elements = n_elements + g.send([tensor], opts) + + +def recv(tensor, src_rank: int, group_name: str = "default"): + """Receive a tensor from a remote process synchronously. + + Args: + tensor: the received tensor. + src_rank: the rank of the source process. + group_name: the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, src_rank) + if src_rank == g.rank: + raise RuntimeError("The destination rank '{}' is self.".format(src_rank)) + opts = types.RecvOptions() + opts.src_rank = src_rank + g.recv([tensor], opts) + + +def recv_multigpu( + tensor, + src_rank: int, + src_gpu_index: int, + group_name: str = "default", + n_elements: int = 0, +): + """Receive a tensor from a remote GPU synchronously. + + The function asssume each process owns >1 GPUs, and the sender + process and receiver process has equal nubmer of GPUs. + + Args: + tensor: the received tensor, located on a GPU. + src_rank: the rank of the source process. + src_gpu_index (int): the index of the source gpu on the src process. + group_name: the name of the collective group. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("recv_multigpu call requires NCCL.") + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, src_rank) + if src_rank == g.rank: + raise RuntimeError( + "The dst_rank '{}' is self. Considering " + "doing GPU to GPU memcpy instead?".format(src_rank) + ) + if n_elements < 0: + raise RuntimeError("The n_elements '{}' should be >= 0.".format(n_elements)) + opts = types.RecvOptions() + opts.src_rank = src_rank + opts.src_gpu_index = src_gpu_index + opts.n_elements = n_elements + g.recv([tensor], opts) + + +def synchronize(gpu_id: int): + """Synchronize the current process to a give device. + + Args: + gpu_id: the GPU device id to synchronize. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("synchronize call requires CUDA and NCCL.") + import cupy as cp + + cp.cuda.Device(gpu_id).synchronize() + + +def _check_and_get_group(group_name): + """Check the existence and return the group handle.""" + _check_inside_actor() + global _group_mgr + if not is_group_initialized(group_name): + # try loading from remote info store + try: + # if the information is stored in an Info object, + # get and create the group. + name = "info_" + group_name + mgr = ray.get_actor(name=name) + ids, world_size, rank, backend = ray.get(mgr.get_info.remote()) + worker = ray._private.worker.global_worker + id_ = worker.core_worker.get_actor_id() + r = rank[ids.index(id_)] + _group_mgr.create_collective_group(backend, world_size, r, group_name) + except ValueError as exc: + # check if this group is initialized using options() + if ( + "collective_group_name" in os.environ + and os.environ["collective_group_name"] == group_name + ): + rank = int(os.environ["collective_rank"]) + world_size = int(os.environ["collective_world_size"]) + backend = os.environ["collective_backend"] + _group_mgr.create_collective_group( + backend, world_size, rank, group_name + ) + else: + raise RuntimeError( + "The collective group '{}' is not " + "initialized in the process.".format(group_name) + ) from exc + g = _group_mgr.get_group_by_name(group_name) + return g + + +def _check_single_tensor_input(tensor): + """Check if the tensor is with a supported type.""" + if isinstance(tensor, np.ndarray): + return + if types.cupy_available(): + if isinstance(tensor, types.cp.ndarray): + return + if types.torch_available(): + if isinstance(tensor, types.th.Tensor): + return + raise RuntimeError( + "Unrecognized tensor type '{}'. Supported types are: " + "np.ndarray, torch.Tensor, cupy.ndarray.".format(type(tensor)) + ) + + +def _check_backend_availability(backend: types.Backend): + """Check whether the backend is available.""" + if backend == types.Backend.GLOO: + if not gloo_available(): + raise RuntimeError("GLOO is not available.") + elif backend == types.Backend.NCCL: + if not nccl_available(): + raise RuntimeError("NCCL is not available.") + + +def _check_inside_actor(): + """Check if currently it is inside a Ray actor/task.""" + worker = ray._private.worker.global_worker + if worker.mode == ray.WORKER_MODE: + return + else: + raise RuntimeError( + "The collective APIs shall be only used inside a Ray actor or task." + ) + + +def _check_rank_valid(g, rank: int): + """Check the rank: 0 <= rank < world_size.""" + if rank < 0: + raise ValueError("rank '{}' is negative.".format(rank)) + if rank >= g.world_size: + raise ValueError( + "rank '{}' must be less than world size '{}'".format(rank, g.world_size) + ) + + +def _check_tensor_list_input(tensor_list): + """Check if the input is a list of supported tensor types.""" + if not isinstance(tensor_list, list): + raise RuntimeError( + "The input must be a list of tensors. " + "Got '{}'.".format(type(tensor_list)) + ) + if not tensor_list: + raise RuntimeError("Got an empty list of tensors.") + for t in tensor_list: + _check_single_tensor_input(t) + + +def _check_tensor_lists_input(tensor_lists): + """Check if the input is a list of lists of supported tensor types.""" + if not isinstance(tensor_lists, list): + raise RuntimeError( + "The input must be a list of lists of tensors. " + "Got '{}'.".format(type(tensor_lists)) + ) + if not tensor_lists: + raise RuntimeError(f"Did not receive tensors. Got: {tensor_lists}") + for t in tensor_lists: + _check_tensor_list_input(t) + + +def _check_root_tensor_valid(length, root_tensor): + """Check the root_tensor device is 0 <= root_tensor < length""" + if root_tensor < 0: + raise ValueError("root_tensor '{}' is negative.".format(root_tensor)) + if root_tensor >= length: + raise ValueError( + "root_tensor '{}' is greater than the number of GPUs: " + "'{}'".format(root_tensor, length) + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/base_collective_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/base_collective_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26d45cc57ec01fc093906fc480166fe05cd34103 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/base_collective_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/cuda_stream.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/cuda_stream.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d905967a93c4b2b9fd01edc87e1556f7bed7d7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/cuda_stream.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_collective_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_collective_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0c50359b62ac4841d5670264895683899ebba5f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_collective_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..198468c472ac4a82b84c22f283ef1e09ea878262 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/__pycache__/nccl_util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/base_collective_group.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/base_collective_group.py new file mode 100644 index 0000000000000000000000000000000000000000..1272d946f0a372552d1a16a3cc8e96d9962a8651 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/base_collective_group.py @@ -0,0 +1,84 @@ +"""Abstract class for collective groups.""" +from abc import ABCMeta +from abc import abstractmethod + +from ray.util.collective.types import ( + AllReduceOptions, + BarrierOptions, + ReduceOptions, + AllGatherOptions, + BroadcastOptions, + ReduceScatterOptions, +) + + +class BaseGroup(metaclass=ABCMeta): + def __init__(self, world_size, rank, group_name): + """Init the process group with basic information. + + Args: + world_size: The total number of processes in the group. + rank: The rank of the current process. + group_name: The group name. + """ + self._world_size = world_size + self._rank = rank + self._group_name = group_name + + @property + def rank(self): + """Return the rank of the current process.""" + return self._rank + + @property + def world_size(self): + """Return the number of processes in this group.""" + return self._world_size + + @property + def group_name(self): + """Return the group name of this group.""" + return self._group_name + + def destroy_group(self): + """GC the communicators.""" + pass + + @classmethod + def backend(cls): + """The backend of this collective group.""" + raise NotImplementedError() + + @abstractmethod + def allreduce(self, tensor, allreduce_options=AllReduceOptions()): + raise NotImplementedError() + + @abstractmethod + def barrier(self, barrier_options=BarrierOptions()): + raise NotImplementedError() + + @abstractmethod + def reduce(self, tensor, reduce_options=ReduceOptions()): + raise NotImplementedError() + + @abstractmethod + def allgather(self, tensor_list, tensor, allgather_options=AllGatherOptions()): + raise NotImplementedError() + + @abstractmethod + def broadcast(self, tensor, broadcast_options=BroadcastOptions()): + raise NotImplementedError() + + @abstractmethod + def reducescatter( + self, tensor, tensor_list, reducescatter_options=ReduceScatterOptions() + ): + raise NotImplementedError() + + @abstractmethod + def send(self, tensor, dst_rank): + raise NotImplementedError() + + @abstractmethod + def recv(self, tensor, src_rank): + raise NotImplementedError() diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/cuda_stream.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/cuda_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..d5496755f82b87c43ca5ad64ce542c9eca25b8c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/cuda_stream.py @@ -0,0 +1,94 @@ +import logging +import threading + +import cupy +from ray.util.collective.collective_group import nccl_util +from ray.util.collective.const import ENV + +NCCL_STREAM_POOL_SIZE = 32 +MAX_GPU_PER_ACTOR = 16 + +logger = logging.getLogger(__name__) + + +class StreamPool: + """The class that represents a stream pool associated with a GPU. + + When multistream is enabled, we will allocate a pool of streams for each + GPU, and get available stream from this pool when a collective kernel is + initialized. This enables overlapping computation/communication kernels + using multiple CUDA streams, given that the streams a appropriately + synchronized. The class is thread-safe. + + + Args: + device_idx: the absolute index of the device for this pool. + """ + + def __init__(self, device_idx): + self.device_idx = device_idx + + self._initialized = False + self._initialized_lock = threading.Lock() + + self._pool = [None] * NCCL_STREAM_POOL_SIZE + self._counter = 0 + self._pool_lock = threading.Lock() + + def get_stream(self): + """Get an available stream from the pool. + + The function locks the stream pool and releases the lock before + returning. + + Returns: + stream (cupy.cuda.Stream): the returned stream from pool. + """ + + # check the flag + self._initialized_lock.acquire() + if not self._initialized: + self._init_once() + self._initialized_lock.release() + + # Get the stream from the pool. + self._pool_lock.acquire() + stream = self._pool[self._counter] + self._counter = (self._counter + 1) % NCCL_STREAM_POOL_SIZE + self._pool_lock.release() + return stream + + def _init_once(self): + """Initialize the stream pool only for once.""" + with nccl_util.Device(self.device_idx): + for i in range(NCCL_STREAM_POOL_SIZE): + # this is the only place where self._pool will be written. + if ENV.NCCL_USE_MULTISTREAM.val: + logger.debug("NCCL multistream enabled.") + self._pool[i] = cupy.cuda.Stream(null=False, non_blocking=False) + else: + logger.debug("NCCL multistream disabled.") + self._pool[i] = cupy.cuda.Stream.null + self._init_flag = True + + +# This is a map from GPU index to its stream pool. +# It is supposed to be READ-ONLY out of this file +_device_stream_pool_map = dict() + + +def _init_stream_pool(): + global _device_stream_pool_map + for i in range(MAX_GPU_PER_ACTOR): + _device_stream_pool_map[i] = StreamPool(i) + + +def get_stream_pool(device_idx): + """Get the CUDA stream pool of a GPU device.""" + # In case there will be multiple threads writing to the pool. + lock = threading.Lock() + lock.acquire() + if not _device_stream_pool_map: + _init_stream_pool() + lock.release() + return _device_stream_pool_map[device_idx] diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_collective_group.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_collective_group.py new file mode 100644 index 0000000000000000000000000000000000000000..309f5943880941ecc2a08ffb1ee2a277bfa1b770 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_collective_group.py @@ -0,0 +1,565 @@ +import datetime +import logging +import os +import shutil +import time + +import numpy +import pygloo + +import ray +from ray._private import ray_constants +from ray.util.collective.collective_group import gloo_util +from ray.util.collective.collective_group.base_collective_group import BaseGroup +from ray.util.collective.const import get_store_name +from ray.util.collective.types import ( + AllGatherOptions, + AllReduceOptions, + Backend, + BarrierOptions, + BroadcastOptions, + RecvOptions, + ReduceOptions, + ReduceScatterOptions, + SendOptions, +) + +logger = logging.getLogger(__name__) + + +class Rendezvous: + """A rendezvous class for different actor/task processes to meet. + + To initialize an GLOO collective communication group, different + actors/tasks spawned in Ray in a collective group needs to meet + each other to synchronize the GLOOUniqueID. This class guarantees + they meet via the GLOOUniqueIDStore, initialized on the rank=0 + process. + + Args: + group_name: the unique user-specified group name. + """ + + def __init__(self, group_name, context, store_type, device_type): + self._group_name = group_name + self._context = context + redis_address = ray._private.worker._global_node.redis_address + (self._redis_ip_address, self._redis_port) = ( + redis_address.split(":") if store_type == "redis" else (None, None) + ) + self._process_ip_address = ray.util.get_node_ip_address() + logger.debug( + "Redis address: {}, port: {}, this actor address: {}.".format( + self._redis_ip_address, self._redis_port, self._process_ip_address + ) + ) + self._store_type = store_type + self._device_type = device_type + self._store = None + self._device = None + self.create_store(store_type) + self.create_device(device_type) + + def create_store(self, store_type): + if store_type == "ray_internal_kv": + ray_internal_kv_store = gloo_util.RayInternalKvStore(self._group_name) + self._store = pygloo.rendezvous.CustomStore(ray_internal_kv_store) + elif store_type == "redis": + redisStore = pygloo.rendezvous.RedisStore( + self._redis_ip_address, int(self._redis_port) + ) + redis_password = ray._private.worker._global_node.redis_password + if redis_password is None or len(redis_password) == 0: + redis_password = ray_constants.REDIS_DEFAULT_PASSWORD + redisStore.authorize(redis_password) + self._store = redisStore + elif store_type == "file": + store_name = get_store_name(self._group_name) + store_path = gloo_util.get_gloo_store_path(store_name) + if self._context.rank == 0: + if not os.path.exists(store_path): + os.makedirs(store_path) + elif os.listdir(store_path) and os.listdir(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + else: + while not os.path.exists(store_path): + time.sleep(0.1) + # Note: multi-machines needs a shared NFS. + fileStore = pygloo.rendezvous.FileStore(store_path) + self._store = pygloo.rendezvous.PrefixStore(self._group_name, fileStore) + elif store_type == "hash": + raise NotImplementedError("No implementation for hash store.") + else: + raise RuntimeError("Unrecognized store type: {}.".format(store_type)) + + def create_device(self, device_type): + if device_type == "tcp": + attr = pygloo.transport.tcp.attr(self._process_ip_address) + self._device = pygloo.transport.tcp.CreateDevice(attr) + elif device_type == "uv": + raise NotImplementedError("No implementation for uv.") + + def meet(self, timeout_s=180): + """Meet at the named actor store. + + Args: + timeout_s: timeout in seconds. + + Return: + None + """ + if timeout_s <= 0: + raise ValueError( + "The 'timeout' argument must be positive. " + "Got '{}'.".format(timeout_s) + ) + + timeout_delta = datetime.timedelta(seconds=timeout_s) + elapsed = datetime.timedelta(seconds=0) + start_time = datetime.datetime.now() + q, s = None, None + + if self._store_type == "redis" or self._store_type == "ray_internal_kv": + while elapsed < timeout_delta: + try: + # I don't quite understand why we need gloo queue actor. + q = ray.get_actor("gloo_queue") + s = ray.get_actor(f"gloo_{self._group_name}_signal") + break + except ValueError: + if self._context.rank == 0: + if not q: + ray.remote(gloo_util.glooQueue).options( + name="gloo_queue", lifetime="detached" + ).remote(1000) + if not s: + gloo_util.SignalActor.options( + name=f"gloo_{self._group_name}_signal", + lifetime="detached", + ).remote(self._context.size) + else: + time.sleep(0.1) + elapsed = datetime.datetime.now() - start_time + if not q: + raise RuntimeError("Unable to get gloo_queue.") + if self._context.rank == 0: + ray.get(q.put_nowait.remote(self._group_name)) + while ray.get(q.index.remote(self._group_name)): + time.sleep(0.1) + + self._context.connectFullMesh(self._store, self._device) + ray.get(s.send.remote(self._context.rank)) + if self._context.rank == 0: + ray.get(s.wait.remote()) + keys = [] + keys += [f"rank_{i}" for i in range(self._context.size)] + keys += [f"{i}" for i in range(self._context.size)] + self._store.delKeys(keys) + group_name = ray.get(q.get_nowait.remote()) + assert group_name == self._group_name + ray.kill(s) + + @property + def store_type(self): + return self._store_type + + @property + def store(self): + return self._store + + @property + def device_type(self): + return self._device_type + + @property + def device(self): + return self._device + + def destroy(self): + """GC the store and device used by this rendevzous.""" + self._device = None + + +class GLOOGroup(BaseGroup): + def __init__( + self, + world_size, + rank, + group_name, + store_type="ray_internal_kv", + device_type="tcp", + ): + """Init an GLOO collective group. + + Args: + world_size: The number of processes. + rank: The id of process + group_name: The unique user-specified group name. + store_type: The store type. Optional: "redis", + "file", "hash". + device_type: The device type to transport. + Optional: "tcp", "uv". + """ + super(GLOOGroup, self).__init__(world_size, rank, group_name) + self._gloo_context = gloo_util.create_gloo_context(self.rank, self.world_size) + self._rendezvous = Rendezvous( + self.group_name, self._gloo_context, store_type, device_type + ) + self._rendezvous.meet() + + def destroy_group(self): + """Destroy the group and release GLOO communicators.""" + self._rendezvous.destroy() + + if self._gloo_context is not None: + pygloo.barrier(self._gloo_context) + # destroy the communicator + self._gloo_context = None + + if self.rank == 0 and self._rendezvous.store_type == "file": + store_name = get_store_name(self._group_name) + store_path = gloo_util.get_gloo_store_path(store_name) + if os.path.exists(store_path): + shutil.rmtree(store_path) + super(GLOOGroup, self).destroy_group() + + @classmethod + def backend(cls): + return Backend.GLOO + + def allreduce(self, tensors, allreduce_options=AllReduceOptions()): + """AllReduce a list of tensors following options. + + Args: + tensor: the tensor to be reduced, each tensor locates on CPU + allreduce_options: + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, context): + pygloo.allreduce( + context, + gloo_util.get_tensor_ptr(input_tensor), + gloo_util.get_tensor_ptr(output_tensor), + gloo_util.get_tensor_n_elements(input_tensor), + gloo_util.get_gloo_tensor_dtype(input_tensor), + gloo_util.get_gloo_reduce_op(allreduce_options.reduceOp), + ) + + self._collective(tensors, tensors, collective_fn) + + def barrier(self, barrier_options=BarrierOptions()): + """Blocks until all processes reach this barrier. + + Args: + barrier_options: barrier options. + + Returns: + None + """ + barrier_tensor = numpy.array([1]) + self.allreduce([barrier_tensor]) + + def reduce(self, tensors, reduce_options=ReduceOptions()): + """Reduce tensors following options. + + Args: + tensors: the list of tensors to be reduced, + this list only have one tensor. + reduce_options: reduce options. + + Returns: + None + """ + root_rank = reduce_options.root_rank + + def collective_fn(input_tensor, output_tensor, context): + pygloo.reduce( + context, + gloo_util.get_tensor_ptr(input_tensor), + gloo_util.get_tensor_ptr(output_tensor), + gloo_util.get_tensor_n_elements(input_tensor), + gloo_util.get_gloo_tensor_dtype(input_tensor), + gloo_util.get_gloo_reduce_op(reduce_options.reduceOp), + root_rank, + ) + + self._collective(tensors, tensors, collective_fn) + + def broadcast(self, tensors, broadcast_options=BroadcastOptions()): + """Broadcast tensors to all other processes following options. + + Args: + tensors: tensors to be broadcast or received. + broadcast_options: broadcast options. + + Returns: + None + """ + root_rank = broadcast_options.root_rank + + def collective_fn(input_tensor, output_tensor, context): + pygloo.broadcast( + context, + gloo_util.get_tensor_ptr(input_tensor), + gloo_util.get_tensor_ptr(output_tensor), + gloo_util.get_tensor_n_elements(input_tensor), + gloo_util.get_gloo_tensor_dtype(input_tensor), + root_rank, + ) + + self._collective(tensors, tensors, collective_fn) + + def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): + """Allgather tensors on CPU into a list of tensors. + + Args: + tensor_lists (List[List[Tensor]]): allgathered tensors. + tensors: the list of tensors to allgather across the group. + Each tensor must locate on CPU. + allgather_options: allgather options. + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, context): + pygloo.allgather( + context, + gloo_util.get_tensor_ptr(input_tensor), + gloo_util.get_tensor_ptr(output_tensor), + gloo_util.get_tensor_n_elements(input_tensor), + gloo_util.get_gloo_tensor_dtype(input_tensor), + ) + + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + output_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] + + def postprocess_fn(): + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + gloo_util.copy_tensor(tensor, output_flattened[i][j]) + + self._collective( + tensors, output_flattened, collective_fn, postprocess_fn=postprocess_fn + ) + + def reducescatter( + self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions() + ): + """Reduce the scatter a list of tensors across the group. + + Args: + tensors: the output tensors (could be unspecified), each + located on CPU. + tensor_lists (List[List]): the list of tensors to be reduced then + scattered. + reducescatter_options: reduce-scatter options. + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, context): + size = gloo_util.get_tensor_n_elements(input_tensor) + world_size = self._gloo_context.size + pygloo.reduce_scatter( + context, + gloo_util.get_tensor_ptr(input_tensor), + gloo_util.get_tensor_ptr(output_tensor), + size, + [size // world_size for _ in range(world_size)], + gloo_util.get_gloo_tensor_dtype(output_tensor), + gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp), + ) + + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + input_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] + + def preprocess_fn(): + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + gloo_util.copy_tensor(input_flattened[i][j], tensor) + + self._collective( + input_flattened, tensors, collective_fn, preprocess_fn=preprocess_fn + ) + + def send(self, tensors, send_options=SendOptions()): + """Send a tensor to a destination rank in the group. + + Args: + tensors: the tensor to send. + send_options: send options. + + Returns: + None + """ + + def p2p_fn(tensor, context, peer): + pygloo.send( + context, + gloo_util.get_tensor_ptr(tensor), + gloo_util.get_tensor_n_elements(tensor), + gloo_util.get_gloo_tensor_dtype(tensor), + peer, + ) + + self._point2point(tensors, p2p_fn, send_options.dst_rank) + + def recv(self, tensors, recv_options=RecvOptions()): + """Receive a tensor from a source rank in the group. + + Args: + tensors: the received tensor. + recv_options: Receive options. + + Returns: + None + """ + + def p2p_fn(tensor, context, peer): + pygloo.recv( + context, + gloo_util.get_tensor_ptr(tensor), + gloo_util.get_tensor_n_elements(tensor), + gloo_util.get_gloo_tensor_dtype(tensor), + peer, + ) + + self._point2point(tensors, p2p_fn, recv_options.src_rank) + + def _collective( + self, + input_tensors, + output_tensors, + collective_fn, + preprocess_fn=None, + postprocess_fn=None, + ): + """A method to encapsulate all collective calls. + + Args: + input_tensors: the list of the input tensors. + output_tensors: the list of the output tensors. + collective_fn: the collective function call. + preprocess_fn: preprocess procedures before collective calls. + postprocess_fn: postprocess procedures after collective calls. + + Returns: + None + """ + _check_cpu_tensors(input_tensors) + _check_cpu_tensors(output_tensors) + + if preprocess_fn: + preprocess_fn() + collective_fn(input_tensors[0], output_tensors[0], self._gloo_context) + if postprocess_fn: + postprocess_fn() + + def _point2point(self, tensors, p2p_fn, peer_rank: int): + """A method to encapsulate all peer-to-peer calls (i.e., send/recv). + + Args: + tensors: the tensor to send or receive. + p2p_fn: the p2p function call. + peer_rank: the rank of the peer process. + + Returns: + None + """ + _check_cpu_tensors(tensors) + + p2p_fn(tensors[0], self._gloo_context, peer_rank) + + +def _check_cpu_tensors(tensors): + """Check only have one tensor and located on CPU.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("'tensors' must be a nonempty list.") + if len(tensors) != 1: + raise RuntimeError( + "Gloo only accept one tensor in the tensor list." + " Got {} != 1.".format(len(tensors)) + ) + d = gloo_util.get_tensor_device(tensors[0]) + if d != "cpu": + raise RuntimeError("Gloo only accept cpu tensor . Got {}.".format(d)) + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + """Flatten the tensor for gather/scatter operations. + + Args: + tensor_list: the list of tensors to be scattered/gathered. + copy: whether the copy the tensors in tensor_list into the buffer. + + Returns: + The flattened tensor buffer. + """ + if not tensor_list: + raise RuntimeError("Received an empty list.") + + t = tensor_list[0] + # note we need a numpy dtype here. + dtype = gloo_util.get_numpy_tensor_dtype(t) + buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t) + + buffer = numpy.empty(buffer_shape, dtype=dtype) + if copy: + for i, tensor in enumerate(tensor_list): + gloo_util.copy_tensor(buffer[i], tensor) + return buffer + + +def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): + """Check the compatibility between tensor input and tensor list input.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("The first argument 'tensors' expects a list of tensors.") + + if len(tensors) != 1: + raise RuntimeError( + "Gloo only accept one tensor in the first argument 'tensors'." + " Got {} != 1.".format(len(tensors)) + ) + + if not tensor_lists or not isinstance(tensor_lists, list): + raise RuntimeError( + "The second argument 'tensor_lists' expects a list of tensor list." + ) + + if len(tensor_lists) != 1: + raise RuntimeError( + "Gloo only accept one tensor list " + "in the second argument 'tensor_lists'." + " Got {} != 1.".format(len(tensor_lists)) + ) + + dtype = gloo_util.get_gloo_tensor_dtype(tensors[0]) + shape = gloo_util.get_tensor_shape(tensors[0]) + + # check all tensors in `tensor_lists` match. + for t in tensor_lists[0]: + # check dtype + dt = gloo_util.get_gloo_tensor_dtype(t) + if dt != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'.".format(dt, dtype) + ) + s = gloo_util.get_tensor_shape(t) + if s != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same shape. Got '{}' and '{}'.".format(s, shape) + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_util.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_util.py new file mode 100644 index 0000000000000000000000000000000000000000..11e18d6fded31477e3600aa3798ee53cd6ee8b0f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/gloo_util.py @@ -0,0 +1,316 @@ +"""Code to wrap some GLOO API calls.""" +import asyncio +import time +from typing import List + +import numpy + +import ray +import ray.experimental.internal_kv as internal_kv +from ray._raylet import GcsClient +from ray.util.collective.types import ReduceOp, torch_available +from ray.util.queue import _QueueActor + +try: + import pygloo +except ImportError: + raise ImportError( + "Can not import pygloo. Please run 'pip install pygloo' to install pygloo." + ) + + +GLOO_REDUCE_OP_MAP = { + ReduceOp.SUM: pygloo.ReduceOp.SUM, + ReduceOp.PRODUCT: pygloo.ReduceOp.PRODUCT, + ReduceOp.MIN: pygloo.ReduceOp.MIN, + ReduceOp.MAX: pygloo.ReduceOp.MAX, +} + +NUMPY_GLOO_DTYPE_MAP = { + # INT types + numpy.int_: pygloo.glooDataType_t.glooInt64, + numpy.uint8: pygloo.glooDataType_t.glooUint8, + numpy.uint32: pygloo.glooDataType_t.glooUint32, + numpy.uint64: pygloo.glooDataType_t.glooUint64, + numpy.int8: pygloo.glooDataType_t.glooInt8, + numpy.int32: pygloo.glooDataType_t.glooInt32, + numpy.int64: pygloo.glooDataType_t.glooInt64, + # FLOAT types + numpy.half: pygloo.glooDataType_t.glooFloat16, + float: pygloo.glooDataType_t.glooFloat64, + numpy.float16: pygloo.glooDataType_t.glooFloat16, + numpy.float32: pygloo.glooDataType_t.glooFloat32, + numpy.float64: pygloo.glooDataType_t.glooFloat64, + numpy.double: pygloo.glooDataType_t.glooFloat64, +} + +if torch_available(): + import torch + + TORCH_GLOO_DTYPE_MAP = { + torch.int: pygloo.glooDataType_t.glooInt32, + torch.uint8: pygloo.glooDataType_t.glooUint8, + torch.int8: pygloo.glooDataType_t.glooInt8, + torch.int32: pygloo.glooDataType_t.glooInt32, + torch.int64: pygloo.glooDataType_t.glooInt64, + torch.long: pygloo.glooDataType_t.glooInt64, + # FLOAT types + torch.half: pygloo.glooDataType_t.glooFloat16, + torch.float: pygloo.glooDataType_t.glooFloat32, + torch.float16: pygloo.glooDataType_t.glooFloat16, + torch.float32: pygloo.glooDataType_t.glooFloat32, + torch.float64: pygloo.glooDataType_t.glooFloat64, + torch.double: pygloo.glooDataType_t.glooFloat64, + } + + TORCH_NUMPY_DTYPE_MAP = { + # INT types + torch.int: numpy.int32, + torch.uint8: numpy.uint8, + torch.int8: numpy.int8, + torch.int32: numpy.int32, + torch.int64: numpy.int64, + torch.long: numpy.int64, + # FLOAT types + torch.half: numpy.half, + torch.float: numpy.float32, + torch.float16: numpy.float16, + torch.float32: numpy.float32, + torch.float64: numpy.float64, + } + + +def create_gloo_context(rank, world_size): + """Create a GLOO context using GLOO APIs. + + Args: + rank: the rank of this process. + world_size: the number of processes of this collective group. + + Returns: + context (pygloo.Context): a GLOO context. + """ + context = pygloo.rendezvous.Context(rank, world_size) + return context + + +def get_gloo_reduce_op(reduce_op): + """Map the reduce op to GLOO reduce op type. + + Args: + reduce_op: ReduceOp Enum (SUM/PRODUCT/MIN/MAX). + + Returns: + (pygloo.ReduceOp): the mapped GLOO reduce op. + """ + if reduce_op not in GLOO_REDUCE_OP_MAP: + raise RuntimeError("Gloo does not support reduce op: '{}'.".format(reduce_op)) + return GLOO_REDUCE_OP_MAP[reduce_op] + + +def get_gloo_tensor_dtype(tensor): + """Return the corresponded GLOO dtype given a tensor.""" + if isinstance(tensor, numpy.ndarray): + return NUMPY_GLOO_DTYPE_MAP[tensor.dtype.type] + if torch_available(): + if isinstance(tensor, torch.Tensor): + if not tensor.is_cuda: + return TORCH_GLOO_DTYPE_MAP[tensor.dtype] + else: + raise ValueError( + "Expect torch CPU tensor. Got {}.".format(tensor.device) + ) + raise ValueError("Unsupported tensor type. Got: {}.".format(type(tensor))) + + +def get_numpy_tensor_dtype(tensor): + """Return the corresponded Cupy dtype given a tensor.""" + if isinstance(tensor, numpy.ndarray): + return tensor.dtype.type + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "CPU tensor types are: torch.Tensor, " + "numpy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_ptr(tensor): + """Return the pointer to the underlying memory storage of a tensor.""" + if isinstance(tensor, numpy.ndarray): + return tensor.ctypes.data + if torch_available(): + if isinstance(tensor, torch.Tensor): + if tensor.is_cuda: + raise RuntimeError( + "Torch tensor must be on CPU when using GLOO collectives." + ) + return tensor.data_ptr() + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "CPU tensor types are: torch.Tensor, " + "numpy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_n_elements(tensor): + """Return the number of elements in a tensor.""" + if isinstance(tensor, numpy.ndarray): + return tensor.size + if torch_available(): + if isinstance(tensor, torch.Tensor): + return torch.numel(tensor) + raise ValueError("Unsupported tensor type. Got: {}.".format(type(tensor))) + + +def get_gloo_store_path(store_name): + from ray._private.utils import get_ray_temp_dir + + store_path = f"{get_ray_temp_dir()}_collective/gloo/{store_name}" + return store_path + + +def get_tensor_device(tensor): + if isinstance(tensor, numpy.ndarray): + return "cpu" + elif torch_available() and isinstance(tensor, torch.Tensor): + if not tensor.is_cuda: + return "cpu" + else: + return "cuda" + else: + raise RuntimeError("Unrecognized tensor type: '{}'.".format(type(tensor))) + + +def get_tensor_shape(tensor): + """Return the shape of the tensor as a list.""" + if isinstance(tensor, numpy.ndarray): + return list(tensor.shape) + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.size()) + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "CPU tensor types are: torch.Tensor, " + "numpy.ndarray.".format(type(tensor)) + ) + + +def copy_tensor(dst_tensor, src_tensor): + """Copy the content from src_tensor to dst_tensor. + + Args: + dst_tensor: the tensor to copy from. + src_tensor: the tensor to copy to. + + Returns: + None + """ + copied = True + if isinstance(dst_tensor, numpy.ndarray) and isinstance(src_tensor, numpy.ndarray): + numpy.copyto(dst_tensor, src_tensor) + elif torch_available(): + if isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, torch.Tensor + ): + dst_tensor.copy_(src_tensor) + elif isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, numpy.ndarray + ): + t = torch.Tensor(src_tensor) + dst_tensor.copy_(t) + elif isinstance(dst_tensor, numpy.ndarray) and isinstance( + src_tensor, torch.Tensor + ): + t = src_tensor.numpy() + numpy.copyto(dst_tensor, t) + else: + copied = False + else: + copied = False + if not copied: + raise ValueError( + "Unsupported tensor type. Got: {} and {}. Supported " + "CPU tensor types are: torch.Tensor, numpy.ndarray.".format( + type(dst_tensor), type(src_tensor) + ) + ) + + +# Note(Hao): this requires Ray >= 1.2.0, +# otherwise _QueueActor is an actor class. +class glooQueue(_QueueActor): + def index(self, group_name): + try: + return self.queue._queue.index(group_name) + except ValueError: + return -1 + + +@ray.remote(num_cpus=0) +class SignalActor: + def __init__(self, world_size): + self.ready_events = [asyncio.Event() for _ in range(world_size)] + self.world_size = world_size + + def send(self, rank, clear=False): + self.ready_events[rank].set() + if clear: + self.ready_events[rank].clear() + + async def wait(self, should_wait=True): + if should_wait: + for i in range(self.world_size): + await self.ready_events[i].wait() + + +# The custom store which is implementated in Ray internal kv storage, helping +# to store the rank meta information when setting up the gloo collective group. +class RayInternalKvStore: + def __init__(self, group_name: str): + self._group_name = group_name + self._job_id = ray.get_runtime_context().get_job_id() + gcs_address = ray._private.worker._global_node.gcs_address + self._gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=10) + internal_kv._initialize_internal_kv(self._gcs_client) + + def set(self, key: str, data: bytes) -> bool: + key = self.__concat_key_with_prefixes(key) + ret = internal_kv._internal_kv_put(key, data) + return ret + + def get(self, key: str) -> bytes: + key = self.__concat_key_with_prefixes(key) + ret = internal_kv._internal_kv_get(key) + return ret + + def delete(self, key: str) -> int: + key = self.__concat_key_with_prefixes(key) + ret = internal_kv._internal_kv_del(key) + return ret + + def del_keys(self, keys: List[str]) -> List[int]: + results = [] + for key in keys: + results.append(self.delete(key)) + return results + + def wait(self, keys: List[str]): + while True: + all_exist = True + for key in keys: + key = self.__concat_key_with_prefixes(key) + result = internal_kv._internal_kv_exists(key) + if not result: + all_exist = False + break + if all_exist: + return True + time.sleep(1) + + def __concat_key_with_prefixes(self, original_key): + """Concat the necessary prefixes and key for isolation purpose for + different jobs and different groups.""" + return f"{self._job_id}-{self._group_name}-{original_key}" diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_collective_group.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_collective_group.py new file mode 100644 index 0000000000000000000000000000000000000000..9c21b936d898ffdb7c45f27b3b01d36e2580c8e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_collective_group.py @@ -0,0 +1,830 @@ +import logging +import datetime +import time + +import ray +import cupy + +from ray.util.collective.const import ENV +from ray.util.collective.collective_group import nccl_util +from ray.util.collective.collective_group.base_collective_group import BaseGroup +from ray.util.collective.const import get_store_name +from ray.util.collective.types import ( + AllReduceOptions, + BarrierOptions, + Backend, + ReduceOptions, + BroadcastOptions, + AllGatherOptions, + ReduceScatterOptions, + SendOptions, + RecvOptions, + torch_available, +) +from ray.util.collective.collective_group.cuda_stream import get_stream_pool + +logger = logging.getLogger(__name__) + + +class Rendezvous: + """A rendezvous class for different actor/task processes to meet. + + To initialize an NCCL collective communication group, different + actors/tasks spawned in Ray in a collective group needs to meet + each other to synchronize the NCCLUniqueID. This class guarantees + they meet via the NCCLUniqueIDStore, initialized on the rank=0 + process. + + Args: + store_key: the unique store key, usually as a concatanation + of group_name and communicator key. See `get_nccl_communicator` + for more details. + """ + + def __init__(self, store_key): + if not store_key: + raise ValueError( + "Invalid store_key. The store_key is a concatenation of " + "'group_name' and the 'communicator_key'. See the " + "docstring of `get_nccl_communicator` for details." + ) + self._store_key = store_key + self._store_name = None + self._store = None + + def meet(self, timeout_s=180): + """Meet at the named actor store. + + Args: + timeout_s: timeout in seconds. + + Return: + None + """ + if timeout_s <= 0: + raise ValueError( + "The 'timeout' argument must be positive. " + "Got '{}'.".format(timeout_s) + ) + self._store_name = get_store_name(self._store_key) + timeout_delta = datetime.timedelta(seconds=timeout_s) + elapsed = datetime.timedelta(seconds=0) + start_time = datetime.datetime.now() + while elapsed < timeout_delta: + try: + logger.debug( + "Trying to meet at the store '{}'".format(self._store_name) + ) + self._store = ray.get_actor(self._store_name) + except ValueError: + logger.debug( + "Failed to meet at the store '{}'." + "Trying again...".format(self._store_name) + ) + time.sleep(1) + elapsed = datetime.datetime.now() - start_time + continue + logger.debug("Successful rendezvous!") + break + if not self._store: + raise RuntimeError( + "Unable to meet other processes " + "at the rendezvous store. If you are using " + "P2P communication, please check if tensors " + "are put in the correct GPU. " + ) + + @property + def store(self): + return self._store + + def get_nccl_id(self, timeout_s=180): + """Get the NCCLUniqueID from the store through Ray. + + Args: + timeout_s: timeout in seconds. + + Return: + uid: the NCCLUniqueID if successful. + """ + if not self._store: + raise ValueError("Rendezvous store is not setup.") + uid = None + timeout_delta = datetime.timedelta(seconds=timeout_s) + elapsed = datetime.timedelta(seconds=0) + start_time = datetime.datetime.now() + while elapsed < timeout_delta: + uid = ray.get(self._store.get_id.remote()) + if not uid: + time.sleep(1) + elapsed = datetime.datetime.now() - start_time + continue + break + if not uid: + raise RuntimeError("Unable to get the NCCLUniqueID from the store.") + return uid + + +class NCCLGroup(BaseGroup): + def __init__(self, world_size, rank, group_name): + """Init an NCCL collective group.""" + super(NCCLGroup, self).__init__(world_size, rank, group_name) + + # communicator and stream cache. + # TODO (Hao): we need a lock here... + self._dev_comm_map = {} + self._dev_streams_map = {} + + # record the used GPU IDs. + self._used_gpu_indices = set() + + # TODO(Fu): might need an event map + self._dev_event_map = {} + + if nccl_util.get_nccl_build_version() < 2000: + raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.") + if nccl_util.get_nccl_runtime_version() < 2704: + logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") + + def destroy_group(self): + """Destroy the group and release NCCL communicators.""" + if len(self._dev_comm_map.keys()) > 0: + # TODO(Hao): check this barrier call + # self.barrier() + + # Destroy the communicators and streams. + for comm_key, comms in self._dev_comm_map.items(): + for c in comms: + c.destroy() + self._dev_comm_map[comm_key] = None + + if self.rank == 0: + for comm_key in self._dev_comm_map: + assert not self._dev_comm_map[comm_key] + group_key = self._generate_group_key(comm_key) + self._destroy_store(group_key) + self._barrier_tensor = None + self._dev_comm_map = None + self._dev_streams_map = None + super(NCCLGroup, self).destroy_group() + + @classmethod + def backend(cls): + return Backend.NCCL + + def allreduce(self, tensors, allreduce_options=AllReduceOptions()): + """AllReduce tensors across the collective group following options. + + Args: + tensors: the list of tensors to be reduced. Each tensor must + reside on one GPU of the current process. + allreduce_options: allreduce options. + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, comm, stream): + comm.allReduce( + nccl_util.get_tensor_ptr(input_tensor), + nccl_util.get_tensor_ptr(output_tensor), + nccl_util.get_tensor_n_elements(input_tensor), + nccl_util.get_nccl_tensor_dtype(input_tensor), + nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp), + stream.ptr, + ) + + self._collective(tensors, tensors, collective_fn) + + def barrier(self, barrier_options=BarrierOptions()): + """Blocks until all processes reach this barrier. + + Args: + barrier_options: barrier options. + + Returns: + None + """ + # Get the device list. + if self._used_gpu_indices: + devices = list(self._used_gpu_indices) + else: + devices = list(range(nccl_util.get_num_gpus())) + barrier_tensors = [None] * len(devices) + for i, d in enumerate(devices): + with nccl_util.Device(d): + barrier_tensors[i] = cupy.array([1]) + self.allreduce(barrier_tensors) + + def reduce(self, tensors, reduce_options=ReduceOptions()): + """Reduce tensors to a destination gpu following options. + + Args: + tensors: the list of tensors to be reduced, each tensor + must reside on one gpu of the current process. + reduce_options: reduce options. + + Returns: + None + """ + root_rank = len(tensors) * reduce_options.root_rank + reduce_options.root_tensor + + def collective_fn(input_tensor, output_tensor, comm, stream): + comm.reduce( + nccl_util.get_tensor_ptr(input_tensor), + nccl_util.get_tensor_ptr(output_tensor), + nccl_util.get_tensor_n_elements(input_tensor), + nccl_util.get_nccl_tensor_dtype(input_tensor), + nccl_util.get_nccl_reduce_op(reduce_options.reduceOp), + root_rank, + stream.ptr, + ) + + self._collective(tensors, tensors, collective_fn) + + def broadcast(self, tensors, broadcast_options=BroadcastOptions()): + """Broadcast tensors to all other gpus following options. + + Args: + tensors: tensors to be broadcast or received. + broadcast_options: broadcast options. + + Returns: + None + """ + root_rank = ( + len(tensors) * broadcast_options.root_rank + broadcast_options.root_tensor + ) + + def collective_fn(input_tensor, output_tensor, comm, stream): + comm.broadcast( + nccl_util.get_tensor_ptr(input_tensor), + nccl_util.get_tensor_ptr(output_tensor), + nccl_util.get_tensor_n_elements(input_tensor), + nccl_util.get_nccl_tensor_dtype(input_tensor), + root_rank, + stream.ptr, + ) + + self._collective(tensors, tensors, collective_fn) + + def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): + """Allgather tensors across gpus into a list of tensors. + + Args: + tensor_lists (List[List[Tensor]]): allgathered tensors. + tensors: the list of tensors to allgather across the group. + Each tensor must lolcate on a GPU of the process. + allgather_options: allgather options. + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, comm, stream): + comm.allGather( + nccl_util.get_tensor_ptr(input_tensor), + nccl_util.get_tensor_ptr(output_tensor), + nccl_util.get_tensor_n_elements(input_tensor), + nccl_util.get_nccl_tensor_dtype(input_tensor), + stream.ptr, + ) + + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + output_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] + + def postprocess_fn(stream): + # TODO(Hao): designate a copy stream. + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(tensor, output_flattened[i][j]) + + self._collective( + tensors, output_flattened, collective_fn, postprocess_fn=postprocess_fn + ) + + def reducescatter( + self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions() + ): + """Reduce then scatter a list of tensors across the group. + + Args: + tensors: the output tensors (could be unspecified), each + located on a GPU of the current process. + tensor_lists (List[List]): the list of tensors to be reduced then + scattered. + reducescatter_options: reduce-scatter options. + + Returns: + None + """ + + def collective_fn(input_tensor, output_tensor, comm, stream): + comm.reduceScatter( + nccl_util.get_tensor_ptr(input_tensor), + nccl_util.get_tensor_ptr(output_tensor), + nccl_util.get_tensor_n_elements(output_tensor), + nccl_util.get_nccl_tensor_dtype(output_tensor), + nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp), + stream.ptr, + ) + + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + input_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] + + def preprocess_fn(stream): + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(input_flattened[i][j], tensor) + + self._collective( + input_flattened, tensors, collective_fn, preprocess_fn=preprocess_fn + ) + + def send(self, tensors, send_options=SendOptions()): + """Send a tensor to a destination gpu in the group. + + Args: + tensors: the tensor to send. + send_options: send options. + + Returns: + None + """ + + def p2p_fn(tensor, comm, stream, peer): + comm.send( + nccl_util.get_tensor_ptr(tensor), + send_options.n_elements + if send_options.n_elements > 0 + else nccl_util.get_tensor_n_elements(tensor), + nccl_util.get_nccl_tensor_dtype(tensor), + peer, + stream.ptr, + ) + + self._point2point( + tensors, p2p_fn, send_options.dst_rank, send_options.dst_gpu_index + ) + + def recv(self, tensors, recv_options=RecvOptions()): + """Receive a tensor from a source gpu in the group. + + Args: + tensors: the received tensor. + recv_options: Receive options. + + Returns: + None + """ + + def p2p_fn(tensor, comm, stream, peer): + comm.recv( + nccl_util.get_tensor_ptr(tensor), + recv_options.n_elements + if recv_options.n_elements > 0 + else nccl_util.get_tensor_n_elements(tensor), + nccl_util.get_nccl_tensor_dtype(tensor), + peer, + stream.ptr, + ) + + self._point2point( + tensors, p2p_fn, recv_options.src_rank, recv_options.src_gpu_index + ) + + def _get_nccl_collective_communicator(self, comm_key, device_list): + """Create or retrieve an NCCL communicator from cache. + + If the communicator is found in cache, return the communicator. If not, + a communicator and a stream will be created and put in cache. + TODO(Hao): this function is not thread-safe now. + + Args: + comm_key: the key to query the communicator cache. + device_list: a list of GPU devices of the current process + that participates into the collective. + + Returns: + communicator: the NCCL communicator corresponded to the devices. + """ + if not comm_key: + raise RuntimeError("Got empty communicator key.") + for d in device_list: + self._used_gpu_indices.add(d) + + # TODO(Hao): lock the _dev_comm_map here. + if comm_key in self._dev_comm_map: + return self._dev_comm_map[comm_key] + + group_key = self._generate_group_key(comm_key) + if self.rank == 0: + nccl_uid = self._generate_nccl_uid(group_key) + else: + rendezvous = Rendezvous(group_key) + rendezvous.meet() + nccl_uid = rendezvous.get_nccl_id() + + # Now create the communicators + actual_world_size = len(device_list) * self.world_size + comms = [None] * len(device_list) + streams = [None] * len(device_list) + events = [None] * len(device_list) + nccl_util.groupStart() + for i, device in enumerate(device_list): + actual_rank = self.rank * len(device_list) + i + with nccl_util.Device(device): + comms[i] = nccl_util.create_nccl_communicator( + actual_world_size, nccl_uid, actual_rank + ) + # request a stream from the pool + # note the device_idx is absolute index. + streams[i] = get_stream_pool(device).get_stream() + # TODO(Fu): double check the parameters + events[i] = cupy.cuda.Event() + nccl_util.groupEnd() + # TODO(Fu): lock + self._dev_comm_map[comm_key] = comms + self._dev_streams_map[comm_key] = streams + self._dev_event_map[comm_key] = events + return comms + + @staticmethod + def _sync_streams(device_list, events, streams): + """Let NCCL streams wait for current streams for every device.""" + # TODO(Fu): recordStream besides calling this function? + if ENV.NCCL_USE_MULTISTREAM.val: + for i, device in enumerate(device_list): + with nccl_util.Device(device): + events[i].record(cupy.cuda.get_current_stream()) + streams[i].wait_event(events[i]) + + def _get_nccl_p2p_communicator(self, comm_key, my_gpu_idx, peer_rank, peer_gpu_idx): + """Create or retrieve an NCCL communicator for p2p tasks. + + Note(Hao): this function is not thread-safe now. + + Args: + comm_key: communicator key. + my_gpu_idx: the gpu index on the current process. + peer_rank: the rank of the destination process. + peer_gpu_idx: the gpu index on the peer process. + Returns: + communicator + """ + if not comm_key: + raise RuntimeError("Got empty communicator key.") + + # TODO(Hao): lock the _dev_comm_map here. + if comm_key in self._dev_comm_map: + return self._dev_comm_map[comm_key] + + # Note (Hao): This is a bit complex so I decide to take a note here. + # Here we need to consider three cases: + # Case 1: src_rank != dst_rank, hence the send and recv happen on + # different process (actors/tasks); each process makes independent + # collective calls and manages corresponding communicators. + # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for + # this case, we simply throw a RuntimeError; + # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which + # means the send and recv will be called on the same process. We + # DO NOT support this case for now. We need to properly scope: + # (1) communicators creation, and + # (2) send/recv calls + # using groupStart(( and groupEnd() calls to avoid deadlocks. + if self.rank < peer_rank: + my_p2p_rank = 0 + elif self.rank > peer_rank: + my_p2p_rank = 1 + else: + raise RuntimeError( + "Send and recv happens on the same process! " + "ray.util.collective does not support this case as of now. " + "Alternatively, consider doing GPU to GPU memcpy?" + ) + + group_key = self._generate_group_key(comm_key) + if my_p2p_rank == 0: + nccl_uid = self._generate_nccl_uid(group_key) + else: + rendezvous = Rendezvous(group_key) + rendezvous.meet() + nccl_uid = rendezvous.get_nccl_id() + + # create the p2p communicators + with nccl_util.Device(my_gpu_idx): + comm = nccl_util.create_nccl_communicator(2, nccl_uid, my_p2p_rank) + stream = get_stream_pool(my_gpu_idx).get_stream() + event = cupy.cuda.Event() + + # TODO(Fu): lock and might need to add event + self._dev_comm_map[comm_key] = [comm] + self._dev_streams_map[comm_key] = [stream] + self._dev_event_map[comm_key] = [event] + return [comm] + + def _generate_group_key(self, comm_key): + """Generate a unique key used to initialize the KV store. + + The group key is a concatenation of the communicator key and + the group name, following: [comm_key]@[group_name]. + """ + return comm_key + "@" + self.group_name + + @staticmethod + def _destroy_store(group_key): + """Destroy the KV store (Ray named actor). + + Args: + group_key: the unique key to retrieve the KV store. + + Returns: + None + """ + store_name = get_store_name(group_key) + store = ray.get_actor(store_name) + # ray.get([store.__ray_terminate__.remote()]) + ray.kill(store) + + def _generate_nccl_uid(self, key): + """Generate an NCCL unique ID for initializing communicators. + + The method will also create a KV store using Ray named actor and store + the NCCLUniqueID in the store. The store needs to be garbage collected + when destroying the collective group. + + Args: + key: the key of the . + + Returns: + NCCLUniqueID (str): NCCL unique ID. + """ + group_uid = nccl_util.get_nccl_unique_id() + store_name = get_store_name(key) + # Avoid a potential circular dependency in ray/actor.py + from ray.util.collective.util import NCCLUniqueIDStore + + store = NCCLUniqueIDStore.options(name=store_name, lifetime="detached").remote( + store_name + ) + ray.get([store.set_id.remote(group_uid)]) + return group_uid + + def _collective( + self, + input_tensors, + output_tensors, + collective_fn, + preprocess_fn=None, + postprocess_fn=None, + ): + """A method to encapsulate all collective calls. + + Args: + input_tensors: the list of the input tensors. + output_tensors: the list of the output tensors. + collective_fn: the collective function call. + preprocess_fn: preprocess procedures before collective calls. + postprocess_fn: postprocess procedures after collective calls. + + Returns: + None + """ + _check_gpu_tensors(input_tensors) + _check_gpu_tensors(output_tensors) + + devices = nccl_util.get_tensor_device_list(input_tensors) + key = _get_comm_key_from_devices(devices) + comms = self._get_nccl_collective_communicator(key, devices) + streams = self._dev_streams_map[key] + events = self._dev_event_map[key] + + # TODO(Hao): sync streams and events + self._sync_streams(devices, events, streams) + + # Make the collective call + if preprocess_fn: + preprocess_fn(streams) + + nccl_util.groupStart() + # TODO(Fu): how to recordStreams as there are no library functions + # We also need to make sure input tensors are not freed before their + # usages on ncclStreams finish. This can be achieved by calling + # c10::cuda::CUDACachingAllocator::recordStream, which remembers the + # usage stream (ncclStream), creates an event on the usage stream + # when GC attempts to free the input tensor, and delays GC until that + # event is done. + for i, tensor in enumerate(input_tensors): + collective_fn(tensor, output_tensors[i], comms[i], streams[i]) + nccl_util.groupEnd() + if postprocess_fn: + postprocess_fn(streams) + + def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int): + """A method to encapsulate all peer-to-peer calls (i.e., send/recv). + + Args: + tensors: the tensor to send or receive. + p2p_fn: the p2p function call. + peer_rank: the rank of the peer process. + peer_gpu_idx: the index of the gpu on the peer process. + + Returns: + None + """ + # check send/recv availability. + if nccl_util.get_nccl_runtime_version() < 2704: + raise RuntimeError( + "P2p send/recv requires NCCL >= 2.7.4. " + "Got '{}'.".format(nccl_util.get_nccl_runtime_version()) + ) + _check_gpu_tensors(tensors) + + # we currently only support single device to single device send/recv. + assert len(tensors) == 1 + my_gpu_idx = nccl_util.get_tensor_device(tensors[0]) + comm_key = _get_comm_key_send_recv( + self.rank, my_gpu_idx, peer_rank, peer_gpu_idx + ) + comms = self._get_nccl_p2p_communicator( + comm_key, my_gpu_idx, peer_rank, peer_gpu_idx + ) + streams = self._dev_streams_map[comm_key] + events = self._dev_event_map[comm_key] + + # TODO(Hao): sync streams and events + self._sync_streams([my_gpu_idx], events, streams) + + # We have made sure that self.rank != peer_rank during API check. + peer_p2p_rank = 0 if self.rank > peer_rank else 1 + for i, tensor in enumerate(tensors): + p2p_fn(tensors[i], comms[i], streams[i], peer_p2p_rank) + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + """Flatten the tensor for gather/scatter operations. + + Args: + tensor_list: the list of tensors to be scattered/gathered. + copy: whether the copy the tensors in tensor_list into the buffer. + + Returns: + The flattened tensor buffer. + """ + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) + + # TODO(wuxibin): cupy doesn't support bfloat16 for now, + # once it is supported, we can eliminate this if statement. + if torch_available(): + import torch + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + else: + # note we need a cupy dtype here. + dtype = nccl_util.get_cupy_tensor_dtype(t) + device = nccl_util.get_tensor_device(t) + with nccl_util.Device(device): + buffer = cupy.empty(buffer_shape, dtype=dtype) + + if copy: + for i, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(buffer[i], tensor) + return buffer + + +def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): + """Check the compatibility between tensor input and tensor list input.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("The first argument 'tensors' expects a list of tensors.") + if not tensor_lists or not isinstance(tensor_lists, list): + raise RuntimeError( + "The second argument 'tensor_lists' expects a list of tensor list." + ) + dtype = nccl_util.get_nccl_tensor_dtype(tensors[0]) + shape = nccl_util.get_tensor_shape(tensors[0]) + for i, tensor_list in enumerate(tensor_lists): + # check all tensor in `tensors` match. + dt = nccl_util.get_nccl_tensor_dtype(tensors[i]) + if dt != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'.".format(dt, dtype) + ) + # Note: typically CCL libraries only requires they have the same + # number of elements; Here we make it more strict -- we require + # exact shape match. + s = nccl_util.get_tensor_shape(tensors[i]) + if s != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same shape. Got '{}' and '{}'.".format(s, shape) + ) + # check all tensors in `tensor_lists` match. + for t in tensor_lists[i]: + # check dtype + dt = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'.".format(dt, dtype) + ) + s = nccl_util.get_tensor_shape(t) + if s != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same shape. Got '{}' and '{}'.".format(s, shape) + ) + + +def _check_gpu_tensors(tensors): + """Check all tensors are distributed on different GPUs.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("'tensors' must be a nonempty list.") + if len(tensors) > nccl_util.get_num_gpus(): + raise RuntimeError( + "Tensor list cannot be larger than the number" + "of available GPUs. Got {} > {}.".format( + len(tensors), nccl_util.get_num_gpus() + ) + ) + t0 = tensors[0] + dt = nccl_util.get_nccl_tensor_dtype(t0) + s = nccl_util.get_tensor_shape(t0) + d = nccl_util.get_tensor_device(t0) + for i, t in enumerate(tensors): + if i == 0: + continue + # We need to check the following: + # (1) tensor is cuda (already checked during API) + # (2) tensor dtype + # (3) tensor shape match + # (4) each tensor is on a different GPU + dtype = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError( + "Tensors must have identical dtype. Got: '{}'.".format(dtype) + ) + shape = nccl_util.get_tensor_shape(t) + if s != shape: + raise RuntimeError( + "Tensor must have identical shape. Got: '{}'.".format(shape) + ) + device = nccl_util.get_tensor_device(t) + if device == d: + raise RuntimeError("Tensor must be on distinct GPUs.") + + +def _get_comm_key_from_devices(devices): + """Return a key from a list of devices for collective calls. + + For example, if the tensors are on gpus 0, 1, 2, 3, + then the key would be "0,1,2,3". + + Args: + devices: a list of GPU device indices + + Returns: + str: a string represents the key to query the communicator cache. + + """ + return ",".join([str(d) for d in devices]) + + +def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): + """Return a key given source and destination ranks for p2p tasks. + + The p2p key is in the following form: + [min_rank]_[gpu_index]:[max_rank]_[gpu_index]. + + Args: + my_rank: the rank of the source process. + my_gpu_idx: the source gpu index on the process. + peer_rank: the rank of the destination process. + peer_gpu_idx: the destination gpu index on the process. + + Returns: + comm_key: a string key to query the communication cache. + """ + if my_rank < peer_rank: + lower_key = str(my_rank) + "_" + str(my_gpu_idx) + higher_key = str(peer_rank) + "_" + str(peer_gpu_idx) + elif my_rank > peer_rank: + lower_key = str(peer_rank) + "_" + str(peer_gpu_idx) + higher_key = str(my_rank) + "_" + str(my_gpu_idx) + else: + raise RuntimeError( + "Send and recv happens on the same process. ray.util.collective " + "does not support this case as of now. Alternatively, consider " + "doing GPU to GPU memcpy?" + ) + comm_key = lower_key + ":" + higher_key + return comm_key diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_util.py b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_util.py new file mode 100644 index 0000000000000000000000000000000000000000..221d5885c411fef7481ab7ffaa2d2c86ab35f7f0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/collective_group/nccl_util.py @@ -0,0 +1,293 @@ +"""Code to wrap some NCCL API calls.""" +import numpy + +try: + import cupy + from cupy.cuda import nccl + from cupy.cuda import Device # noqa: F401 + from cupy.cuda.nccl import get_version + from cupy.cuda.nccl import get_build_version + from cupy.cuda.nccl import NcclCommunicator + from cupy.cuda.nccl import groupStart # noqa: F401 + from cupy.cuda.nccl import groupEnd # noqa: F401 +except ImportError: + raise ImportError("NCCL in Ray requires Cupy being available!") + +from ray.util.collective.types import ReduceOp, torch_available + +NCCL_REDUCE_OP_MAP = { + ReduceOp.SUM: nccl.NCCL_SUM, + ReduceOp.PRODUCT: nccl.NCCL_PROD, + ReduceOp.MIN: nccl.NCCL_MIN, + ReduceOp.MAX: nccl.NCCL_MAX, +} + +# cupy types are the same with numpy types +NUMPY_NCCL_DTYPE_MAP = { + # INT types + numpy.int_: nccl.NCCL_INT64, + numpy.uint8: nccl.NCCL_UINT8, + numpy.uint32: nccl.NCCL_UINT32, + numpy.uint64: nccl.NCCL_UINT64, + numpy.int8: nccl.NCCL_INT8, + numpy.int32: nccl.NCCL_INT32, + numpy.int64: nccl.NCCL_INT64, + # FLOAT types + numpy.half: nccl.NCCL_HALF, + numpy.float16: nccl.NCCL_FLOAT16, + numpy.float32: nccl.NCCL_FLOAT32, + numpy.float64: nccl.NCCL_FLOAT64, + numpy.double: nccl.NCCL_DOUBLE, +} + +if torch_available(): + import torch + import torch.utils.dlpack + + TORCH_NCCL_DTYPE_MAP = { + torch.bool: nccl.NCCL_INT8, + # INT types + torch.int: nccl.NCCL_INT, + torch.uint8: nccl.NCCL_UINT8, + torch.int8: nccl.NCCL_INT8, + torch.int32: nccl.NCCL_INT32, + torch.int64: nccl.NCCL_INT64, + torch.long: nccl.NCCL_INT64, + # FLOAT types + torch.half: nccl.NCCL_HALF, + torch.float: nccl.NCCL_FLOAT, + torch.float16: nccl.NCCL_FLOAT16, + torch.float32: nccl.NCCL_FLOAT32, + torch.float64: nccl.NCCL_FLOAT64, + torch.double: nccl.NCCL_DOUBLE, + } + + # Older versions of cupy don't support bfloat16. + if hasattr(nccl, "NCCL_BFLOAT16"): + TORCH_NCCL_DTYPE_MAP[torch.bfloat16] = nccl.NCCL_BFLOAT16 + + TORCH_NUMPY_DTYPE_MAP = { + # INT types + torch.int: numpy.int32, + torch.uint8: numpy.uint8, + torch.int8: numpy.int8, + torch.int32: numpy.int32, + torch.int64: numpy.int64, + torch.long: numpy.int64, + # FLOAT types + torch.half: numpy.half, + torch.float: numpy.float32, + torch.float16: numpy.float16, + torch.float32: numpy.float32, + torch.float64: numpy.float64, + } + + +def get_num_gpus(): + """Returns the number of compute-capable GPUs.""" + return cupy.cuda.runtime.getDeviceCount() + + +def get_nccl_build_version(): + return get_build_version() + + +def get_nccl_runtime_version(): + return get_version() + + +def get_nccl_unique_id(): + return nccl.get_unique_id() + + +def create_nccl_communicator(world_size, nccl_unique_id, rank): + """Create an NCCL communicator using NCCL APIs. + + Args: + world_size: the number of processes of this communicator group. + nccl_unique_id: the NCCLUniqueID for this group. + rank: the rank of this process. + Returns: + comm (nccl.ncclComm_t): an NCCL communicator. + """ + comm = NcclCommunicator(world_size, nccl_unique_id, rank) + return comm + + +def get_nccl_reduce_op(reduce_op): + """Map the reduce op to NCCL reduce op type. + + Args: + reduce_op: ReduceOp Enum (SUM/PRODUCT/MIN/MAX). + Returns: + (nccl.ncclRedOp_t): the mapped NCCL reduce op. + """ + if reduce_op not in NCCL_REDUCE_OP_MAP: + raise RuntimeError("NCCL does not support reduce op: '{}'.".format(reduce_op)) + return NCCL_REDUCE_OP_MAP[reduce_op] + + +def get_nccl_tensor_dtype(tensor): + """Return the corresponded NCCL dtype given a tensor.""" + if isinstance(tensor, cupy.ndarray): + return NUMPY_NCCL_DTYPE_MAP[tensor.dtype.type] + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NCCL_DTYPE_MAP[tensor.dtype] + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_cupy_tensor_dtype(tensor): + """Return the corresponded Cupy dtype given a tensor.""" + if isinstance(tensor, cupy.ndarray): + return tensor.dtype.type + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_ptr(tensor): + """Return the pointer to the underlying memory storage of a tensor.""" + if isinstance(tensor, cupy.ndarray): + return tensor.data.ptr + if isinstance(tensor, numpy.ndarray): + return tensor.data + if torch_available(): + if isinstance(tensor, torch.Tensor): + if not tensor.is_cuda: + raise RuntimeError( + "Torch tensor must be on GPU when using NCCL collectives." + ) + return tensor.data_ptr() + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_n_elements(tensor): + """Return the number of elements in a tensor.""" + if isinstance(tensor, cupy.ndarray) or isinstance(tensor, numpy.ndarray): + return tensor.size + if torch_available(): + if isinstance(tensor, torch.Tensor): + return torch.numel(tensor) + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_shape(tensor): + """Return the shape of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return list(tensor.shape) + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.size()) + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_strides(tensor): + """Return the strides of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return [int(stride / tensor.dtype.itemsize) for stride in tensor.strides] + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.stride()) + raise ValueError( + "Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor)) + ) + + +def get_tensor_device(tensor): + """Return the GPU index of a tensor.""" + if isinstance(tensor, cupy.ndarray): + try: + device = tensor.device.id + except AttributeError as exec: + raise RuntimeError("The tensor is not on a valid GPU.") from exec + elif torch_available() and isinstance(tensor, torch.Tensor): + device = tensor.device.index + if not isinstance(device, int): + raise RuntimeError("The tensor is not on a valid GPU.") + else: + raise ValueError("Unsupported tensor type. Got: {}.".format(type(tensor))) + return device + + +def copy_tensor(dst_tensor, src_tensor): + """Copy the content from src_tensor to dst_tensor. + + Args: + dst_tensor: the tensor to copy from. + src_tensor: the tensor to copy to. + + Returns: + None + """ + copied = True + if isinstance(dst_tensor, cupy.ndarray) and isinstance(src_tensor, cupy.ndarray): + cupy.copyto(dst_tensor, src_tensor) + elif torch_available(): + if isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, torch.Tensor + ): + dst_tensor.copy_(src_tensor) + elif isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, cupy.ndarray + ): + t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack()) + dst_tensor.copy_(t) + elif isinstance(dst_tensor, cupy.ndarray) and isinstance( + src_tensor, torch.Tensor + ): + t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor)) + cupy.copyto(dst_tensor, t) + else: + copied = False + else: + copied = False + if not copied: + raise ValueError( + "Unsupported tensor type. Got: {} and {}. Supported " + "GPU tensor types are: torch.Tensor, cupy.ndarray.".format( + type(dst_tensor), type(src_tensor) + ) + ) + + +def get_tensor_device_list(tensors): + """Returns the gpu devices of the list of input tensors. + + Args: + tensors: a list of tensors, each locates on a GPU. + + Returns: + list: the list of GPU devices. + + """ + if not isinstance(tensors, list): + raise RuntimeError( + "Expect a list of tensors each locates on a GPU device. " + "Got: '{}'.".format(type(tensors)) + ) + devices = [get_tensor_device(t) for t in tensors] + return devices diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/const.py b/.venv/lib/python3.11/site-packages/ray/util/collective/const.py new file mode 100644 index 0000000000000000000000000000000000000000..35a11d23abbf4f3708acc3c76cb4f9dde93e10cc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/const.py @@ -0,0 +1,34 @@ +""" +Constants. + +Contains constants used to setup collective groups. +""" +import hashlib +import os +from enum import Enum, auto + + +def get_store_name(group_name): + """Generate the unique name for the NCCLUniqueID store (named actor). + + Args: + group_name: unique user name for the store. + Return: + str: SHA1-hexlified name for the store. + """ + if not group_name: + raise ValueError("group_name is None.") + hexlified_name = hashlib.sha1(group_name.encode()).hexdigest() + return hexlified_name + + +class ENV(Enum): + """ray.util.collective environment variables.""" + + NCCL_USE_MULTISTREAM = auto(), lambda v: (v or "True") == "True" + + @property + def val(self): + """Return the output of the lambda against the system's env value.""" + _, default_fn = self.value + return default_fn(os.getenv(self.name)) diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/types.py b/.venv/lib/python3.11/site-packages/ray/util/collective/types.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c1730b3d610186258c04a30a9df59e8c17e583 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/types.py @@ -0,0 +1,115 @@ +"""Types conversion between different backends.""" +from enum import Enum +from dataclasses import dataclass +from datetime import timedelta + +_NUMPY_AVAILABLE = True +_TORCH_AVAILABLE = True +_CUPY_AVAILABLE = True + +try: + import torch as th # noqa: F401 +except ImportError: + _TORCH_AVAILABLE = False + +try: + import cupy as cp # noqa: F401 +except ImportError: + _CUPY_AVAILABLE = False + + +def cupy_available(): + return _CUPY_AVAILABLE + + +def torch_available(): + return _TORCH_AVAILABLE + + +class Backend(object): + """A class to represent different backends.""" + + NCCL = "nccl" + MPI = "mpi" + GLOO = "gloo" + UNRECOGNIZED = "unrecognized" + + def __new__(cls, name: str): + backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED) + if backend == Backend.UNRECOGNIZED: + raise ValueError( + "Unrecognized backend: '{}'. Only NCCL is supported".format(name) + ) + if backend == Backend.MPI: + raise RuntimeError("Ray does not support MPI backend.") + return backend + + +class ReduceOp(Enum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + + +unset_timeout_ms = timedelta(milliseconds=-1) + + +@dataclass +class AllReduceOptions: + reduceOp = ReduceOp.SUM + timeout_ms = unset_timeout_ms + + +@dataclass +class BarrierOptions: + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceOptions: + reduceOp = ReduceOp.SUM + root_rank = 0 + root_tensor = 0 # index for multi-gpu reduce operations + timeout_ms = unset_timeout_ms + + +@dataclass +class AllGatherOptions: + timeout_ms = unset_timeout_ms + + +# +# @dataclass +# class GatherOptions: +# root_rank = 0 +# timeout = unset_timeout + + +@dataclass +class BroadcastOptions: + root_rank = 0 + root_tensor = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceScatterOptions: + reduceOp = ReduceOp.SUM + timeout_ms = unset_timeout_ms + + +@dataclass +class SendOptions: + dst_rank = 0 + dst_gpu_index = 0 + n_elements = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class RecvOptions: + src_rank = 0 + src_gpu_index = 0 + n_elements = 0 + unset_timeout_ms = unset_timeout_ms diff --git a/.venv/lib/python3.11/site-packages/ray/util/collective/util.py b/.venv/lib/python3.11/site-packages/ray/util/collective/util.py new file mode 100644 index 0000000000000000000000000000000000000000..c162c89779b1ceb08b412b4b1d379d91099fd1fc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/collective/util.py @@ -0,0 +1,68 @@ +"""Some utility class for Collectives.""" +import ray +import logging + +logger = logging.getLogger(__name__) + + +@ray.remote +class NCCLUniqueIDStore: + """NCCLUniqueID Store as a named actor class. + + Args: + name: the unique name for this named actor. + + Attributes: + name: the unique name for this named actor. + nccl_id: the NCCLUniqueID held in this store. + """ + + def __init__(self, name): + self.name = name + self.nccl_id = None + + def set_id(self, uid): + """ + Initialize the NCCL unique ID for this store. + + Args: + uid: the unique ID generated via the NCCL get_unique_id API. + + Returns: + None + """ + self.nccl_id = uid + return self.nccl_id + + def get_id(self): + """Get the NCCL unique ID held in this store.""" + if not self.nccl_id: + logger.warning( + "The NCCL ID has not been set yet for store {}.".format(self.name) + ) + return self.nccl_id + + +@ray.remote +class Info: + """Store the group information created via `create_collective_group`. + + Note: Should be used as a NamedActor. + """ + + def __init__(self): + self.ids = None + self.world_size = -1 + self.rank = -1 + self.backend = None + + def set_info(self, ids, world_size, rank, backend): + """Store collective information.""" + self.ids = ids + self.world_size = world_size + self.rank = rank + self.backend = backend + + def get_info(self): + """Get previously stored collective information.""" + return self.ids, self.world_size, self.rank, self.backend diff --git a/.venv/lib/python3.11/site-packages/ray/util/horovod/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/horovod/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a08001abc7961a19474b54b31bf89be0757df6cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/horovod/__init__.py @@ -0,0 +1,4 @@ +raise DeprecationWarning( + "ray.util.horovod has been removed as of Ray 2.0. Instead, use the `horovod` " + "library directly or the `HorovodTrainer` in Ray Train." +) diff --git a/.venv/lib/python3.11/site-packages/ray/util/horovod/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/horovod/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c33397e0adc810ce6adb6ab73ee07f262de30ba Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/horovod/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/joblib/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/joblib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b02709aa53b5ba1d52709b6f80334b8e2308594 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/joblib/__init__.py @@ -0,0 +1,20 @@ +from joblib.parallel import register_parallel_backend + + +def register_ray(): + """Register Ray Backend to be called with parallel_backend("ray").""" + try: + from ray.util.joblib.ray_backend import RayBackend + + register_parallel_backend("ray", RayBackend) + except ImportError: + msg = ( + "To use the ray backend you must install ray." + "Try running 'pip install ray'." + "See https://docs.ray.io/en/master/installation.html" + "for more information." + ) + raise ImportError(msg) + + +__all__ = ["register_ray"] diff --git a/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcbcc4989a850350dcf6f51bdb3d512d0c56edf3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/ray_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/ray_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72522081383b48e72e543b497a222f6f57f25e51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/joblib/__pycache__/ray_backend.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/joblib/ray_backend.py b/.venv/lib/python3.11/site-packages/ray/util/joblib/ray_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a6eda4daa6354c6ba36a97c861880f424c67e6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/joblib/ray_backend.py @@ -0,0 +1,96 @@ +import logging +from typing import Any, Dict, Optional + +from joblib import Parallel +from joblib._parallel_backends import MultiprocessingBackend +from joblib.pool import PicklingPool + +import ray +from ray._private.usage import usage_lib +from ray.util.multiprocessing.pool import Pool + +logger = logging.getLogger(__name__) + + +class RayBackend(MultiprocessingBackend): + """Ray backend uses ray, a system for scalable distributed computing. + More info about Ray is available here: https://docs.ray.io. + """ + + def __init__( + self, + nesting_level: Optional[int] = None, + inner_max_num_threads: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + **kwargs + ): + """``ray_remote_args`` will be used to configure Ray Actors + making up the pool.""" + usage_lib.record_library_usage("util.joblib") + + self.ray_remote_args = ray_remote_args + super().__init__( + nesting_level=nesting_level, + inner_max_num_threads=inner_max_num_threads, + **kwargs + ) + + # ray_remote_args is used both in __init__ and configure to allow for it to be + # set in both `parallel_backend` and `Parallel` respectively + + def configure( + self, + n_jobs: int = 1, + parallel: Optional[Parallel] = None, + prefer: Optional[str] = None, + require: Optional[str] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + **memmappingpool_args + ): + """Make Ray Pool the father class of PicklingPool. PicklingPool is a + father class that inherits Pool from multiprocessing.pool. The next + line is a patch, which changes the inheritance of Pool to be from + ray.util.multiprocessing.pool. + + ``ray_remote_args`` will be used to configure Ray Actors making up the pool. + This will override ``ray_remote_args`` set during initialization. + """ + PicklingPool.__bases__ = (Pool,) + """Use all available resources when n_jobs == -1. Must set RAY_ADDRESS + variable in the environment or run ray.init(address=..) to run on + multiple nodes. + """ + if n_jobs == -1: + if not ray.is_initialized(): + import os + + if "RAY_ADDRESS" in os.environ: + logger.info( + "Connecting to ray cluster at address='{}'".format( + os.environ["RAY_ADDRESS"] + ) + ) + else: + logger.info("Starting local ray cluster") + ray.init() + ray_cpus = int(ray._private.state.cluster_resources()["CPU"]) + n_jobs = ray_cpus + + eff_n_jobs = super(RayBackend, self).configure( + n_jobs, + parallel, + prefer, + require, + ray_remote_args=ray_remote_args + if ray_remote_args is not None + else self.ray_remote_args, + **memmappingpool_args + ) + return eff_n_jobs + + def effective_n_jobs(self, n_jobs): + eff_n_jobs = super(RayBackend, self).effective_n_jobs(n_jobs) + if n_jobs == -1: + ray_cpus = int(ray._private.state.cluster_resources()["CPU"]) + eff_n_jobs = ray_cpus + return eff_n_jobs