HY-Video-PRFL / scripts /prfl /train_prfl.py
Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
import argparse
import json
import logging
import os
import time
import itertools
from copy import deepcopy
from collections import deque
from easydict import EasyDict
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torchvision import transforms
import torch.nn as nn
import numpy as np
import torch.amp as amp
from diffusers.optimization import get_scheduler
from einops import rearrange
from omegaconf import OmegaConf
from peft import LoraConfig, get_peft_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from diffusers_lite.constants import PRECISION_TO_TYPE
from diffusers_lite.datasets.image2video_dataset import Image2VideoTrainDataset
from diffusers_lite.schedulers.scheduling_flow_match_discrete import (
FlowMatchDiscreteScheduler,
)
from diffusers_lite.wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from diffusers_lite.wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from diffusers_lite.wan.modules.model import WanModel
from diffusers_lite.wan.modules.t5 import T5EncoderModel
from diffusers_lite.wan.modules.vae import WanVAE
from diffusers_lite.wan.modules.clip import CLIPModel
from diffusers_lite.utils.communication import (
broadcast,
sp_parallel_dataloader_wrapper_wanx,
all_gather,
)
from diffusers_lite.utils.data_utils import (
LengthGroupedSampler,
save_videos_grid,
crop_tensor,
BlockDistributedSampler,
VideoImageBatchIterator
)
from diffusers_lite.utils.fsdp_utils import (
apply_fsdp_checkpointing,
get_dit_fsdp_kwargs,
get_vae_fsdp_kwargs,
)
from diffusers_lite.utils.parallel_states import initialize_sequence_parallel_state,nccl_info,get_sequence_parallel_state
from diffusers_lite.utils.torch_utils import set_manual_seed, free_memory, set_logging, set_worker_seed_builder
from diffusers_lite.utils.diffusion_utils import (
batch2list,
list2batch,
vae_encode,
vae_decode,
image_encode,
prompt2states,
load_lora_state_dict,
transformer_zero_init,
prepare_video_condition_wanx,
stable_mse_loss,
)
from diffusers_lite.utils.model_utils import (
save_lora_checkpoint,
save_checkpoint,
load_state_dict,
print_parameters_information,
update_ema_model,
)
import random
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
NAME_MAPPING = {
"t2v-1.3b": "Wan2.1-T2V-1.3B",
"t2v-14b": "Wan2.1-T2V-14B",
"i2v-1.3b": "Wan2.1-T2V-1.3B",
"i2v-14b-480p": "Wan2.1-I2V-14B-480P",
"i2v-14b-720p": "Wan2.1-I2V-14B-720P",
"flf2v-14b-720p": "Wan2.1-FLF2V-14B-720P",
}
from transformers import AutoProcessor, AutoModel
from PIL import Image
from diffusers_lite.utils.network import MLP, QueryAttention, forward_siamese, forward_mlp, train_model, save_model
import gc
import torch
def log_memory_usage(step_name, rank=None):
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
reserved = torch.cuda.memory_reserved() / 1024**3 # GB
max_allocated = torch.cuda.max_memory_allocated() / 1024**3 # GB
rank_str = f"[Rank {rank}] " if rank is not None else ""
print(f"{rank_str}{step_name}: Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB")
def basic_init(config):
# Init process groups
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
dtype = PRECISION_TO_TYPE[config.train.precision]
initialize_sequence_parallel_state(config.dataset.sp_size)
set_logging(local_rank)
# Init seed
set_manual_seed(config.train.seed + nccl_info.group_id)
logging.info(f"lanuch with seed {config.train.seed + rank}")
# Init repository creation
config.save.ckpt_dir = os.path.join(
config.save.output_dir, f"{config.train_id}/checkpoints"
)
config.save.log_dir = os.path.join(
config.save.output_dir, f"{config.train_id}/logs"
)
config.save.sanity_check_dir = f"outputs/sanity_check/wanx/{config.train_id}"
config.save.tensorboard_dir = os.path.join(config.save.output_dir, f"{config.train_id}/tensorboard")
log_path = os.path.join(config.save.log_dir, "log.txt")
if rank == 0:
os.makedirs(config.save.output_dir, exist_ok=True)
os.makedirs(config.save.ckpt_dir, exist_ok=True)
os.makedirs(config.save.log_dir, exist_ok=True)
os.makedirs(config.save.tensorboard_dir, exist_ok=True)
OmegaConf.save(config, os.path.join(config.save.log_dir, "train_config.yaml"))
if not os.path.exists(log_path):
with open(log_path, "w") as f:
f.write(f"Start logging {config.train_id}:\n")
if config.train.sanity_check_interval > 0:
os.makedirs(config.save.sanity_check_dir, exist_ok=True)
logging.info(f"save ckpt directory {config.save.ckpt_dir}")
if config.train.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
logging.info(f"enable TF32")
basic_kwargs = EasyDict(
{
"local_rank": local_rank,
"rank": rank,
"world_size": world_size,
"device": device,
"dtype": dtype,
"log_path": log_path,
}
)
torch.cuda.set_per_process_memory_fraction(0.95, device=basic_kwargs.device)
torch.cuda.memory_pressure_threshold = 0.8
os.environ["FSDP_FLATTEN_PARAMS"] = "1"
os.environ["FSDP_SHARD_GRAD_PARAMS"] = "1"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
return config, basic_kwargs
def model_init(config, basic_kwargs):
assert config.task in NAME_MAPPING.keys()
base_dir = config.model.base_path
if config.model.resume_transformer_path:
logging.info(f"loading model tranformer from {config.model.resume_transformer_path}")
transformer = WanModel.from_pretrained(config.model.resume_transformer_path)
resume_step = int(config.model.resume_transformer_path.split("-")[-1])
elif config.model.init_transformer_path:
logging.info(f"loading model tranformer from {config.model.init_transformer_path}")
transformer = WanModel.from_pretrained(config.model.init_transformer_path)
resume_step = 0
else:
if config.task in [
"t2v-1.3b",
"t2v-14b",
"i2v-14b-480p",
"i2v-14b-720p",
"flf2v-14b-720p",
]:
logging.info(f"loading model tranformer from {base_dir}")
transformer = WanModel.from_pretrained(base_dir)
elif config.task in ["i2v-1.3b"]:
transformer_config = json.load(
open(os.path.join(base_dir, "config.json"), "r")
)
transformer_config["in_dim"] = 36
transformer_config["model_type"] = "i2v"
transformer = WanModel.from_config(transformer_config)
transformer = transformer_zero_init(transformer)
state_dict = load_state_dict(model_dir=base_dir)
del state_dict["patch_embedding.bias"]
del state_dict["patch_embedding.weight"]
m, u = transformer.load_state_dict(state_dict, strict=False)
logging.info(f"load lora from {base_dir}.")
logging.info(f"miss {len(m)}; unexpect {len(u)}.")
resume_step = 0
# lrm transformer init
lrm_transformer = WanModel.from_pretrained(config.model.base_path)
frozen_modules = [
'patch_embedding',
'text_embedding',
'time_embedding',
'time_projection',
'img_emb',
# 'freqs',
]
for module_name in frozen_modules:
if hasattr(lrm_transformer, module_name):
module = getattr(lrm_transformer, module_name)
for param in module.parameters():
param.requires_grad = False
trainable_blocks = config.lrm.trainable_blocks
if not hasattr(config.lrm, 'feature_layer'):
config.lrm.feature_layer = [6, 7]
logging.info(f"Setting default feature_layer to {config.lrm.feature_layer}")
logging.info(f"Freezing all blocks except for {trainable_blocks}")
new_blocks = []
for i, block in enumerate(lrm_transformer.blocks):
if i in trainable_blocks:
logging.info(f"Block {i} is set to be trainable.")
for param in block.parameters():
param.requires_grad = True
new_blocks.append(block)
else:
logging.info(f"Block {i} is frozen and removed.")
for param in block.parameters():
param.requires_grad = False
lrm_transformer.blocks = nn.ModuleList(new_blocks)
if hasattr(lrm_transformer, 'head'):
del lrm_transformer.head
lrm_transformer.head = None
if hasattr(config.model, 'lrm_transformer_path') and config.model.lrm_transformer_path:
logging.info(f"loading LRM transformer from {config.model.lrm_transformer_path}")
state_dict = load_state_dict(config.model.lrm_transformer_path)
lrm_transformer.load_state_dict(state_dict, strict=False)
else:
logging.info("No LRM transformer path specified, using base transformer")
lrm_transformer.to(dtype=torch.float32)
mlp_input_dim = config.lrm.mlp_dim
mlp = MLP(mlp_input_dim)
if hasattr(config.model, 'lrm_mlp_path') and config.model.lrm_mlp_path:
logging.info(f"loading MLP from {config.model.lrm_mlp_path}")
try:
mlp.load_state_dict(torch.load(config.model.lrm_mlp_path))
logging.info("Successfully loaded MLP from checkpoint")
except:
try:
mlp.load_state_dict(torch.load(config.model.lrm_mlp_path)["state_dict"])
logging.info("Successfully loaded MLP from checkpoint with state_dict key")
except Exception as e:
logging.error(f"Failed to load MLP from {config.model.lrm_mlp_path}: {e}")
logging.info("Using newly created MLP due to loading failure")
else:
logging.info("No MLP path specified, using newly created MLP")
mlp.to(basic_kwargs.device)
mlp.eval()
for param in mlp.parameters():
param.requires_grad = False
query_attention_config = getattr(config.lrm, 'query_attention', {})
num_queries = query_attention_config.get('num_queries', 1)
num_heads = query_attention_config.get('num_heads', 8)
dropout = query_attention_config.get('dropout', 0.)
layer_norm = query_attention_config.get('layer_norm', False)
return_type = query_attention_config.get('return_type', None)
product_text = query_attention_config.get('product_text', False)
text_dim = query_attention_config.get('text_dim', 4096)
query_attention = QueryAttention(
feature_dim=mlp_input_dim,
num_queries=num_queries,
num_heads=num_heads,
dropout=dropout,
return_type=return_type,
product_text=product_text,
text_dim=text_dim
)
if hasattr(config.model, 'lrm_query_attention_path') and config.model.lrm_query_attention_path:
logging.info(f"loading model query_attention from {config.model.lrm_query_attention_path}")
checkpoint = torch.load(config.model.lrm_query_attention_path)
query_attention.load_state_dict(checkpoint)
query_attention = query_attention.to(device=basic_kwargs.device, dtype=torch.float32)
query_attention.eval()
transformer.__class__.enable_teacache = False
lrm_transformer.__class__.enable_teacache = False
# Init LoRA for transformer
if config.model.lora.use_lora:
lora_config = LoraConfig(
r=config.model.lora.lora_rank,
lora_alpha=config.model.lora.lora_rank,
init_lora_weights=True,
target_modules=config.model.lora.target_modules,
)
transformer = get_peft_model(transformer, lora_config)
if config.model.lora.resume_lora_path:
lora_state_dict = load_lora_state_dict(config.model.lora.resume_lora_path)
m, u = transformer.load_state_dict(lora_state_dict, strict=False)
logging.info(f"load lora from {config.model.lora.resume_lora_path}.")
logging.info(f"miss {len(m)}; unexpect {len(u)}.")
resume_step = int(config.model.lora.resume_lora_path.split("-")[-1])
transformer = transformer.to(dtype=torch.float32)
# Init EMA
if config.model.ema.use_ema:
logging.info("loading ema model")
ema_transformer = deepcopy(transformer)
else:
ema_transformer = None
# Init FSDP
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
config.model.fsdp.fsdp_sharding_startegy,
config.model.lora.use_lora,
config.model.fsdp.use_cpu_offload,
master_weight_type="fp32",
)
if config.model.lora.use_lora:
transformer.config.lora_rank = config.model.lora.lora_rank
transformer.config.lora_alpha = config.model.lora.lora_rank
transformer.config.lora_target_modules = config.model.lora.target_modules
transformer._no_split_modules = [cls.__name__ for cls in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(transformer, **fsdp_kwargs)
lrm_transformer = FSDP(lrm_transformer, **fsdp_kwargs)
if config.model.ema.use_ema:
ema_transformer = FSDP(ema_transformer, **fsdp_kwargs)
# Init gradient checkpointing
if config.model.gradient_checkpointing:
apply_fsdp_checkpointing(
transformer, no_split_modules, config.model.selective_checkpointing
)
apply_fsdp_checkpointing(
lrm_transformer, no_split_modules, config.model.selective_checkpointing
)
if config.model.ema.use_ema:
apply_fsdp_checkpointing(
ema_transformer, no_split_modules, config.model.selective_checkpointing
)
logging.info("enable gradient checkpointing")
# Set model as trainable
transformer.train()
print_parameters_information(transformer, "WAN", basic_kwargs.rank)
if config.model.ema.use_ema:
ema_transformer.requires_grad_(False)
print_parameters_information(ema_transformer, "WAN EMA", basic_kwargs.rank)
model_kwargs = EasyDict(
{
"transformer": transformer,
"ema_transformer": ema_transformer,
"resume_step": resume_step,
"lrm_transformer": lrm_transformer,
"query_attention": query_attention,
"mlp": mlp,
}
)
return model_kwargs
def extra_model_init(config, basic_kwargs):
# base_dir = os.path.join(config.model.base_path, NAME_MAPPING["i2v-14b-480p"])
base_dir = config.model.base_path
# Init noise scheduler
noise_scheduler = FlowMatchDiscreteScheduler(
shift=config.extra_model.scheduler.flow_shift
)
noise_scheduler.set_timesteps(
config.extra_model.scheduler.num_train_timesteps, dtype=torch.int64
)
noise_scheduler_refl =FlowUniPCMultistepScheduler(num_train_timesteps= config.extra_model.scheduler.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
vae = WanVAE(
vae_pth=os.path.join(base_dir, config.extra_model.vae.name),
dtype = basic_kwargs.dtype,
device=basic_kwargs.device,
)
tokenizer = None
text_encoder =None
image_encoder = None
extra_model_kwargs = EasyDict(
{
"noise_scheduler": noise_scheduler,
"noise_scheduler_refl":noise_scheduler_refl,
"vae": vae,
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"image_encoder": image_encoder,
"reward_model": None,
}
)
logging.info(f"extra model initialized")
return extra_model_kwargs
def dataloader_init(config, basic_kwargs, resume_step=0):
dataset = Image2VideoTrainDataset(
dataset_type="refl",
task=config.task,
meta_file_list=config.dataset.meta_file_list,
uncond_prob=config.dataset.uncond_prob,
sp_size=config.dataset.sp_size,
patch_size=config.model.patch_size
)
logging.info(f"dataset length {len(dataset)}")
sampler = BlockDistributedSampler(
dataset=dataset,
num_replicas=basic_kwargs.world_size // nccl_info.sp_size,
rank=nccl_info.group_id,
shuffle=True,
seed=config.train.seed,
drop_last=True,
batch_size=config.dataset.batch_size,
start_index=resume_step
)
dataloader = DataLoader(
dataset,
sampler=sampler,
pin_memory=True,
batch_size=config.dataset.batch_size,
num_workers=config.dataset.num_workers,
drop_last=True,
worker_init_fn=set_worker_seed_builder(basic_kwargs.rank),
persistent_workers=False if config.dataset.num_workers == 0 else True
)
return VideoImageBatchIterator(video_dataloader=dataloader, sp_size=nccl_info.sp_size)
def optimizer_init(config, basic_kwargs, model_kwargs):
transformer = model_kwargs.transformer
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=config.optimizer.learning_rate,
betas=(config.optimizer.adam_beta1, config.optimizer.adam_beta2),
weight_decay=config.optimizer.weight_decay,
eps=1e-8,
)
lr_scheduler = get_scheduler(
config.optimizer.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.optimizer.lr_warmup_steps,
num_training_steps=config.optimizer.max_train_steps,
num_cycles=config.optimizer.lr_num_cycles,
power=config.optimizer.lr_power,
)
optimizer_kwargs = EasyDict({"optimizer": optimizer, "lr_scheduler": lr_scheduler})
logging.info("optimizer initialized")
return optimizer_kwargs
def before_train_step(config, sp_dataloader, basic_kwargs, extra_model_kwargs):
# Model
vae = extra_model_kwargs.vae
text_encoder = extra_model_kwargs.text_encoder
image_encoder = extra_model_kwargs.image_encoder
if vae is not None:
vae.model.requires_grad_(False)
vae.model.eval()
# Data
(
latents,
text_states,
uncond_text_states,
image_embeds,
latents_condition,
long_caption
) = next(sp_dataloader)
latents = latents.to(basic_kwargs.device, dtype=basic_kwargs.dtype)
text_states = text_states.to(basic_kwargs.device, dtype=basic_kwargs.dtype)
uncond_text_states =uncond_text_states.to(basic_kwargs.device, dtype=basic_kwargs.dtype)
latents_condition = (
latents_condition.to(basic_kwargs.device, dtype=basic_kwargs.dtype)
if "i2v" in config.task or "flf2v" in config.task
else None
)
if latents_condition is not None:
b,c,f,h,w = latents_condition.shape
mask_lat_size = torch.ones((b,4,f,h,w), dtype=basic_kwargs.dtype, device=basic_kwargs.device)
mask_lat_size[:,:,1:,...]=0.0
if int(c)==16:
latents_condition = torch.concat([mask_lat_size, latents_condition], dim=1)
image_embeds = (
image_embeds.to(basic_kwargs.device, dtype=basic_kwargs.dtype)
if "i2v" in config.task or "flf2v" in config.task
else None
)
if image_embeds is not None:
N = image_embeds.shape[1] // 257
image_embeds = rearrange(image_embeds, "b (n s) d -> (b n) s d", n=N)
if config.dataset.sp_size <= 1:
latents, latents_condition = crop_tensor(
latents,
latents_condition,
config.dataset.crop_ratio[0],
config.dataset.crop_ratio[1],
config.dataset.crop_type,
crop_time_ratio=config.dataset.crop_ratio[2],
)
_, _, latents_t, latents_h, latents_w = latents.shape
max_sequence_length = (
latents_t
* latents_h
* latents_w
// (config.model.patch_size[1] * config.model.patch_size[2])
)
data_kwargs = EasyDict(
{
"latents": latents,
"text_states": text_states,
"image_embeds": image_embeds,
"latents_condition": latents_condition,
"max_sequence_length": max_sequence_length,
"uncond_text_states":uncond_text_states,
"text_prompt": long_caption,
}
)
return data_kwargs
def train_step_refl(
config,
step,
basic_kwargs,
model_kwargs,
extra_model_kwargs,
optimizer_kwargs,
data_kwargs,
):
log_memory_usage("Training step start", dist.get_rank() if hasattr(dist, 'get_rank') else None)
transformer = model_kwargs.transformer
transformer.gradient_checkpointing_enable() if hasattr(transformer, 'gradient_checkpointing_enable') else None
lrm_transformer = model_kwargs.lrm_transformer
query_attention = model_kwargs.query_attention
mlp = model_kwargs.mlp
vae = extra_model_kwargs.vae
if vae is not None:
if hasattr(vae.model, 'gradient_checkpointing_enable'):
vae.model.gradient_checkpointing_enable()
logging.info("Enabled gradient checkpointing for VAE model")
else:
if hasattr(vae.model, 'enable_gradient_checkpointing'):
vae.model.enable_gradient_checkpointing()
logging.info("Enabled gradient checkpointing for VAE model via alternative method")
else:
logging.warning("Gradient checkpointing not supported for VAE model")
else:
logging.info("VAE is None, skipping VAE-related operations")
noise_scheduler = extra_model_kwargs.noise_scheduler_refl
latents = data_kwargs.latents
text_states = data_kwargs.text_states
latents_condition = data_kwargs.latents_condition
image_embeds = data_kwargs.image_embeds
max_sequence_length = data_kwargs.max_sequence_length
prompt = data_kwargs.text_prompt
# Optimizer
optimizer = optimizer_kwargs.optimizer
lr_scheduler = optimizer_kwargs.lr_scheduler
# Forward
bsz = latents.shape[0]
inference_steps = 40
noise_scheduler.set_timesteps(num_inference_steps=inference_steps, device=basic_kwargs.device, shift=config.extra_model.scheduler.flow_shift)
timesteps = noise_scheduler.timesteps
transformer.eval()
latent = torch.randn_like(latents)
if basic_kwargs.rank == 0:
mid_timestep = random.randint(0, inference_steps - 2)
else:
mid_timestep = 0
del latents
torch.cuda.empty_cache()
gc.collect()
log_memory_usage("After creating noise latents", dist.get_rank() if hasattr(dist, 'get_rank') else None)
mid_timestep_tensor = torch.tensor(mid_timestep, device=latent.device, dtype=torch.long)
dist.broadcast(mid_timestep_tensor, src=0)
mid_timestep = mid_timestep_tensor.item()
# 序列并行广播
if config.dataset.sp_size > 1:
if "i2v" in config.task or "flf2v" in config.task:
broadcast(latents_condition)
broadcast(image_embeds)
broadcast(latent)
broadcast(text_states)
log_memory_usage("After sequence parallel broadcast", dist.get_rank() if hasattr(dist, 'get_rank') else None)
# ========== 1. infer with no grad to mid timestep ==========
with torch.no_grad():
for i in range(mid_timestep):
t = timesteps[i]
with torch.autocast("cuda", dtype=basic_kwargs.dtype):
latent_model_input = latent
timestep_tensor = torch.tensor([t], device=basic_kwargs.device)
arg_c = {
"x": batch2list(latent_model_input),
"t": timestep_tensor,
"context": batch2list(text_states),
"seq_len": max_sequence_length,
"clip_fea": image_embeds,
"y": (
batch2list(latents_condition)
if "i2v" in config.task or "flf2v" in config.task
else None
),
'cond_flag': True,
}
noise_pred = transformer(**arg_c)
noise_pred = list2batch(noise_pred)
scheduler_output = noise_scheduler.step(noise_pred, t, latent, return_dict=False)
latent = scheduler_output[0] if isinstance(scheduler_output, tuple) else scheduler_output
del latent_model_input, timestep_tensor, noise_pred, scheduler_output, arg_c
torch.cuda.empty_cache()
if i % 10 == 0:
gc.collect()
dist.barrier()
log_memory_usage(f"After inference step {i}", dist.get_rank() if hasattr(dist, 'get_rank') else None)
log_memory_usage("After inference loop", dist.get_rank() if hasattr(dist, 'get_rank') else None)
# ========== 2. cal gradient ==========
transformer.train()
t_mid = timesteps[mid_timestep]
timestep_mid = torch.tensor([t_mid], device=basic_kwargs.device)
arg_c = {
"x": batch2list(latent),
"t": timestep_mid,
"context": batch2list(text_states),
"seq_len": max_sequence_length,
"clip_fea": image_embeds,
"y": (
batch2list(latents_condition)
if "i2v" in config.task or "flf2v" in config.task
else None
),
'cond_flag': True,
}
with torch.autocast("cuda", dtype=basic_kwargs.dtype, enabled=True):
noise_pred = transformer(**arg_c)
noise_pred = list2batch(noise_pred)
del timestep_mid, arg_c
torch.cuda.empty_cache()
gc.collect()
log_memory_usage("After gradient computation", dist.get_rank() if hasattr(dist, 'get_rank') else None)
# ========== 3. cal pred_original_sample ==========
scheduler_output = noise_scheduler.step(noise_pred, t_mid, latent,return_dict=False)
latent = scheduler_output[0] if isinstance(scheduler_output, tuple) else scheduler_output
del scheduler_output
torch.cuda.empty_cache()
gc.collect()
dist.barrier()
log_memory_usage("After pred_original_sample computation", dist.get_rank() if hasattr(dist, 'get_rank') else None)
# ========== 4. cal reward ==========
t_mid_1 = timesteps[mid_timestep+1]
timestep_mid_1 = torch.tensor([t_mid_1], device=basic_kwargs.device)
with torch.autocast("cuda", dtype=basic_kwargs.dtype, enabled=True):
lrm_cond_kwargs = {
"x": batch2list(latent),
"t": timestep_mid_1,
"context": batch2list(text_states),
"seq_len": max_sequence_length,
"clip_fea": image_embeds,
"y": (
batch2list(latents_condition)
if "i2v" in config.task or "flf2v" in config.task
else None
),
"output_features": True,
"selected_layers": config.lrm.feature_layer,
}
lrm_features = lrm_transformer(**lrm_cond_kwargs)
lrm_features = list2batch(lrm_features)
if config.dataset.sp_size > 1:
if len(lrm_features.shape) == 4: # [sp_size, batch, seq_len_per_device, feature_dim]
if config.lrm.pool == 'q_attn':
lrm_features_final = query_attention(lrm_features)
else:
lrm_features_pooled = lrm_features.mean(dim=2) # [sp_size, batch, feature_dim]
lrm_features_final = lrm_features_pooled.mean(dim=0) # [batch, feature_dim]
else:
original_batch_size = bsz
lrm_features_flat = lrm_features.view(original_batch_size, -1)
lrm_features_final = lrm_features_flat.mean(dim=1, keepdim=True) # [batch, 1]
else:
if len(lrm_features.shape) == 3: # [batch, seq_len, feature_dim]
if config.lrm.pool == 'q_attn':
lrm_features_final = query_attention(lrm_features)
else:
lrm_features_final = lrm_features.mean(dim=1) # [batch, feature_dim]
elif len(lrm_features.shape) == 4: # [batch, channels, seq_len, feature_dim] or similar
if config.lrm.pool == 'q_attn':
lrm_features_final = query_attention(lrm_features)
else:
lrm_features_pooled = lrm_features.mean(dim=2) # [batch, feature_dim]
lrm_features_final = lrm_features_pooled.mean(dim=1) # [batch, feature_dim]
elif len(lrm_features.shape) == 2: # [batch, feature_dim] - already good
lrm_features_final = lrm_features
else:
batch_size = lrm_features.shape[0]
lrm_features_final = lrm_features.view(batch_size, -1).mean(dim=1, keepdim=True) # [batch, 1]
reward_scores = forward_mlp(mlp, lrm_features_final)
target_reward = 2
loss = 0.1 * F.relu(-reward_scores.squeeze() + target_reward).mean()
# 检查损失值是否有效
if torch.isnan(loss) or torch.isinf(loss):
print("ERROR: Loss is NaN or Inf!")
del lrm_features, lrm_features_final, reward_scores, lrm_cond_kwargs, timestep_mid_1, t_mid_1
del image_embeds, text_states, latents_condition, noise_pred, latent
torch.cuda.empty_cache()
gc.collect()
return {"loss": torch.tensor(0.0), "grad_norm": 0}
if abs(loss.item()) > 1e6:
print(f"WARNING: Loss value {loss.item()} is very large, clipping to 1e6")
loss = torch.clamp(loss, -1e6, 1e6)
del lrm_features, lrm_features_final, reward_scores, lrm_cond_kwargs, timestep_mid_1, t_mid_1, image_embeds, text_states, latents_condition
torch.cuda.empty_cache()
gc.collect()
dist.barrier()
log_memory_usage("After LRM computation", dist.get_rank() if hasattr(dist, 'get_rank') else None)
# ========== 5. backwards ==========
try:
loss /= config.train.gradient_accumulation_steps
loss.backward()
grad_norm = transformer.clip_grad_norm_(max_norm=1.0)
if (step + 1) % config.train.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
except Exception as e:
print(f"ERROR during backward/optimization: {e}")
del latent
torch.cuda.empty_cache()
gc.collect()
return {"loss": torch.tensor(0.0), "grad_norm": 0}
torch.cuda.empty_cache()
gc.collect()
dist.barrier()
log_memory_usage("After optimization", dist.get_rank() if hasattr(dist, 'get_rank') else None)
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, dist.ReduceOp.AVG)
# Logs results
if (
config.train.sanity_check_interval >= 0 and step <= 50 # and step % config.train.sanity_check_interval == 0
):
if basic_kwargs.rank == 0:
with torch.no_grad():
sigma_t = noise_scheduler.sigmas[mid_timestep+1]
pred_original_sample = latent - sigma_t * noise_pred
pred_x0_s = vae_decode(
vae, pred_original_sample.clone().detach(), dtype=basic_kwargs.dtype, vae_type="wanx"
)
latents_s = vae_decode(
vae, latent.clone(), dtype=basic_kwargs.dtype, vae_type="wanx"
)
print("save_videos_grid:",os.path.join(
config.save.sanity_check_dir,
f"step{step}_pred_x0_rank{basic_kwargs.rank}_{sigma_t.item()}.mp4",
))
save_videos_grid(
pred_x0_s.to(torch.float32).cpu(),
os.path.join(
config.save.sanity_check_dir,
f"step{step}_pred_x0_rank{basic_kwargs.rank}_{sigma_t.item()}.mp4",
),
fps=15,
rescale=True,
)
save_videos_grid(
latents_s.to(torch.float32).cpu(),
os.path.join(
config.save.sanity_check_dir,
f"step{step}_real_x0_rank{basic_kwargs.rank}.mp4",
),
fps=15,
rescale=True,
)
del pred_original_sample, sigma_t, latents_s, pred_x0_s
torch.cuda.empty_cache()
gc.collect()
log_kwargs = EasyDict({
"loss": avg_loss,
"grad_norm": grad_norm,
})
del latent,noise_pred
dist.barrier()
free_memory()
log_memory_usage("Training step end", dist.get_rank() if hasattr(dist, 'get_rank') else None)
return log_kwargs
def train_step(
config,
step,
basic_kwargs,
model_kwargs,
extra_model_kwargs,
optimizer_kwargs,
data_kwargs,
):
# Model
transformer = model_kwargs.transformer
vae = extra_model_kwargs.vae
noise_scheduler = extra_model_kwargs.noise_scheduler
latents = data_kwargs.latents
text_states = data_kwargs.text_states
latents_condition = data_kwargs.latents_condition
image_embeds = data_kwargs.image_embeds
max_sequence_length = data_kwargs.max_sequence_length
# Optimizer
optimizer = optimizer_kwargs.optimizer
lr_scheduler = optimizer_kwargs.lr_scheduler
# Forward
bsz = latents.shape[0]
noise = torch.randn_like(latents)
timestep, sigma = noise_scheduler.get_train_timestep_and_sigma(
weighting_scheme=config.extra_model.scheduler.weighting_scheme,
batch_size=bsz,
logit_mean=config.extra_model.scheduler.logit_mean,
logit_std=config.extra_model.scheduler.logit_std,
device=latents.device,
n_dim=latents.ndim,
)
if config.dataset.sp_size > 1:
if "i2v" in config.task or "flf2v" in config.task:
broadcast(latents_condition)
broadcast(image_embeds)
broadcast(sigma)
broadcast(noise)
broadcast(timestep)
broadcast(latents)
broadcast(text_states)
noisy_latents = noise_scheduler.add_noise(latents, noise, sigma)
cond_kwargs = {
"x": batch2list(noisy_latents),
"t": timestep,
"context": batch2list(text_states),
"seq_len": max_sequence_length,
"clip_fea": image_embeds,
"y": (
batch2list(latents_condition)
if "i2v" in config.task or "flf2v" in config.task
else None
),
}
with torch.autocast("cuda", dtype=basic_kwargs.dtype):
model_pred = transformer(**cond_kwargs)
model_pred = list2batch(model_pred)
training_target = noise_scheduler.get_train_target(latents, noise)
weighting = noise_scheduler.get_train_loss_weighting(sigma)
loss = torch.mean(
weighting.float() * (model_pred.float() - training_target.float()) ** 2
)
loss /= config.train.gradient_accumulation_steps
loss.backward()
grad_norm = transformer.clip_grad_norm_(max_norm=1.0)
if (step + 1) % config.train.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, dist.ReduceOp.AVG)
# Compute loss
log_kwargs = EasyDict(
{
"loss": avg_loss,
"grad_norm": grad_norm,
}
)
del sigma, noise,timestep,latents,text_states, latents_condition, image_embeds,loss,training_target,weighting,model_pred
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(log_kwargs)
if (
config.train.sanity_check_interval > 0
and step % config.train.sanity_check_interval == 0
and step <= 50
):
if basic_kwargs.rank == 0:
pred_x0 = noise_scheduler.get_x0(model_pred, noisy_latents, sigma)
pred_x0 = pred_x0.to(dtype=basic_kwargs.dtype)
pred_x0_s = vae_decode(
vae, pred_x0.clone().detach(), dtype=basic_kwargs.dtype, vae_type="wanx"
)
latents_s = vae_decode(
vae, latents.clone(), dtype=basic_kwargs.dtype, vae_type="wanx"
)
print("save_videos_grid_path:",os.path.join(
config.save.sanity_check_dir,
f"step{step}_pred_x0_rank{basic_kwargs.rank}_{sigma.item()}.mp4",
))
save_videos_grid(
pred_x0_s.to(torch.float32).cpu(),
os.path.join(
config.save.sanity_check_dir,
f"step{step}_pred_x0_rank{basic_kwargs.rank}_{sigma.item()}.mp4",
),
fps=15,
rescale=True,
)
save_videos_grid(
latents_s.to(torch.float32).cpu(),
os.path.join(
config.save.sanity_check_dir,
f"step{step}_real_x0_rank{basic_kwargs.rank}.mp4",
),
fps=15,
rescale=True,
)
dist.barrier()
free_memory()
return log_kwargs
def after_train_step(config, step, basic_kwargs, model_kwargs,
log_kwargs_normal, log_kwargs_reward, writer):
transformer = model_kwargs.transformer
ema_transformer = model_kwargs.ema_transformer
log_loss_normal = log_kwargs_normal.loss
log_grad_norm_normal = log_kwargs_normal.grad_norm
log_step_time_normal = log_kwargs_normal.step_time
log_avg_step_time_normal = log_kwargs_normal.avg_step_time
log_lr = log_kwargs_normal.lr
log_loss_reward = log_kwargs_reward.loss
log_grad_norm_reward = log_kwargs_reward.grad_norm
log_step_time_reward = log_kwargs_reward.step_time
log_avg_step_time_reward = log_kwargs_reward.avg_step_time
if basic_kwargs.local_rank == 0:
log_info = (
f"│ Rank {basic_kwargs.rank:02d} │ Workers: {basic_kwargs.world_size} │ "
f"Step {step:05d} │ LR: {log_lr:.2e} │\n"
f"│ Normal - Loss: {log_loss_normal:.4f} │ Grad: {log_grad_norm_normal:.4f} │ "
f"Time: {log_step_time_normal:>6.2f}s │ Avg: {log_avg_step_time_normal:>6.2f}s │\n"
f"│ Reward - Loss: {log_loss_reward:.4f} │ Grad: {log_grad_norm_reward:.4f} │ "
f"Time: {log_step_time_reward:>6.2f}s │ Avg: {log_avg_step_time_reward:>6.2f}s │"
)
print(log_info)
if basic_kwargs.rank == 0 and writer is not None:
writer.add_scalar('train/normal_loss', log_loss_normal, step)
writer.add_scalar('train/normal_grad_norm', log_grad_norm_normal, step)
writer.add_scalar('train/normal_step_time', log_step_time_normal, step)
writer.add_scalar('train/normal_avg_step_time', log_avg_step_time_normal, step)
writer.add_scalar('train/reward_loss', log_loss_reward, step)
writer.add_scalar('train/reward_grad_norm', log_grad_norm_reward, step)
writer.add_scalar('train/reward_step_time', log_step_time_reward, step)
writer.add_scalar('train/reward_avg_step_time', log_avg_step_time_reward, step)
writer.add_scalar('train/lr', log_lr, step)
total_loss = log_loss_normal + log_loss_reward
total_time = log_step_time_normal + log_step_time_reward
writer.add_scalar('train/total_loss', total_loss, step)
writer.add_scalar('train/total_step_time', total_time, step)
if basic_kwargs.rank == 0:
with open(basic_kwargs.log_path, "a", encoding="utf-8") as f:
f.write(log_info + "\n")
if config.model.ema.use_ema:
dist.barrier()
update_ema_model(transformer, ema_transformer, config.model.ema.ema_decay)
if config.train.save_interval > 0 and step % config.train.save_interval == 0:
dist.barrier()
if config.model.lora.use_lora:
save_lora_checkpoint(transformer, basic_kwargs.rank, config.save.ckpt_dir, step)
if config.model.ema.use_ema:
save_lora_checkpoint(ema_transformer, basic_kwargs.rank,
config.save.ckpt_dir, step, ema=True)
else:
save_checkpoint(transformer, basic_kwargs.rank, config.save.ckpt_dir, step)
if config.model.ema.use_ema:
save_checkpoint(ema_transformer, basic_kwargs.rank,
config.save.ckpt_dir, step, ema=True)
logging.info(f"Checkpoint saved at step {step}")
free_memory()
def main(config):
config, basic_kwargs = basic_init(config)
model_kwargs = model_init(config, basic_kwargs)
extra_model_kwargs = extra_model_init(config, basic_kwargs)
optimizer_kwargs = optimizer_init(config, basic_kwargs, model_kwargs)
sp_dataloader = dataloader_init(config, basic_kwargs, model_kwargs.resume_step)
dist.barrier()
free_memory()
writer = SummaryWriter(config.save.tensorboard_dir) if basic_kwargs.rank == 0 else None
total_batch_size = (
config.dataset.batch_size
* (basic_kwargs.world_size // nccl_info.sp_size)
* config.train.gradient_accumulation_steps
)
logging.info("***** Running training *****")
logging.info(
f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
)
logging.info(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in model_kwargs['transformer'].parameters() if p.requires_grad) / 1e9} B"
)
step_times = deque(maxlen=100)
step_times_2 = deque(maxlen=100)
for step in range(
model_kwargs.resume_step + 1, config.optimizer.max_train_steps + 1
):
start_time = time.time()
data_kwargs = before_train_step(
config, sp_dataloader, basic_kwargs, extra_model_kwargs
)
log_kwargs = train_step(
config,
step,
basic_kwargs,
model_kwargs,
extra_model_kwargs,
optimizer_kwargs,
data_kwargs,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
log_kwargs.update(
{
"step_time": step_time,
"avg_step_time": avg_step_time,
"lr": optimizer_kwargs.optimizer.param_groups[0]["lr"],
}
)
start_time = time.time()
log_kwargs2 = train_step_refl(
config,
step,
basic_kwargs,
model_kwargs,
extra_model_kwargs,
optimizer_kwargs,
data_kwargs,
)
step_time_2 = time.time() - start_time
step_times_2.append(step_time_2)
avg_step_time_2 = sum(step_times_2) / len(step_times_2)
log_kwargs2.update(
{
"step_time": step_time_2,
"avg_step_time": avg_step_time_2,
"lr": optimizer_kwargs.optimizer.param_groups[0]["lr"],
}
)
after_train_step(config, step, basic_kwargs, model_kwargs, log_kwargs, log_kwargs2, writer)
if basic_kwargs.rank == 0 and writer is not None:
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
required=True,
default="scripts/train/train_wanx.yaml",
)
args = parser.parse_args()
main(OmegaConf.load(args.config_path))