| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import functools |
| import itertools |
| import json |
| import math |
| import os |
| from abc import ABC |
| from collections import OrderedDict |
| from contextlib import contextmanager, nullcontext |
| from typing import Optional, cast |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from packaging import version |
| from torch.distributed import DeviceMesh |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp._runtime_utils import _lazy_init |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
| from transformers.trainer_pt_utils import get_module_class_from_name |
|
|
| from verl.utils.device import get_device_id, get_device_name, get_torch_device |
| from verl.utils.model import check_exclude_modules, check_target_modules |
|
|
| if version.parse(torch.__version__) >= version.parse("2.6"): |
| from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard |
| from torch.distributed.fsdp._fully_shard._fsdp_init import _get_post_forward_mesh_info |
| from torch.distributed.tensor import DTensor, Shard |
| from torch.distributed.tensor._dtensor_spec import DTensorSpec |
|
|
| fully_shard_module = torch.distributed.fsdp._fully_shard._fully_shard |
| elif version.parse(torch.__version__) >= version.parse("2.4"): |
| from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard |
|
|
| fully_shard_module = torch.distributed._composable.fsdp |
| else: |
| fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy, fully_shard_module = None, None, None, None, None |
|
|
|
|
| def init_fn(x: torch.nn.Module): |
| if torch.distributed.get_rank() != 0: |
| x = x.to_empty(device=get_device_id(), recurse=False) |
| get_torch_device().empty_cache() |
| return x |
|
|
|
|
| def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): |
| from accelerate import init_empty_weights |
|
|
| cpu_init_weights = lambda: torch.device("cpu") |
| if use_meta_tensor: |
| if mesh is None: |
| init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights |
| else: |
| init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights |
| else: |
| init_context = cpu_init_weights |
| return init_context |
|
|
|
|
| |
| |
| def get_fsdp_wrap_policy(module, config=None, is_lora=False): |
| """Get FSDP wrap policy for the module. |
| |
| Args: |
| module: The module to get wrap policy for |
| config: Configuration for wrap policy |
| is_lora: Whether to enable lambda policy for LoRA modules |
| """ |
| if config is None: |
| config = {} |
|
|
| |
| |
| def _get_attr(attr_name, default_value=None): |
| if hasattr(config, "get"): |
| return config.get(attr_name, default_value) |
| else: |
| return config.__getattribute__(attr_name) |
|
|
| if _get_attr("disable", False): |
| return None |
|
|
| default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) |
| fsdp_transformer_layer_cls_to_wrap = _get_attr( |
| "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap |
| ) |
| min_num_params = _get_attr("min_num_params", 0) |
| auto_wrap_policy = None |
|
|
| policies = [] |
|
|
| from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy |
|
|
| |
| if is_lora: |
|
|
| def lambda_policy_fn(module): |
| return bool( |
| len(list(module.named_children())) == 0 |
| and getattr(module, "weight", None) is not None |
| and module.weight.requires_grad |
| ) |
|
|
| lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
| policies.append(lambda_policy) |
|
|
| if min_num_params > 0: |
| size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) |
| policies.append(size_policy) |
| elif fsdp_transformer_layer_cls_to_wrap is not None: |
| transformer_cls_to_wrap = set() |
| for layer_class in fsdp_transformer_layer_cls_to_wrap: |
| transformer_cls = get_module_class_from_name(module, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
|
|
| transformer_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
| policies.append(transformer_policy) |
|
|
| if len(policies) > 0: |
| auto_wrap_policy = functools.partial(_or_policy, policies=policies) |
|
|
| return auto_wrap_policy |
|
|
|
|
| @torch.no_grad() |
| def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): |
| if fsdp_version(model) == 2 or fsdp_version(model) == 0: |
| offload_fsdp2_model_to_cpu(model, empty_cache) |
| return |
|
|
| assert isinstance(model, FSDP) |
| |
| _lazy_init(model, model) |
| assert model._is_root, "Only support root model offloading to CPU" |
| for handle in model._all_handles: |
| if handle._offload_params: |
| continue |
| flat_param = handle.flat_param |
| assert ( |
| flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() |
| and id(flat_param.data) != id(flat_param._local_shard) |
| and flat_param.data.size() == flat_param._local_shard.size() |
| ) |
| handle.flat_param_to(torch.device("cpu"), non_blocking=True) |
| |
| flat_param._local_shard = flat_param.data |
| assert id(flat_param._local_shard) != id(flat_param.data) |
| if empty_cache: |
| get_torch_device().empty_cache() |
|
|
|
|
| @torch.no_grad() |
| def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): |
| model.cpu() |
| if empty_cache: |
| get_torch_device().empty_cache() |
|
|
|
|
| @torch.no_grad() |
| def load_fsdp_model_to_gpu(model: FSDP): |
| if fsdp_version(model) == 2 or fsdp_version(model) == 0: |
| load_fsdp2_model_to_gpu(model) |
| return |
|
|
| assert isinstance(model, FSDP) |
| |
| _lazy_init(model, model) |
| assert model._is_root, "Only support root model loading to GPU" |
| device_id = get_device_id() |
| for handle in model._all_handles: |
| if handle._offload_params: |
| continue |
| flat_param = handle.flat_param |
| handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) |
| |
| flat_param._local_shard = flat_param.data |
|
|
|
|
| @torch.no_grad() |
| def load_fsdp2_model_to_gpu(model): |
| device = get_device_id() |
| model.to(device) |
|
|
|
|
| @torch.no_grad() |
| def offload_fsdp_optimizer(optimizer): |
| if not optimizer.state: |
| return |
| for param_group in optimizer.param_groups: |
| for param in param_group["params"]: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to("cpu", non_blocking=True) |
|
|
|
|
| @torch.no_grad() |
| def load_fsdp_optimizer(optimizer, device_id): |
| if not optimizer.state: |
| return |
| for param_group in optimizer.param_groups: |
| for param in param_group["params"]: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to(device_id, non_blocking=True) |
|
|
|
|
| @contextmanager |
| def meta_device_init(): |
| """ |
| Create model parameters with meta device. |
| |
| Note buffers in model will still be initialized in default device (e.g., CPU), |
| since the buffers can be non-persistent and filled with expected values that can |
| NOT be captured in meta device. |
| """ |
| device = torch.device("meta") |
| old_register_parameter = nn.Module.register_parameter |
| registered = set() |
|
|
| def register_empty_parameter(module, name, param): |
| old_register_parameter(module, name, param) |
| |
| |
| if param is not None and param not in registered: |
| param_cls = type(module._parameters[name]) |
| kwargs = module._parameters[name].__dict__ |
| kwargs["requires_grad"] = param.requires_grad |
| module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) |
| registered.add(module._parameters[name]) |
|
|
| try: |
| nn.Module.register_parameter = register_empty_parameter |
| yield |
| finally: |
| registered.clear() |
| nn.Module.register_parameter = old_register_parameter |
|
|
|
|
| def parallel_load_safetensors(filepath): |
| """ |
| Parallel load safetensors from huggingface checkpoint |
| |
| Huggingface checkpoint contains: |
| |
| - config.json: a json file for model configuration |
| - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index |
| - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks |
| |
| Or (when model is small), |
| |
| - model.safetensors: a binary file for all parameters and buffers |
| |
| Each rank will own a part of model chunks and load them directly into GPU memory. |
| """ |
| from safetensors.torch import load_file |
|
|
| safetensors2param = {} |
|
|
| index_file = os.path.join(filepath, "model.safetensors.index.json") |
| if os.path.exists(index_file): |
| index = json.load(open(index_file, "rb")) |
| for param_name, filename in index["weight_map"].items(): |
| safetensors2param.setdefault(filename, []).append(param_name) |
| else: |
| |
| param_file = os.path.join(filepath, "model.safetensors") |
| assert os.path.exists(param_file), f"Cannot find {param_file}" |
| states = load_file(param_file) |
| for param_name in states: |
| safetensors2param.setdefault("model.safetensors", []).append(param_name) |
| del states |
|
|
| total_files = len(safetensors2param) |
| ckpt_chunks = sorted(safetensors2param.keys()) |
| world_size = dist.get_world_size() |
| size = int(math.ceil(total_files / world_size)) |
| ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] |
|
|
| shard_states = {} |
| device = get_device_id() |
| for rank, files in enumerate(ckpt_chunks): |
| if rank == dist.get_rank(): |
| for file in files: |
| file = os.path.join(filepath, file) |
| states = load_file(file, device=device) |
| |
| shard_states.update(states) |
| else: |
| for file in files: |
| for param_name in safetensors2param[file]: |
| shard_states[param_name] = rank |
| return shard_states |
|
|
|
|
| def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): |
| """ |
| Generate a function to initialize sub-modules in the `module` with `shard_states` |
| from huggingface checkpoint. |
| |
| Args: |
| module (torch.nn.Module): the global module to be initialized |
| shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint |
| |
| Returns: |
| init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` |
| """ |
|
|
| state2fqn = {} |
| for name, state in itertools.chain( |
| module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) |
| ): |
| state2fqn.setdefault(state, []).append(name) |
| |
| shared = {s for s, names in state2fqn.items() if len(names) > 1} |
| materialized_states = {} |
|
|
| @torch.no_grad() |
| def create_and_sync_state(param_name, state, is_param): |
| assert param_name in shard_states, f"{param_name} not loaded" |
| device = get_device_id() |
| if is_param: |
| param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) |
| else: |
| param = torch.empty_like(state.data, device=device) |
| loaded = shard_states[param_name] |
| if isinstance(loaded, torch.nn.Parameter | torch.Tensor): |
| |
| param.data.copy_(loaded.data) |
| dist.broadcast(param.data, src=dist.get_rank()) |
| else: |
| assert isinstance(loaded, int) |
| dist.broadcast(param.data, src=loaded) |
| shard_states.pop(param_name) |
| del loaded |
| return param |
|
|
| def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): |
| param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) |
| |
| for name, state in param_and_buffers: |
| if not state.is_meta: |
| continue |
| is_param = name in sub_mod._parameters |
| fqn = state2fqn[state].pop(0) |
| |
| if (not is_param) and fqn not in shard_states: |
| if state.is_meta: |
| raise RuntimeError( |
| f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved " |
| f"in checkpoint and user should guarantee to init in CPU / GPU device." |
| ) |
| continue |
| |
| if state in shared: |
| if state not in materialized_states: |
| materialized_states[state] = create_and_sync_state(fqn, state, is_param) |
| else: |
| if fqn in shard_states: |
| shard_states.pop(fqn) |
| materialize_state = materialized_states[state] |
| |
| else: |
| materialize_state = create_and_sync_state(fqn, state, is_param) |
| if is_param: |
| sub_mod._parameters[name] = materialize_state |
| else: |
| sub_mod._buffers[name] = materialize_state |
| if recurse: |
| for module in sub_mod.children(): |
| init_fn(module, recurse=True) |
|
|
| |
| |
| return sub_mod |
|
|
| return init_fn |
|
|
|
|
| def fsdp_version(model): |
| if isinstance(model, FSDP): |
| return 1 |
| elif isinstance(model, FSDPModule): |
| return 2 |
| else: |
| return 0 |
|
|
|
|
| def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): |
| if fsdp_version(model) == 1: |
| return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) |
| else: |
| return nullcontext() |
|
|
|
|
| def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): |
| """ |
| Get the full state dict from an FSDP model. |
| |
| Args: |
| model (torch.nn.Module): The FSDP model to get state dict from |
| offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True. |
| rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True. |
| |
| Returns: |
| dict: The full state dict of the model |
| |
| Raises: |
| NotImplementedError: If the FSDP version is unknown |
| """ |
| if fsdp_version(model) == 1: |
| from torch.distributed.fsdp import FullStateDictConfig, StateDictType |
|
|
| state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) |
| with get_fsdp_state_ctx( |
| model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None |
| ): |
| state_dict = model.state_dict() |
| return state_dict |
| elif fsdp_version(model) == 2 or fsdp_version(model) == 0: |
| from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict |
|
|
| state_dict_config = StateDictOptions( |
| full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only |
| ) |
| state_dict = get_model_state_dict(model, options=state_dict_config) |
| return state_dict |
| else: |
| raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") |
|
|
|
|
| def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): |
| """ |
| Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the |
| parameters from rank 0 to all other ranks. This function modifies the model in-place. |
| |
| Args: |
| model (`torch.nn.Module`): The model to load the state dict into |
| full_state (`dict`): The full state dict to load, can only be on rank 0 |
| """ |
|
|
| if version.parse(torch.__version__) >= version.parse("2.7.0"): |
| from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict |
| else: |
| |
| |
| from verl.third_party.torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict |
|
|
| |
| if dist.get_rank() == 0: |
| model = model.to(device=get_device_id(), non_blocking=True) |
| else: |
| model = model.to_empty(device=get_device_id()) |
|
|
| cpu_offload = cpu_offload is not None |
| options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) |
| set_model_state_dict(model, full_state, options=options) |
|
|
| |
| for name, buf in model.named_buffers(): |
| dist.broadcast(buf, src=0) |
|
|
| if cpu_offload: |
| model.to("cpu", non_blocking=True) |
| for buf in model.buffers(): |
| buf.data = buf.data.to(get_device_id()) |
|
|
|
|
| @contextmanager |
| def maybe_patch_fsdp_module(model): |
| if fully_shard_module is None: |
| yield |
| return |
|
|
| orig_fsdp_module = fully_shard_module.FSDPModule |
|
|
| class FSDPModuleABC(ABC, orig_fsdp_module): |
| pass |
|
|
| try: |
| if isinstance(model, ABC): |
| fully_shard_module.FSDPModule = FSDPModuleABC |
| yield |
| finally: |
| fully_shard_module.FSDPModule = orig_fsdp_module |
|
|
|
|
| def _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap): |
| """Select modules to wrap individually with fully_shard in FSDP2. |
| |
| Matches transformer layers by class name, and embed_tokens/lm_head by name |
| (with isinstance fallback). Name-based matching is needed because peft wraps |
| embed_tokens in ModulesToSaveWrapper, breaking isinstance(module, nn.Embedding). |
| When tie_word_embeddings is True, embed_tokens and lm_head share weights and |
| must not be wrapped separately. |
| """ |
| _tie = getattr(model.config, "tie_word_embeddings", False) |
| _wrap_by_name = set() if _tie else {"embed_tokens", "lm_head"} |
|
|
| modules = [] |
| for name, module in model.named_modules(): |
| leaf_name = name.rsplit(".", 1)[-1] if "." in name else name |
| if ( |
| module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap |
| or (isinstance(module, nn.Embedding) and not _tie) |
| or (leaf_name in _wrap_by_name and hasattr(module, "weight")) |
| ): |
| modules.append(module) |
| return modules |
|
|
|
|
| def apply_fsdp2(model, fsdp_kwargs, config): |
| """model: AutoModelForCausalLM""" |
| assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" |
|
|
| default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) |
| fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( |
| "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap |
| ) |
|
|
| if isinstance(fsdp_transformer_layer_cls_to_wrap, str): |
| fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] |
|
|
| assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None |
|
|
| modules = _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap) |
|
|
| for idx, module in enumerate(modules): |
| |
| |
| with maybe_patch_fsdp_module(module): |
| fully_shard(module, **fsdp_kwargs) |
|
|
| |
| |
| with maybe_patch_fsdp_module(model): |
| fully_shard(model, **fsdp_kwargs) |
|
|
|
|
| def get_shard_placement_fn(fsdp_size): |
| """Choose the dimension that can divide fsdp_size to avoid padding""" |
|
|
| def shard_placement_fn(param): |
| shape = list(param.shape) |
| for i in range(len(shape)): |
| if shape[i] % fsdp_size == 0: |
| return Shard(i) |
| return Shard(0) |
|
|
| return shard_placement_fn |
|
|
|
|
| def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): |
| """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" |
| from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm |
|
|
| if isinstance(parameters, torch.Tensor): |
| parameters = [parameters] |
| else: |
| |
| parameters = list(parameters) |
| grads = [p.grad for p in parameters if p.grad is not None] |
| total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) |
| total_norm = total_norm.to(get_device_id(), non_blocking=True) |
| _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) |
| return total_norm |
|
|
|
|
| def layered_summon_lora_params(fsdp_module) -> OrderedDict: |
| from peft.utils.save_and_load import get_peft_model_state_dict |
|
|
| def __prefix_submodules(module, prefix): |
| for name, submodule in module.named_modules(): |
| if name.startswith(prefix) and "." not in name[len(prefix) :]: |
| yield name, submodule |
|
|
| lora_params = OrderedDict() |
| prefix_list = [ |
| |
| "_fsdp_wrapped_module.base_model.model.", |
| "_fsdp_wrapped_module.base_model.model.model.", |
| "_fsdp_wrapped_module.base_model.model.model.layers.", |
| "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", |
| |
| "base_model.model.", |
| "base_model.model.model.", |
| "base_model.model.model.layers.", |
| "base_model.model.model.language_model.layers.", |
| ] |
| peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) |
| for prefix in prefix_list: |
| for name, submodule in __prefix_submodules(fsdp_module, prefix): |
| prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") |
| if name.endswith(".model") or name.endswith(".layers"): |
| continue |
| if fsdp_version(submodule) > 0: |
| with FSDP.summon_full_params(submodule, writeback=False): |
| sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) |
| sub_lora_params = { |
| f"{prefix}.{name}": param.full_tensor().detach().cpu() |
| if hasattr(param, "full_tensor") |
| else param.detach().cpu() |
| for name, param in sub_lora_params.items() |
| } |
| lora_params.update(sub_lora_params) |
| submodule._is_root = False |
| get_torch_device().empty_cache() |
| return lora_params |
|
|
|
|
| def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict: |
| """ |
| collect lora params or full params if base model is not ready in vllm |
| work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) |
| """ |
| from peft.utils.save_and_load import get_peft_model_state_dict |
|
|
| lora_params = OrderedDict() |
| peft_model = getattr(module, "_fsdp_wrapped_module", module) |
| if fsdp_version(module) > 0: |
| if layered_summon: |
| if not base_sync_done: |
| raise ValueError( |
| "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " |
| "rollout.load_format=safetensors" |
| ) |
| lora_params = layered_summon_lora_params(module) |
| else: |
| with FSDP.summon_full_params(module, writeback=False): |
| if base_sync_done: |
| lora_params = get_peft_model_state_dict(peft_model) |
| lora_params = { |
| name: param.full_tensor().detach().cpu() |
| if hasattr(param, "full_tensor") |
| else param.detach().cpu() |
| for name, param in lora_params.items() |
| } |
| else: |
| model = peft_model.base_model.model |
| orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() |
| model = model.to("cpu") |
| for name, param in model.state_dict().items(): |
| if any(x in name for x in ["_flat_param", "lora_"]): |
| continue |
| name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") |
| lora_params[name] = ( |
| param.full_tensor().detach().cpu() |
| if hasattr(param, "full_tensor") |
| else param.detach().cpu() |
| ) |
| model = model.to(orig_dev) |
| get_torch_device().empty_cache() |
| else: |
| if base_sync_done: |
| lora_params = get_peft_model_state_dict(peft_model) |
| else: |
| model = peft_model.base_model.model |
| orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() |
| model = model.to("cpu") |
| for name, param in model.state_dict().items(): |
| if any(x in name for x in ["_flat_param", "lora_"]): |
| continue |
| name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") |
| lora_params[name] = param.detach().cpu() |
| model = model.to(orig_dev) |
| return lora_params |
|
|
|
|
| def replace_lora_wrapper(k, peft_config): |
| """Replace LoRA parameter keys with base layer equivalents. |
| |
| Transforms LoRA parameter names to their corresponding base layer |
| names for proper weight loading in vLLM when base model sync is not done. |
| |
| Args: |
| k (str): Original parameter key name. |
| |
| Returns: |
| str: Transformed parameter key for base layer. |
| """ |
| stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
| if k.endswith(".weight"): |
| module_k = k[: -len(".weight")] |
| if check_exclude_modules(peft_config, module_k): |
| return k |
| elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): |
| return f"{module_k}.base_layer.weight" |
| if k.endswith(".bias"): |
| module_k = k[: -len(".bias")] |
| if check_exclude_modules(peft_config, module_k): |
| return k |
| elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): |
| return f"{module_k}.base_layer.bias" |
| return k |
|
|
|
|
| def set_reshard_after_forward(module: FSDPModule, reshard_after_forward: bool, recurse: bool = True) -> None: |
| """ |
| Sets if the module should reshard parameters after forward. This can be |
| used to change the ``reshard_after_forward`` FSDP arg at runtime. For |
| example, this can be used to set the FSDP root module's value to |
| ``True`` (since it is otherwise specially set to ``False``), or it can |
| set an FSDP module's value to ``False`` for running evals and set back |
| to ``True`` for training. |
| |
| Args: |
| reshard_after_forward (bool): Whether to reshard parameters after |
| forward. |
| recurse (bool): Whether to set for all FSDP submodules or just the |
| passed-in module. |
| |
| --- |
| Copied from https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fully_shard.py to |
| address the absence of the set_reshard_after_forward function in torch versions earlier than 2.8.0. |
| """ |
|
|
| if not isinstance(reshard_after_forward, bool): |
| raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}") |
| self_module = cast(nn.Module, module) |
| modules = list(self_module.modules()) if recurse else [self_module] |
| for module in modules: |
| if isinstance(module, FSDPModule): |
| state = module._get_fsdp_state() |
| state._auto_reshard_after_forward = False |
| if fsdp_param_group := state._fsdp_param_group: |
| fsdp_param_group.post_forward_mesh_info = _get_post_forward_mesh_info( |
| reshard_after_forward, fsdp_param_group.mesh_info |
| ) |
|
|
|
|
| def normalize_peft_param_name(params: dict) -> dict: |
| """ |
| Converts peft model parameter name to base parameter name |
| For example, |
| base_model.model.model.embed_tokens.weight -> model.embed_tokens.weight |
| base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight -> model.layers.0.self_attn.q_proj.weight |
| and remove params such as base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight, |
| base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight |
| """ |
|
|
| def _normalize_peft_name(name: str) -> str: |
| return name.replace("base_model.model.", "").replace("base_model.", "").replace(".base_layer", "") |
|
|
| def _is_lora_key(name: str) -> bool: |
| |
| return ("lora_" in name) or (".adapter_" in name) |
|
|
| params = [(_normalize_peft_name(k), v) for k, v in params.items()] |
| |
| params = {k: v for k, v in params if not _is_lora_key(k)} |
| return params |
|
|
|
|
| def _merge_or_unmerge_lora_(module, merge: bool): |
| """Merge or unmerge LoRA adapters in a module. |
| |
| Args: |
| module: The module containing LoRA layers |
| merge: If True, merge LoRA into base model; if False, unmerge LoRA |
| """ |
| from peft.tuners.lora import LoraLayer |
|
|
| with torch.no_grad(): |
| for m in module.modules(): |
| if isinstance(m, LoraLayer): |
| is_merged = getattr(m, "merged", False) |
| if merge and not is_merged: |
| m.merge() |
| elif (not merge) and is_merged: |
| m.unmerge() |
|
|
|
|
| |
| def _clean_merged_lora_(module): |
| """Cleans the merged lora adapters""" |
| from peft.tuners.lora import LoraLayer |
|
|
| with torch.no_grad(): |
| for m in module.modules(): |
| if isinstance(m, LoraLayer): |
| merged_adapters = getattr(m, "merged_adapters", False) |
| if merged_adapters: |
| m.merged_adapters = [] |
|
|
|
|
| def fsdp_merge_unmerge(module: nn.Module, do_merge: bool): |
| """Merge or unmerge LoRA adapters in FSDP module. |
| |
| For FSDP (v1), it gathers all model parameters to each device, which may cause OOM. |
| For FSDP2, it gathers model parameters layer-by-layer to reduce memory footprint. |
| |
| Args: |
| module: The FSDP module to merge/unmerge LoRA adapters |
| do_merge: If True, merge LoRA into base model; if False, unmerge LoRA |
| """ |
| version = fsdp_version(module) |
| assert version in [1, 2], f"fsdp_merge_unmerge requires FSDP module, got version {version}" |
|
|
| if version == 1: |
| |
| with FSDP.summon_full_params(module, writeback=True, with_grads=False): |
| _merge_or_unmerge_lora_(module, merge=do_merge) |
| else: |
| |
| for name, submodule in module.named_modules(): |
| if isinstance(submodule, FSDPModule) and name != "": |
| with FSDP.summon_full_params(submodule, writeback=True, with_grads=False): |
| _merge_or_unmerge_lora_(submodule, merge=do_merge) |
|
|
|
|
| def backup_base_model_weights(module): |
| """Backup base model weights to CPU with LoRA temporarily disabled. |
| |
| This function temporarily disables LoRA adapters, backs up the clean base model weights |
| to CPU, then re-enables the adapters. |
| |
| Args: |
| module: The PEFT model with LoRA adapters |
| |
| Returns: |
| dict: Dictionary mapping parameter name to CPU tensor backup of base model weights |
| """ |
| from peft import PeftModel |
|
|
| backup = {} |
| with torch.no_grad(): |
| |
| if isinstance(module, PeftModel): |
| |
| with module.disable_adapter(): |
| |
| for name, param in module.named_parameters(): |
| if "lora" not in name.lower(): |
| backup[name] = param.data.clone().cpu() |
| else: |
| |
| for name, param in module.named_parameters(): |
| backup[name] = param.data.clone().cpu() |
| return backup |
|
|
|
|
| def restore_base_model_weights(module, backup): |
| """Restore base model weights from CPU backup. |
| |
| This function restores the base model weights from the CPU backup, effectively |
| undoing any LoRA merge operations. |
| |
| Args: |
| module: The PEFT model with LoRA adapters |
| backup: Dictionary mapping parameter name to CPU tensor backup of base model weights |
| """ |
| with torch.no_grad(): |
| for name, param in module.named_parameters(): |
| if name in backup: |
| param.data.copy_(backup[name].to(param.device)) |
|
|
|
|
| @contextmanager |
| def merged_lora_context(actor, backup_adapters=False): |
| """Context manager to temporarily merge LoRA adapters. |
| |
| This context manager merges LoRA adapters into the base model weights, |
| performs operations (like syncing weights to vLLM), then restores the base model |
| weights from backup. |
| |
| Args: |
| actor: The actor module with LoRA adapters to merge |
| backup_adapters: If True, backup base model weights (with LoRA disabled) before |
| merging and restore them after. This is more numerically stable than unmerging. |
| |
| Yields: |
| None |
| """ |
| base_weights_backup = None |
| if backup_adapters: |
| |
| base_weights_backup = backup_base_model_weights(actor) |
|
|
| |
| fsdp_merge_unmerge(actor, do_merge=True) |
| try: |
| |
| yield |
| finally: |
| if backup_adapters and base_weights_backup is not None: |
| |
| restore_base_model_weights(actor, base_weights_backup) |
| _clean_merged_lora_(actor) |
| else: |
| |
| fsdp_merge_unmerge(actor, do_merge=False) |
|
|
|
|
| def fsdp2_sharded_save_to_cpu( |
| model: torch.nn.Module, |
| ) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: |
| """ |
| Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. |
| |
| Args: |
| model: FSDP2-wrapped model whose parameters are of DTensor type. |
| |
| Returns: |
| cpu_sharded_state: Dictionary of CPU shards for the current process. |
| Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) |
| global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) |
| """ |
| cpu_sharded_state = {} |
| global_spec = None |
|
|
| for param_name, param in model.named_parameters(): |
| |
| if not isinstance(param, DTensor): |
| |
| cpu_tensor = param.detach().cpu() |
| cpu_sharded_state[param_name] = (cpu_tensor, None) |
| continue |
|
|
| |
| if global_spec is None: |
| global_spec = param._spec |
| assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" |
| assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" |
|
|
| |
| local_gpu_tensor = param._local_tensor |
| |
| local_cpu_tensor = local_gpu_tensor.detach().cpu() |
| |
| cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) |
|
|
| assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." |
| return cpu_sharded_state, global_spec |
|
|
|
|
| def fsdp2_sharded_load_from_cpu( |
| model: torch.nn.Module, |
| cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], |
| target_spec: DTensorSpec, |
| ) -> None: |
| """ |
| Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, |
| keeping sharding rules unchanged. |
| |
| Args: |
| model: FSDP2 model to be restored (must have the same structure as when saved) |
| cpu_sharded_state: Shard data read from CPU memory by the current process |
| (from fsdp2_sharded_save_to_cpu) |
| target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) |
| """ |
| |
| current_device_mesh = None |
| for param in model.parameters(): |
| if isinstance(param, DTensor): |
| current_device_mesh = param._spec.device_mesh |
| break |
| assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" |
| assert current_device_mesh == target_spec.device_mesh, ( |
| f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" |
| ) |
|
|
| for param_name, param in model.named_parameters(): |
| |
| if param_name not in cpu_sharded_state: |
| continue |
|
|
| |
| local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] |
|
|
| |
| if isinstance(param, DTensor): |
| |
| assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" |
| assert saved_spec.placements == target_spec.placements, ( |
| f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" |
| ) |
|
|
| |
| target_device = param._local_tensor.device |
| local_gpu_tensor = local_cpu_tensor.to(target_device) |
|
|
| |
| param._local_tensor.copy_(local_gpu_tensor) |
|
|
| else: |
| |
| target_device = param.device |
| param.data.copy_(local_cpu_tensor.to(target_device)) |
|
|
| |
| dist.barrier() |
|
|