PencilFolder / examples /wanvideo /model_training /train_instancev.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
Raw
History Blame Contribute Delete
31.7 kB
#!/usr/bin/env python3
"""
InstanceV 训练脚本
基于 DiffSynth-Studio 框架,训练 WanVideo 的 InstanceV 模块:
- IMCA (Instance-aware Masked Cross-Attention)
- STAPE (Shared Timestep Adaptive Prompt Enhancement)
- SAUG (Spatially-Aware Unconditional Guidance) 训练时的 dropout
主要参考:
- wan_video_dit_instancev.py: 模型定义
- wan_video_instanceV.py: Pipeline 定义
- train.py: 基础训练框架
"""
import torch
import os
import sys
import json
import argparse
import accelerate
import warnings
from pathlib import Path
from PIL import Image
from typing import Optional
from datetime import datetime
import numpy as np
# Wandb (可选)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
print("[Warning] wandb not installed. Run 'pip install wandb' to enable wandb logging.")
# 添加项目根目录到 path
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
from diffsynth.core import UnifiedDataset, load_state_dict
from diffsynth.core.data.operators import LoadVideo, ToAbsolutePath
from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig
from diffsynth.diffusion import (
DiffusionTrainingModule,
add_general_config,
add_video_size_config,
launch_training_task,
launch_data_process_task,
ModelLogger,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class LoadInstanceMasks:
"""
自定义数据算子:加载 instance mask 序列
mask 目录结构:
{mask_dir}/{frame_id:06d}_No.{instance_id}.png
"""
def __init__(
self,
num_frames: int,
time_division_factor: int = 4,
time_division_remainder: int = 1,
target_height: Optional[int] = None,
target_width: Optional[int] = None,
):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.target_height = target_height
self.target_width = target_width
def _adjust_num_frames(self, num_frames: int) -> int:
num_frames = min(int(num_frames), self.num_frames)
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def _crop_and_resize_mask(self, mask: Image.Image) -> Image.Image:
if self.target_height is None or self.target_width is None:
return mask
width, height = mask.size
scale = max(self.target_width / width, self.target_height / height)
new_w = max(1, int(round(width * scale)))
new_h = max(1, int(round(height * scale)))
if (new_w, new_h) != mask.size:
mask = mask.resize((new_w, new_h), resample=Image.NEAREST)
left = max(0, (new_w - self.target_width) // 2)
top = max(0, (new_h - self.target_height) // 2)
right = left + self.target_width
bottom = top + self.target_height
return mask.crop((left, top, right, bottom))
def __call__(self, instance_mask_dirs: list) -> list:
"""
Args:
instance_mask_dirs: list of {"mask_dir": str, "instance_id": int, "num_frames": int}
Returns:
list of list of PIL.Image: [Nins][num_frames]
"""
all_masks = []
for mask_info in instance_mask_dirs:
mask_dir = mask_info["mask_dir"]
inst_id = mask_info["instance_id"]
raw_num_frames = mask_info.get("num_frames", self.num_frames)
num_frames = self._adjust_num_frames(raw_num_frames)
masks = []
blank_size = (
self.target_width if self.target_width is not None else 64,
self.target_height if self.target_height is not None else 64,
)
for frame_idx in range(num_frames):
mask_path = os.path.join(mask_dir, f"{frame_idx:06d}_No.{inst_id}.png")
if os.path.exists(mask_path):
mask = Image.open(mask_path).convert("L")
else:
# 如果 mask 不存在,创建全黑 mask
mask = Image.new("L", blank_size, 0)
mask = self._crop_and_resize_mask(mask)
masks.append(mask)
# 采样到目标帧数
target_frames = (num_frames - self.time_division_remainder) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
if len(masks) > target_frames:
# 均匀采样
indices = np.linspace(0, len(masks) - 1, target_frames, dtype=int)
masks = [masks[i] for i in indices]
all_masks.append(masks)
return all_masks
class InstanceVTrainingModule(DiffusionTrainingModule):
"""
InstanceV 训练模块
继承自 DiffusionTrainingModule,添加 InstanceV 特定的处理逻辑
"""
def __init__(
self,
model_paths=None,
model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None,
lora_target_modules="",
lora_rank=32,
lora_checkpoint=None,
preset_lora_path=None,
preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
# InstanceV 特有参数
saug_drop_prob=0.1,
saug_scale=0.0,
):
super().__init__()
# Gradient checkpointing 检查
if not use_gradient_checkpointing:
warnings.warn(
"Gradient checkpointing is disabled. This may increase VRAM usage and risk OOM."
)
# 加载模型
model_configs = self.parse_model_configs(
model_paths, model_id_with_origin_paths,
fp8_models=fp8_models, offload_models=offload_models, device=device
)
tokenizer_config = (
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/")
if tokenizer_path is None
else ModelConfig(tokenizer_path)
)
# 使用 InstanceV Pipeline
self.pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=model_configs,
tokenizer_config=tokenizer_config,
audio_processor_config=None, # InstanceV 不需要音频
)
# 升级模型:动态添加 InstanceV 模块 (IMCA, STAPE, mv)
self._upgrade_dit_with_instancev()
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# 训练模式
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# InstanceV 特有:冻结 backbone,只训练新增模块 (IMCA, STAPE, mv)
# 论文 Table 1: 只增加 20.65% 参数
self._freeze_backbone_keep_instancev()
# 存储配置
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
# InstanceV 特有参数
self.saug_drop_prob = saug_drop_prob
self.saug_scale = saug_scale
# Loss 函数映射
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": self._compute_instancev_loss,
"sft:train": self._compute_instancev_loss,
}
def _upgrade_dit_with_instancev(self):
"""
动态升级原始 Wan 模型,添加 InstanceV 模块:
- STAPE (Shared Timestep-Adaptive Prompt Enhancement)
- IMCA (Instance-aware Masked Cross-Attention) for each block
- mv (gated residual) for each block
- norm_imca (LayerNorm) for each block
- 替换 block 的 forward 方法为 InstanceV 版本
"""
from diffsynth.models.wan_video_dit_instancev import (
SharedTimestepAdaptivePromptEnhancement,
InstanceAwareMaskedCrossAttention,
DiTBlock as InstanceVDiTBlock,
modulate,
)
import torch.nn as nn
from typing import Optional
if self.pipe.dit is None:
print("Warning: No dit model found in pipeline")
return
dit = self.pipe.dit
# 获取模型维度
dim = dit.dim
num_heads = dit.blocks[0].self_attn.num_heads if hasattr(dit.blocks[0], 'self_attn') else 12
eps = 1e-6
# 设置 enable_instancev 标志
dit.enable_instancev = True
# 添加共享的 STAPE 模块
if not hasattr(dit, 'stape') or dit.stape is None:
dit.stape = SharedTimestepAdaptivePromptEnhancement(
dim=dim, num_heads=num_heads, eps=eps
).to(device=next(dit.parameters()).device, dtype=next(dit.parameters()).dtype)
print(f"InstanceV: Added STAPE module (dim={dim}, num_heads={num_heads})")
# 定义 InstanceV 版本的 forward 方法
def instancev_forward(
self,
x: torch.Tensor,
context: torch.Tensor,
t_mod: torch.Tensor,
freqs: torch.Tensor,
instance_tokens: Optional[torch.Tensor] = None,
instance_attn_mask: Optional[torch.Tensor] = None,
empty_instance_tokens: Optional[torch.Tensor] = None,
saug_drop_prob: float = 0.0,
):
"""InstanceV 版本的 DiTBlock forward"""
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
# 1) Self-attention
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
# 2) IMCA + STAPE
if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None):
# SAUG training-time drop
if isinstance(saug_drop_prob, torch.Tensor):
saug_p = float(saug_drop_prob.detach().cpu().item())
else:
saug_p = float(saug_drop_prob)
if self.training and saug_p > 0.0 and empty_instance_tokens is not None:
if torch.rand((), device=x.device) < saug_p:
instance_tokens_use = empty_instance_tokens
else:
instance_tokens_use = instance_tokens
else:
instance_tokens_use = instance_tokens
# STAPE
if self.stape is not None:
alpha1 = gate_msa
instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1)
# IMCA
imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask)
x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out
# 3) Cross-attention
x = x + self.cross_attn(self.norm3(x), context)
# 4) FFN
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
# 为每个 DiT block 添加 IMCA、mv、norm_imca,并替换 forward 方法
for block_idx, block in enumerate(dit.blocks):
block.enable_instancev = True
block.stape = dit.stape # 共享 STAPE
# 添加 IMCA
if not hasattr(block, 'imca') or block.imca is None:
block.imca = InstanceAwareMaskedCrossAttention(
dim=dim, num_heads=num_heads, eps=eps
).to(device=next(block.parameters()).device, dtype=next(block.parameters()).dtype)
# 从 cross_attn 初始化权重(论文建议)
if hasattr(block, 'cross_attn'):
try:
block.imca.attn.q.load_state_dict(block.cross_attn.q.state_dict())
block.imca.attn.k.load_state_dict(block.cross_attn.k.state_dict())
block.imca.attn.v.load_state_dict(block.cross_attn.v.state_dict())
block.imca.attn.o.load_state_dict(block.cross_attn.o.state_dict())
block.imca.attn.norm_q.load_state_dict(block.cross_attn.norm_q.state_dict())
block.imca.attn.norm_k.load_state_dict(block.cross_attn.norm_k.state_dict())
except Exception as e:
print(f"Warning: Failed to init IMCA from cross_attn in block {block_idx}: {e}")
# 添加 mv (gated residual, 初始化为 0)
if not hasattr(block, 'mv') or block.mv is None:
block.mv = nn.Parameter(torch.zeros(1, device=next(block.parameters()).device,
dtype=next(block.parameters()).dtype))
# 添加 norm_imca
if not hasattr(block, 'norm_imca') or block.norm_imca is None:
block.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False).to(
device=next(block.parameters()).device, dtype=next(block.parameters()).dtype
)
# 替换 forward 方法为 InstanceV 版本
import types
block.forward = types.MethodType(instancev_forward, block)
print(f"InstanceV: Added IMCA, mv, norm_imca to {len(dit.blocks)} blocks and replaced forward methods")
def _freeze_backbone_keep_instancev(self):
"""
冻结原始 backbone,只保留 InstanceV 新增模块可训练:
- STAPE (Shared Timestep-Adaptive Prompt Enhancement)
- IMCA (Instance-aware Masked Cross-Attention)
- mv (gated residual parameter)
- norm_imca (LayerNorm for IMCA)
论文 Table 1: InstanceV 只增加 20.65% 参数
"""
if self.pipe.dit is None:
return
dit = self.pipe.dit
# 首先冻结整个 dit
dit.requires_grad_(False)
# 解冻 STAPE(共享模块)
if hasattr(dit, 'stape') and dit.stape is not None:
dit.stape.requires_grad_(True)
dit.stape.train()
print("InstanceV: Enabled training for dit.stape")
# 解冻每个 block 中的 IMCA、mv、norm_imca
trainable_params = 0
total_params = 0
for block_idx, block in enumerate(dit.blocks):
# IMCA 模块
if hasattr(block, 'imca') and block.imca is not None:
block.imca.requires_grad_(True)
block.imca.train()
trainable_params += sum(p.numel() for p in block.imca.parameters())
# mv 参数 (gated residual)
if hasattr(block, 'mv') and block.mv is not None:
block.mv.requires_grad_(True)
trainable_params += block.mv.numel()
# norm_imca (LayerNorm)
if hasattr(block, 'norm_imca') and block.norm_imca is not None:
block.norm_imca.requires_grad_(True)
block.norm_imca.train()
trainable_params += sum(p.numel() for p in block.norm_imca.parameters())
# 统计参数
total_params = sum(p.numel() for p in dit.parameters())
if hasattr(dit, 'stape') and dit.stape is not None:
trainable_params += sum(p.numel() for p in dit.stape.parameters())
print(f"InstanceV: Trainable params: {trainable_params:,} / {total_params:,} "
f"({100.0 * trainable_params / total_params:.2f}%)")
def _compute_instancev_loss(self, pipe, inputs_shared, inputs_posi, inputs_nega):
"""
InstanceV 专用 Loss 计算
包含:
1. 标准 Flow Matching SFT Loss
2. SAUG training-time dropout(在 DiT block 中处理)
"""
from diffsynth.diffusion.loss import FlowMatchSFTLoss
return FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi)
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
"""处理额外输入"""
for extra_input in extra_inputs:
if extra_input == "input_image":
inputs_shared["input_image"] = data["video"][0]
elif extra_input == "end_image":
inputs_shared["end_image"] = data["video"][-1]
elif extra_input in data:
inputs_shared[extra_input] = data[extra_input]
return inputs_shared
def get_pipeline_inputs(self, data):
"""
构建 Pipeline 输入
InstanceV 特有字段:
- instance_prompts: list[str]
- instance_masks: list[list[PIL.Image]]
"""
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
# 基础输入
inputs_shared = {
"input_video": data["video"],
"height": data["video"][0].size[1],
"width": data["video"][0].size[0],
"num_frames": len(data["video"]),
"cfg_scale": 1,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"cfg_merge": False,
"vace_scale": 1,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
# InstanceV 特有输入
if "instance_prompts" in data:
inputs_shared["instance_prompts"] = data["instance_prompts"]
# 注意: metadata 中使用 "instance_mask_dirs",pipeline 需要 "instance_masks"
if "instance_mask_dirs" in data:
inputs_shared["instance_masks"] = data["instance_mask_dirs"]
elif "instance_masks" in data:
inputs_shared["instance_masks"] = data["instance_masks"]
# SAUG 参数
inputs_shared["saug_drop_prob"] = self.saug_drop_prob
inputs_shared["saug_scale"] = self.saug_scale
# 处理额外输入
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
"""前向传播"""
if inputs is None:
inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def instancev_parser():
"""命令行参数解析"""
parser = argparse.ArgumentParser(description="InstanceV Training Script")
parser = add_general_config(parser)
parser = add_video_size_config(parser)
# Tokenizer
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to tokenizer.",
)
# InstanceV 特有参数
parser.add_argument(
"--saug_drop_prob",
type=float,
default=0.1,
help="SAUG training-time dropout probability (paper recommends 0.1).",
)
parser.add_argument(
"--saug_scale",
type=float,
default=0.0,
help="SAUG unconditional guidance scale (training time, usually 0).",
)
# 时间步边界
parser.add_argument(
"--max_timestep_boundary",
type=float,
default=1.0,
help="Max timestep boundary.",
)
parser.add_argument(
"--min_timestep_boundary",
type=float,
default=0.0,
help="Min timestep boundary.",
)
# 模型初始化
parser.add_argument(
"--initialize_model_on_cpu",
default=False,
action="store_true",
help="Whether to initialize models on CPU.",
)
# Wandb 参数
parser.add_argument(
"--use_wandb",
default=False,
action="store_true",
help="Enable wandb logging.",
)
parser.add_argument(
"--wandb_project",
type=str,
default="instancev-training",
help="Wandb project name.",
)
parser.add_argument(
"--wandb_run_name",
type=str,
default=None,
help="Wandb run name. Default: auto-generated with timestamp.",
)
parser.add_argument(
"--wandb_log_every",
type=int,
default=10,
help="Log to wandb every N steps.",
)
# 断点续跑参数
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Path to a trainable checkpoint (.safetensors) to resume from.",
)
return parser
def _resolve_resume_target(model, remove_prefix_in_ckpt):
if not remove_prefix_in_ckpt:
return model
prefix = remove_prefix_in_ckpt.rstrip(".")
if not prefix:
return model
target = model
for part in prefix.split("."):
if not hasattr(target, part):
return None
target = getattr(target, part)
return target
def _load_resume_checkpoint(model, checkpoint_path, remove_prefix_in_ckpt, accelerator):
if checkpoint_path is None:
return
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Resume checkpoint not found: {checkpoint_path}")
state_dict = load_state_dict(checkpoint_path)
target = _resolve_resume_target(model, remove_prefix_in_ckpt)
if target is None:
target = model
if accelerator.is_main_process:
print(
f"[Warning] remove_prefix_in_ckpt='{remove_prefix_in_ckpt}' does not map to a model attribute. "
"Fallback to loading into the full model."
)
load_result = target.load_state_dict(state_dict, strict=False)
if accelerator.is_main_process:
missing_keys, unexpected_keys = load_result
print(f"[Resume] Loaded checkpoint: {checkpoint_path}")
print(f"[Resume] Loaded keys: {len(state_dict)}")
if missing_keys:
print(f"[Resume] Missing keys: {len(missing_keys)}")
if unexpected_keys:
print(f"[Resume] Unexpected keys: {len(unexpected_keys)}")
def print_training_info(args, accelerator):
"""打印训练配置信息"""
if not accelerator.is_main_process:
return
print("\n" + "="*60)
print(" InstanceV Training Configuration")
print("="*60)
print(f"\n[时间] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"\n[数据配置]")
print(f" - 数据路径: {args.dataset_base_path}")
print(f" - Metadata: {args.dataset_metadata_path}")
print(f" - 数据集重复: {args.dataset_repeat}x")
print(f"\n[模型配置]")
print(f" - 模型来源: {args.model_id_with_origin_paths or args.model_paths}")
print(f" - 可训练模块: STAPE, IMCA, mv, norm_imca")
print(f" - 输出路径: {args.output_path}")
print(f"\n[训练参数]")
print(f" - 分辨率: {args.width}x{args.height}")
print(f" - 帧数: {args.num_frames}")
print(f" - 学习率: {args.learning_rate}")
print(f" - Epochs: {args.num_epochs}")
print(f" - 梯度累积: {args.gradient_accumulation_steps}")
print(f" - 保存间隔: {args.save_steps} steps")
if args.resume_from_checkpoint:
print(f" - 断点续跑: {args.resume_from_checkpoint}")
print(f"\n[InstanceV 参数]")
print(f" - SAUG Dropout: {args.saug_drop_prob}")
print(f" - SAUG Scale: {args.saug_scale}")
print(f" - Timestep 边界: [{args.min_timestep_boundary}, {args.max_timestep_boundary}]")
if args.use_wandb:
print(f"\n[Wandb]")
print(f" - Project: {args.wandb_project}")
print(f" - Run Name: {args.wandb_run_name}")
print(f" - Log Every: {args.wandb_log_every} steps")
print(f"\n[GPU/分布式]")
print(f" - Accelerator: {accelerator.device}")
print(f" - 进程数: {accelerator.num_processes}")
print(f" - 混合精度: {accelerator.mixed_precision}")
print("\n" + "="*60 + "\n")
def main():
parser = instancev_parser()
args = parser.parse_args()
# 初始化 Accelerator
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[
accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)
],
)
# 设置 wandb_mode(runner.py 需要)
if args.use_wandb and WANDB_AVAILABLE:
args.wandb_mode = "online"
# 生成默认 run name
if args.wandb_run_name is None:
args.wandb_run_name = f"instancev_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 只在主进程初始化 wandb
if accelerator.is_main_process:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config={
"learning_rate": args.learning_rate,
"num_epochs": args.num_epochs,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"height": args.height,
"width": args.width,
"num_frames": args.num_frames,
"saug_drop_prob": args.saug_drop_prob,
"saug_scale": args.saug_scale,
"trainable_models": args.trainable_models,
"model": "Wan2.1-T2V-1.3B + InstanceV",
},
)
print(f"[Wandb] 初始化完成: {args.wandb_project}/{args.wandb_run_name}")
else:
args.wandb_mode = "disabled"
# 打印训练配置
print_training_info(args, accelerator)
# 构建 Dataset
# 注意:data_file_keys 需要包含所有需要处理的键(UnifiedDataset 只处理这些键)
# - video: 使用 main_data_operator(视频加载器)
# - instance_mask_dirs: 使用 special_operator_map 中的 LoadInstanceMasks
# - instance_prompts: 使用 special_operator_map 中的 lambda(透传)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=["video", "instance_mask_dirs", "instance_prompts"], # 所有需要处理的键
main_data_operator=UnifiedDataset.default_video_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
),
special_operator_map={
# InstanceV 特有:加载 instance masks
"instance_mask_dirs": LoadInstanceMasks(
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
target_height=args.height,
target_width=args.width,
),
# instance_prompts 直接透传(字符串列表)
"instance_prompts": lambda x: x,
},
)
# 构建训练模块
model = InstanceVTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
saug_drop_prob=args.saug_drop_prob,
saug_scale=args.saug_scale,
)
# Model Logger
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
# 断点续跑(只加载权重)
if args.resume_from_checkpoint and not args.task.endswith(":data_process"):
_load_resume_checkpoint(
model,
args.resume_from_checkpoint,
args.remove_prefix_in_ckpt,
accelerator,
)
# 启动训练
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
}
if accelerator.is_main_process:
print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始训练...\n")
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
# 训练完成
if accelerator.is_main_process:
print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 训练完成!")
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Checkpoints 保存至: {args.output_path}")
# 关闭 wandb
if args.use_wandb and WANDB_AVAILABLE:
wandb.finish()
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Wandb 已关闭")
if __name__ == "__main__":
main()