# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import itertools import json import math import os from abc import ABC from collections import OrderedDict from contextlib import contextmanager, nullcontext from typing import Optional, cast import torch import torch.distributed as dist import torch.nn as nn from packaging import version from torch.distributed import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name from verl.utils.device import get_device_id, get_device_name, get_torch_device from verl.utils.model import check_exclude_modules, check_target_modules if version.parse(torch.__version__) >= version.parse("2.6"): from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard from torch.distributed.fsdp._fully_shard._fsdp_init import _get_post_forward_mesh_info from torch.distributed.tensor import DTensor, Shard from torch.distributed.tensor._dtensor_spec import DTensorSpec fully_shard_module = torch.distributed.fsdp._fully_shard._fully_shard elif version.parse(torch.__version__) >= version.parse("2.4"): from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard fully_shard_module = torch.distributed._composable.fsdp else: fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy, fully_shard_module = None, None, None, None, None def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: x = x.to_empty(device=get_device_id(), recurse=False) get_torch_device().empty_cache() return x def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): from accelerate import init_empty_weights cpu_init_weights = lambda: torch.device("cpu") if use_meta_tensor: if mesh is None: init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights else: init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights else: init_context = cpu_init_weights return init_context # Copyright 2020-present the HuggingFace Inc. team. # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py def get_fsdp_wrap_policy(module, config=None, is_lora=False): """Get FSDP wrap policy for the module. Args: module: The module to get wrap policy for config: Configuration for wrap policy is_lora: Whether to enable lambda policy for LoRA modules """ if config is None: config = {} # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this # once we have make all config in verl from OmegaConf to data class. def _get_attr(attr_name, default_value=None): if hasattr(config, "get"): return config.get(attr_name, default_value) else: return config.__getattribute__(attr_name) if _get_attr("disable", False): return None default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = _get_attr( "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap ) min_num_params = _get_attr("min_num_params", 0) auto_wrap_policy = None policies = [] from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy # Add lambda policy for LoRA modules if is_lora is True if is_lora: def lambda_policy_fn(module): return bool( len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad ) lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) policies.append(lambda_policy) if min_num_params > 0: size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) policies.append(size_policy) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: transformer_cls = get_module_class_from_name(module, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) transformer_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap, ) policies.append(transformer_policy) if len(policies) > 0: auto_wrap_policy = functools.partial(_or_policy, policies=policies) return auto_wrap_policy @torch.no_grad() def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): if fsdp_version(model) == 2 or fsdp_version(model) == 0: offload_fsdp2_model_to_cpu(model, empty_cache) return assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model offloading to CPU" for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param assert ( flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and id(flat_param.data) != id(flat_param._local_shard) and flat_param.data.size() == flat_param._local_shard.size() ) handle.flat_param_to(torch.device("cpu"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: get_torch_device().empty_cache() @torch.no_grad() def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): model.cpu() if empty_cache: get_torch_device().empty_cache() @torch.no_grad() def load_fsdp_model_to_gpu(model: FSDP): if fsdp_version(model) == 2 or fsdp_version(model) == 0: load_fsdp2_model_to_gpu(model) return assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model loading to GPU" device_id = get_device_id() for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @torch.no_grad() def load_fsdp2_model_to_gpu(model): device = get_device_id() model.to(device) @torch.no_grad() def offload_fsdp_optimizer(optimizer): if not optimizer.state: return for param_group in optimizer.param_groups: for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to("cpu", non_blocking=True) @torch.no_grad() def load_fsdp_optimizer(optimizer, device_id): if not optimizer.state: return for param_group in optimizer.param_groups: for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device_id, non_blocking=True) @contextmanager def meta_device_init(): """ Create model parameters with meta device. Note buffers in model will still be initialized in default device (e.g., CPU), since the buffers can be non-persistent and filled with expected values that can NOT be captured in meta device. """ device = torch.device("meta") old_register_parameter = nn.Module.register_parameter registered = set() def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) # we will skip register shared parameters as it # is already registered previously if param is not None and param not in registered: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) registered.add(module._parameters[name]) try: nn.Module.register_parameter = register_empty_parameter yield finally: registered.clear() nn.Module.register_parameter = old_register_parameter def parallel_load_safetensors(filepath): """ Parallel load safetensors from huggingface checkpoint Huggingface checkpoint contains: - config.json: a json file for model configuration - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks Or (when model is small), - model.safetensors: a binary file for all parameters and buffers Each rank will own a part of model chunks and load them directly into GPU memory. """ from safetensors.torch import load_file safetensors2param = {} index_file = os.path.join(filepath, "model.safetensors.index.json") if os.path.exists(index_file): index = json.load(open(index_file, "rb")) for param_name, filename in index["weight_map"].items(): safetensors2param.setdefault(filename, []).append(param_name) else: # in this case, the model is small and we can load it all at once param_file = os.path.join(filepath, "model.safetensors") assert os.path.exists(param_file), f"Cannot find {param_file}" states = load_file(param_file) for param_name in states: safetensors2param.setdefault("model.safetensors", []).append(param_name) del states total_files = len(safetensors2param) ckpt_chunks = sorted(safetensors2param.keys()) world_size = dist.get_world_size() size = int(math.ceil(total_files / world_size)) ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} device = get_device_id() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: file = os.path.join(filepath, file) states = load_file(file, device=device) # print(f"rank {rank} loading {file}...") shard_states.update(states) else: for file in files: for param_name in safetensors2param[file]: shard_states[param_name] = rank return shard_states def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): """ Generate a function to initialize sub-modules in the `module` with `shard_states` from huggingface checkpoint. Args: module (torch.nn.Module): the global module to be initialized shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint Returns: init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` """ state2fqn = {} for name, state in itertools.chain( module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) ): state2fqn.setdefault(state, []).append(name) # remove standalone parameters and buffers shared = {s for s, names in state2fqn.items() if len(names) > 1} materialized_states = {} @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" device = get_device_id() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer param = torch.empty_like(state.data, device=device) loaded = shard_states[param_name] if isinstance(loaded, torch.nn.Parameter | torch.Tensor): # NOTE: loaded.dtype can be different with param.dtype param.data.copy_(loaded.data) dist.broadcast(param.data, src=dist.get_rank()) else: assert isinstance(loaded, int) # the rank that holds the state dist.broadcast(param.data, src=loaded) shard_states.pop(param_name) del loaded return param def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) for name, state in param_and_buffers: if not state.is_meta: continue is_param = name in sub_mod._parameters fqn = state2fqn[state].pop(0) # non-persistent buffers will not be saved in state dict, we can safely skip it if (not is_param) and fqn not in shard_states: if state.is_meta: raise RuntimeError( f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved " f"in checkpoint and user should guarantee to init in CPU / GPU device." ) continue # for shared parameter, we get it from the first time it is created if state in shared: if state not in materialized_states: materialized_states[state] = create_and_sync_state(fqn, state, is_param) else: if fqn in shard_states: shard_states.pop(fqn) materialize_state = materialized_states[state] # for not shared parameter, we create it directly else: materialize_state = create_and_sync_state(fqn, state, is_param) if is_param: sub_mod._parameters[name] = materialize_state else: sub_mod._buffers[name] = materialize_state if recurse: for module in sub_mod.children(): init_fn(module, recurse=True) # for debug # if len(shard_states) == 0: print("clear") return sub_mod return init_fn def fsdp_version(model): if isinstance(model, FSDP): return 1 elif isinstance(model, FSDPModule): return 2 else: return 0 def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): if fsdp_version(model) == 1: return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) else: return nullcontext() def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): """ Get the full state dict from an FSDP model. Args: model (torch.nn.Module): The FSDP model to get state dict from offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True. rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True. Returns: dict: The full state dict of the model Raises: NotImplementedError: If the FSDP version is unknown """ if fsdp_version(model) == 1: from torch.distributed.fsdp import FullStateDictConfig, StateDictType state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) with get_fsdp_state_ctx( model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None ): state_dict = model.state_dict() return state_dict elif fsdp_version(model) == 2 or fsdp_version(model) == 0: from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict state_dict_config = StateDictOptions( full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only ) state_dict = get_model_state_dict(model, options=state_dict_config) return state_dict else: raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place. Args: model (`torch.nn.Module`): The model to load the state dict into full_state (`dict`): The full state dict to load, can only be on rank 0 """ if version.parse(torch.__version__) >= version.parse("2.7.0"): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict else: # official torch 2.6.0 set_model_state_dict API leads to OOM # use torch 2.7.0 copy from verl/third_party/torch/distributed/checkpoint from verl.third_party.torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict # To broadcast, it needs to be instantiated in the GPU. if dist.get_rank() == 0: model = model.to(device=get_device_id(), non_blocking=True) else: model = model.to_empty(device=get_device_id()) cpu_offload = cpu_offload is not None options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) set_model_state_dict(model, full_state, options=options) # rotary_emb is not in state_dict, so we need to broadcast it manually for name, buf in model.named_buffers(): dist.broadcast(buf, src=0) if cpu_offload: model.to("cpu", non_blocking=True) for buf in model.buffers(): buf.data = buf.data.to(get_device_id()) @contextmanager def maybe_patch_fsdp_module(model): if fully_shard_module is None: yield return orig_fsdp_module = fully_shard_module.FSDPModule class FSDPModuleABC(ABC, orig_fsdp_module): pass try: if isinstance(model, ABC): fully_shard_module.FSDPModule = FSDPModuleABC yield finally: fully_shard_module.FSDPModule = orig_fsdp_module def _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap): """Select modules to wrap individually with fully_shard in FSDP2. Matches transformer layers by class name, and embed_tokens/lm_head by name (with isinstance fallback). Name-based matching is needed because peft wraps embed_tokens in ModulesToSaveWrapper, breaking isinstance(module, nn.Embedding). When tie_word_embeddings is True, embed_tokens and lm_head share weights and must not be wrapped separately. """ _tie = getattr(model.config, "tie_word_embeddings", False) _wrap_by_name = set() if _tie else {"embed_tokens", "lm_head"} modules = [] for name, module in model.named_modules(): leaf_name = name.rsplit(".", 1)[-1] if "." in name else name if ( module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (isinstance(module, nn.Embedding) and not _tie) or (leaf_name in _wrap_by_name and hasattr(module, "weight")) ): modules.append(module) return modules def apply_fsdp2(model, fsdp_kwargs, config): """model: AutoModelForCausalLM""" assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap ) if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None modules = _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap) for idx, module in enumerate(modules): # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: # print(f"wrap module {module.__class__.__name__}") with maybe_patch_fsdp_module(module): fully_shard(module, **fsdp_kwargs) # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: # print(f"wrap module {model.__class__.__name__}") with maybe_patch_fsdp_module(model): fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module def get_shard_placement_fn(fsdp_size): """Choose the dimension that can divide fsdp_size to avoid padding""" def shard_placement_fn(param): shape = list(param.shape) for i in range(len(shape)): if shape[i] % fsdp_size == 0: return Shard(i) return Shard(0) return shard_placement_fn def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm if isinstance(parameters, torch.Tensor): parameters = [parameters] else: # prevent generators from being exhausted parameters = list(parameters) grads = [p.grad for p in parameters if p.grad is not None] total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) total_norm = total_norm.to(get_device_id(), non_blocking=True) _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm def layered_summon_lora_params(fsdp_module) -> OrderedDict: from peft.utils.save_and_load import get_peft_model_state_dict def __prefix_submodules(module, prefix): for name, submodule in module.named_modules(): if name.startswith(prefix) and "." not in name[len(prefix) :]: yield name, submodule lora_params = OrderedDict() prefix_list = [ # fsdp "_fsdp_wrapped_module.base_model.model.", "_fsdp_wrapped_module.base_model.model.model.", "_fsdp_wrapped_module.base_model.model.model.layers.", "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", # fsdp2 "base_model.model.", "base_model.model.model.", "base_model.model.model.layers.", "base_model.model.model.language_model.layers.", ] peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) for prefix in prefix_list: for name, submodule in __prefix_submodules(fsdp_module, prefix): prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") if name.endswith(".model") or name.endswith(".layers"): continue if fsdp_version(submodule) > 0: with FSDP.summon_full_params(submodule, writeback=False): sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) sub_lora_params = { f"{prefix}.{name}": param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() for name, param in sub_lora_params.items() } lora_params.update(sub_lora_params) submodule._is_root = False get_torch_device().empty_cache() return lora_params def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict: """ collect lora params or full params if base model is not ready in vllm work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) """ from peft.utils.save_and_load import get_peft_model_state_dict lora_params = OrderedDict() peft_model = getattr(module, "_fsdp_wrapped_module", module) if fsdp_version(module) > 0: if layered_summon: if not base_sync_done: raise ValueError( "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " "rollout.load_format=safetensors" ) lora_params = layered_summon_lora_params(module) else: with FSDP.summon_full_params(module, writeback=False): if base_sync_done: lora_params = get_peft_model_state_dict(peft_model) lora_params = { name: param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() for name, param in lora_params.items() } else: model = peft_model.base_model.model orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): continue name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") lora_params[name] = ( param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() ) model = model.to(orig_dev) get_torch_device().empty_cache() else: if base_sync_done: lora_params = get_peft_model_state_dict(peft_model) else: model = peft_model.base_model.model orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): continue name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") lora_params[name] = param.detach().cpu() model = model.to(orig_dev) return lora_params def replace_lora_wrapper(k, peft_config): """Replace LoRA parameter keys with base layer equivalents. Transforms LoRA parameter names to their corresponding base layer names for proper weight loading in vLLM when base model sync is not done. Args: k (str): Original parameter key name. Returns: str: Transformed parameter key for base layer. """ stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] if k.endswith(".weight"): module_k = k[: -len(".weight")] if check_exclude_modules(peft_config, module_k): return k elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): return f"{module_k}.base_layer.weight" if k.endswith(".bias"): module_k = k[: -len(".bias")] if check_exclude_modules(peft_config, module_k): return k elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): return f"{module_k}.base_layer.bias" return k def set_reshard_after_forward(module: FSDPModule, reshard_after_forward: bool, recurse: bool = True) -> None: """ Sets if the module should reshard parameters after forward. This can be used to change the ``reshard_after_forward`` FSDP arg at runtime. For example, this can be used to set the FSDP root module's value to ``True`` (since it is otherwise specially set to ``False``), or it can set an FSDP module's value to ``False`` for running evals and set back to ``True`` for training. Args: reshard_after_forward (bool): Whether to reshard parameters after forward. recurse (bool): Whether to set for all FSDP submodules or just the passed-in module. --- Copied from https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fully_shard.py to address the absence of the set_reshard_after_forward function in torch versions earlier than 2.8.0. """ if not isinstance(reshard_after_forward, bool): raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}") self_module = cast(nn.Module, module) modules = list(self_module.modules()) if recurse else [self_module] for module in modules: if isinstance(module, FSDPModule): state = module._get_fsdp_state() state._auto_reshard_after_forward = False if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.post_forward_mesh_info = _get_post_forward_mesh_info( reshard_after_forward, fsdp_param_group.mesh_info ) def normalize_peft_param_name(params: dict) -> dict: """ Converts peft model parameter name to base parameter name For example, base_model.model.model.embed_tokens.weight -> model.embed_tokens.weight base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight -> model.layers.0.self_attn.q_proj.weight and remove params such as base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight, base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight """ def _normalize_peft_name(name: str) -> str: return name.replace("base_model.model.", "").replace("base_model.", "").replace(".base_layer", "") def _is_lora_key(name: str) -> bool: # catch typical PEFT keys return ("lora_" in name) or (".adapter_" in name) params = [(_normalize_peft_name(k), v) for k, v in params.items()] # strip any residual LoRA tensors params = {k: v for k, v in params if not _is_lora_key(k)} return params def _merge_or_unmerge_lora_(module, merge: bool): """Merge or unmerge LoRA adapters in a module. Args: module: The module containing LoRA layers merge: If True, merge LoRA into base model; if False, unmerge LoRA """ from peft.tuners.lora import LoraLayer with torch.no_grad(): for m in module.modules(): if isinstance(m, LoraLayer): is_merged = getattr(m, "merged", False) if merge and not is_merged: m.merge() elif (not merge) and is_merged: m.unmerge() # merged_adapters def _clean_merged_lora_(module): """Cleans the merged lora adapters""" from peft.tuners.lora import LoraLayer with torch.no_grad(): for m in module.modules(): if isinstance(m, LoraLayer): merged_adapters = getattr(m, "merged_adapters", False) if merged_adapters: m.merged_adapters = [] def fsdp_merge_unmerge(module: nn.Module, do_merge: bool): """Merge or unmerge LoRA adapters in FSDP module. For FSDP (v1), it gathers all model parameters to each device, which may cause OOM. For FSDP2, it gathers model parameters layer-by-layer to reduce memory footprint. Args: module: The FSDP module to merge/unmerge LoRA adapters do_merge: If True, merge LoRA into base model; if False, unmerge LoRA """ version = fsdp_version(module) assert version in [1, 2], f"fsdp_merge_unmerge requires FSDP module, got version {version}" if version == 1: # Unshard → merge → Reshard with FSDP.summon_full_params(module, writeback=True, with_grads=False): _merge_or_unmerge_lora_(module, merge=do_merge) else: # FSDP2: Unshard → merge → Reshard layer-by-layer for name, submodule in module.named_modules(): if isinstance(submodule, FSDPModule) and name != "": # skip root model with FSDP.summon_full_params(submodule, writeback=True, with_grads=False): _merge_or_unmerge_lora_(submodule, merge=do_merge) def backup_base_model_weights(module): """Backup base model weights to CPU with LoRA temporarily disabled. This function temporarily disables LoRA adapters, backs up the clean base model weights to CPU, then re-enables the adapters. Args: module: The PEFT model with LoRA adapters Returns: dict: Dictionary mapping parameter name to CPU tensor backup of base model weights """ from peft import PeftModel backup = {} with torch.no_grad(): # Check if module is a PEFT model if isinstance(module, PeftModel): # Temporarily disable adapters to get clean base model weights with module.disable_adapter(): # Backup base model weights (excluding lora parameters) for name, param in module.named_parameters(): if "lora" not in name.lower(): backup[name] = param.data.clone().cpu() else: # For non-PEFT models, just backup all parameters for name, param in module.named_parameters(): backup[name] = param.data.clone().cpu() return backup def restore_base_model_weights(module, backup): """Restore base model weights from CPU backup. This function restores the base model weights from the CPU backup, effectively undoing any LoRA merge operations. Args: module: The PEFT model with LoRA adapters backup: Dictionary mapping parameter name to CPU tensor backup of base model weights """ with torch.no_grad(): for name, param in module.named_parameters(): if name in backup: param.data.copy_(backup[name].to(param.device)) @contextmanager def merged_lora_context(actor, backup_adapters=False): """Context manager to temporarily merge LoRA adapters. This context manager merges LoRA adapters into the base model weights, performs operations (like syncing weights to vLLM), then restores the base model weights from backup. Args: actor: The actor module with LoRA adapters to merge backup_adapters: If True, backup base model weights (with LoRA disabled) before merging and restore them after. This is more numerically stable than unmerging. Yields: None """ base_weights_backup = None if backup_adapters: # Backup base model weights with LoRA temporarily disabled base_weights_backup = backup_base_model_weights(actor) # Merge LoRA adapters into base model fsdp_merge_unmerge(actor, do_merge=True) try: # Do work while merged (sync_to_vllm / generate / etc.) yield finally: if backup_adapters and base_weights_backup is not None: # Restore base model weights from CPU backup (effectively undoing the merge) restore_base_model_weights(actor, base_weights_backup) _clean_merged_lora_(actor) else: # Fall back to unmerge if no backup was made fsdp_merge_unmerge(actor, do_merge=False) def fsdp2_sharded_save_to_cpu( model: torch.nn.Module, ) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: """ Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. Args: model: FSDP2-wrapped model whose parameters are of DTensor type. Returns: cpu_sharded_state: Dictionary of CPU shards for the current process. Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) """ cpu_sharded_state = {} global_spec = None # Record global sharding rules (all parameters follow the same spec) for param_name, param in model.named_parameters(): # Only process sharded parameters of DTensor type (core parameters of FSDP2) if not isinstance(param, DTensor): # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data cpu_tensor = param.detach().cpu() cpu_sharded_state[param_name] = (cpu_tensor, None) continue # Record global sharding rules (take spec of the first DTensor to ensure consistency) if global_spec is None: global_spec = param._spec assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" # 1. Extract local shard data from the current GPU (_local_tensor) local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class # 2. Move to CPU memory and detach from computation graph local_cpu_tensor = local_gpu_tensor.detach().cpu() # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged) cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." return cpu_sharded_state, global_spec def fsdp2_sharded_load_from_cpu( model: torch.nn.Module, cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], target_spec: DTensorSpec, ) -> None: """ Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, keeping sharding rules unchanged. Args: model: FSDP2 model to be restored (must have the same structure as when saved) cpu_sharded_state: Shard data read from CPU memory by the current process (from fsdp2_sharded_save_to_cpu) target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) """ # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs) current_device_mesh = None for param in model.parameters(): if isinstance(param, DTensor): current_device_mesh = param._spec.device_mesh break assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" assert current_device_mesh == target_spec.device_mesh, ( f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" ) for param_name, param in model.named_parameters(): # Skip parameters not in the saved state (e.g., newly added parameters) if param_name not in cpu_sharded_state: continue # Extract CPU shard data and original Spec local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] # Handle different parameter types: DTensor sharded parameters vs. regular parameters if isinstance(param, DTensor): # 1. Verify sharding rule consistency (placements must match original Spec) assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" assert saved_spec.placements == target_spec.placements, ( f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" ) # 2. Move CPU shard data to the current GPU (device of param._local_tensor) target_device = param._local_tensor.device local_gpu_tensor = local_cpu_tensor.to(target_device) # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged) param._local_tensor.copy_(local_gpu_tensor) else: # Regular parameters: load directly to original device target_device = param.device param.data.copy_(local_cpu_tensor.to(target_device)) # Process synchronization: ensure all processes complete loading before proceeding dist.barrier()