Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
# ruff: noqa: E731
import functools
from functools import partial
import torch
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from .load import get_no_split_modules
from torch.distributed.fsdp import BackwardPrefetch
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def apply_fsdp_checkpointing(model, no_split_modules, p=1):
# https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print("--> applying fdsp activation checkpointing...")
block_idx = 0
cut_off = 1 / 2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p = eval(p) if isinstance(p, str) else p
def selective_checkpointing(submodule):
nonlocal block_idx
nonlocal cut_off
if isinstance(submodule, no_split_modules):
block_idx += 1
if block_idx * p >= cut_off:
cut_off += 1
return True
return False
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)
def get_mixed_precision(master_weight_type="fp32"):
weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16
mixed_precision = MixedPrecision(
param_dtype=weight_type,
# Gradient communication precision.
reduce_dtype=weight_type,
# Buffer precision.
buffer_dtype=weight_type,
cast_forward_inputs=False,
)
return mixed_precision
def get_dit_fsdp_kwargs(
transformer,
sharding_strategy,
use_lora=False,
cpu_offload=False,
master_weight_type="fp32",
):
no_split_modules = get_no_split_modules(transformer)
if use_lora:
auto_wrap_policy = fsdp_auto_wrap_policy
else:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=no_split_modules,
)
# we use float32 for fsdp but autocast during training
mixed_precision = get_mixed_precision(master_weight_type)
# NOTE: if no modules are split, we use NO_SHARD
if sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif sharding_strategy == "hybrid_full":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif sharding_strategy == "none":
sharding_strategy = ShardingStrategy.NO_SHARD
auto_wrap_policy = None
elif sharding_strategy == "hybrid_zero2":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
elif sharding_strategy == 'shard_grad_op':
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
device_id = torch.cuda.current_device()
cpu_offload = (
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload,
}
# Add LoRA-specific settings when LoRA is enabled
if len(no_split_modules) != 0 and use_lora:
fsdp_kwargs.update(
{
"use_orig_params": False, # Required for LoRA memory savings
"sync_module_states": True,
}
)
elif len(no_split_modules) == 0 and use_lora:
fsdp_kwargs.update({"use_orig_params": True})
return fsdp_kwargs, no_split_modules
def get_discriminator_fsdp_kwargs(master_weight_type="fp32"):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
sharding_strategy = ShardingStrategy.NO_SHARD
device_id = torch.cuda.current_device()
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
}
return fsdp_kwargs
def get_vae_fsdp_kwargs(master_weight_type="fp32", cpu_offload=False):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
# sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
sharding_strategy = ShardingStrategy.FULL_SHARD # 而不是SHARD_GRAD_OP
# sharding_strategy = ShardingStrategy.NO_SHARD # 注释掉的备用策略
device_id = torch.cuda.current_device()
cpu_offload = (
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload, # 添加cpu_offload参数
"limit_all_gathers": True,
"use_orig_params": True, # 保持原始参数结构
# "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
}
return fsdp_kwargs