File size: 4,518 Bytes
b171568 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# ruff: noqa: E731
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
import functools
from functools import partial
import torch
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformerBlock
from fastvideo.utils.load import get_no_split_modules
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, MochiTransformerBlock)
def apply_fsdp_checkpointing(model, no_split_modules, p=1):
# https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print("--> applying fdsp activation checkpointing...")
block_idx = 0
cut_off = 1 / 2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p = eval(p) if isinstance(p, str) else p
def selective_checkpointing(submodule):
nonlocal block_idx
nonlocal cut_off
if isinstance(submodule, no_split_modules):
block_idx += 1
if block_idx * p >= cut_off:
cut_off += 1
return True
return False
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)
def get_mixed_precision(master_weight_type="fp32"):
weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16
mixed_precision = MixedPrecision(
param_dtype=weight_type,
# Gradient communication precision.
reduce_dtype=weight_type,
# Buffer precision.
buffer_dtype=weight_type,
cast_forward_inputs=False,
)
return mixed_precision
def get_dit_fsdp_kwargs(
transformer,
sharding_strategy,
use_lora=False,
cpu_offload=False,
master_weight_type="fp32",
):
no_split_modules = get_no_split_modules(transformer)
if use_lora:
auto_wrap_policy = fsdp_auto_wrap_policy
else:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=no_split_modules,
)
# we use float32 for fsdp but autocast during training
mixed_precision = get_mixed_precision(master_weight_type)
if sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif sharding_strategy == "hybrid_full":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif sharding_strategy == "none":
sharding_strategy = ShardingStrategy.NO_SHARD
auto_wrap_policy = None
elif sharding_strategy == "hybrid_zero2":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
device_id = torch.cuda.current_device()
cpu_offload = (torch.distributed.fsdp.CPUOffload(
offload_params=True) if cpu_offload else None)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload,
}
# Add LoRA-specific settings when LoRA is enabled
if use_lora:
fsdp_kwargs.update({
"use_orig_params": False, # Required for LoRA memory savings
"sync_module_states": True,
})
return fsdp_kwargs, no_split_modules
def get_discriminator_fsdp_kwargs(master_weight_type="fp32"):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
sharding_strategy = ShardingStrategy.NO_SHARD
device_id = torch.cuda.current_device()
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
}
return fsdp_kwargs
|