File size: 5,703 Bytes
e14f899 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# ruff: noqa: E731
import functools
from functools import partial
import torch
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from .load import get_no_split_modules
from torch.distributed.fsdp import BackwardPrefetch
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def apply_fsdp_checkpointing(model, no_split_modules, p=1):
# https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print("--> applying fdsp activation checkpointing...")
block_idx = 0
cut_off = 1 / 2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p = eval(p) if isinstance(p, str) else p
def selective_checkpointing(submodule):
nonlocal block_idx
nonlocal cut_off
if isinstance(submodule, no_split_modules):
block_idx += 1
if block_idx * p >= cut_off:
cut_off += 1
return True
return False
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)
def get_mixed_precision(master_weight_type="fp32"):
weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16
mixed_precision = MixedPrecision(
param_dtype=weight_type,
# Gradient communication precision.
reduce_dtype=weight_type,
# Buffer precision.
buffer_dtype=weight_type,
cast_forward_inputs=False,
)
return mixed_precision
def get_dit_fsdp_kwargs(
transformer,
sharding_strategy,
use_lora=False,
cpu_offload=False,
master_weight_type="fp32",
):
no_split_modules = get_no_split_modules(transformer)
if use_lora:
auto_wrap_policy = fsdp_auto_wrap_policy
else:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=no_split_modules,
)
# we use float32 for fsdp but autocast during training
mixed_precision = get_mixed_precision(master_weight_type)
# NOTE: if no modules are split, we use NO_SHARD
if sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif sharding_strategy == "hybrid_full":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif sharding_strategy == "none":
sharding_strategy = ShardingStrategy.NO_SHARD
auto_wrap_policy = None
elif sharding_strategy == "hybrid_zero2":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
elif sharding_strategy == 'shard_grad_op':
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
device_id = torch.cuda.current_device()
cpu_offload = (
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload,
}
# Add LoRA-specific settings when LoRA is enabled
if len(no_split_modules) != 0 and use_lora:
fsdp_kwargs.update(
{
"use_orig_params": False, # Required for LoRA memory savings
"sync_module_states": True,
}
)
elif len(no_split_modules) == 0 and use_lora:
fsdp_kwargs.update({"use_orig_params": True})
return fsdp_kwargs, no_split_modules
def get_discriminator_fsdp_kwargs(master_weight_type="fp32"):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
sharding_strategy = ShardingStrategy.NO_SHARD
device_id = torch.cuda.current_device()
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
}
return fsdp_kwargs
def get_vae_fsdp_kwargs(master_weight_type="fp32", cpu_offload=False):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
# sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
sharding_strategy = ShardingStrategy.FULL_SHARD # 而不是SHARD_GRAD_OP
# sharding_strategy = ShardingStrategy.NO_SHARD # 注释掉的备用策略
device_id = torch.cuda.current_device()
cpu_offload = (
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload, # 添加cpu_offload参数
"limit_all_gathers": True,
"use_orig_params": True, # 保持原始参数结构
# "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
}
return fsdp_kwargs |