| |
| """ |
| InstanceV 训练脚本 |
| |
| 基于 DiffSynth-Studio 框架,训练 WanVideo 的 InstanceV 模块: |
| - IMCA (Instance-aware Masked Cross-Attention) |
| - STAPE (Shared Timestep Adaptive Prompt Enhancement) |
| - SAUG (Spatially-Aware Unconditional Guidance) 训练时的 dropout |
| |
| 主要参考: |
| - wan_video_dit_instancev.py: 模型定义 |
| - wan_video_instanceV.py: Pipeline 定义 |
| - train.py: 基础训练框架 |
| """ |
|
|
| import torch |
| import os |
| import sys |
| import json |
| import argparse |
| import accelerate |
| import warnings |
| from pathlib import Path |
| from PIL import Image |
| from typing import Optional |
| from datetime import datetime |
| import numpy as np |
|
|
| |
| try: |
| import wandb |
| WANDB_AVAILABLE = True |
| except ImportError: |
| WANDB_AVAILABLE = False |
| print("[Warning] wandb not installed. Run 'pip install wandb' to enable wandb logging.") |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[3])) |
|
|
| from diffsynth.core import UnifiedDataset, load_state_dict |
| from diffsynth.core.data.operators import LoadVideo, ToAbsolutePath |
| from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig |
| from diffsynth.diffusion import ( |
| DiffusionTrainingModule, |
| add_general_config, |
| add_video_size_config, |
| launch_training_task, |
| launch_data_process_task, |
| ModelLogger, |
| ) |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
| class LoadInstanceMasks: |
| """ |
| 自定义数据算子:加载 instance mask 序列 |
| |
| mask 目录结构: |
| {mask_dir}/{frame_id:06d}_No.{instance_id}.png |
| """ |
| |
| def __init__( |
| self, |
| num_frames: int, |
| time_division_factor: int = 4, |
| time_division_remainder: int = 1, |
| target_height: Optional[int] = None, |
| target_width: Optional[int] = None, |
| ): |
| self.num_frames = num_frames |
| self.time_division_factor = time_division_factor |
| self.time_division_remainder = time_division_remainder |
| self.target_height = target_height |
| self.target_width = target_width |
|
|
| def _adjust_num_frames(self, num_frames: int) -> int: |
| num_frames = min(int(num_frames), self.num_frames) |
| while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: |
| num_frames -= 1 |
| return num_frames |
|
|
| def _crop_and_resize_mask(self, mask: Image.Image) -> Image.Image: |
| if self.target_height is None or self.target_width is None: |
| return mask |
| width, height = mask.size |
| scale = max(self.target_width / width, self.target_height / height) |
| new_w = max(1, int(round(width * scale))) |
| new_h = max(1, int(round(height * scale))) |
| if (new_w, new_h) != mask.size: |
| mask = mask.resize((new_w, new_h), resample=Image.NEAREST) |
| left = max(0, (new_w - self.target_width) // 2) |
| top = max(0, (new_h - self.target_height) // 2) |
| right = left + self.target_width |
| bottom = top + self.target_height |
| return mask.crop((left, top, right, bottom)) |
| |
| def __call__(self, instance_mask_dirs: list) -> list: |
| """ |
| Args: |
| instance_mask_dirs: list of {"mask_dir": str, "instance_id": int, "num_frames": int} |
| |
| Returns: |
| list of list of PIL.Image: [Nins][num_frames] |
| """ |
| all_masks = [] |
| for mask_info in instance_mask_dirs: |
| mask_dir = mask_info["mask_dir"] |
| inst_id = mask_info["instance_id"] |
| raw_num_frames = mask_info.get("num_frames", self.num_frames) |
| num_frames = self._adjust_num_frames(raw_num_frames) |
| |
| masks = [] |
| blank_size = ( |
| self.target_width if self.target_width is not None else 64, |
| self.target_height if self.target_height is not None else 64, |
| ) |
| for frame_idx in range(num_frames): |
| mask_path = os.path.join(mask_dir, f"{frame_idx:06d}_No.{inst_id}.png") |
| if os.path.exists(mask_path): |
| mask = Image.open(mask_path).convert("L") |
| else: |
| |
| mask = Image.new("L", blank_size, 0) |
| mask = self._crop_and_resize_mask(mask) |
| masks.append(mask) |
| |
| |
| target_frames = (num_frames - self.time_division_remainder) // self.time_division_factor * self.time_division_factor + self.time_division_remainder |
| if len(masks) > target_frames: |
| |
| indices = np.linspace(0, len(masks) - 1, target_frames, dtype=int) |
| masks = [masks[i] for i in indices] |
| |
| all_masks.append(masks) |
| |
| return all_masks |
|
|
|
|
| class InstanceVTrainingModule(DiffusionTrainingModule): |
| """ |
| InstanceV 训练模块 |
| |
| 继承自 DiffusionTrainingModule,添加 InstanceV 特定的处理逻辑 |
| """ |
| |
| def __init__( |
| self, |
| model_paths=None, |
| model_id_with_origin_paths=None, |
| tokenizer_path=None, |
| trainable_models=None, |
| lora_base_model=None, |
| lora_target_modules="", |
| lora_rank=32, |
| lora_checkpoint=None, |
| preset_lora_path=None, |
| preset_lora_model=None, |
| use_gradient_checkpointing=True, |
| use_gradient_checkpointing_offload=False, |
| extra_inputs=None, |
| fp8_models=None, |
| offload_models=None, |
| device="cpu", |
| task="sft", |
| max_timestep_boundary=1.0, |
| min_timestep_boundary=0.0, |
| |
| saug_drop_prob=0.1, |
| saug_scale=0.0, |
| ): |
| super().__init__() |
| |
| |
| if not use_gradient_checkpointing: |
| warnings.warn( |
| "Gradient checkpointing is disabled. This may increase VRAM usage and risk OOM." |
| ) |
| |
| |
| model_configs = self.parse_model_configs( |
| model_paths, model_id_with_origin_paths, |
| fp8_models=fp8_models, offload_models=offload_models, device=device |
| ) |
| tokenizer_config = ( |
| ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") |
| if tokenizer_path is None |
| else ModelConfig(tokenizer_path) |
| ) |
| |
| |
| self.pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch.bfloat16, |
| device=device, |
| model_configs=model_configs, |
| tokenizer_config=tokenizer_config, |
| audio_processor_config=None, |
| ) |
| |
| |
| self._upgrade_dit_with_instancev() |
| |
| self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) |
| |
| |
| self.switch_pipe_to_training_mode( |
| self.pipe, trainable_models, |
| lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, |
| preset_lora_path, preset_lora_model, |
| task=task, |
| ) |
| |
| |
| |
| self._freeze_backbone_keep_instancev() |
| |
| |
| self.use_gradient_checkpointing = use_gradient_checkpointing |
| self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload |
| self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] |
| self.fp8_models = fp8_models |
| self.task = task |
| self.max_timestep_boundary = max_timestep_boundary |
| self.min_timestep_boundary = min_timestep_boundary |
| |
| |
| self.saug_drop_prob = saug_drop_prob |
| self.saug_scale = saug_scale |
| |
| |
| self.task_to_loss = { |
| "sft:data_process": lambda pipe, *args: args, |
| "sft": self._compute_instancev_loss, |
| "sft:train": self._compute_instancev_loss, |
| } |
| |
| def _upgrade_dit_with_instancev(self): |
| """ |
| 动态升级原始 Wan 模型,添加 InstanceV 模块: |
| - STAPE (Shared Timestep-Adaptive Prompt Enhancement) |
| - IMCA (Instance-aware Masked Cross-Attention) for each block |
| - mv (gated residual) for each block |
| - norm_imca (LayerNorm) for each block |
| - 替换 block 的 forward 方法为 InstanceV 版本 |
| """ |
| from diffsynth.models.wan_video_dit_instancev import ( |
| SharedTimestepAdaptivePromptEnhancement, |
| InstanceAwareMaskedCrossAttention, |
| DiTBlock as InstanceVDiTBlock, |
| modulate, |
| ) |
| import torch.nn as nn |
| from typing import Optional |
| |
| if self.pipe.dit is None: |
| print("Warning: No dit model found in pipeline") |
| return |
| |
| dit = self.pipe.dit |
| |
| |
| dim = dit.dim |
| num_heads = dit.blocks[0].self_attn.num_heads if hasattr(dit.blocks[0], 'self_attn') else 12 |
| eps = 1e-6 |
| |
| |
| dit.enable_instancev = True |
| |
| |
| if not hasattr(dit, 'stape') or dit.stape is None: |
| dit.stape = SharedTimestepAdaptivePromptEnhancement( |
| dim=dim, num_heads=num_heads, eps=eps |
| ).to(device=next(dit.parameters()).device, dtype=next(dit.parameters()).dtype) |
| print(f"InstanceV: Added STAPE module (dim={dim}, num_heads={num_heads})") |
| |
| |
| def instancev_forward( |
| self, |
| x: torch.Tensor, |
| context: torch.Tensor, |
| t_mod: torch.Tensor, |
| freqs: torch.Tensor, |
| instance_tokens: Optional[torch.Tensor] = None, |
| instance_attn_mask: Optional[torch.Tensor] = None, |
| empty_instance_tokens: Optional[torch.Tensor] = None, |
| saug_drop_prob: float = 0.0, |
| ): |
| """InstanceV 版本的 DiTBlock forward""" |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
|
|
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod |
| ).chunk(6, dim=chunk_dim) |
|
|
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), |
| ) |
|
|
| |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) |
|
|
| |
| if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None): |
| |
| if isinstance(saug_drop_prob, torch.Tensor): |
| saug_p = float(saug_drop_prob.detach().cpu().item()) |
| else: |
| saug_p = float(saug_drop_prob) |
|
|
| if self.training and saug_p > 0.0 and empty_instance_tokens is not None: |
| if torch.rand((), device=x.device) < saug_p: |
| instance_tokens_use = empty_instance_tokens |
| else: |
| instance_tokens_use = instance_tokens |
| else: |
| instance_tokens_use = instance_tokens |
|
|
| |
| if self.stape is not None: |
| alpha1 = gate_msa |
| instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1) |
|
|
| |
| imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask) |
| x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out |
|
|
| |
| x = x + self.cross_attn(self.norm3(x), context) |
|
|
| |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
| return x |
| |
| |
| for block_idx, block in enumerate(dit.blocks): |
| block.enable_instancev = True |
| block.stape = dit.stape |
| |
| |
| if not hasattr(block, 'imca') or block.imca is None: |
| block.imca = InstanceAwareMaskedCrossAttention( |
| dim=dim, num_heads=num_heads, eps=eps |
| ).to(device=next(block.parameters()).device, dtype=next(block.parameters()).dtype) |
| |
| |
| if hasattr(block, 'cross_attn'): |
| try: |
| block.imca.attn.q.load_state_dict(block.cross_attn.q.state_dict()) |
| block.imca.attn.k.load_state_dict(block.cross_attn.k.state_dict()) |
| block.imca.attn.v.load_state_dict(block.cross_attn.v.state_dict()) |
| block.imca.attn.o.load_state_dict(block.cross_attn.o.state_dict()) |
| block.imca.attn.norm_q.load_state_dict(block.cross_attn.norm_q.state_dict()) |
| block.imca.attn.norm_k.load_state_dict(block.cross_attn.norm_k.state_dict()) |
| except Exception as e: |
| print(f"Warning: Failed to init IMCA from cross_attn in block {block_idx}: {e}") |
| |
| |
| if not hasattr(block, 'mv') or block.mv is None: |
| block.mv = nn.Parameter(torch.zeros(1, device=next(block.parameters()).device, |
| dtype=next(block.parameters()).dtype)) |
| |
| |
| if not hasattr(block, 'norm_imca') or block.norm_imca is None: |
| block.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False).to( |
| device=next(block.parameters()).device, dtype=next(block.parameters()).dtype |
| ) |
| |
| |
| import types |
| block.forward = types.MethodType(instancev_forward, block) |
| |
| print(f"InstanceV: Added IMCA, mv, norm_imca to {len(dit.blocks)} blocks and replaced forward methods") |
| |
| def _freeze_backbone_keep_instancev(self): |
| """ |
| 冻结原始 backbone,只保留 InstanceV 新增模块可训练: |
| - STAPE (Shared Timestep-Adaptive Prompt Enhancement) |
| - IMCA (Instance-aware Masked Cross-Attention) |
| - mv (gated residual parameter) |
| - norm_imca (LayerNorm for IMCA) |
| |
| 论文 Table 1: InstanceV 只增加 20.65% 参数 |
| """ |
| if self.pipe.dit is None: |
| return |
| |
| dit = self.pipe.dit |
| |
| |
| dit.requires_grad_(False) |
| |
| |
| if hasattr(dit, 'stape') and dit.stape is not None: |
| dit.stape.requires_grad_(True) |
| dit.stape.train() |
| print("InstanceV: Enabled training for dit.stape") |
| |
| |
| trainable_params = 0 |
| total_params = 0 |
| |
| for block_idx, block in enumerate(dit.blocks): |
| |
| if hasattr(block, 'imca') and block.imca is not None: |
| block.imca.requires_grad_(True) |
| block.imca.train() |
| trainable_params += sum(p.numel() for p in block.imca.parameters()) |
| |
| |
| if hasattr(block, 'mv') and block.mv is not None: |
| block.mv.requires_grad_(True) |
| trainable_params += block.mv.numel() |
| |
| |
| if hasattr(block, 'norm_imca') and block.norm_imca is not None: |
| block.norm_imca.requires_grad_(True) |
| block.norm_imca.train() |
| trainable_params += sum(p.numel() for p in block.norm_imca.parameters()) |
| |
| |
| total_params = sum(p.numel() for p in dit.parameters()) |
| if hasattr(dit, 'stape') and dit.stape is not None: |
| trainable_params += sum(p.numel() for p in dit.stape.parameters()) |
| |
| print(f"InstanceV: Trainable params: {trainable_params:,} / {total_params:,} " |
| f"({100.0 * trainable_params / total_params:.2f}%)") |
| |
| def _compute_instancev_loss(self, pipe, inputs_shared, inputs_posi, inputs_nega): |
| """ |
| InstanceV 专用 Loss 计算 |
| |
| 包含: |
| 1. 标准 Flow Matching SFT Loss |
| 2. SAUG training-time dropout(在 DiT block 中处理) |
| """ |
| from diffsynth.diffusion.loss import FlowMatchSFTLoss |
| return FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi) |
| |
| def parse_extra_inputs(self, data, extra_inputs, inputs_shared): |
| """处理额外输入""" |
| for extra_input in extra_inputs: |
| if extra_input == "input_image": |
| inputs_shared["input_image"] = data["video"][0] |
| elif extra_input == "end_image": |
| inputs_shared["end_image"] = data["video"][-1] |
| elif extra_input in data: |
| inputs_shared[extra_input] = data[extra_input] |
| return inputs_shared |
| |
| def get_pipeline_inputs(self, data): |
| """ |
| 构建 Pipeline 输入 |
| |
| InstanceV 特有字段: |
| - instance_prompts: list[str] |
| - instance_masks: list[list[PIL.Image]] |
| """ |
| inputs_posi = {"prompt": data["prompt"]} |
| inputs_nega = {} |
| |
| |
| inputs_shared = { |
| "input_video": data["video"], |
| "height": data["video"][0].size[1], |
| "width": data["video"][0].size[0], |
| "num_frames": len(data["video"]), |
| "cfg_scale": 1, |
| "tiled": False, |
| "rand_device": self.pipe.device, |
| "use_gradient_checkpointing": self.use_gradient_checkpointing, |
| "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, |
| "cfg_merge": False, |
| "vace_scale": 1, |
| "max_timestep_boundary": self.max_timestep_boundary, |
| "min_timestep_boundary": self.min_timestep_boundary, |
| } |
| |
| |
| if "instance_prompts" in data: |
| inputs_shared["instance_prompts"] = data["instance_prompts"] |
| |
| |
| if "instance_mask_dirs" in data: |
| inputs_shared["instance_masks"] = data["instance_mask_dirs"] |
| elif "instance_masks" in data: |
| inputs_shared["instance_masks"] = data["instance_masks"] |
| |
| |
| inputs_shared["saug_drop_prob"] = self.saug_drop_prob |
| inputs_shared["saug_scale"] = self.saug_scale |
| |
| |
| inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) |
| |
| return inputs_shared, inputs_posi, inputs_nega |
| |
| def forward(self, data, inputs=None): |
| """前向传播""" |
| if inputs is None: |
| inputs = self.get_pipeline_inputs(data) |
| |
| inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) |
| |
| for unit in self.pipe.units: |
| inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) |
| |
| loss = self.task_to_loss[self.task](self.pipe, *inputs) |
| return loss |
|
|
|
|
| def instancev_parser(): |
| """命令行参数解析""" |
| parser = argparse.ArgumentParser(description="InstanceV Training Script") |
| parser = add_general_config(parser) |
| parser = add_video_size_config(parser) |
| |
| |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default=None, |
| help="Path to tokenizer.", |
| ) |
| |
| |
| parser.add_argument( |
| "--saug_drop_prob", |
| type=float, |
| default=0.1, |
| help="SAUG training-time dropout probability (paper recommends 0.1).", |
| ) |
| parser.add_argument( |
| "--saug_scale", |
| type=float, |
| default=0.0, |
| help="SAUG unconditional guidance scale (training time, usually 0).", |
| ) |
| |
| |
| parser.add_argument( |
| "--max_timestep_boundary", |
| type=float, |
| default=1.0, |
| help="Max timestep boundary.", |
| ) |
| parser.add_argument( |
| "--min_timestep_boundary", |
| type=float, |
| default=0.0, |
| help="Min timestep boundary.", |
| ) |
| |
| |
| parser.add_argument( |
| "--initialize_model_on_cpu", |
| default=False, |
| action="store_true", |
| help="Whether to initialize models on CPU.", |
| ) |
| |
| |
| parser.add_argument( |
| "--use_wandb", |
| default=False, |
| action="store_true", |
| help="Enable wandb logging.", |
| ) |
| parser.add_argument( |
| "--wandb_project", |
| type=str, |
| default="instancev-training", |
| help="Wandb project name.", |
| ) |
| parser.add_argument( |
| "--wandb_run_name", |
| type=str, |
| default=None, |
| help="Wandb run name. Default: auto-generated with timestamp.", |
| ) |
| parser.add_argument( |
| "--wandb_log_every", |
| type=int, |
| default=10, |
| help="Log to wandb every N steps.", |
| ) |
| |
| |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default=None, |
| help="Path to a trainable checkpoint (.safetensors) to resume from.", |
| ) |
| |
| return parser |
|
|
|
|
| def _resolve_resume_target(model, remove_prefix_in_ckpt): |
| if not remove_prefix_in_ckpt: |
| return model |
| prefix = remove_prefix_in_ckpt.rstrip(".") |
| if not prefix: |
| return model |
| target = model |
| for part in prefix.split("."): |
| if not hasattr(target, part): |
| return None |
| target = getattr(target, part) |
| return target |
|
|
|
|
| def _load_resume_checkpoint(model, checkpoint_path, remove_prefix_in_ckpt, accelerator): |
| if checkpoint_path is None: |
| return |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Resume checkpoint not found: {checkpoint_path}") |
| state_dict = load_state_dict(checkpoint_path) |
| target = _resolve_resume_target(model, remove_prefix_in_ckpt) |
| if target is None: |
| target = model |
| if accelerator.is_main_process: |
| print( |
| f"[Warning] remove_prefix_in_ckpt='{remove_prefix_in_ckpt}' does not map to a model attribute. " |
| "Fallback to loading into the full model." |
| ) |
| load_result = target.load_state_dict(state_dict, strict=False) |
| if accelerator.is_main_process: |
| missing_keys, unexpected_keys = load_result |
| print(f"[Resume] Loaded checkpoint: {checkpoint_path}") |
| print(f"[Resume] Loaded keys: {len(state_dict)}") |
| if missing_keys: |
| print(f"[Resume] Missing keys: {len(missing_keys)}") |
| if unexpected_keys: |
| print(f"[Resume] Unexpected keys: {len(unexpected_keys)}") |
|
|
|
|
| def print_training_info(args, accelerator): |
| """打印训练配置信息""" |
| if not accelerator.is_main_process: |
| return |
| |
| print("\n" + "="*60) |
| print(" InstanceV Training Configuration") |
| print("="*60) |
| print(f"\n[时间] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| print(f"\n[数据配置]") |
| print(f" - 数据路径: {args.dataset_base_path}") |
| print(f" - Metadata: {args.dataset_metadata_path}") |
| print(f" - 数据集重复: {args.dataset_repeat}x") |
| |
| print(f"\n[模型配置]") |
| print(f" - 模型来源: {args.model_id_with_origin_paths or args.model_paths}") |
| print(f" - 可训练模块: STAPE, IMCA, mv, norm_imca") |
| print(f" - 输出路径: {args.output_path}") |
| |
| print(f"\n[训练参数]") |
| print(f" - 分辨率: {args.width}x{args.height}") |
| print(f" - 帧数: {args.num_frames}") |
| print(f" - 学习率: {args.learning_rate}") |
| print(f" - Epochs: {args.num_epochs}") |
| print(f" - 梯度累积: {args.gradient_accumulation_steps}") |
| print(f" - 保存间隔: {args.save_steps} steps") |
| if args.resume_from_checkpoint: |
| print(f" - 断点续跑: {args.resume_from_checkpoint}") |
| |
| print(f"\n[InstanceV 参数]") |
| print(f" - SAUG Dropout: {args.saug_drop_prob}") |
| print(f" - SAUG Scale: {args.saug_scale}") |
| print(f" - Timestep 边界: [{args.min_timestep_boundary}, {args.max_timestep_boundary}]") |
| |
| if args.use_wandb: |
| print(f"\n[Wandb]") |
| print(f" - Project: {args.wandb_project}") |
| print(f" - Run Name: {args.wandb_run_name}") |
| print(f" - Log Every: {args.wandb_log_every} steps") |
| |
| print(f"\n[GPU/分布式]") |
| print(f" - Accelerator: {accelerator.device}") |
| print(f" - 进程数: {accelerator.num_processes}") |
| print(f" - 混合精度: {accelerator.mixed_precision}") |
| |
| print("\n" + "="*60 + "\n") |
|
|
|
|
| def main(): |
| parser = instancev_parser() |
| args = parser.parse_args() |
| |
| |
| accelerator = accelerate.Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| kwargs_handlers=[ |
| accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters) |
| ], |
| ) |
| |
| |
| if args.use_wandb and WANDB_AVAILABLE: |
| args.wandb_mode = "online" |
| |
| if args.wandb_run_name is None: |
| args.wandb_run_name = f"instancev_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| |
| |
| if accelerator.is_main_process: |
| wandb.init( |
| project=args.wandb_project, |
| name=args.wandb_run_name, |
| config={ |
| "learning_rate": args.learning_rate, |
| "num_epochs": args.num_epochs, |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, |
| "height": args.height, |
| "width": args.width, |
| "num_frames": args.num_frames, |
| "saug_drop_prob": args.saug_drop_prob, |
| "saug_scale": args.saug_scale, |
| "trainable_models": args.trainable_models, |
| "model": "Wan2.1-T2V-1.3B + InstanceV", |
| }, |
| ) |
| print(f"[Wandb] 初始化完成: {args.wandb_project}/{args.wandb_run_name}") |
| else: |
| args.wandb_mode = "disabled" |
| |
| |
| print_training_info(args, accelerator) |
| |
| |
| |
| |
| |
| |
| dataset = UnifiedDataset( |
| base_path=args.dataset_base_path, |
| metadata_path=args.dataset_metadata_path, |
| repeat=args.dataset_repeat, |
| data_file_keys=["video", "instance_mask_dirs", "instance_prompts"], |
| main_data_operator=UnifiedDataset.default_video_operator( |
| base_path=args.dataset_base_path, |
| max_pixels=args.max_pixels, |
| height=args.height, |
| width=args.width, |
| height_division_factor=16, |
| width_division_factor=16, |
| num_frames=args.num_frames, |
| time_division_factor=4, |
| time_division_remainder=1, |
| ), |
| special_operator_map={ |
| |
| "instance_mask_dirs": LoadInstanceMasks( |
| num_frames=args.num_frames, |
| time_division_factor=4, |
| time_division_remainder=1, |
| target_height=args.height, |
| target_width=args.width, |
| ), |
| |
| "instance_prompts": lambda x: x, |
| }, |
| ) |
| |
| |
| model = InstanceVTrainingModule( |
| model_paths=args.model_paths, |
| model_id_with_origin_paths=args.model_id_with_origin_paths, |
| tokenizer_path=args.tokenizer_path, |
| trainable_models=args.trainable_models, |
| lora_base_model=args.lora_base_model, |
| lora_target_modules=args.lora_target_modules, |
| lora_rank=args.lora_rank, |
| lora_checkpoint=args.lora_checkpoint, |
| preset_lora_path=args.preset_lora_path, |
| preset_lora_model=args.preset_lora_model, |
| use_gradient_checkpointing=args.use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, |
| extra_inputs=args.extra_inputs, |
| fp8_models=args.fp8_models, |
| offload_models=args.offload_models, |
| task=args.task, |
| device="cpu" if args.initialize_model_on_cpu else accelerator.device, |
| max_timestep_boundary=args.max_timestep_boundary, |
| min_timestep_boundary=args.min_timestep_boundary, |
| saug_drop_prob=args.saug_drop_prob, |
| saug_scale=args.saug_scale, |
| ) |
| |
| |
| model_logger = ModelLogger( |
| args.output_path, |
| remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, |
| ) |
|
|
| |
| if args.resume_from_checkpoint and not args.task.endswith(":data_process"): |
| _load_resume_checkpoint( |
| model, |
| args.resume_from_checkpoint, |
| args.remove_prefix_in_ckpt, |
| accelerator, |
| ) |
| |
| |
| launcher_map = { |
| "sft:data_process": launch_data_process_task, |
| "sft": launch_training_task, |
| "sft:train": launch_training_task, |
| } |
| |
| if accelerator.is_main_process: |
| print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始训练...\n") |
| |
| launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) |
| |
| |
| if accelerator.is_main_process: |
| print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 训练完成!") |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Checkpoints 保存至: {args.output_path}") |
| |
| |
| if args.use_wandb and WANDB_AVAILABLE: |
| wandb.finish() |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Wandb 已关闭") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|