import torch, os, json, sys _THIS_DIR = os.path.dirname(__file__) if _THIS_DIR not in sys.path: sys.path.insert(0, _THIS_DIR) _DIFFSYNTH_ROOT = os.path.join(_THIS_DIR, "DiffSynth-Studio-main") if _DIFFSYNTH_ROOT not in sys.path: sys.path.insert(0, _DIFFSYNTH_ROOT) from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, launch_training_task, wan_parser os.environ["TOKENIZERS_PARALLELISM"] = "false" import yaml import torch import imageio, os, torch, warnings, torchvision, argparse, json from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs import matplotlib.pyplot as plt import os import re from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset # from modules.wanx_module import WanTrainingModule class WanTrainingModule(DiffusionTrainingModule): def __init__( self, model_paths=None, model_id_with_origin_paths=None, trainable_models=None, lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, max_timestep_boundary=1.0, min_timestep_boundary=0.0, local_model_path=None, ): super().__init__() # Load models model_configs = [] if model_paths is not None: model_paths = json.loads(model_paths) model_configs += [ModelConfig(path=path) for path in model_paths] if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") model_configs += [ModelConfig(local_model_path = local_model_path, model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, redirect_common_files=False) # Reset training scheduler self.pipe.scheduler.set_timesteps(1000, training=True) # Freeze untrainable models self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) # Add LoRA to the base models if lora_base_model is not None: model = self.add_lora_to_model( getattr(self.pipe, lora_base_model), target_modules=lora_target_modules.split(","), lora_rank=lora_rank ) setattr(self.pipe, lora_base_model, model) # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload # print(use_gradient_checkpointing_offload) self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.max_timestep_boundary = max_timestep_boundary self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): # CFG-sensitive parameters inputs_posi = {"prompt": [d["pre_shot_caption"] for d in data], "global_caption": None} inputs_nega = {} # CFG-unsensitive parameters inputs_shared = { # Assume you are using this pipeline for inference, # please fill in the input parameters. "input_video": [d["video"] for d in data], "height": data[0]["video"][0].size[1], "width": data[0]["video"][0].size[0], "num_frames": len(data[0]["video"]), "ref_images": [d["ref_images"] for d in data], # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, "tiled": False, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, "max_timestep_boundary": self.max_timestep_boundary, "min_timestep_boundary": self.min_timestep_boundary, "num_ref_images": data[0]["ref_num"], "batch_size": len(data), } # Extra inputs # for extra_input in self.extra_inputs: # if extra_input == "input_image": # inputs_shared["input_image"] = data["video"][0] # elif extra_input == "end_image": # inputs_shared["end_image"] = data["video"][-1] # elif extra_input == "reference_image" or extra_input == "vace_reference_image": # inputs_shared[extra_input] = data[extra_input][0] # elif extra_input == "input_video": # inputs_shared["input_pre_video"] = [data["prev_video"][i] for i in range(len(data["prev_video"]))] # elif extra_input == "cropped_images": # inputs_shared["ref_images"] = [data["ref_images"][i] for i in range(len(data["cropped_images"]))] # else: # inputs_shared[extra_input] = data[extra_input] # Pipeline units will automatically process the input parameters. for unit in self.pipe.units: inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) return {**inputs_shared, **inputs_posi} def forward(self, data, args, inputs=None): if inputs is None: inputs = self.forward_preprocess(data) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(args = args, **models, **inputs) return loss if __name__ == "__main__": parser = wan_parser() args = parser.parse_args() args, unknown = parser.parse_known_args() print("❗ Unknown arguments:", unknown) ### 执行过pip install -e . 的话diffsynth 里的东西修改后要重新安装 # import pdb; pdb.set_trace() ###下面是解析train.yaml里的内容 with open(args.train_yaml, "r", encoding="utf-8") as f: conf_info = yaml.safe_load(f) # 用 safe_load 更安全 print(conf_info) args.dataset_base_path = conf_info["dataset_args"]["base_path"] args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"] args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"] args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"] args.seed = conf_info["train_args"]["seed"] args.output_path = conf_info["train_args"]["output_path"] args.save_steps = conf_info["train_args"]["save_steps"] args.save_epoches = conf_info["train_args"]["save_epoches"] print("outpath:", args.output_path) print("visual_log_project_name:", args.visual_log_project_name) # 检查None值防止字符串拼接错误 if args.output_path is None or args.visual_log_project_name is None: raise ValueError(f"output_path或visual_log_project_name为None: output_path={args.output_path}, visual_log_project_name={args.visual_log_project_name}") args.output_path = args.output_path + "/" + args.visual_log_project_name args.batch_size = conf_info["train_args"]["batch_size"] args.local_model_path = conf_info["train_args"]["local_model_path"] if "model_id_with_origin_paths" in conf_info["train_args"]: args.model_id_with_origin_paths = conf_info["train_args"]["model_id_with_origin_paths"] if "trainable_models" in conf_info["train_args"]: args.trainable_models = conf_info["train_args"]["trainable_models"] if "learning_rate" in conf_info["train_args"]: args.learning_rate = float(conf_info["train_args"]["learning_rate"]) args.debug_infer = bool(conf_info["train_args"].get("debug_infer", False)) args.debug_infer_interval = int(conf_info["train_args"].get("debug_infer_interval", 1)) args.debug_infer_steps = int(conf_info["train_args"].get("debug_infer_steps", 8)) args.debug_infer_cfg_scale = float(conf_info["train_args"].get("debug_infer_cfg_scale", 5.0)) args.debug_infer_cfg_scale_face = float(conf_info["train_args"].get("debug_infer_cfg_scale_face", 5.0)) args.debug_infer_seed = int(conf_info["train_args"].get("debug_infer_seed", args.seed)) args.debug_infer_tiled = bool(conf_info["train_args"].get("debug_infer_tiled", True)) args.debug_infer_use_input_video = bool(conf_info["train_args"].get("debug_infer_use_input_video", True)) args.debug_infer_negative_prompt = conf_info["train_args"].get("debug_infer_negative_prompt", "") args.debug_infer_indices = conf_info["train_args"].get("debug_infer_indices", [0]) args.zero_face_ratio = conf_info["train_args"]["zero_face_ratio"] args.split_rope = conf_info["train_args"]["split_rope"] args.split1 = conf_info["train_args"]["split1"] args.split2 = conf_info["train_args"]["split2"] args.split3 = conf_info["train_args"]["split3"] if args.batch_size != 1: args.learning_rate = min(args.learning_rate * ((args.batch_size * 1 / 2) * 1.5), args.learning_rate * 10) ### TODO:如果是多机的话,那要乘上机子的数量 args.height = conf_info["dataset_args"]["height"] args.width = conf_info["dataset_args"]["width"] args.num_frames = conf_info["dataset_args"]["num_frames"] args.ref_num = conf_info["dataset_args"]["ref_num"] # args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"]+"_{}".formate(args.height)+"_{}".formate(args.width) dataset = MulltiShot_MultiView_Dataset( dataset_base_path=args.dataset_base_path, resolution=(args.height, args.width), ref_num=args.ref_num, training=True ) model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, local_model_path = args.local_model_path ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) launch_training_task( args, dataset, model, optimizer, scheduler, num_epochs =args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, output_path = args.output_path, save_steps = args.save_steps, # 新增:每多少步保存一次 save_epoches = args.save_epoches, max_checkpoints_to_keep = args.max_checkpoints_to_keep, # 最多只保留 5 个最新的检查点 resume_from_checkpoint = args.resume_from_checkpoint, # 新增:从何处恢复 seed = args.seed, visual_log_project_name = args.visual_log_project_name, )