braindeck
Initial commit
bcdf9fa
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import AsyncGenerator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cloudpickle
import ray
from omegaconf import DictConfig
from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.executor.abstract import Executor
from vllm.worker.worker_base import WorkerWrapperBase
from verl.utils.fs import copy_to_local
from verl.workers.rollout.async_server import AsyncServerBase
logger = logging.getLogger(__file__)
class ExternalRayDistributedExecutor(Executor):
"""An executor that engines are launched by external ray actors."""
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.vllm_config.instance_id is not None, "instance_id must be set for external ray actors."
fields = self.vllm_config.instance_id.split(":")
assert len(fields) == 4, f"instance_id: {self.vllm_config.instance_id} must be in the format of <namespace>:<wg_prefix>:<vllm_dp_size>:<vllm_dp_rank>."
namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3])
# Make sure subprocess in same namespace as parent actor.
# actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank}
ray.init(namespace=namespace)
actor_names = [actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict")]
vllm_tp_size = self.vllm_config.parallel_config.tensor_parallel_size
assert len(actor_names) == vllm_dp_size * vllm_tp_size, f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected."
def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]:
fields = actor_name.split(":")
assert len(fields) == 2, f"invalid actor name: {actor_name}"
pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1])
return pg_index, local_rank
# sort actor names by pg_index and local_rank
actor_names = sorted(actor_names, key=get_pg_index_and_local_rank)
actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size]
self.workers: List[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names]
print(f"instance_id: {self.vllm_config.instance_id} intializes with external actors: {actor_names}")
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=None,
rank=None,
distributed_init_method="env://",
is_driver_worker=True,
)
self.collective_rpc("init_worker", args=([kwargs],))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
print(f"instance_id: {self.vllm_config.instance_id} intializes finished.")
def collective_rpc(
self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None,
) -> List[Any]:
# TODO(wuxibin): support ray compiled graph
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
outputs = ray.get([worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers])
return outputs
def check_health(self):
return
@ray.remote(num_cpus=1)
class AsyncvLLMServer(AsyncServerBase):
"""
AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines
in hybrid rollout workers, i.e AsyncActorRolloutRefWorker.
AsyncvLLMServer works as follows:
1. Start FastAPI server first.
2. Initialize AsyncLLM with ExternalRayDistributedExecutor.
3. AsyncLLM spawn EngineCore in subprocess.
4. EngineCore initialize ExternalRayDistributedExecutor.
5. ExternalRayDistributedExecutor lookup its corresponding actors by name.
6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model.
For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826
"""
def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str):
"""
Args:
config: DictConfig, actor_rollout_ref config.
vllm_dp_size: int, vllm data parallel size.
vllm_dp_rank: int, vllm data parallel rank.
wg_prefix: str, worker group prefix, used to lookup actors.
"""
super().__init__()
self.config = config
self.vllm_dp_size = vllm_dp_size
self.vllm_dp_rank = vllm_dp_rank
self.wg_prefix = wg_prefix
self.engine: AsyncLLM = None
async def init_engine(self):
"""Init vLLM AsyncLLM engine."""
config = self.config
model_path = config.model.path
model_name = "/".join(model_path.split("/")[-2:])
local_path = copy_to_local(model_path)
trust_remote_code = config.model.get("trust_remote_code", False)
config = config.rollout
tensor_parallel_size = config.get("tensor_model_parallel_size", 1)
max_num_batched_tokens = config.get("max_num_batched_tokens", 8192)
max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length
max_model_len = int(max_model_len)
# Override default generation config from hugging face model config,
# user can still override them by passing kwargs in each request.
kwargs = dict(
n=1,
logprobs=0,
max_tokens=config.response_length,
)
for k in config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = config.get(k)
print(f"override_generation_config: {kwargs}")
engine_args = AsyncEngineArgs(
model=local_path,
enable_sleep_mode=True,
override_generation_config=kwargs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=ExternalRayDistributedExecutor,
dtype=config.dtype,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
disable_custom_all_reduce=True,
disable_mm_preprocessor_cache=True,
skip_tokenizer_init=False,
max_model_len=max_model_len,
load_format="auto",
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_prefix_caching=True,
trust_remote_code=trust_remote_code,
seed=self.vllm_dp_rank,
)
# init async llm engine
vllm_config = engine_args.create_engine_config()
namespace = ray.get_runtime_context().namespace
vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}"
self.engine = AsyncLLM.from_vllm_config(vllm_config)
# build serving chat
model_config = self.engine.model_config
BASE_MODEL_PATHS = [BaseModelPath(name=model_name, model_path=model_path)]
models = OpenAIServingModels(self.engine, model_config, BASE_MODEL_PATHS)
self.openai_serving_chat = OpenAIServingChat(
self.engine,
model_config,
models,
"assistant",
request_logger=RequestLogger(max_log_len=4096),
chat_template=None,
chat_template_content_format="auto",
)
async def chat_completion(self, raw_request: Request):
"""OpenAI-compatible HTTP endpoint.
API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
"""
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
generator = await self.openai_serving_chat.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator, media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
async def chat_completion_generator(self, request: ChatCompletionRequest) -> AsyncGenerator[Tuple[int, str]]:
"""Direct chat completion without FastAPI.
Args:
request: ChatCompletionRequest, request object.
Returns:
AsyncGenerator[Tuple[int, str]]: async generator of (status_code, data) pairs.
"""
generator = await self.openai_serving_chat.create_chat_completion(request)
if isinstance(generator, ErrorResponse):
data = generator.model_dump_json(exclude_unset=True)
yield generator.code, f"data: {data}\n\n"
if request.stream:
async for chunk in generator:
yield 200, chunk
else:
assert isinstance(generator, ChatCompletionResponse)
data = generator.model_dump_json(exclude_unset=True)
yield 200, f"data: {data}\n\n"
async def wake_up(self):
await self.engine.wake_up()
async def sleep(self):
# TODO: https://github.com/vllm-project/vllm/issues/17103
await self.engine.reset_prefix_cache()
await self.engine.sleep()