File size: 8,799 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.protocol import all_gather_data_proto
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.torch_functional import check_cuda_is_available
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
from .base import BaseShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class FSDPVLLMShardingManager(BaseShardingManager):
@check_cuda_is_available()
def __init__(
self,
module: FSDP,
inference_engine: LLM,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None,
offload_param: bool = False,
):
self.module = module
# For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model
self.inference_engine = inference_engine
# self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None
if 'vllm_v_0_6_3' in str(type(self.inference_engine)) or 'vllm_v_0_5_4' in str(type(self.inference_engine)):
# vLLM <= v0.6.3
self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None
else:
# vLLM > v0.6.3
self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None
self.model_config = model_config
self.device_mesh = device_mesh
self.offload_param = offload_param
# Full params
self.full_params = full_params
if full_params and fsdp_version(self.module) == 1:
FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())
elif fsdp_version(self.module) == 1:
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.tp_size = self.device_mesh["infer_tp"].size()
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def __enter__(self):
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
params = self.module.state_dict()
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
# Copy, not share memory
load_format = "hf" if self.full_params else "dtensor"
if vllm_version in (
"0.5.4",
"0.6.3",
):
self.inference_engine.sync_model_weights(params, load_format=load_format)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
del params
else:
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
else:
self.inference_engine.wake_up()
# update model params
self.update_params(params)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
del params
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
torch.cuda.empty_cache()
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["kv_cache"])
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
# TODO(ZSL): check this
if vllm_version in (
"0.5.4",
"0.6.3",
):
self.inference_engine.offload_model_weights()
else:
self.inference_engine.sleep(level=1)
self.module.train()
# add empty cache after each compute
torch.cuda.empty_cache()
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp
if vllm_version in (
"0.5.4",
"0.6.3",
):
group = vllm_ps.get_tensor_model_parallel_group()
else:
group = vllm_ps.get_tensor_model_parallel_group().device_group
all_gather_data_proto(data=data, process_group=group)
return data
@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]
def update_params(self, updated_params):
model = self.model_runner.model
patch_vllm_moe_model_weight_loader(model)
world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()))
logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))
|