#!/usr/bin/env python3 """ InstanceV 训练脚本 基于 DiffSynth-Studio 框架,训练 WanVideo 的 InstanceV 模块: - IMCA (Instance-aware Masked Cross-Attention) - STAPE (Shared Timestep Adaptive Prompt Enhancement) - SAUG (Spatially-Aware Unconditional Guidance) 训练时的 dropout 主要参考: - wan_video_dit_instancev.py: 模型定义 - wan_video_instanceV.py: Pipeline 定义 - train.py: 基础训练框架 """ import torch import os import sys import json import argparse import accelerate import warnings from pathlib import Path from PIL import Image from typing import Optional from datetime import datetime import numpy as np # Wandb (可选) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False print("[Warning] wandb not installed. Run 'pip install wandb' to enable wandb logging.") # 添加项目根目录到 path sys.path.insert(0, str(Path(__file__).resolve().parents[3])) from diffsynth.core import UnifiedDataset, load_state_dict from diffsynth.core.data.operators import LoadVideo, ToAbsolutePath from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig from diffsynth.diffusion import ( DiffusionTrainingModule, add_general_config, add_video_size_config, launch_training_task, launch_data_process_task, ModelLogger, ) os.environ["TOKENIZERS_PARALLELISM"] = "false" class LoadInstanceMasks: """ 自定义数据算子:加载 instance mask 序列 mask 目录结构: {mask_dir}/{frame_id:06d}_No.{instance_id}.png """ def __init__( self, num_frames: int, time_division_factor: int = 4, time_division_remainder: int = 1, target_height: Optional[int] = None, target_width: Optional[int] = None, ): self.num_frames = num_frames self.time_division_factor = time_division_factor self.time_division_remainder = time_division_remainder self.target_height = target_height self.target_width = target_width def _adjust_num_frames(self, num_frames: int) -> int: num_frames = min(int(num_frames), self.num_frames) while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: num_frames -= 1 return num_frames def _crop_and_resize_mask(self, mask: Image.Image) -> Image.Image: if self.target_height is None or self.target_width is None: return mask width, height = mask.size scale = max(self.target_width / width, self.target_height / height) new_w = max(1, int(round(width * scale))) new_h = max(1, int(round(height * scale))) if (new_w, new_h) != mask.size: mask = mask.resize((new_w, new_h), resample=Image.NEAREST) left = max(0, (new_w - self.target_width) // 2) top = max(0, (new_h - self.target_height) // 2) right = left + self.target_width bottom = top + self.target_height return mask.crop((left, top, right, bottom)) def __call__(self, instance_mask_dirs: list) -> list: """ Args: instance_mask_dirs: list of {"mask_dir": str, "instance_id": int, "num_frames": int} Returns: list of list of PIL.Image: [Nins][num_frames] """ all_masks = [] for mask_info in instance_mask_dirs: mask_dir = mask_info["mask_dir"] inst_id = mask_info["instance_id"] raw_num_frames = mask_info.get("num_frames", self.num_frames) num_frames = self._adjust_num_frames(raw_num_frames) masks = [] blank_size = ( self.target_width if self.target_width is not None else 64, self.target_height if self.target_height is not None else 64, ) for frame_idx in range(num_frames): mask_path = os.path.join(mask_dir, f"{frame_idx:06d}_No.{inst_id}.png") if os.path.exists(mask_path): mask = Image.open(mask_path).convert("L") else: # 如果 mask 不存在,创建全黑 mask mask = Image.new("L", blank_size, 0) mask = self._crop_and_resize_mask(mask) masks.append(mask) # 采样到目标帧数 target_frames = (num_frames - self.time_division_remainder) // self.time_division_factor * self.time_division_factor + self.time_division_remainder if len(masks) > target_frames: # 均匀采样 indices = np.linspace(0, len(masks) - 1, target_frames, dtype=int) masks = [masks[i] for i in indices] all_masks.append(masks) return all_masks class InstanceVTrainingModule(DiffusionTrainingModule): """ InstanceV 训练模块 继承自 DiffusionTrainingModule,添加 InstanceV 特定的处理逻辑 """ def __init__( self, model_paths=None, model_id_with_origin_paths=None, tokenizer_path=None, trainable_models=None, lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, preset_lora_path=None, preset_lora_model=None, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, fp8_models=None, offload_models=None, device="cpu", task="sft", max_timestep_boundary=1.0, min_timestep_boundary=0.0, # InstanceV 特有参数 saug_drop_prob=0.1, saug_scale=0.0, ): super().__init__() # Gradient checkpointing 检查 if not use_gradient_checkpointing: warnings.warn( "Gradient checkpointing is disabled. This may increase VRAM usage and risk OOM." ) # 加载模型 model_configs = self.parse_model_configs( model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device ) tokenizer_config = ( ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path) ) # 使用 InstanceV Pipeline self.pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=None, # InstanceV 不需要音频 ) # 升级模型:动态添加 InstanceV 模块 (IMCA, STAPE, mv) self._upgrade_dit_with_instancev() self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) # 训练模式 self.switch_pipe_to_training_mode( self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, preset_lora_path, preset_lora_model, task=task, ) # InstanceV 特有:冻结 backbone,只训练新增模块 (IMCA, STAPE, mv) # 论文 Table 1: 只增加 20.65% 参数 self._freeze_backbone_keep_instancev() # 存储配置 self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.fp8_models = fp8_models self.task = task self.max_timestep_boundary = max_timestep_boundary self.min_timestep_boundary = min_timestep_boundary # InstanceV 特有参数 self.saug_drop_prob = saug_drop_prob self.saug_scale = saug_scale # Loss 函数映射 self.task_to_loss = { "sft:data_process": lambda pipe, *args: args, "sft": self._compute_instancev_loss, "sft:train": self._compute_instancev_loss, } def _upgrade_dit_with_instancev(self): """ 动态升级原始 Wan 模型,添加 InstanceV 模块: - STAPE (Shared Timestep-Adaptive Prompt Enhancement) - IMCA (Instance-aware Masked Cross-Attention) for each block - mv (gated residual) for each block - norm_imca (LayerNorm) for each block - 替换 block 的 forward 方法为 InstanceV 版本 """ from diffsynth.models.wan_video_dit_instancev import ( SharedTimestepAdaptivePromptEnhancement, InstanceAwareMaskedCrossAttention, DiTBlock as InstanceVDiTBlock, modulate, ) import torch.nn as nn from typing import Optional if self.pipe.dit is None: print("Warning: No dit model found in pipeline") return dit = self.pipe.dit # 获取模型维度 dim = dit.dim num_heads = dit.blocks[0].self_attn.num_heads if hasattr(dit.blocks[0], 'self_attn') else 12 eps = 1e-6 # 设置 enable_instancev 标志 dit.enable_instancev = True # 添加共享的 STAPE 模块 if not hasattr(dit, 'stape') or dit.stape is None: dit.stape = SharedTimestepAdaptivePromptEnhancement( dim=dim, num_heads=num_heads, eps=eps ).to(device=next(dit.parameters()).device, dtype=next(dit.parameters()).dtype) print(f"InstanceV: Added STAPE module (dim={dim}, num_heads={num_heads})") # 定义 InstanceV 版本的 forward 方法 def instancev_forward( self, x: torch.Tensor, context: torch.Tensor, t_mod: torch.Tensor, freqs: torch.Tensor, instance_tokens: Optional[torch.Tensor] = None, instance_attn_mask: Optional[torch.Tensor] = None, empty_instance_tokens: Optional[torch.Tensor] = None, saug_drop_prob: float = 0.0, ): """InstanceV 版本的 DiTBlock forward""" has_seq = len(t_mod.shape) == 4 chunk_dim = 2 if has_seq else 1 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod ).chunk(6, dim=chunk_dim) if has_seq: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), ) # 1) Self-attention input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) # 2) IMCA + STAPE if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None): # SAUG training-time drop if isinstance(saug_drop_prob, torch.Tensor): saug_p = float(saug_drop_prob.detach().cpu().item()) else: saug_p = float(saug_drop_prob) if self.training and saug_p > 0.0 and empty_instance_tokens is not None: if torch.rand((), device=x.device) < saug_p: instance_tokens_use = empty_instance_tokens else: instance_tokens_use = instance_tokens else: instance_tokens_use = instance_tokens # STAPE if self.stape is not None: alpha1 = gate_msa instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1) # IMCA imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask) x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out # 3) Cross-attention x = x + self.cross_attn(self.norm3(x), context) # 4) FFN input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) x = self.gate(x, gate_mlp, self.ffn(input_x)) return x # 为每个 DiT block 添加 IMCA、mv、norm_imca,并替换 forward 方法 for block_idx, block in enumerate(dit.blocks): block.enable_instancev = True block.stape = dit.stape # 共享 STAPE # 添加 IMCA if not hasattr(block, 'imca') or block.imca is None: block.imca = InstanceAwareMaskedCrossAttention( dim=dim, num_heads=num_heads, eps=eps ).to(device=next(block.parameters()).device, dtype=next(block.parameters()).dtype) # 从 cross_attn 初始化权重(论文建议) if hasattr(block, 'cross_attn'): try: block.imca.attn.q.load_state_dict(block.cross_attn.q.state_dict()) block.imca.attn.k.load_state_dict(block.cross_attn.k.state_dict()) block.imca.attn.v.load_state_dict(block.cross_attn.v.state_dict()) block.imca.attn.o.load_state_dict(block.cross_attn.o.state_dict()) block.imca.attn.norm_q.load_state_dict(block.cross_attn.norm_q.state_dict()) block.imca.attn.norm_k.load_state_dict(block.cross_attn.norm_k.state_dict()) except Exception as e: print(f"Warning: Failed to init IMCA from cross_attn in block {block_idx}: {e}") # 添加 mv (gated residual, 初始化为 0) if not hasattr(block, 'mv') or block.mv is None: block.mv = nn.Parameter(torch.zeros(1, device=next(block.parameters()).device, dtype=next(block.parameters()).dtype)) # 添加 norm_imca if not hasattr(block, 'norm_imca') or block.norm_imca is None: block.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False).to( device=next(block.parameters()).device, dtype=next(block.parameters()).dtype ) # 替换 forward 方法为 InstanceV 版本 import types block.forward = types.MethodType(instancev_forward, block) print(f"InstanceV: Added IMCA, mv, norm_imca to {len(dit.blocks)} blocks and replaced forward methods") def _freeze_backbone_keep_instancev(self): """ 冻结原始 backbone,只保留 InstanceV 新增模块可训练: - STAPE (Shared Timestep-Adaptive Prompt Enhancement) - IMCA (Instance-aware Masked Cross-Attention) - mv (gated residual parameter) - norm_imca (LayerNorm for IMCA) 论文 Table 1: InstanceV 只增加 20.65% 参数 """ if self.pipe.dit is None: return dit = self.pipe.dit # 首先冻结整个 dit dit.requires_grad_(False) # 解冻 STAPE(共享模块) if hasattr(dit, 'stape') and dit.stape is not None: dit.stape.requires_grad_(True) dit.stape.train() print("InstanceV: Enabled training for dit.stape") # 解冻每个 block 中的 IMCA、mv、norm_imca trainable_params = 0 total_params = 0 for block_idx, block in enumerate(dit.blocks): # IMCA 模块 if hasattr(block, 'imca') and block.imca is not None: block.imca.requires_grad_(True) block.imca.train() trainable_params += sum(p.numel() for p in block.imca.parameters()) # mv 参数 (gated residual) if hasattr(block, 'mv') and block.mv is not None: block.mv.requires_grad_(True) trainable_params += block.mv.numel() # norm_imca (LayerNorm) if hasattr(block, 'norm_imca') and block.norm_imca is not None: block.norm_imca.requires_grad_(True) block.norm_imca.train() trainable_params += sum(p.numel() for p in block.norm_imca.parameters()) # 统计参数 total_params = sum(p.numel() for p in dit.parameters()) if hasattr(dit, 'stape') and dit.stape is not None: trainable_params += sum(p.numel() for p in dit.stape.parameters()) print(f"InstanceV: Trainable params: {trainable_params:,} / {total_params:,} " f"({100.0 * trainable_params / total_params:.2f}%)") def _compute_instancev_loss(self, pipe, inputs_shared, inputs_posi, inputs_nega): """ InstanceV 专用 Loss 计算 包含: 1. 标准 Flow Matching SFT Loss 2. SAUG training-time dropout(在 DiT block 中处理) """ from diffsynth.diffusion.loss import FlowMatchSFTLoss return FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi) def parse_extra_inputs(self, data, extra_inputs, inputs_shared): """处理额外输入""" for extra_input in extra_inputs: if extra_input == "input_image": inputs_shared["input_image"] = data["video"][0] elif extra_input == "end_image": inputs_shared["end_image"] = data["video"][-1] elif extra_input in data: inputs_shared[extra_input] = data[extra_input] return inputs_shared def get_pipeline_inputs(self, data): """ 构建 Pipeline 输入 InstanceV 特有字段: - instance_prompts: list[str] - instance_masks: list[list[PIL.Image]] """ inputs_posi = {"prompt": data["prompt"]} inputs_nega = {} # 基础输入 inputs_shared = { "input_video": data["video"], "height": data["video"][0].size[1], "width": data["video"][0].size[0], "num_frames": len(data["video"]), "cfg_scale": 1, "tiled": False, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, "max_timestep_boundary": self.max_timestep_boundary, "min_timestep_boundary": self.min_timestep_boundary, } # InstanceV 特有输入 if "instance_prompts" in data: inputs_shared["instance_prompts"] = data["instance_prompts"] # 注意: metadata 中使用 "instance_mask_dirs",pipeline 需要 "instance_masks" if "instance_mask_dirs" in data: inputs_shared["instance_masks"] = data["instance_mask_dirs"] elif "instance_masks" in data: inputs_shared["instance_masks"] = data["instance_masks"] # SAUG 参数 inputs_shared["saug_drop_prob"] = self.saug_drop_prob inputs_shared["saug_scale"] = self.saug_scale # 处理额外输入 inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega def forward(self, data, inputs=None): """前向传播""" if inputs is None: inputs = self.get_pipeline_inputs(data) inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) for unit in self.pipe.units: inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) loss = self.task_to_loss[self.task](self.pipe, *inputs) return loss def instancev_parser(): """命令行参数解析""" parser = argparse.ArgumentParser(description="InstanceV Training Script") parser = add_general_config(parser) parser = add_video_size_config(parser) # Tokenizer parser.add_argument( "--tokenizer_path", type=str, default=None, help="Path to tokenizer.", ) # InstanceV 特有参数 parser.add_argument( "--saug_drop_prob", type=float, default=0.1, help="SAUG training-time dropout probability (paper recommends 0.1).", ) parser.add_argument( "--saug_scale", type=float, default=0.0, help="SAUG unconditional guidance scale (training time, usually 0).", ) # 时间步边界 parser.add_argument( "--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary.", ) parser.add_argument( "--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary.", ) # 模型初始化 parser.add_argument( "--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.", ) # Wandb 参数 parser.add_argument( "--use_wandb", default=False, action="store_true", help="Enable wandb logging.", ) parser.add_argument( "--wandb_project", type=str, default="instancev-training", help="Wandb project name.", ) parser.add_argument( "--wandb_run_name", type=str, default=None, help="Wandb run name. Default: auto-generated with timestamp.", ) parser.add_argument( "--wandb_log_every", type=int, default=10, help="Log to wandb every N steps.", ) # 断点续跑参数 parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="Path to a trainable checkpoint (.safetensors) to resume from.", ) return parser def _resolve_resume_target(model, remove_prefix_in_ckpt): if not remove_prefix_in_ckpt: return model prefix = remove_prefix_in_ckpt.rstrip(".") if not prefix: return model target = model for part in prefix.split("."): if not hasattr(target, part): return None target = getattr(target, part) return target def _load_resume_checkpoint(model, checkpoint_path, remove_prefix_in_ckpt, accelerator): if checkpoint_path is None: return if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Resume checkpoint not found: {checkpoint_path}") state_dict = load_state_dict(checkpoint_path) target = _resolve_resume_target(model, remove_prefix_in_ckpt) if target is None: target = model if accelerator.is_main_process: print( f"[Warning] remove_prefix_in_ckpt='{remove_prefix_in_ckpt}' does not map to a model attribute. " "Fallback to loading into the full model." ) load_result = target.load_state_dict(state_dict, strict=False) if accelerator.is_main_process: missing_keys, unexpected_keys = load_result print(f"[Resume] Loaded checkpoint: {checkpoint_path}") print(f"[Resume] Loaded keys: {len(state_dict)}") if missing_keys: print(f"[Resume] Missing keys: {len(missing_keys)}") if unexpected_keys: print(f"[Resume] Unexpected keys: {len(unexpected_keys)}") def print_training_info(args, accelerator): """打印训练配置信息""" if not accelerator.is_main_process: return print("\n" + "="*60) print(" InstanceV Training Configuration") print("="*60) print(f"\n[时间] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"\n[数据配置]") print(f" - 数据路径: {args.dataset_base_path}") print(f" - Metadata: {args.dataset_metadata_path}") print(f" - 数据集重复: {args.dataset_repeat}x") print(f"\n[模型配置]") print(f" - 模型来源: {args.model_id_with_origin_paths or args.model_paths}") print(f" - 可训练模块: STAPE, IMCA, mv, norm_imca") print(f" - 输出路径: {args.output_path}") print(f"\n[训练参数]") print(f" - 分辨率: {args.width}x{args.height}") print(f" - 帧数: {args.num_frames}") print(f" - 学习率: {args.learning_rate}") print(f" - Epochs: {args.num_epochs}") print(f" - 梯度累积: {args.gradient_accumulation_steps}") print(f" - 保存间隔: {args.save_steps} steps") if args.resume_from_checkpoint: print(f" - 断点续跑: {args.resume_from_checkpoint}") print(f"\n[InstanceV 参数]") print(f" - SAUG Dropout: {args.saug_drop_prob}") print(f" - SAUG Scale: {args.saug_scale}") print(f" - Timestep 边界: [{args.min_timestep_boundary}, {args.max_timestep_boundary}]") if args.use_wandb: print(f"\n[Wandb]") print(f" - Project: {args.wandb_project}") print(f" - Run Name: {args.wandb_run_name}") print(f" - Log Every: {args.wandb_log_every} steps") print(f"\n[GPU/分布式]") print(f" - Accelerator: {accelerator.device}") print(f" - 进程数: {accelerator.num_processes}") print(f" - 混合精度: {accelerator.mixed_precision}") print("\n" + "="*60 + "\n") def main(): parser = instancev_parser() args = parser.parse_args() # 初始化 Accelerator accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[ accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters) ], ) # 设置 wandb_mode(runner.py 需要) if args.use_wandb and WANDB_AVAILABLE: args.wandb_mode = "online" # 生成默认 run name if args.wandb_run_name is None: args.wandb_run_name = f"instancev_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # 只在主进程初始化 wandb if accelerator.is_main_process: wandb.init( project=args.wandb_project, name=args.wandb_run_name, config={ "learning_rate": args.learning_rate, "num_epochs": args.num_epochs, "gradient_accumulation_steps": args.gradient_accumulation_steps, "height": args.height, "width": args.width, "num_frames": args.num_frames, "saug_drop_prob": args.saug_drop_prob, "saug_scale": args.saug_scale, "trainable_models": args.trainable_models, "model": "Wan2.1-T2V-1.3B + InstanceV", }, ) print(f"[Wandb] 初始化完成: {args.wandb_project}/{args.wandb_run_name}") else: args.wandb_mode = "disabled" # 打印训练配置 print_training_info(args, accelerator) # 构建 Dataset # 注意:data_file_keys 需要包含所有需要处理的键(UnifiedDataset 只处理这些键) # - video: 使用 main_data_operator(视频加载器) # - instance_mask_dirs: 使用 special_operator_map 中的 LoadInstanceMasks # - instance_prompts: 使用 special_operator_map 中的 lambda(透传) dataset = UnifiedDataset( base_path=args.dataset_base_path, metadata_path=args.dataset_metadata_path, repeat=args.dataset_repeat, data_file_keys=["video", "instance_mask_dirs", "instance_prompts"], # 所有需要处理的键 main_data_operator=UnifiedDataset.default_video_operator( base_path=args.dataset_base_path, max_pixels=args.max_pixels, height=args.height, width=args.width, height_division_factor=16, width_division_factor=16, num_frames=args.num_frames, time_division_factor=4, time_division_remainder=1, ), special_operator_map={ # InstanceV 特有:加载 instance masks "instance_mask_dirs": LoadInstanceMasks( num_frames=args.num_frames, time_division_factor=4, time_division_remainder=1, target_height=args.height, target_width=args.width, ), # instance_prompts 直接透传(字符串列表) "instance_prompts": lambda x: x, }, ) # 构建训练模块 model = InstanceVTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, tokenizer_path=args.tokenizer_path, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, lora_checkpoint=args.lora_checkpoint, preset_lora_path=args.preset_lora_path, preset_lora_model=args.preset_lora_model, use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, device="cpu" if args.initialize_model_on_cpu else accelerator.device, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, saug_drop_prob=args.saug_drop_prob, saug_scale=args.saug_scale, ) # Model Logger model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, ) # 断点续跑(只加载权重) if args.resume_from_checkpoint and not args.task.endswith(":data_process"): _load_resume_checkpoint( model, args.resume_from_checkpoint, args.remove_prefix_in_ckpt, accelerator, ) # 启动训练 launcher_map = { "sft:data_process": launch_data_process_task, "sft": launch_training_task, "sft:train": launch_training_task, } if accelerator.is_main_process: print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始训练...\n") launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) # 训练完成 if accelerator.is_main_process: print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 训练完成!") print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Checkpoints 保存至: {args.output_path}") # 关闭 wandb if args.use_wandb and WANDB_AVAILABLE: wandb.finish() print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Wandb 已关闭") if __name__ == "__main__": main()