|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
from typing import Union
|
|
|
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from sglang.srt.entrypoints.engine import Engine
|
|
|
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
|
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
|
|
from sglang.srt.utils import MultiprocessingSerializer
|
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
|
|
from verl import DataProto
|
|
|
from verl.protocol import all_gather_data_proto
|
|
|
from verl.utils.debug import log_gpu_memory_usage
|
|
|
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
|
|
|
from verl.utils.torch_functional import broadcast_dict_tensor, check_cuda_is_available
|
|
|
|
|
|
from .base import BaseShardingManager
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
|
|
if isinstance(tensor, DTensor):
|
|
|
return tensor.full_tensor()
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
class FSDPSGLangShardingManager(BaseShardingManager):
|
|
|
@check_cuda_is_available()
|
|
|
def __init__(
|
|
|
self,
|
|
|
module: FSDP,
|
|
|
inference_engine: Union[VerlEngine, Engine],
|
|
|
model_config,
|
|
|
full_params: bool = False,
|
|
|
device_mesh: DeviceMesh = None,
|
|
|
offload_param: bool = False,
|
|
|
):
|
|
|
self.module = module
|
|
|
self.inference_engine = inference_engine
|
|
|
self.model_config = model_config
|
|
|
self.device_mesh = device_mesh
|
|
|
self.offload_param = offload_param
|
|
|
|
|
|
|
|
|
self.full_params = full_params
|
|
|
if full_params and fsdp_version(self.module) == 1:
|
|
|
FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())
|
|
|
elif fsdp_version(self.module) == 1:
|
|
|
FSDP.set_state_dict_type(
|
|
|
self.module,
|
|
|
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
|
|
state_dict_config=ShardedStateDictConfig(),
|
|
|
)
|
|
|
|
|
|
|
|
|
self.torch_random_states = torch.cuda.get_rng_state()
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
|
|
|
torch.cuda.manual_seed(gen_dp_rank + 1000)
|
|
|
self.gen_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.torch_random_states)
|
|
|
else:
|
|
|
self.gen_random_states = None
|
|
|
|
|
|
def __enter__(self):
|
|
|
torch.cuda.empty_cache()
|
|
|
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
|
|
|
if self.offload_param:
|
|
|
load_fsdp_model_to_gpu(self.module)
|
|
|
params = self.module.state_dict()
|
|
|
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
|
|
|
device = torch.cuda.current_device()
|
|
|
params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}
|
|
|
|
|
|
self.update_weights(params)
|
|
|
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
|
|
|
|
|
|
del params
|
|
|
if self.offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.module)
|
|
|
torch.cuda.empty_cache()
|
|
|
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
|
|
|
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
self.torch_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.gen_random_states)
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
|
|
|
self.release_memory()
|
|
|
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
|
|
|
|
|
|
self.module.train()
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
if self.device_mesh is not None:
|
|
|
self.gen_random_states = torch.cuda.get_rng_state()
|
|
|
torch.cuda.set_rng_state(self.torch_random_states)
|
|
|
|
|
|
def update_weights(self, params):
|
|
|
self.inference_engine.resume_memory_occupation()
|
|
|
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
|
|
|
|
|
|
def release_memory(self):
|
|
|
self.inference_engine.release_memory_occupation()
|
|
|
|
|
|
def preprocess_data(self, data: DataProto) -> DataProto:
|
|
|
"""All gather across tp group to make each rank has identical input."""
|
|
|
if self.device_mesh["infer_tp"].mesh.size()[0] == 1:
|
|
|
return data
|
|
|
|
|
|
|
|
|
group = self.device_mesh["infer_tp"].get_group()
|
|
|
|
|
|
all_gather_data_proto(data=data, process_group=group)
|
|
|
return data
|
|
|
|
|
|
def postprocess_data(self, data: DataProto) -> DataProto:
|
|
|
|
|
|
global_rank = self.device_mesh.get_rank()
|
|
|
tp_rank = self.device_mesh["infer_tp"].get_local_rank()
|
|
|
tp_size = self.device_mesh["infer_tp"].mesh.size()[0]
|
|
|
src_rank = global_rank // tp_size * tp_size
|
|
|
broadcast_dict_tensor(data.batch, src=src_rank, group=self.device_mesh["infer_tp"].get_group())
|
|
|
if tp_size > 1:
|
|
|
local_prompts = data.chunk(chunks=tp_size)
|
|
|
data = local_prompts[tp_rank]
|
|
|
return data
|
|
|
|
|
|
|
|
|
class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager):
|
|
|
def __init__(
|
|
|
self,
|
|
|
module: FSDP,
|
|
|
inference_engine: Engine,
|
|
|
model_config,
|
|
|
full_params: bool = False,
|
|
|
device_mesh: DeviceMesh = None,
|
|
|
offload_param: bool = False,
|
|
|
):
|
|
|
super().__init__(module, inference_engine, model_config, full_params, device_mesh, offload_param)
|
|
|
|
|
|
def update_weights(self, params):
|
|
|
load_format = None if self.full_params else "dtensor"
|
|
|
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
|
|
self.inference_engine.resume_memory_occupation()
|
|
|
|
|
|
|
|
|
named_tensors = [(k, v) for k, v in params.items()]
|
|
|
load_format = None
|
|
|
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
|
|
serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))
|
|
|
|
|
|
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
|
|
gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])]
|
|
|
else:
|
|
|
gathered_serialized_tensors = None
|
|
|
dist.gather_object(
|
|
|
obj=serialized_tensor,
|
|
|
object_gather_list=gathered_serialized_tensors,
|
|
|
dst=self.device_mesh["infer_tp"].mesh.tolist()[0],
|
|
|
group=self.device_mesh["infer_tp"].get_group(),
|
|
|
)
|
|
|
|
|
|
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
|
|
self.inference_engine.update_weights_from_tensor(
|
|
|
named_tensors=[
|
|
|
(
|
|
|
name,
|
|
|
LocalSerializedTensor(values=gathered_serialized_tensors),
|
|
|
)
|
|
|
],
|
|
|
load_format=load_format,
|
|
|
flush_cache=tensor_index == len(named_tensors) - 1,
|
|
|
)
|
|
|
|
|
|
def release_memory(self):
|
|
|
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
|
|
self.inference_engine.release_memory_occupation()
|
|
|
|