PencilFolder / docs /wan_video_instance_control_design.md
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
|
Raw
History Blame Contribute Delete
9.98 kB
# Wan Video Instance Control:模型设计说明(bbox + per-frame state weights)
本文档描述当前 DiffSynth-Studio 的 **Wan Video Statemachine + Instance Control** 设计:仅使用
- `instance_ids`(区分同类不同个体)
- `instance_class_text`(每个实例的 tag/class 文本)
- `instance_state_texts`(每个实例的 **固定** state 文本集合)
- `instance_state_weights`**逐帧** state 权重,允许软融合)
- `instance_bboxes`**逐帧** 2D bbox,xyxy 像素坐标)
来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(`class_id/state_id/mask/state_a/b/progress` 等)不再使用。
---
## 1. 入口 API 与张量约定
入口在 `diffsynth/pipelines/wan_video_statemachine.py``WanVideoPipeline.__call__`
### 1.1 必需字段(启用 instance control 时)
当你传入以下任意一个字段不为 `None` 时,pipeline 视为启用 instance control,并要求 **全部提供**
- `instance_ids`: `Tensor`,形状 `(B, N)``(N,)`,dtype `long`
- `N`:实例数(objects)
- `instance_class_text`: `List[str]`(长度 `N`)或 `str`(单实例)
- 每个实例一个 class/tag,例如 `"egg"`, `"dog"`, `"person"`
- `instance_state_texts`: `List[List[str]]`(形状 `N × S`)或 `List[str]`(单实例)
- 每个实例有一个 **固定大小** 的 state 候选集合(`S` 个 state 文本)
- 例如单实例:`["raw", "cooked"]`;多实例:`[["idle","run"], ["open","close"]]`
- 约束:所有实例的 `S` 必须相同(当前实现强制)。
- `instance_state_weights`: `Tensor`/list,形状 `(B, N, F, S)``(N, F, S)`,dtype `float`
- `F`:逐帧权重的时间长度(推荐等于输入视频帧数 `num_frames`,但允许不同,后续会映射/下采样到 patch-time)
- `S`:state 数量,必须等于 `instance_state_texts` 的 state 数
- 语义:对每个 `(b,n,f)`,给出 `S` 个 state 的权重(可 one-hot,也可软融合)
- `instance_bboxes`: `Tensor`,形状 `(B, N, F, 4)``(N, F, 4)`,dtype `float`
- bbox 是 `xyxy`,单位为像素坐标,坐标系必须与推理时的 `height/width` 对齐
- 约束:`instance_bboxes.shape[2]` 必须等于 `instance_state_weights.shape[2]`(同一个 `F`
### 1.2 推荐的常见配置
- **单实例**(N=1)+ 两状态(S=2):
- `instance_class_text="egg"`
- `instance_state_texts=["raw","cooked"]`
- `instance_state_weights.shape=(1,1,F,2)`
- `instance_bboxes.shape=(1,1,F,4)`
- **多实例**(N>1):
- `instance_class_text` 长度必须与 `N` 相同
- `instance_state_texts` 外层长度必须与 `N` 相同
---
## 2. Pipeline 数据流(从输入到 model_fn)
对应代码:
- `diffsynth/pipelines/wan_video_statemachine.py`
- `WanVideoPipeline.__call__`
- `WanVideoUnit_InstanceStateTextEmbedder`
- `model_fn_wan_video`
### 2.1 参数归一化与校验(__call__
`__call__` 中会把输入转为 Tensor,并补齐 batch 维:
- `instance_ids`:若输入为 `(N,)` 会补成 `(1,N)`
- `instance_bboxes`:若输入为 `(N,F,4)` 会补成 `(1,N,F,4)`
- `instance_state_weights`:若输入为 `(N,F,S)` 会补成 `(1,N,F,S)`
启用 instance control 时会做关键校验:
- 5 个输入必须同时存在:`ids/class_text/state_texts/state_weights/bboxes`
- `state_weights``bboxes``F` 必须一致
### 2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder)
该 unit 负责把 `(class_text, state_texts)` 变成可供 DiT 使用的 state phrase embedding:
1. 先构造短语:
- 对每个实例 `n`,对每个 state `s`:
- phrase = `"<class_text[n]> is <state_texts[n][s]>"`
2. 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding:
- 输出 `instance_state_text_embeds_multi`,形状 `(1, N, S, text_dim)`
注意:
- 这里不使用 `instance_state_weights` 做融合;融合在 DiT 内根据逐帧权重完成。
- unit 只产出 `instance_state_text_embeds_multi`,并且 pipeline 在 unit 之后会把 `instance_class_text/instance_state_texts``inputs_shared` 中移除,确保下游 model_fn 只接收张量(最小化接口)。
---
## 3. DiT 内部设计(instance tokens + bbox mask-guided attention)
对应代码:
- `diffsynth/models/wan_video_dit_instance.py`
- `InstanceFeatureExtractor`
- `MaskGuidedCrossAttention`
- `DiTBlock.forward(..., instance_tokens, instance_masks)`
### 3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens”
核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。
#### 输入
- `state_text_embeds_multi`: `(B, N, S, text_dim)`
每个 state 对应短语 `"<class> is <state>"` 的 pooled embedding
- `state_weights`: `(B, N, F, S)`
每帧对 `S` 个 state 的权重
- `instance_ids`: `(B, N)`
用于区分同类个体
- `num_time_patches = f`
DiT patchify 后的时间 patch 数(由 `patch_embedding` 决定)
#### 步骤
1. **文本投影到 hidden_dim**
- `sem_multi = text_proj(state_text_embeds_multi)` → `(B, N, S, H)`
2. **权重截断与时间下采样**
- `weights = clamp(state_weights, min=0)`
-`F != f`:把 `(B,N,F,S)` 平均池化到 `(B,N,f,S)`
- 映射规则:`pt = floor(t * f / F)`
3. **按权重对 state 语义做逐时间融合**
- `sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)`
- 得到 `(B, N, f, H)`
4. **融合 instance_id embedding**
- `i_feat = Emb(instance_ids)` → `(B, N, H)`,并广播到时间维 `(B, N, f, H)`
- 拼接并通过 fusion MLP:
- `token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )`
- 输出 `inst_tokens`:`(B, f, N, D)`(注意转置后时间维在前)
### 3.2 bbox → patch mask(每个 patch 是否被某实例覆盖)
`WanModel.process_masks` 将 `instance_bboxes` 投影到 patch token 网格,返回 `inst_mask_flat`:
- 输入 bbox:`(B, N, F, 4)`,`xyxy` 像素坐标
- patch 网格:`(f_p, h_p, w_p)`
- 输出 mask:`(B, N, L)`,其中 `L = f_p * h_p * w_p`
关键映射规则:
- 空间缩放:
- `px = x * (w_p / W_img)`
- `py = y * (h_p / H_img)`
- 时间映射:
- `pt = floor(t * f_p / F_bbox)`
最终每个 `(b,n,pt)` 上把 bbox 覆盖到的 `(py0:py1, px0:px1)` patch 置 1。
### 3.3 MaskGuidedCrossAttention(log-mask trick)
每个 DiT block 都包含一个 instance cross attention:
- Q:来自 patch tokens `x`(形状 `(B, L, D)`)
- K/V:来自 instance tokens(按时间对齐后使用)
- Mask:`(B, N, L)`
attention logits 里加入 `log(mask)` 作为 bias:
- `sim = (q · k) / sqrt(d)`
- `sim = sim + log(mask.clamp(min=1e-6))`
这样 mask=0 的位置会得到接近 `-inf` 的 bias,从而 softmax 后强制为 0,实现 **只让每个 patch 关注覆盖它的实例**。
### 3.4 时间对齐方式(per-time tokens vs per-token tokens)
`MaskGuidedCrossAttention` 支持三种形状:
- `(B, N, D)`:整段序列共享同一组 instance tokens(当前不用)
- `(B, T, N, D)`:按 patch-time 切分(默认路径)
- 假设序列按时间展开:`L = T * (h*w)`,按时间分段计算 attention
- `(B, L, N, D)`:按 token 位置提供 instance tokens(用于 Unified Sequence Parallel)
在 `model_fn_wan_video` 开启 USP 时会将 `inst_tokens (B,T,N,D)` 转换成当前 rank 的 chunk 对应的 `(B, chunk_len, N, D)`:
- 先计算每个 token 在全局序列里的位置 `global_pos`
- `time_index = global_pos // (h*w)`
- `inst_tokens_chunk = inst_tokens[:, time_index]`
并对 padding 部分置 0,避免污染。
### 3.5 Reference latents 的处理
当 pipeline 使用 `reference_latents` 拼到序列前面时:
- patch token 序列会多出 1 个时间片(`f += 1`)
- `inst_mask_flat` 会在序列前补 0(reference 部分不属于任何 instance)
- `inst_tokens` 也会在时间维前补 0(reference 时间片不注入 instance 语义)
---
## 4. 重要限制与注意事项
1. **必须给每个实例提供相同数量的 state 文本(S 必须一致)**
2. **`instance_state_weights` 与 `instance_bboxes` 的时间长度 `F` 必须一致**
3. **bbox 的像素坐标必须与推理时的 `height/width` 对齐**
- 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox
4. **sliding window 不支持 instance control**
- `model_fn_wan_video``sliding_window_size/stride` 与 instance 输入同时存在时直接报错
---
## 5. 最小可用示例(伪代码)
```python
F = num_frames
N = 1
S = 3
instance_ids = torch.tensor([[1]]) # (1,1)
instance_class_text = ["egg"] # len=1
instance_state_texts = [["raw", "half", "cooked"]] # (N,S)
# 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked
w = torch.zeros((1,1,F,S), dtype=torch.float32)
t = torch.linspace(0, 1, F)
w[0,0,:,0] = (1 - t) # raw
w[0,0,:,2] = t # cooked
w[0,0,:,1] = 0.0 # half (可选)
# bbox (1,1,F,4):每帧一个 bbox,xyxy
b = torch.zeros((1,1,F,4), dtype=torch.float32)
b[0,0,:,0] = 100; b[0,0,:,1] = 120
b[0,0,:,2] = 260; b[0,0,:,3] = 320
video = pipe(
prompt="...",
height=H, width=W, num_frames=F,
instance_ids=instance_ids,
instance_class_text=instance_class_text,
instance_state_texts=instance_state_texts,
instance_state_weights=w,
instance_bboxes=b,
)
```
---
## 6. 代码入口索引
- Pipeline API / 文本编码:
- `diffsynth/pipelines/wan_video_statemachine.py`
- `WanVideoPipeline.__call__`
- `WanVideoUnit_InstanceStateTextEmbedder`
- `model_fn_wan_video`
- Instance-aware DiT:
- `diffsynth/models/wan_video_dit_instance.py`
- `InstanceFeatureExtractor`
- `MaskGuidedCrossAttention`
- `WanModel.forward(... instance_*)`