# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import logging import os import torch from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from verl import DataProto from verl.protocol import all_gather_data_proto from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import check_cuda_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from .base import BaseShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class FSDPVLLMShardingManager(BaseShardingManager): @check_cuda_is_available() def __init__( self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False, device_mesh: DeviceMesh = None, offload_param: bool = False, ): self.module = module # For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model self.inference_engine = inference_engine # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None if 'vllm_v_0_6_3' in str(type(self.inference_engine)) or 'vllm_v_0_5_4' in str(type(self.inference_engine)): # vLLM <= v0.6.3 self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None else: # vLLM > v0.6.3 self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None self.model_config = model_config self.device_mesh = device_mesh self.offload_param = offload_param # Full params self.full_params = full_params if full_params and fsdp_version(self.module) == 1: FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( self.module, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig(), ) self.tp_size = self.device_mesh["infer_tp"].size() self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __enter__(self): # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 torch.cuda.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory load_format = "hf" if self.full_params else "dtensor" if vllm_version in ( "0.5.4", "0.6.3", ): self.inference_engine.sync_model_weights(params, load_format=load_format) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params else: if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["weights"]) else: self.inference_engine.wake_up() # update model params self.update_params(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params if self.offload_param: offload_fsdp_model_to_cpu(self.module) torch.cuda.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): # TODO(ZSL): check this if vllm_version in ( "0.5.4", "0.6.3", ): self.inference_engine.offload_model_weights() else: self.inference_engine.sleep(level=1) self.module.train() # add empty cache after each compute torch.cuda.empty_cache() # restore random states if self.device_mesh is not None: self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input.""" if self.tp_size == 1: return data # TODO: Current impl doesn't consider FSDP with torch micro-dp if vllm_version in ( "0.5.4", "0.6.3", ): group = vllm_ps.get_tensor_model_parallel_group() else: group = vllm_ps.get_tensor_model_parallel_group().device_group all_gather_data_proto(data=data, process_group=group) return data @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def postprocess_data(self, data: DataProto) -> DataProto: """Get chunk data of this tp rank since we do all gather in preprocess.""" if self.tp_size == 1: return data return data.chunk(chunks=self.tp_size)[self.tp_rank] def update_params(self, updated_params): model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) world_size = torch.distributed.get_world_size() device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items())) logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))