|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
|
|
|
"""
|
|
|
|
|
|
import inspect
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
import torch
|
|
|
import torch.distributed
|
|
|
import torch.distributed as dist
|
|
|
from megatron.core import DistributedDataParallel as LocalDDP
|
|
|
from megatron.core import parallel_state as mpu
|
|
|
from megatron.core.transformer.module import Float16Module
|
|
|
from torch import nn
|
|
|
from torch.distributed import new_group
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
|
|
|
|
|
import verl.utils.megatron.tensor_parallel as tp_utils
|
|
|
from verl import DataProto
|
|
|
from verl.models.mcore.weight_converter import McoreToHFWeightConverterBase
|
|
|
from verl.third_party.vllm import LLM, vllm_version
|
|
|
from verl.third_party.vllm import parallel_state as vllm_ps
|
|
|
from verl.utils.debug import GPUMemoryLogger
|
|
|
from verl.utils.megatron_utils import (
|
|
|
broadcast_from_megatron_pp,
|
|
|
broadcast_str_from_megatron_pp,
|
|
|
convert_megatron_model_to_transformers_model,
|
|
|
get_model,
|
|
|
unwrap_model,
|
|
|
)
|
|
|
from verl.utils.memory_buffer import (
|
|
|
build_memory_buffer,
|
|
|
build_memory_reference_from_module,
|
|
|
get_weight_buffer_meta_from_module,
|
|
|
)
|
|
|
from verl.utils.model import normalize_model_name
|
|
|
from verl.utils.torch_functional import allgather_dict_tensors, check_cuda_is_available
|
|
|
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
|
|
|
|
|
|
from .base import BaseShardingManager
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
class AllGatherPPModel:
|
|
|
def __init__(self, model_provider, use_distributed_optimizer=True) -> None:
|
|
|
print(
|
|
|
"[WARNING] This class is deprecated and will no longer be supported. \
|
|
|
Consider using the `MegatronPPOActor` class directly as a replacement."
|
|
|
)
|
|
|
self._pp_group = mpu.get_pipeline_model_parallel_group()
|
|
|
self._pp_rank = mpu.get_pipeline_model_parallel_rank()
|
|
|
self._pp_size = mpu.get_pipeline_model_parallel_world_size()
|
|
|
self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
|
|
|
self._model_chunk_size = self._vpp_size or 1
|
|
|
|
|
|
|
|
|
self._pp_models = [None] * self.pp_size
|
|
|
|
|
|
rank_list = list(range(self.pp_size))
|
|
|
|
|
|
rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
|
|
|
self._this_rank_models = None
|
|
|
|
|
|
|
|
|
self.memory_buffers = [None] * self.pp_size
|
|
|
for cur_pp_rank in rank_list:
|
|
|
print(
|
|
|
"create pp model",
|
|
|
f"torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB",
|
|
|
)
|
|
|
|
|
|
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
|
|
if cur_pp_rank != self.pp_rank:
|
|
|
models = get_model(model_provider, wrap_with_ddp=False, use_distributed_optimizer=False)
|
|
|
models = nn.ModuleList(models)
|
|
|
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
|
|
self.pp_models[cur_pp_rank] = models
|
|
|
else:
|
|
|
|
|
|
models = get_model(model_provider, wrap_with_ddp=True, use_distributed_optimizer=use_distributed_optimizer)
|
|
|
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
|
|
self._this_rank_models = nn.ModuleList(models)
|
|
|
self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))
|
|
|
|
|
|
self._build_param_buffer(cur_pp_rank)
|
|
|
self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)
|
|
|
|
|
|
|
|
|
if cur_pp_rank != self.pp_rank:
|
|
|
for model in self.pp_models[cur_pp_rank]:
|
|
|
model.eval()
|
|
|
self._offload_params_to_cpu(cur_pp_rank)
|
|
|
|
|
|
def _build_param_buffer(self, pp_rank):
|
|
|
"""Build the parameter buffer in each pp rank"""
|
|
|
if pp_rank == self._pp_rank:
|
|
|
from verl.utils.memory_buffer import MemoryBuffer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source = self._this_rank_models[0].buffers[0].param_data
|
|
|
self.memory_buffers[pp_rank] = {torch.bfloat16: MemoryBuffer(source.numel(), source.numel(), torch.bfloat16, source)}
|
|
|
else:
|
|
|
model = self.pp_models[pp_rank]
|
|
|
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
|
|
|
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
|
|
|
|
|
|
def _build_param_references(self, pp_rank, maintain_weight=False):
|
|
|
if pp_rank == self._pp_rank:
|
|
|
return
|
|
|
model = self.pp_models[pp_rank]
|
|
|
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
|
|
|
|
|
|
def _load_params_to_cuda(self, pp_rank, to_empty=False):
|
|
|
assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
|
|
|
for buffer in self.memory_buffers[pp_rank].values():
|
|
|
if not to_empty:
|
|
|
buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
|
|
|
else:
|
|
|
buffer.data = torch.empty_like(buffer.data, device="cuda")
|
|
|
|
|
|
self._build_param_references(pp_rank)
|
|
|
|
|
|
def _offload_params_to_cpu(self, pp_rank, to_empty=False):
|
|
|
assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
|
|
|
for buffer in self.memory_buffers[pp_rank].values():
|
|
|
if not to_empty:
|
|
|
|
|
|
buffer.data = buffer.data.to("cpu", non_blocking=True)
|
|
|
else:
|
|
|
buffer.data = torch.empty_like(buffer.data, device="cpu")
|
|
|
self._build_param_references(pp_rank)
|
|
|
|
|
|
def load_params_to_cuda(self, to_empty=False):
|
|
|
"""load all model params to cuda"""
|
|
|
for cur_pp_rank in range(self.pp_size):
|
|
|
if cur_pp_rank != self.pp_rank:
|
|
|
self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)
|
|
|
|
|
|
def allgather_params(self):
|
|
|
"""allgather params of all pp ranks. Return a list of handles"""
|
|
|
for cur_pp_rank in range(self.pp_size):
|
|
|
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
|
|
|
|
|
|
|
|
|
|
|
|
for _, param in sorted(self.pp_models[cur_pp_rank].named_parameters()):
|
|
|
dist.broadcast(tensor=param.data, src=global_src, group=self.pp_group, async_op=False)
|
|
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
|
try:
|
|
|
prev_output = None
|
|
|
for cur_chunk_rank in range(self._model_chunk_size):
|
|
|
if self._vpp_size:
|
|
|
mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)
|
|
|
|
|
|
for cur_pp_rank in range(self.pp_size):
|
|
|
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
|
|
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
|
|
|
ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
|
|
|
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
|
|
|
prev_output = ret
|
|
|
finally:
|
|
|
if self._vpp_size:
|
|
|
mpu.set_virtual_pipeline_model_parallel_rank(0)
|
|
|
mpu.set_pipeline_model_parallel_rank(self.pp_rank)
|
|
|
return ret
|
|
|
|
|
|
def __call__(self, *inputs, **kwargs):
|
|
|
return self.forward(*inputs, **kwargs)
|
|
|
|
|
|
def eval(self):
|
|
|
for model in self.pp_models[self.pp_rank]:
|
|
|
model.eval()
|
|
|
|
|
|
def train(self):
|
|
|
for model in self.pp_models[self.pp_rank]:
|
|
|
model.train()
|
|
|
|
|
|
def offload_params_to_cpu(self, to_empty=False):
|
|
|
"""offload params of models that are not of current pp rank to cpu"""
|
|
|
for cur_pp_rank in range(self.pp_size):
|
|
|
if cur_pp_rank != self.pp_rank:
|
|
|
self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)
|
|
|
|
|
|
def get_all_params(self):
|
|
|
"""Get all the parameters of the models in all pp ranks
|
|
|
|
|
|
Returns:
|
|
|
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
|
|
|
tensors of each model chunk
|
|
|
|
|
|
"""
|
|
|
params = []
|
|
|
for pp_rank in range(self.pp_size):
|
|
|
params.append([])
|
|
|
for model_chunk_idx in range(len(self.pp_models[pp_rank])):
|
|
|
params[pp_rank].append({})
|
|
|
pp_model = self.pp_models[pp_rank][model_chunk_idx]
|
|
|
pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module)))
|
|
|
for name, param in pp_model.named_parameters():
|
|
|
|
|
|
if "lora" in name:
|
|
|
continue
|
|
|
params[pp_rank][model_chunk_idx][name] = param
|
|
|
|
|
|
return params
|
|
|
|
|
|
def update_this_rank_models(self, new_models):
|
|
|
self._this_rank_models = new_models
|
|
|
self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))
|
|
|
|
|
|
@property
|
|
|
def this_rank_models(self):
|
|
|
return self._this_rank_models
|
|
|
|
|
|
@property
|
|
|
def pp_size(self):
|
|
|
return self._pp_size
|
|
|
|
|
|
@property
|
|
|
def pp_rank(self):
|
|
|
return self._pp_rank
|
|
|
|
|
|
@property
|
|
|
def pp_group(self):
|
|
|
return self._pp_group
|
|
|
|
|
|
@property
|
|
|
def pp_models(self):
|
|
|
return self._pp_models
|
|
|
|
|
|
|
|
|
"""
|
|
|
Megatron Hybrid Engine:
|
|
|
- During training, only the current pp stage holds the parameters
|
|
|
- Before inference, broadcast the parameters of the current pp rank
|
|
|
to all other pp ranks (all pp ranks holds all the parameters)
|
|
|
- Bind the parameters to the inference engine
|
|
|
- Do inference in tp. pp is treated as additional dp
|
|
|
- After inference, all the parameters that doesn't belong to this pp rank is freed.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_MICRO_DATA_PARALLEL_GROUP = None
|
|
|
|
|
|
|
|
|
class MegatronVLLMShardingManager(BaseShardingManager):
|
|
|
@check_cuda_is_available()
|
|
|
def __init__(
|
|
|
self,
|
|
|
actor_module: nn.ModuleList,
|
|
|
inference_engine: LLM,
|
|
|
model_config,
|
|
|
layer_name_mapping,
|
|
|
weight_converter: McoreToHFWeightConverterBase,
|
|
|
module: AllGatherPPModel = None,
|
|
|
):
|
|
|
from megatron.core import parallel_state as mpu
|
|
|
|
|
|
self.actor_module = actor_module
|
|
|
self.inference_engine = inference_engine
|
|
|
self.model_config = model_config
|
|
|
self.layer_name_mapping = layer_name_mapping
|
|
|
self.weight_converter = weight_converter
|
|
|
self.module = module
|
|
|
|
|
|
global _MICRO_DATA_PARALLEL_GROUP
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
rank = torch.distributed.get_rank()
|
|
|
self.infer_tp_size = vllm_ps.get_tensor_model_parallel_world_size()
|
|
|
self.infer_tp_rank = vllm_ps.get_tensor_model_parallel_rank()
|
|
|
self.infer_tp_group = vllm_ps.get_tensor_model_parallel_group()
|
|
|
self.train_tp_size = mpu.get_tensor_model_parallel_world_size()
|
|
|
self.train_tp_rank = mpu.get_tensor_model_parallel_rank()
|
|
|
self.train_tp_group = mpu.get_tensor_model_parallel_group()
|
|
|
self.need_tp_reshard = self.infer_tp_size == self.train_tp_size
|
|
|
|
|
|
|
|
|
assert self.infer_tp_size <= self.train_tp_size, "Not implemented for infer_tp > train_tp"
|
|
|
assert self.train_tp_size % self.infer_tp_size == 0
|
|
|
|
|
|
micro_dp_size = self.train_tp_size // self.infer_tp_size
|
|
|
num_micro_dp_groups = world_size // micro_dp_size
|
|
|
assert _MICRO_DATA_PARALLEL_GROUP is None, "micro data parallel group is already initialized"
|
|
|
for i in range(num_micro_dp_groups):
|
|
|
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
|
|
|
group = new_group(ranks=ranks)
|
|
|
if rank in ranks:
|
|
|
_MICRO_DATA_PARALLEL_GROUP = group
|
|
|
|
|
|
def per_tensor_generator(self, convert_qkv_gate_up_by_simple_split=True):
|
|
|
"""
|
|
|
convert_qkv_gate_up_by_simple_split is a parameter affected by the vLLM version.
|
|
|
"""
|
|
|
from megatron.core import parallel_state as mpu
|
|
|
|
|
|
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
|
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
|
|
vpp_size = len(self.actor_module)
|
|
|
|
|
|
all_gather_group = (
|
|
|
get_micro_data_parallel_group()
|
|
|
if vllm_version
|
|
|
in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
)
|
|
|
else self.train_tp_group
|
|
|
)
|
|
|
all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
|
|
|
|
|
|
def tensor_generator():
|
|
|
for scan_vpp_idx in range(vpp_size):
|
|
|
yield from self.actor_module[scan_vpp_idx].named_parameters()
|
|
|
|
|
|
|
|
|
meta_info = []
|
|
|
for scan_vpp_idx in range(vpp_size):
|
|
|
for idx, (name, _) in enumerate(self.actor_module[scan_vpp_idx].named_parameters()):
|
|
|
meta_info.append((pp_rank, scan_vpp_idx, idx, name))
|
|
|
|
|
|
obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()
|
|
|
torch.distributed.all_gather_object(object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group())
|
|
|
layer_list_meta = [item for sublist in obj_spec_output for item in sublist]
|
|
|
|
|
|
gen_func = tensor_generator()
|
|
|
|
|
|
|
|
|
for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:
|
|
|
if cur_pp_rank == pp_rank:
|
|
|
try:
|
|
|
cur_name, cur_tensor = next(gen_func)
|
|
|
except StopIteration:
|
|
|
cur_name, cur_tensor = None, None
|
|
|
cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, self.model_config.num_hidden_layers)
|
|
|
else:
|
|
|
cur_tensor, cur_name = None, None
|
|
|
|
|
|
|
|
|
cur_name = broadcast_str_from_megatron_pp(cur_name)
|
|
|
broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor)
|
|
|
|
|
|
|
|
|
while cur_name.startswith("module."):
|
|
|
cur_name = cur_name[len("module.") :]
|
|
|
|
|
|
|
|
|
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
|
|
|
|
|
|
if all_gather_group_size <= 1:
|
|
|
infer_params = [broad_pp_tensor]
|
|
|
else:
|
|
|
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)]
|
|
|
torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group())
|
|
|
infer_params = self.default_tp_concat_fn(cur_name, broad_pp_tensor, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split)
|
|
|
else:
|
|
|
infer_params = broad_pp_tensor
|
|
|
|
|
|
if vllm_version in ("0.4.2", "0.5.4", "0.6.3"):
|
|
|
converted_names, converted_params = convert_megatron_model_to_transformers_model(
|
|
|
cur_name,
|
|
|
infer_params,
|
|
|
self.model_config,
|
|
|
self.train_tp_size,
|
|
|
0,
|
|
|
convert_qkv_gate_up_by_trunk_concat=False,
|
|
|
)
|
|
|
else:
|
|
|
if not isinstance(infer_params, list):
|
|
|
infer_params = [infer_params]
|
|
|
converted_names, converted_params = self.weight_converter.convert_param(cur_name, infer_params)
|
|
|
|
|
|
yield from zip(converted_names, converted_params)
|
|
|
|
|
|
def default_tp_concat_fn(self, name, param, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False):
|
|
|
"""
|
|
|
name: name of the parameter
|
|
|
param: training parameters
|
|
|
infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered
|
|
|
from train tp group (vllm 0.8.2) or micro-dp group (vllm <= 0.6.3)
|
|
|
model_config: huggingface model_config
|
|
|
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
|
|
|
definition so that it is model-agnostic. If the model doesn't implement this function,
|
|
|
we can throw an error to force user disable TP HybridEngine.
|
|
|
"""
|
|
|
if self.layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name:
|
|
|
|
|
|
|
|
|
q_lst = []
|
|
|
k_lst = []
|
|
|
v_lst = []
|
|
|
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
|
|
|
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
|
|
|
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}"
|
|
|
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
|
|
|
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
|
|
|
for infer_param in infer_params:
|
|
|
num_query_groups_per_partition = model_config.num_key_value_heads // self.train_tp_size
|
|
|
for chunk in infer_param.chunk(num_query_groups_per_partition):
|
|
|
split_size = [
|
|
|
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
|
|
|
kv_size_per_tp // num_query_groups_per_partition,
|
|
|
kv_size_per_tp // num_query_groups_per_partition,
|
|
|
]
|
|
|
q, k, v = chunk.split(split_size)
|
|
|
q_lst.append(q)
|
|
|
k_lst.append(k)
|
|
|
v_lst.append(v)
|
|
|
q = torch.cat(q_lst, dim=0)
|
|
|
k = torch.cat(k_lst, dim=0)
|
|
|
v = torch.cat(v_lst, dim=0)
|
|
|
infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v]
|
|
|
|
|
|
elif self.layer_name_mapping.get("gate_proj_layer_name") in name:
|
|
|
|
|
|
gate_lst = []
|
|
|
up_lst = []
|
|
|
for infer_param in infer_params:
|
|
|
gate, up = infer_param.chunk(2)
|
|
|
gate_lst.append(gate)
|
|
|
up_lst.append(up)
|
|
|
gate = torch.cat(gate_lst, dim=0)
|
|
|
up = torch.cat(up_lst, dim=0)
|
|
|
infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up]
|
|
|
|
|
|
elif "mlp.experts.linear_fc2.weight" in name:
|
|
|
infer_params = torch.cat(infer_params, dim=1)
|
|
|
|
|
|
else:
|
|
|
|
|
|
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))
|
|
|
|
|
|
return infer_params
|
|
|
|
|
|
def _post_process_params(self, params, convert_qkv_gate_up_by_simple_split=False):
|
|
|
"""
|
|
|
For each param, if it is a tp-splited param, we all-gather from train
|
|
|
tp group (vllm 0.8.2) or micro-dp group (vllm <= 0.6.3)
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
all_gather_group = (
|
|
|
get_micro_data_parallel_group()
|
|
|
if vllm_version
|
|
|
in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
)
|
|
|
else self.train_tp_group
|
|
|
)
|
|
|
all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
|
|
|
|
|
|
for name, param in params:
|
|
|
if tp_utils.is_tensor_parallel_param(param):
|
|
|
|
|
|
if all_gather_group_size <= 1:
|
|
|
infer_params = [param]
|
|
|
else:
|
|
|
infer_params = [torch.empty_like(param) for _ in range(all_gather_group_size)]
|
|
|
torch.distributed.all_gather(infer_params, param, group=all_gather_group)
|
|
|
infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split)
|
|
|
else:
|
|
|
infer_params = param
|
|
|
if vllm_version in ("0.4.2", "0.5.4", "0.6.3"):
|
|
|
converted_names, converted_params = convert_megatron_model_to_transformers_model(
|
|
|
name,
|
|
|
infer_params,
|
|
|
self.model_config,
|
|
|
self.train_tp_size,
|
|
|
self.module.pp_models[0][0].config.num_query_groups,
|
|
|
convert_qkv_gate_up_by_trunk_concat=False,
|
|
|
)
|
|
|
else:
|
|
|
if not isinstance(infer_params, list):
|
|
|
infer_params = [infer_params]
|
|
|
converted_names, converted_params = self.weight_converter.convert_param(name, infer_params)
|
|
|
yield from zip(converted_names, converted_params)
|
|
|
|
|
|
@GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger)
|
|
|
def __enter__(self):
|
|
|
if vllm_version in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
):
|
|
|
per_tensor_param = self.per_tensor_generator(convert_qkv_gate_up_by_simple_split=False)
|
|
|
self.inference_engine.sync_model_weights(per_tensor_param, load_format="megatron")
|
|
|
else:
|
|
|
|
|
|
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
|
|
|
self.inference_engine.wake_up(tags=["weights"])
|
|
|
else:
|
|
|
self.inference_engine.wake_up()
|
|
|
per_tensor_param = self.per_tensor_generator()
|
|
|
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
|
|
patch_vllm_moe_model_weight_loader(model)
|
|
|
loaded_params = model.load_weights(per_tensor_param)
|
|
|
info = f"vLLM load weights, loaded_params: {len(loaded_params)}"
|
|
|
logger.info(info)
|
|
|
|
|
|
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
|
|
|
self.inference_engine.wake_up(tags=["kv_cache"])
|
|
|
|
|
|
@GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger)
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
if vllm_version in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
):
|
|
|
self.inference_engine.offload_model_weights()
|
|
|
else:
|
|
|
self.inference_engine.sleep(level=1)
|
|
|
for model in self.actor_module:
|
|
|
model.train()
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
@GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger)
|
|
|
def preprocess_data(self, data: DataProto) -> DataProto:
|
|
|
|
|
|
micro_dp_size = get_micro_data_parallel_world_size()
|
|
|
if micro_dp_size > 1:
|
|
|
local_prompts = data.chunk(chunks=micro_dp_size)
|
|
|
data = local_prompts[get_micro_data_parallel_rank()]
|
|
|
|
|
|
return data
|
|
|
|
|
|
@GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger)
|
|
|
def postprocess_data(self, data: DataProto) -> DataProto:
|
|
|
|
|
|
|
|
|
micro_dp_size = get_micro_data_parallel_world_size()
|
|
|
if micro_dp_size > 1:
|
|
|
data.batch = allgather_dict_tensors(
|
|
|
data.batch.contiguous(),
|
|
|
size=get_micro_data_parallel_world_size(),
|
|
|
group=get_micro_data_parallel_group(),
|
|
|
dim=0,
|
|
|
)
|
|
|
return data
|
|
|
|
|
|
|
|
|
"""
|
|
|
Micro Data parallel group
|
|
|
"""
|
|
|
|
|
|
|
|
|
def get_micro_data_parallel_group():
|
|
|
assert _MICRO_DATA_PARALLEL_GROUP is not None
|
|
|
return _MICRO_DATA_PARALLEL_GROUP
|
|
|
|
|
|
|
|
|
def get_micro_data_parallel_world_size():
|
|
|
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
|
|
|
|
|
|
|
|
|
def get_micro_data_parallel_rank():
|
|
|
return torch.distributed.get_rank(group=get_micro_data_parallel_group())
|
|
|
|