# Wan 实例 StateMachine(T5 语义 + BBox Mask)方案草案 基于 `diffsynth/models/wan_video_dit_statemachine_1.py` 的实验版 StateMachine,这里整理一套“实例 ID + 文本语义 + BBox”设计,核心目标是让实例控制与文本语义对齐:` is ` 用 T5 编码,遮罩用 bbox 自动生成,尽量减少手工语义对齐成本。 ## 设计动机 - 类别和状态本身有语义,embedding 表应该来自语言模型而不是离散 ID。 - 文本提示/负提示与实例控制共享语义空间,避免“prompt 说羊在吃草,但实例 state_emb 不懂吃草”。 - 用 bbox 代替像素级 mask,降低标注/推理成本,并便于跨帧插值或填充。 ## 输入约定(bbox only) - `instance_id`:仅用于区分同一类别的不同个体,可继续用可训练的 `nn.Embedding`。 - `class_text`:自由文本(如 `"sheep"`)。 - `state_text`:自由文本(如 `"eating"` 或 `"open"`)。 - `bbox`:形状 `(B, N, F, 4)`,xyxy 像素坐标;若为 `(B, N, 4)` 则对所有帧广播。可选 `bbox_mask` 表示实例在某帧是否存在。 - `class_state_text_embeds` / `instance_text_input_ids`:如果想在模型内部跑 T5,可传 token ids;否则可传已经编码好的 ` is ` 向量。 ## 文本编码(T5) 1. 构造语义短语:`"{class_text} is {state_text}"`(必要时加少量 prompt 工程,如 `"a {class_text} that is {state_text}"`)。 2. 通过 T5-encoder 获得 token hidden states;使用 `[EOS]` 或 mean pool 得到 `(B, N, D_t5)`。 3. 线性映射到 DiT 维度:`Linear(D_t5 -> dim)`,再 LayerNorm。 4. 与 `instance_id_emb` 融合:`fusion([t5_sem, inst_id_emb]) -> inst_token`。可继续保留 gate 以便初始不破坏生成。 对齐策略:T5 编码器最好与主模型文本编码共享词表或直接复用 T5 作为主 prompt 编码器,保证同一语义空间;否则需单独训练配准。 ## BBox → Patch Mask 1. 将 bbox 从像素坐标映射到 patch grid: - `H_p = H / patch_size_h`, `W_p = W / patch_size_w`。 - 对每帧 bbox 取整到 patch 单位:`x0//ps_w, x1//ps_w, ...`,保证 `x1>=x0+1`。 2. 构建 `(B, N, F, H_p, W_p)` 二值 mask,前景=1;如果只给 `(B, N, 4)`,则对每帧复用。 3. 展平为 `(B, N, L)` 供 `MaskGuidedCrossAttention` 使用(`L = F_p * H_p * W_p`)。 4. 可选:对 bbox 边缘做软扩张(dilation 1~2 个 patch)以缓冲量化误差。 ## 前向路径修改要点 - `InstanceFeatureExtractor`:新增 `text_dim` & `text_to_class_state`,直接将 ` is ` 向量拆成类/状态两份,再与 `instance_id_emb` 融合;若未提供文本则回退到 `class_id/state_id` 路径。 - `forward` 输入:新增 - `instance_text_input_ids`/`instance_text_attention_mask`(内部跑 T5); - `instance_class_state_text_embeds`(外部已编码的 T5 向量); - `instance_bboxes` / `instance_bbox_mask`(bbox -> patch mask)。 - `process_masks`:增加 bbox 分支,将 xyxy 投影到 patch grid,支持帧级 mask;仍兼容旧的像素级 mask。 - 其余链路保持不变:实例 tokens 仍在每个 DiTBlock 通过 `instance_tokens/instance_masks` 触发一次 mask-guided cross-attention。 ## 模型结构图(简化) ```mermaid flowchart LR A[Class text] -->|模板 \" is \"| T5[T5 Encoder] S[State text] -->|同上| T5 T5 --> P[Pool & Linear to dim] ID[Instance ID emb] --> FUSE P --> FUSE[Fusion MLP -> Instance Tokens] VID[VAE Latent Video] --> PATCH[3D Patchify] BBOX[BBox / Mask] --> MASK[BBox->Patch Mask] PATCH --> DIT[Wan DiT Blocks] FUSE --> DIT MASK --> DIT DIT --> HEAD[Head / Unpatchify] HEAD --> OUT[Pred noise] ``` ## 训练与推理建议 - 数据:需要 `` 标签;若 bbox 稀疏,可用检测/分割模型自动标注。 - 文本正负提示:确保主 prompt 与实例短语共享编码器,或者在损失中加入对齐项(如对同一语义的 CLIP/T5 空间蒸馏)。 - 稳定性:保持 `gate` 零初始化,先小步微调,逐步解冻 T5/融合层。 - 多实例帧缺失:用 `bbox_mask` 将缺失帧的 mask 设为 0,避免实例 token 影响不存在的帧。 - 性能:bbox→mask 是 O(N*F*H_p*W_p) 的简单填充,可在 dataloader 端完成并缓存。 ## 与现有实现的差异(对 `wan_video_dit_statemachine_1.py` 的映射) - 语义来源:从 `class_id/state_id` embedding 切换为 T5 文本编码;`instance_id` embedding 仍保留用于区分个体。 - mask 生成:由“像素级 mask 下采样”改为“bbox 投影到 patch grid”。 - 其余控制逻辑(逐层 mask-guided cross-attention、gate、gradient checkpoint)可复用现有代码。