PencilFolder / docs /zh /Model_Details /Wan_Instance_StateMachine_T5_BBox.md
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
|
Raw
History Blame Contribute Delete
4.84 kB

Wan 实例 StateMachine(T5 语义 + BBox Mask)方案草案

基于 diffsynth/models/wan_video_dit_statemachine_1.py 的实验版 StateMachine,这里整理一套“实例 ID + 文本语义 + BBox”设计,核心目标是让实例控制与文本语义对齐:<class> is <state> 用 T5 编码,遮罩用 bbox 自动生成,尽量减少手工语义对齐成本。

设计动机

  • 类别和状态本身有语义,embedding 表应该来自语言模型而不是离散 ID。
  • 文本提示/负提示与实例控制共享语义空间,避免“prompt 说羊在吃草,但实例 state_emb 不懂吃草”。
  • 用 bbox 代替像素级 mask,降低标注/推理成本,并便于跨帧插值或填充。

输入约定(bbox only)

  • instance_id:仅用于区分同一类别的不同个体,可继续用可训练的 nn.Embedding
  • class_text:自由文本(如 "sheep")。
  • state_text:自由文本(如 "eating""open")。
  • bbox:形状 (B, N, F, 4),xyxy 像素坐标;若为 (B, N, 4) 则对所有帧广播。可选 bbox_mask 表示实例在某帧是否存在。
  • class_state_text_embeds / instance_text_input_ids:如果想在模型内部跑 T5,可传 token ids;否则可传已经编码好的 <class> is <state> 向量。

文本编码(T5)

  1. 构造语义短语:"{class_text} is {state_text}"(必要时加少量 prompt 工程,如 "a {class_text} that is {state_text}")。
  2. 通过 T5-encoder 获得 token hidden states;使用 [EOS] 或 mean pool 得到 (B, N, D_t5)
  3. 线性映射到 DiT 维度:Linear(D_t5 -> dim),再 LayerNorm。
  4. instance_id_emb 融合:fusion([t5_sem, inst_id_emb]) -> inst_token。可继续保留 gate 以便初始不破坏生成。

对齐策略:T5 编码器最好与主模型文本编码共享词表或直接复用 T5 作为主 prompt 编码器,保证同一语义空间;否则需单独训练配准。

BBox → Patch Mask

  1. 将 bbox 从像素坐标映射到 patch grid:
    • H_p = H / patch_size_h, W_p = W / patch_size_w
    • 对每帧 bbox 取整到 patch 单位:x0//ps_w, x1//ps_w, ...,保证 x1>=x0+1
  2. 构建 (B, N, F, H_p, W_p) 二值 mask,前景=1;如果只给 (B, N, 4),则对每帧复用。
  3. 展平为 (B, N, L)MaskGuidedCrossAttention 使用(L = F_p * H_p * W_p)。
  4. 可选:对 bbox 边缘做软扩张(dilation 1~2 个 patch)以缓冲量化误差。

前向路径修改要点

  • InstanceFeatureExtractor:新增 text_dim & text_to_class_state,直接将 <class> is <state> 向量拆成类/状态两份,再与 instance_id_emb 融合;若未提供文本则回退到 class_id/state_id 路径。
  • forward 输入:新增
    • instance_text_input_ids/instance_text_attention_mask(内部跑 T5);
    • instance_class_state_text_embeds(外部已编码的 T5 向量);
    • instance_bboxes / instance_bbox_mask(bbox -> patch mask)。
  • process_masks:增加 bbox 分支,将 xyxy 投影到 patch grid,支持帧级 mask;仍兼容旧的像素级 mask。
  • 其余链路保持不变:实例 tokens 仍在每个 DiTBlock 通过 instance_tokens/instance_masks 触发一次 mask-guided cross-attention。

模型结构图(简化)

flowchart LR
    A[Class text] -->|模板 \"<class> is <state>\"| T5[T5 Encoder]
    S[State text] -->|同上| T5
    T5 --> P[Pool & Linear to dim]
    ID[Instance ID emb] --> FUSE
    P --> FUSE[Fusion MLP -> Instance Tokens]

    VID[VAE Latent Video] --> PATCH[3D Patchify]
    BBOX[BBox / Mask] --> MASK[BBox->Patch Mask]

    PATCH --> DIT[Wan DiT Blocks]
    FUSE --> DIT
    MASK --> DIT
    DIT --> HEAD[Head / Unpatchify]
    HEAD --> OUT[Pred noise]

训练与推理建议

  • 数据:需要 <class_text, state_text, bbox> 标签;若 bbox 稀疏,可用检测/分割模型自动标注。
  • 文本正负提示:确保主 prompt 与实例短语共享编码器,或者在损失中加入对齐项(如对同一语义的 CLIP/T5 空间蒸馏)。
  • 稳定性:保持 gate 零初始化,先小步微调,逐步解冻 T5/融合层。
  • 多实例帧缺失:用 bbox_mask 将缺失帧的 mask 设为 0,避免实例 token 影响不存在的帧。
  • 性能:bbox→mask 是 O(NFH_p*W_p) 的简单填充,可在 dataloader 端完成并缓存。

与现有实现的差异(对 wan_video_dit_statemachine_1.py 的映射)

  • 语义来源:从 class_id/state_id embedding 切换为 T5 文本编码;instance_id embedding 仍保留用于区分个体。
  • mask 生成:由“像素级 mask 下采样”改为“bbox 投影到 patch grid”。
  • 其余控制逻辑(逐层 mask-guided cross-attention、gate、gradient checkpoint)可复用现有代码。