Xin-Rui's picture
Upload folder using huggingface_hub
7155cf2 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, Iterable, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps
from ...protocol import DataProto, all_gather_data_proto
from ...utils.model_utils import print_gpu_memory_usage
from .base import BaseShardingManager
class FSDPVLLMShardingManager(BaseShardingManager):
def __init__(
self,
module: FSDP,
inference_engine: LLM,
device_mesh: DeviceMesh = None,
):
self.module = module
self.inference_engine = inference_engine
self.device_mesh = device_mesh
with warnings.catch_warnings():
warnings.simplefilter("ignore")
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
# Record freed bytes to estimate memory usage correctly
# https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
self.freed_bytes = 0
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def _make_weight_iterator(
self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
for name, tensor in actor_weights.items():
yield name, tensor.full_tensor() if self.world_size != 1 else tensor
def __enter__(self):
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict()
print_gpu_memory_usage("After state_dict() in sharding manager")
self.inference_engine.wake_up()
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
model.load_weights(self._make_weight_iterator(actor_weights))
print_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights
torch.cuda.empty_cache()
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback):
print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
self.inference_engine.sleep(level=1)
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager")
self.module.train()
torch.cuda.empty_cache() # add empty cache after each compute
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size > 1:
data = data.chunk(chunks=self.tp_size)[self.tp_rank]
return data