| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import warnings |
| | from typing import Dict, Iterable, Tuple, Union |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.distributed._tensor import DTensor |
| | from torch.distributed.device_mesh import DeviceMesh |
| | from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType |
| | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
| | from vllm import LLM |
| | from vllm.distributed import parallel_state as vllm_ps |
| |
|
| | from ...protocol import DataProto, all_gather_data_proto |
| | from ...utils.model_utils import print_gpu_memory_usage |
| | from .base import BaseShardingManager |
| |
|
| |
|
| | class FSDPVLLMShardingManager(BaseShardingManager): |
| | def __init__( |
| | self, |
| | module: FSDP, |
| | inference_engine: LLM, |
| | device_mesh: DeviceMesh = None, |
| | ): |
| | self.module = module |
| | self.inference_engine = inference_engine |
| | self.device_mesh = device_mesh |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| | FSDP.set_state_dict_type( |
| | self.module, |
| | state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| | state_dict_config=ShardedStateDictConfig(), |
| | ) |
| |
|
| | self.world_size = dist.get_world_size() |
| | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() |
| | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() |
| | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group |
| |
|
| | |
| | |
| | self.freed_bytes = 0 |
| |
|
| | |
| | self.torch_random_states = torch.cuda.get_rng_state() |
| | |
| | if self.device_mesh is not None: |
| | gen_dp_rank = self.device_mesh["dp"].get_local_rank() |
| | torch.cuda.manual_seed(gen_dp_rank + 1000) |
| | self.gen_random_states = torch.cuda.get_rng_state() |
| | torch.cuda.set_rng_state(self.torch_random_states) |
| | else: |
| | self.gen_random_states = None |
| |
|
| | def _make_weight_iterator( |
| | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] |
| | ) -> Iterable[Tuple[str, torch.Tensor]]: |
| | for name, tensor in actor_weights.items(): |
| | yield name, tensor.full_tensor() if self.world_size != 1 else tensor |
| |
|
| | def __enter__(self): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | torch.cuda.empty_cache() |
| | print_gpu_memory_usage("Before state_dict() in sharding manager") |
| | actor_weights = self.module.state_dict() |
| | print_gpu_memory_usage("After state_dict() in sharding manager") |
| |
|
| | self.inference_engine.wake_up() |
| | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model |
| | model.load_weights(self._make_weight_iterator(actor_weights)) |
| | print_gpu_memory_usage("After sync model weights in sharding manager") |
| |
|
| | del actor_weights |
| | torch.cuda.empty_cache() |
| | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") |
| | |
| | if self.device_mesh is not None: |
| | self.torch_random_states = torch.cuda.get_rng_state() |
| | torch.cuda.set_rng_state(self.gen_random_states) |
| |
|
| | def __exit__(self, exc_type, exc_value, traceback): |
| | print_gpu_memory_usage("Before vllm offload in sharding manager") |
| | free_bytes_before_sleep = torch.cuda.mem_get_info()[0] |
| | self.inference_engine.sleep(level=1) |
| | free_bytes_after_sleep = torch.cuda.mem_get_info()[0] |
| | self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep |
| | print_gpu_memory_usage("After vllm offload in sharding manager") |
| |
|
| | self.module.train() |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | if self.device_mesh is not None: |
| | self.gen_random_states = torch.cuda.get_rng_state() |
| | torch.cuda.set_rng_state(self.torch_random_states) |
| |
|
| | def preprocess_data(self, data: DataProto) -> DataProto: |
| | """All gather across tp group to make each rank has identical input.""" |
| | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group) |
| | return data |
| |
|
| | def postprocess_data(self, data: DataProto) -> DataProto: |
| | """Get chunk data of this tp rank since we do all gather in preprocess.""" |
| | if self.tp_size > 1: |
| | data = data.chunk(chunks=self.tp_size)[self.tp_rank] |
| |
|
| | return data |
| |
|