egrpo / fastvideo /utils /fsdp_util_qwenimage.py
studyOverflow's picture
Add files using upload-large-folder tool
b171568 verified
import os
import functools
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch, MixedPrecision, CPUOffload
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from fastvideo.models.qwenimage.transformer_qwenimage import QwenImageTransformerBlock
class FSDPConfig:
def __init__(
self,
sharding_strategy="FULL_SHARD",
backward_prefetch="BACKWARD_PRE",
cpu_offload=False,
num_replicate=1,
num_shard=8,
mixed_precision_dtype=torch.bfloat16,
use_device_mesh=False,
):
self.sharding_strategy = sharding_strategy
self.backward_prefetch = backward_prefetch
self.cpu_offload = cpu_offload
self.num_replicate = num_replicate
self.num_shard = num_shard
self.mixed_precision_dtype = mixed_precision_dtype
self.use_device_mesh = use_device_mesh
def fsdp_wrapper(model, fsdp_config, ignored_modules=None):
if ignored_modules is None:
ignored_modules = []
device_mesh = None
if fsdp_config.sharding_strategy == 'HYBRID_SHARD' and fsdp_config.use_device_mesh:
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard),
mesh_dim_names=("replicate", "shard")
)
def get_transformer_layer_cls():
return {
QwenImageTransformerBlock,
}
fsdp_model = FSDP(
model,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=get_transformer_layer_cls(),
),
ignored_modules=ignored_modules,
mixed_precision=MixedPrecision(
param_dtype=fsdp_config.mixed_precision_dtype,
reduce_dtype=fsdp_config.mixed_precision_dtype,
buffer_dtype=fsdp_config.mixed_precision_dtype,
),
device_id=dist.get_rank() % torch.cuda.device_count(),
sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy],
backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch],
cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload),
device_mesh=device_mesh,
use_orig_params=True,
)
return fsdp_model