| # Wan Video Instance Control:模型设计说明(bbox + per-frame state weights) |
|
|
| 本文档描述当前 DiffSynth-Studio 的 **Wan Video Statemachine + Instance Control** 设计:仅使用 |
|
|
| - `instance_ids`(区分同类不同个体) |
| - `instance_class_text`(每个实例的 tag/class 文本) |
| - `instance_state_texts`(每个实例的 **固定** state 文本集合) |
| - `instance_state_weights`(**逐帧** state 权重,允许软融合) |
| - `instance_bboxes`(**逐帧** 2D bbox,xyxy 像素坐标) |
|
|
| 来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(`class_id/state_id/mask/state_a/b/progress` 等)不再使用。 |
|
|
| --- |
|
|
| ## 1. 入口 API 与张量约定 |
|
|
| 入口在 `diffsynth/pipelines/wan_video_statemachine.py` 的 `WanVideoPipeline.__call__`。 |
|
|
| ### 1.1 必需字段(启用 instance control 时) |
|
|
| 当你传入以下任意一个字段不为 `None` 时,pipeline 视为启用 instance control,并要求 **全部提供**: |
|
|
| - `instance_ids`: `Tensor`,形状 `(B, N)` 或 `(N,)`,dtype `long` |
| - `N`:实例数(objects) |
| - `instance_class_text`: `List[str]`(长度 `N`)或 `str`(单实例) |
| - 每个实例一个 class/tag,例如 `"egg"`, `"dog"`, `"person"`… |
| - `instance_state_texts`: `List[List[str]]`(形状 `N × S`)或 `List[str]`(单实例) |
| - 每个实例有一个 **固定大小** 的 state 候选集合(`S` 个 state 文本) |
| - 例如单实例:`["raw", "cooked"]`;多实例:`[["idle","run"], ["open","close"]]` |
| - 约束:所有实例的 `S` 必须相同(当前实现强制)。 |
| - `instance_state_weights`: `Tensor`/list,形状 `(B, N, F, S)` 或 `(N, F, S)`,dtype `float` |
| - `F`:逐帧权重的时间长度(推荐等于输入视频帧数 `num_frames`,但允许不同,后续会映射/下采样到 patch-time) |
| - `S`:state 数量,必须等于 `instance_state_texts` 的 state 数 |
| - 语义:对每个 `(b,n,f)`,给出 `S` 个 state 的权重(可 one-hot,也可软融合) |
| - `instance_bboxes`: `Tensor`,形状 `(B, N, F, 4)` 或 `(N, F, 4)`,dtype `float` |
| - bbox 是 `xyxy`,单位为像素坐标,坐标系必须与推理时的 `height/width` 对齐 |
| - 约束:`instance_bboxes.shape[2]` 必须等于 `instance_state_weights.shape[2]`(同一个 `F`) |
|
|
| ### 1.2 推荐的常见配置 |
|
|
| - **单实例**(N=1)+ 两状态(S=2): |
| - `instance_class_text="egg"` |
| - `instance_state_texts=["raw","cooked"]` |
| - `instance_state_weights.shape=(1,1,F,2)` |
| - `instance_bboxes.shape=(1,1,F,4)` |
| - **多实例**(N>1): |
| - `instance_class_text` 长度必须与 `N` 相同 |
| - `instance_state_texts` 外层长度必须与 `N` 相同 |
|
|
| --- |
|
|
| ## 2. Pipeline 数据流(从输入到 model_fn) |
| |
| 对应代码: |
| |
| - `diffsynth/pipelines/wan_video_statemachine.py` |
| - `WanVideoPipeline.__call__` |
| - `WanVideoUnit_InstanceStateTextEmbedder` |
| - `model_fn_wan_video` |
|
|
| ### 2.1 参数归一化与校验(__call__) |
|
|
| 在 `__call__` 中会把输入转为 Tensor,并补齐 batch 维: |
|
|
| - `instance_ids`:若输入为 `(N,)` 会补成 `(1,N)` |
| - `instance_bboxes`:若输入为 `(N,F,4)` 会补成 `(1,N,F,4)` |
| - `instance_state_weights`:若输入为 `(N,F,S)` 会补成 `(1,N,F,S)` |
|
|
| 启用 instance control 时会做关键校验: |
|
|
| - 5 个输入必须同时存在:`ids/class_text/state_texts/state_weights/bboxes` |
| - `state_weights` 与 `bboxes` 的 `F` 必须一致 |
|
|
| ### 2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder) |
| |
| 该 unit 负责把 `(class_text, state_texts)` 变成可供 DiT 使用的 state phrase embedding: |
| |
| 1. 先构造短语: |
| - 对每个实例 `n`,对每个 state `s`: |
| - phrase = `"<class_text[n]> is <state_texts[n][s]>"` |
| 2. 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding: |
| - 输出 `instance_state_text_embeds_multi`,形状 `(1, N, S, text_dim)` |
|
|
| 注意: |
|
|
| - 这里不使用 `instance_state_weights` 做融合;融合在 DiT 内根据逐帧权重完成。 |
| - unit 只产出 `instance_state_text_embeds_multi`,并且 pipeline 在 unit 之后会把 `instance_class_text/instance_state_texts` 从 `inputs_shared` 中移除,确保下游 model_fn 只接收张量(最小化接口)。 |
| |
| --- |
| |
| ## 3. DiT 内部设计(instance tokens + bbox mask-guided attention) |
| |
| 对应代码: |
| |
| - `diffsynth/models/wan_video_dit_instance.py` |
| - `InstanceFeatureExtractor` |
| - `MaskGuidedCrossAttention` |
| - `DiTBlock.forward(..., instance_tokens, instance_masks)` |
|
|
| ### 3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens” |
|
|
| 核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。 |
|
|
| #### 输入 |
|
|
| - `state_text_embeds_multi`: `(B, N, S, text_dim)` |
| 每个 state 对应短语 `"<class> is <state>"` 的 pooled embedding |
| - `state_weights`: `(B, N, F, S)` |
| 每帧对 `S` 个 state 的权重 |
| - `instance_ids`: `(B, N)` |
| 用于区分同类个体 |
| - `num_time_patches = f` |
| DiT patchify 后的时间 patch 数(由 `patch_embedding` 决定) |
|
|
| #### 步骤 |
|
|
| 1. **文本投影到 hidden_dim** |
| - `sem_multi = text_proj(state_text_embeds_multi)` → `(B, N, S, H)` |
| 2. **权重截断与时间下采样** |
| - `weights = clamp(state_weights, min=0)` |
| - 若 `F != f`:把 `(B,N,F,S)` 平均池化到 `(B,N,f,S)` |
| - 映射规则:`pt = floor(t * f / F)` |
| 3. **按权重对 state 语义做逐时间融合** |
| - `sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)` |
| - 得到 `(B, N, f, H)` |
| 4. **融合 instance_id embedding** |
| - `i_feat = Emb(instance_ids)` → `(B, N, H)`,并广播到时间维 `(B, N, f, H)` |
| - 拼接并通过 fusion MLP: |
| - `token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )` |
| - 输出 `inst_tokens`:`(B, f, N, D)`(注意转置后时间维在前) |
| |
| ### 3.2 bbox → patch mask(每个 patch 是否被某实例覆盖) |
| |
| `WanModel.process_masks` 将 `instance_bboxes` 投影到 patch token 网格,返回 `inst_mask_flat`: |
| |
| - 输入 bbox:`(B, N, F, 4)`,`xyxy` 像素坐标 |
| - patch 网格:`(f_p, h_p, w_p)` |
| - 输出 mask:`(B, N, L)`,其中 `L = f_p * h_p * w_p` |
| |
| 关键映射规则: |
| |
| - 空间缩放: |
| - `px = x * (w_p / W_img)` |
| - `py = y * (h_p / H_img)` |
| - 时间映射: |
| - `pt = floor(t * f_p / F_bbox)` |
| |
| 最终每个 `(b,n,pt)` 上把 bbox 覆盖到的 `(py0:py1, px0:px1)` patch 置 1。 |
| |
| ### 3.3 MaskGuidedCrossAttention(log-mask trick) |
| |
| 每个 DiT block 都包含一个 instance cross attention: |
| |
| - Q:来自 patch tokens `x`(形状 `(B, L, D)`) |
| - K/V:来自 instance tokens(按时间对齐后使用) |
| - Mask:`(B, N, L)` |
| |
| attention logits 里加入 `log(mask)` 作为 bias: |
| |
| - `sim = (q · k) / sqrt(d)` |
| - `sim = sim + log(mask.clamp(min=1e-6))` |
| |
| 这样 mask=0 的位置会得到接近 `-inf` 的 bias,从而 softmax 后强制为 0,实现 **只让每个 patch 关注覆盖它的实例**。 |
| |
| ### 3.4 时间对齐方式(per-time tokens vs per-token tokens) |
| |
| `MaskGuidedCrossAttention` 支持三种形状: |
| |
| - `(B, N, D)`:整段序列共享同一组 instance tokens(当前不用) |
| - `(B, T, N, D)`:按 patch-time 切分(默认路径) |
| - 假设序列按时间展开:`L = T * (h*w)`,按时间分段计算 attention |
| - `(B, L, N, D)`:按 token 位置提供 instance tokens(用于 Unified Sequence Parallel) |
| |
| 在 `model_fn_wan_video` 开启 USP 时会将 `inst_tokens (B,T,N,D)` 转换成当前 rank 的 chunk 对应的 `(B, chunk_len, N, D)`: |
| |
| - 先计算每个 token 在全局序列里的位置 `global_pos` |
| - `time_index = global_pos // (h*w)` |
| - `inst_tokens_chunk = inst_tokens[:, time_index]` |
| |
| 并对 padding 部分置 0,避免污染。 |
| |
| ### 3.5 Reference latents 的处理 |
| |
| 当 pipeline 使用 `reference_latents` 拼到序列前面时: |
| |
| - patch token 序列会多出 1 个时间片(`f += 1`) |
| - `inst_mask_flat` 会在序列前补 0(reference 部分不属于任何 instance) |
| - `inst_tokens` 也会在时间维前补 0(reference 时间片不注入 instance 语义) |
| |
| --- |
| |
| ## 4. 重要限制与注意事项 |
| |
| 1. **必须给每个实例提供相同数量的 state 文本(S 必须一致)** |
| 2. **`instance_state_weights` 与 `instance_bboxes` 的时间长度 `F` 必须一致** |
| 3. **bbox 的像素坐标必须与推理时的 `height/width` 对齐** |
| - 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox |
| 4. **sliding window 不支持 instance control** |
| - `model_fn_wan_video` 在 `sliding_window_size/stride` 与 instance 输入同时存在时直接报错 |
|
|
| --- |
|
|
| ## 5. 最小可用示例(伪代码) |
|
|
| ```python |
| F = num_frames |
| N = 1 |
| S = 3 |
| |
| instance_ids = torch.tensor([[1]]) # (1,1) |
| instance_class_text = ["egg"] # len=1 |
| instance_state_texts = [["raw", "half", "cooked"]] # (N,S) |
| |
| # 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked |
| w = torch.zeros((1,1,F,S), dtype=torch.float32) |
| t = torch.linspace(0, 1, F) |
| w[0,0,:,0] = (1 - t) # raw |
| w[0,0,:,2] = t # cooked |
| w[0,0,:,1] = 0.0 # half (可选) |
| |
| # bbox (1,1,F,4):每帧一个 bbox,xyxy |
| b = torch.zeros((1,1,F,4), dtype=torch.float32) |
| b[0,0,:,0] = 100; b[0,0,:,1] = 120 |
| b[0,0,:,2] = 260; b[0,0,:,3] = 320 |
| |
| video = pipe( |
| prompt="...", |
| height=H, width=W, num_frames=F, |
| instance_ids=instance_ids, |
| instance_class_text=instance_class_text, |
| instance_state_texts=instance_state_texts, |
| instance_state_weights=w, |
| instance_bboxes=b, |
| ) |
| ``` |
|
|
| --- |
|
|
| ## 6. 代码入口索引 |
|
|
| - Pipeline API / 文本编码: |
| - `diffsynth/pipelines/wan_video_statemachine.py` |
| - `WanVideoPipeline.__call__` |
| - `WanVideoUnit_InstanceStateTextEmbedder` |
| - `model_fn_wan_video` |
| - Instance-aware DiT: |
| - `diffsynth/models/wan_video_dit_instance.py` |
| - `InstanceFeatureExtractor` |
| - `MaskGuidedCrossAttention` |
| - `WanModel.forward(... instance_*)` |
|
|
|
|