|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| The vllm_rollout that can be applied in different backend
|
| When working with FSDP:
|
| - Use DTensor weight loader (recommended) or HF weight loader
|
| - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
|
| When working with Megatron:
|
| - Use Megatron weight loader
|
| - During training, only the current pp stage holds the parameters
|
| - Before inference, broadcast the parameters of the current pp rank
|
| to all other pp ranks (all pp ranks holds all the parameters)
|
| - Bind the parameters to the inference engine
|
| - Do inference in tp. pp is treated as additional dp
|
| - After inference, all the parameters that doesn't belong to this pp rank is freed.
|
| """
|
|
|
| import logging
|
| import os
|
| from contextlib import contextmanager
|
| from typing import Any, Dict, List, Union
|
|
|
| import numpy as np
|
| import torch
|
| import torch.distributed
|
| from omegaconf import DictConfig
|
| from tensordict import TensorDict
|
| from vllm import LLM, SamplingParams
|
| from vllm.distributed import parallel_state as vllm_ps
|
| from vllm.worker.worker_base import WorkerWrapperBase
|
|
|
| from verl import DataProto
|
| from verl.third_party.vllm import vllm_version
|
| from verl.utils.debug import GPUMemoryLogger
|
| from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length
|
| from verl.workers.rollout.base import BaseRollout
|
|
|
| logger = logging.getLogger(__file__)
|
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
|
|
|
|
|
|
|
| non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
| token_ids = prompt_token_ids[non_pad_index:].tolist()
|
| return token_ids
|
|
|
|
|
| def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
|
| if isinstance(value, torch.Tensor):
|
| return value.repeat_interleave(repeats, dim=0)
|
| else:
|
| return np.repeat(value, repeats, axis=0)
|
|
|
|
|
| class vLLMRollout(BaseRollout):
|
| def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):
|
| """A vLLM rollout. It requires the module is supported by the vllm.
|
|
|
| Args:
|
| module: module here follows huggingface APIs
|
| config: DictConfig
|
| tokenizer: the task/model tokenizer
|
| model_hf_config: the huggingface config to initiallize the generating model in vllm
|
| **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
|
| """
|
| super().__init__()
|
| self.config = config
|
| assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
|
|
|
| tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
|
| assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size"
|
| max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192)
|
|
|
| if kwargs.get("train_tp") is not None:
|
|
|
| import os
|
|
|
| os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
|
| os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
|
| if vllm_version in (
|
| "0.5.4",
|
| "0.6.3",
|
| ):
|
| train_tp = kwargs.get("train_tp")
|
| num_tp_per_train_tp = train_tp // tensor_parallel_size
|
| vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp)
|
| else:
|
| vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size)
|
|
|
| rope_scaling_config = getattr(model_hf_config, 'rope_scaling', None)
|
| if not rope_scaling_config:
|
| assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length"
|
|
|
| max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)
|
|
|
| if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:
|
| raise ValueError(
|
| "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \
|
| please increase max_num_batched_tokens or disable chunked prefill"
|
| )
|
|
|
| trust_remote_code = kwargs.get("trust_remote_code", False)
|
| load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format
|
|
|
| limit_mm_per_prompt = None
|
| if config.get("limit_images", None):
|
| limit_mm_per_prompt = {"image": config.get("limit_images")}
|
|
|
| self.inference_engine = LLM(
|
| model=model_path,
|
| enable_sleep_mode=True,
|
| tensor_parallel_size=tensor_parallel_size,
|
| distributed_executor_backend="external_launcher",
|
| dtype=config.dtype,
|
| enforce_eager=config.enforce_eager,
|
| gpu_memory_utilization=config.gpu_memory_utilization,
|
| disable_custom_all_reduce=True,
|
| disable_mm_preprocessor_cache=True,
|
| limit_mm_per_prompt=limit_mm_per_prompt,
|
| skip_tokenizer_init=False,
|
| max_model_len=max_model_len,
|
| load_format=load_format,
|
| disable_log_stats=config.disable_log_stats,
|
| max_num_batched_tokens=max_num_batched_tokens,
|
| enable_chunked_prefill=config.enable_chunked_prefill,
|
| enable_prefix_caching=True,
|
| trust_remote_code=trust_remote_code,
|
| seed=config.get("seed", 0),
|
| )
|
|
|
|
|
| self.inference_engine.sleep(level=1)
|
|
|
| kwargs = dict(
|
| n=1,
|
| logprobs=0,
|
| max_tokens=config.response_length,
|
| )
|
|
|
|
|
| if vllm_version != "0.3.1":
|
| kwargs["detokenize"] = False
|
|
|
|
|
| for k in config.keys():
|
| if hasattr(SamplingParams(), str(k)):
|
| kwargs[k] = config.get(k)
|
|
|
| print(f"kwargs: {kwargs}")
|
| self.sampling_params = SamplingParams(**kwargs)
|
|
|
| self.pad_token_id = tokenizer.pad_token_id
|
|
|
| @contextmanager
|
| def update_sampling_params(self, **kwargs):
|
|
|
| old_sampling_params_args = {}
|
| if kwargs:
|
| for key, value in kwargs.items():
|
| if hasattr(self.sampling_params, key):
|
| old_value = getattr(self.sampling_params, key)
|
| old_sampling_params_args[key] = old_value
|
| setattr(self.sampling_params, key, value)
|
| yield
|
|
|
|
|
| for key, value in old_sampling_params_args.items():
|
| setattr(self.sampling_params, key, value)
|
|
|
| @GPUMemoryLogger(role="vllm rollout spmd", logger=logger)
|
| @torch.no_grad()
|
| def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
|
|
|
| if (
|
| vllm_version
|
| in (
|
| "0.5.4",
|
| "0.6.3",
|
| )
|
| and self.config.free_cache_engine
|
| ):
|
| self.inference_engine.init_cache_engine()
|
|
|
| idx = prompts.batch["input_ids"]
|
|
|
| attention_mask = prompts.batch["attention_mask"]
|
| position_ids = prompts.batch["position_ids"]
|
|
|
|
|
| eos_token_id = prompts.meta_info["eos_token_id"]
|
|
|
| batch_size = idx.size(0)
|
|
|
| non_tensor_batch = prompts.non_tensor_batch
|
| if "raw_prompt_ids" not in non_tensor_batch:
|
| non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)
|
|
|
| if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
|
| raise RuntimeError("vllm sharding manager is not work properly.")
|
|
|
| if "multi_modal_data" in non_tensor_batch:
|
| vllm_inputs = []
|
| for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")):
|
| vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data})
|
| else:
|
| vllm_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")]
|
|
|
|
|
|
|
| for input_data in vllm_inputs:
|
| if isinstance(input_data["prompt_token_ids"], np.ndarray):
|
| input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
|
| elif not isinstance(input_data["prompt_token_ids"], list):
|
| raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}")
|
|
|
| do_sample = prompts.meta_info.get("do_sample", True)
|
| is_validate = prompts.meta_info.get("validate", False)
|
| if not do_sample:
|
| kwargs = {
|
| "best_of": 1,
|
| "top_p": 1.0,
|
| "top_k": -1,
|
| "min_p": 0.0,
|
| "temperature": 0,
|
| "n": 1,
|
| }
|
| elif is_validate:
|
|
|
| kwargs = {
|
| "top_k": self.config.val_kwargs.top_k,
|
| "top_p": self.config.val_kwargs.top_p,
|
| "temperature": self.config.val_kwargs.temperature,
|
| "n": 1,
|
| }
|
|
|
|
|
| with self.update_sampling_params(**kwargs):
|
| outputs = self.inference_engine.generate(
|
| prompts=vllm_inputs,
|
| sampling_params=self.sampling_params,
|
| use_tqdm=False,
|
| )
|
|
|
|
|
|
|
|
|
| response = []
|
| for output in outputs:
|
| for sample_id in range(len(output.outputs)):
|
| response.append(output.outputs[sample_id].token_ids)
|
|
|
| response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)
|
|
|
| if self.sampling_params.n > 1 and do_sample:
|
| idx = _repeat_interleave(idx, self.sampling_params.n)
|
| attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
|
| position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
|
| batch_size = batch_size * self.sampling_params.n
|
| if "multi_modal_inputs" in non_tensor_batch.keys():
|
| non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(non_tensor_batch["multi_modal_inputs"], self.sampling_params.n)
|
|
|
| if "tools_kwargs" in non_tensor_batch.keys():
|
| non_tensor_batch["tools_kwargs"] = _repeat_interleave(non_tensor_batch["tools_kwargs"], self.sampling_params.n)
|
|
|
| seq = torch.cat([idx, response], dim=-1)
|
|
|
| response_length = response.size(1)
|
| delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
| delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
|
| if position_ids.dim() == 3:
|
| delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
|
|
|
|
|
|
|
|
|
|
|
| response_position_ids = position_ids[..., -1:] + delta_position_id
|
| position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
| response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
| attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
|
|
|
|
| batch = TensorDict(
|
| {
|
| "prompts": idx,
|
| "responses": response,
|
| "input_ids": seq,
|
|
|
| "attention_mask": attention_mask,
|
| "position_ids": position_ids,
|
| },
|
| batch_size=batch_size,
|
| )
|
|
|
|
|
| if (
|
| vllm_version
|
| in (
|
| "0.5.4",
|
| "0.6.3",
|
| )
|
| and self.config.free_cache_engine
|
| ):
|
| self.inference_engine.free_cache_engine()
|
|
|
| return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
|
|
|
|
|
| class vLLMAsyncRollout:
|
| """vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase,
|
| which is engine in single worker process.
|
| """
|
|
|
| def __init__(self, *args, **kwargs):
|
|
|
| self.inference_engine: WorkerWrapperBase = None
|
| self.sharding_manager = None
|
| self.is_sleep = False
|
|
|
| def init_worker(self, all_kwargs: List[Dict[str, Any]]):
|
| """Initialize worker engine."""
|
| all_kwargs[0]["rank"] = int(os.environ["RANK"])
|
| all_kwargs[0]["local_rank"] = 0
|
|
|
| self.vllm_config = all_kwargs[0]["vllm_config"]
|
| self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
|
| self.inference_engine.init_worker(all_kwargs)
|
|
|
| def load_model(self, *args, **kwargs):
|
| self.inference_engine.load_model(*args, **kwargs)
|
|
|
|
|
| self.sharding_manager.inference_engine = self.inference_engine
|
| self.sharding_manager.model_runner = self.inference_engine.worker.model_runner
|
|
|
| def sleep(self, *args, **kwargs):
|
| """Offload model weights and discard kv cache."""
|
| if self.is_sleep:
|
| return
|
| self.sharding_manager.__exit__(None, None, None)
|
| self.is_sleep = True
|
|
|
| def wake_up(self, *args, **kwargs):
|
| """Load model weights and build kv cache."""
|
| if not self.is_sleep:
|
| return
|
| self.sharding_manager.__enter__()
|
| self.is_sleep = False
|
|
|
| def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
| if method == "init_worker":
|
| return self.init_worker(*args, **kwargs)
|
| elif method == "load_model":
|
| return self.load_model(*args, **kwargs)
|
| elif method == "sleep":
|
| return self.sleep(*args, **kwargs)
|
| elif method == "wake_up":
|
| return self.wake_up(*args, **kwargs)
|
| else:
|
| return self.inference_engine.execute_method(method, *args, **kwargs)
|
|
|