|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
import torch
|
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
|
|
|
|
|
from verl import DataProto
|
|
|
from verl.protocol import all_gather_data_proto
|
|
|
from verl.third_party.vllm import LLM, vllm_version
|
|
|
from verl.third_party.vllm import parallel_state as vllm_ps
|
|
|
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
|
|
|
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
|
|
|
from verl.utils.torch_functional import check_cuda_is_available
|
|
|
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
|
|
|
|
|
|
from .base import BaseShardingManager
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
class FSDPVLLMShardingManager(BaseShardingManager):
|
|
|
@check_cuda_is_available()
|
|
|
def __init__(
|
|
|
self,
|
|
|
module: FSDP,
|
|
|
inference_engine: LLM,
|
|
|
model_config,
|
|
|
full_params: bool = False,
|
|
|
device_mesh: DeviceMesh = None,
|
|
|
offload_param: bool = False,
|
|
|
):
|
|
|
self.module = module
|
|
|
|
|
|
self.inference_engine = inference_engine
|
|
|
|
|
|
|
|
|
if 'vllm_v_0_6_3' in str(type(self.inference_engine)) or 'vllm_v_0_5_4' in str(type(self.inference_engine)):
|
|
|
|
|
|
self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None
|
|
|
else:
|
|
|
|
|
|
self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None
|
|
|
|
|
|
self.model_config = model_config
|
|
|
self.device_mesh = device_mesh
|
|
|
self.offload_param = offload_param
|
|
|
|
|
|
|
|
|
self.full_params = full_params
|
|
|
if full_params and fsdp_version(self.module) == 1:
|
|
|
FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())
|
|
|
elif fsdp_version(self.module) == 1:
|
|
|
FSDP.set_state_dict_type(
|
|
|
self.module,
|
|
|
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
|
|
state_dict_config=ShardedStateDictConfig(),
|
|
|
)
|
|
|
|
|
|
self.tp_size = self.device_mesh["infer_tp"].size()
|
|
|
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
|
|
|
|
|
|
|
|
|
self.torch_random_states = torch.cuda.get_rng_state()
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
|
|
|
torch.cuda.manual_seed(gen_dp_rank + 1000)
|
|
|
self.gen_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.torch_random_states)
|
|
|
else:
|
|
|
self.gen_random_states = None
|
|
|
|
|
|
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
|
|
|
def __enter__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
|
|
|
if self.offload_param:
|
|
|
load_fsdp_model_to_gpu(self.module)
|
|
|
params = self.module.state_dict()
|
|
|
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
|
|
|
|
|
|
load_format = "hf" if self.full_params else "dtensor"
|
|
|
|
|
|
if vllm_version in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
):
|
|
|
self.inference_engine.sync_model_weights(params, load_format=load_format)
|
|
|
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
|
|
|
del params
|
|
|
else:
|
|
|
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
|
|
|
self.inference_engine.wake_up(tags=["weights"])
|
|
|
else:
|
|
|
self.inference_engine.wake_up()
|
|
|
|
|
|
|
|
|
self.update_params(params)
|
|
|
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
|
|
|
del params
|
|
|
if self.offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.module)
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
|
|
|
self.inference_engine.wake_up(tags=["kv_cache"])
|
|
|
|
|
|
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
|
|
|
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
self.torch_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.gen_random_states)
|
|
|
|
|
|
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
|
if vllm_version in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
):
|
|
|
self.inference_engine.offload_model_weights()
|
|
|
else:
|
|
|
self.inference_engine.sleep(level=1)
|
|
|
|
|
|
self.module.train()
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
self.gen_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.torch_random_states)
|
|
|
|
|
|
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
|
|
|
def preprocess_data(self, data: DataProto) -> DataProto:
|
|
|
"""All gather across tp group to make each rank has identical input."""
|
|
|
if self.tp_size == 1:
|
|
|
return data
|
|
|
|
|
|
|
|
|
if vllm_version in (
|
|
|
"0.5.4",
|
|
|
"0.6.3",
|
|
|
):
|
|
|
group = vllm_ps.get_tensor_model_parallel_group()
|
|
|
else:
|
|
|
group = vllm_ps.get_tensor_model_parallel_group().device_group
|
|
|
|
|
|
all_gather_data_proto(data=data, process_group=group)
|
|
|
return data
|
|
|
|
|
|
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
|
|
|
def postprocess_data(self, data: DataProto) -> DataProto:
|
|
|
"""Get chunk data of this tp rank since we do all gather in preprocess."""
|
|
|
if self.tp_size == 1:
|
|
|
return data
|
|
|
|
|
|
return data.chunk(chunks=self.tp_size)[self.tp_rank]
|
|
|
|
|
|
def update_params(self, updated_params):
|
|
|
model = self.model_runner.model
|
|
|
patch_vllm_moe_model_weight_loader(model)
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
device = torch.cuda.current_device()
|
|
|
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()))
|
|
|
logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))
|
|
|
|