Wan 实例 StateMachine(T5 语义 + BBox Mask)方案草案
基于 diffsynth/models/wan_video_dit_statemachine_1.py 的实验版 StateMachine,这里整理一套“实例 ID + 文本语义 + BBox”设计,核心目标是让实例控制与文本语义对齐:<class> is <state> 用 T5 编码,遮罩用 bbox 自动生成,尽量减少手工语义对齐成本。
设计动机
- 类别和状态本身有语义,embedding 表应该来自语言模型而不是离散 ID。
- 文本提示/负提示与实例控制共享语义空间,避免“prompt 说羊在吃草,但实例 state_emb 不懂吃草”。
- 用 bbox 代替像素级 mask,降低标注/推理成本,并便于跨帧插值或填充。
输入约定(bbox only)
instance_id:仅用于区分同一类别的不同个体,可继续用可训练的nn.Embedding。class_text:自由文本(如"sheep")。state_text:自由文本(如"eating"或"open")。bbox:形状(B, N, F, 4),xyxy 像素坐标;若为(B, N, 4)则对所有帧广播。可选bbox_mask表示实例在某帧是否存在。class_state_text_embeds/instance_text_input_ids:如果想在模型内部跑 T5,可传 token ids;否则可传已经编码好的<class> is <state>向量。
文本编码(T5)
- 构造语义短语:
"{class_text} is {state_text}"(必要时加少量 prompt 工程,如"a {class_text} that is {state_text}")。 - 通过 T5-encoder 获得 token hidden states;使用
[EOS]或 mean pool 得到(B, N, D_t5)。 - 线性映射到 DiT 维度:
Linear(D_t5 -> dim),再 LayerNorm。 - 与
instance_id_emb融合:fusion([t5_sem, inst_id_emb]) -> inst_token。可继续保留 gate 以便初始不破坏生成。
对齐策略:T5 编码器最好与主模型文本编码共享词表或直接复用 T5 作为主 prompt 编码器,保证同一语义空间;否则需单独训练配准。
BBox → Patch Mask
- 将 bbox 从像素坐标映射到 patch grid:
H_p = H / patch_size_h,W_p = W / patch_size_w。- 对每帧 bbox 取整到 patch 单位:
x0//ps_w, x1//ps_w, ...,保证x1>=x0+1。
- 构建
(B, N, F, H_p, W_p)二值 mask,前景=1;如果只给(B, N, 4),则对每帧复用。 - 展平为
(B, N, L)供MaskGuidedCrossAttention使用(L = F_p * H_p * W_p)。 - 可选:对 bbox 边缘做软扩张(dilation 1~2 个 patch)以缓冲量化误差。
前向路径修改要点
InstanceFeatureExtractor:新增text_dim&text_to_class_state,直接将<class> is <state>向量拆成类/状态两份,再与instance_id_emb融合;若未提供文本则回退到class_id/state_id路径。forward输入:新增instance_text_input_ids/instance_text_attention_mask(内部跑 T5);instance_class_state_text_embeds(外部已编码的 T5 向量);instance_bboxes/instance_bbox_mask(bbox -> patch mask)。
process_masks:增加 bbox 分支,将 xyxy 投影到 patch grid,支持帧级 mask;仍兼容旧的像素级 mask。- 其余链路保持不变:实例 tokens 仍在每个 DiTBlock 通过
instance_tokens/instance_masks触发一次 mask-guided cross-attention。
模型结构图(简化)
flowchart LR
A[Class text] -->|模板 \"<class> is <state>\"| T5[T5 Encoder]
S[State text] -->|同上| T5
T5 --> P[Pool & Linear to dim]
ID[Instance ID emb] --> FUSE
P --> FUSE[Fusion MLP -> Instance Tokens]
VID[VAE Latent Video] --> PATCH[3D Patchify]
BBOX[BBox / Mask] --> MASK[BBox->Patch Mask]
PATCH --> DIT[Wan DiT Blocks]
FUSE --> DIT
MASK --> DIT
DIT --> HEAD[Head / Unpatchify]
HEAD --> OUT[Pred noise]
训练与推理建议
- 数据:需要
<class_text, state_text, bbox>标签;若 bbox 稀疏,可用检测/分割模型自动标注。 - 文本正负提示:确保主 prompt 与实例短语共享编码器,或者在损失中加入对齐项(如对同一语义的 CLIP/T5 空间蒸馏)。
- 稳定性:保持
gate零初始化,先小步微调,逐步解冻 T5/融合层。 - 多实例帧缺失:用
bbox_mask将缺失帧的 mask 设为 0,避免实例 token 影响不存在的帧。 - 性能:bbox→mask 是 O(NFH_p*W_p) 的简单填充,可在 dataloader 端完成并缓存。
与现有实现的差异(对 wan_video_dit_statemachine_1.py 的映射)
- 语义来源:从
class_id/state_idembedding 切换为 T5 文本编码;instance_idembedding 仍保留用于区分个体。 - mask 生成:由“像素级 mask 下采样”改为“bbox 投影到 patch grid”。
- 其余控制逻辑(逐层 mask-guided cross-attention、gate、gradient checkpoint)可复用现有代码。