Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- README.md +30 -8
- aoti.py +35 -0
- app.py +995 -0
- kill_bill.jpeg +3 -0
- lora_loader.py +127 -0
- model/loss.py +128 -0
- model/pytorch_msssim/__init__.py +198 -0
- model/warplayer.py +24 -0
- packages.txt +1 -0
- requirements.txt +16 -0
- wan22_input_2.jpg +3 -0
- wan_controlnet.py +284 -0
- wan_i2v_input.JPG +3 -0
- wan_t2v_controlnet_pipeline.py +798 -0
- wan_teacache.py +78 -0
- wan_transformer.py +135 -0
- workflows/sam2.1_optimized.json +0 -0
- workflows/sam_optimized.json +0 -0
- workflows/vace_optimized.json +1043 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
kill_bill.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
wan22_input_2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
wan_i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,36 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: WAN 2.2 3-Step V2V Pipeline
|
| 3 |
+
emoji: 🎬
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.44.1
|
| 8 |
+
python_version: "3.10"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
short_description: I2V + T2V + 3-Step V2V (SAM2 → Composite → VACE)
|
| 12 |
+
models:
|
| 13 |
+
- facebook/sam2.1-hiera-large
|
| 14 |
+
- google/umt5-xxl
|
| 15 |
+
- Kijai/WanVideo_comfy
|
| 16 |
+
- linoyts/Wan2.2-T2V-A14B-Diffusers-BF16
|
| 17 |
+
- lkzd7/WAN2.2_LoraSet_NSFW
|
| 18 |
+
- r3gm/RIFE
|
| 19 |
+
- TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING
|
| 20 |
+
- Wan-AI/Wan2.1-VACE-14B-diffusers
|
| 21 |
+
- Wan-AI/Wan2.2-T2V-A14B-Diffusers
|
| 22 |
+
- zerogpu-aoti/Wan2
|
| 23 |
---
|
| 24 |
|
| 25 |
+
# WAN 2.2 Multi-Task Video Generation
|
| 26 |
+
|
| 27 |
+
## Features
|
| 28 |
+
- **I2V**: Image-to-Video (Lightning 14B, 6-step, FP8+AoT)
|
| 29 |
+
- **T2V**: Text-to-Video (Lightning 14B, 4-step, Lightning LoRA)
|
| 30 |
+
- **V2V**: 3-Step Video-to-Video Pipeline
|
| 31 |
+
1. **SAM2 Segmentation**: Click points on first frame → auto-track through video → mask video
|
| 32 |
+
2. **Composite + GrowMask**: Original + mask → expanded mask + composite video (automatic)
|
| 33 |
+
3. **VACE Generation**: Composite + grown mask + reference image + prompt → final video
|
| 34 |
+
|
| 35 |
+
## V2V Workflow
|
| 36 |
+
Based on ComfyUI workflows: `sam_optimized`, `sam2.1_optimized`, `vace_optimized`
|
aoti.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from typing import cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
|
| 9 |
+
from spaces.zero.torch.aoti import ZeroGPUWeights
|
| 10 |
+
from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
|
| 14 |
+
clone = object.__new__(module.__class__)
|
| 15 |
+
clone.__dict__ = module.__dict__.copy()
|
| 16 |
+
clone._parameters = module._parameters.copy()
|
| 17 |
+
clone._buffers = module._buffers.copy()
|
| 18 |
+
clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
|
| 19 |
+
return clone
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
|
| 23 |
+
repeated_blocks = cast(list[str], module._repeated_blocks)
|
| 24 |
+
aoti_files = {name: hf_hub_download(
|
| 25 |
+
repo_id=repo_id,
|
| 26 |
+
filename='package.pt2',
|
| 27 |
+
subfolder=name if variant is None else f'{name}.{variant}',
|
| 28 |
+
) for name in repeated_blocks}
|
| 29 |
+
for block_name, aoti_file in aoti_files.items():
|
| 30 |
+
for block in module.modules():
|
| 31 |
+
if block.__class__.__name__ == block_name:
|
| 32 |
+
block_ = _shallow_clone_module(block)
|
| 33 |
+
unwrap_tensor_subclass_parameters(block_)
|
| 34 |
+
weights = ZeroGPUWeights(block_.state_dict())
|
| 35 |
+
block.forward = ZeroGPUCompiledModel(aoti_file, weights)
|
app.py
ADDED
|
@@ -0,0 +1,995 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WAN 2.2 Multi-Task Video Generation - 3-Step V2V Pipeline
|
| 3 |
+
I2V: Lightning 14B (6 steps, FP8+AoT)
|
| 4 |
+
T2V: Lightning 14B (4 steps, Lightning LoRA + FP8)
|
| 5 |
+
V2V: 3-Step Pipeline (SAM2 → Composite → VACE)
|
| 6 |
+
Step 1: SAM2 video segmentation (click points → mask video)
|
| 7 |
+
Step 2: ImageComposite (original + mask → composite video)
|
| 8 |
+
Step 3: VACE generation (composite + grown mask + ref image + prompt → final)
|
| 9 |
+
LoRA: from lkzd7/WAN2.2_LoraSet_NSFW (I2V only)
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import spaces
|
| 14 |
+
import shutil
|
| 15 |
+
import subprocess
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
import tempfile
|
| 19 |
+
import warnings
|
| 20 |
+
import time
|
| 21 |
+
import gc
|
| 22 |
+
import uuid
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from PIL import Image, ImageFilter
|
| 30 |
+
|
| 31 |
+
import gradio as gr
|
| 32 |
+
from diffusers import (
|
| 33 |
+
AutoencoderKLWan,
|
| 34 |
+
FlowMatchEulerDiscreteScheduler,
|
| 35 |
+
WanPipeline,
|
| 36 |
+
SASolverScheduler,
|
| 37 |
+
DEISMultistepScheduler,
|
| 38 |
+
DPMSolverMultistepInverseScheduler,
|
| 39 |
+
UniPCMultistepScheduler,
|
| 40 |
+
DPMSolverMultistepScheduler,
|
| 41 |
+
DPMSolverSinglestepScheduler,
|
| 42 |
+
)
|
| 43 |
+
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
|
| 44 |
+
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
| 45 |
+
from diffusers.pipelines.wan.pipeline_wan_vace import WanVACEPipeline
|
| 46 |
+
from diffusers.utils.export_utils import export_to_video
|
| 47 |
+
from diffusers.utils import load_video
|
| 48 |
+
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
|
| 49 |
+
import aoti
|
| 50 |
+
import lora_loader
|
| 51 |
+
|
| 52 |
+
# SAM2 for video mask generation
|
| 53 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
| 54 |
+
|
| 55 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 56 |
+
warnings.filterwarnings("ignore")
|
| 57 |
+
|
| 58 |
+
def clear_vram():
|
| 59 |
+
gc.collect()
|
| 60 |
+
torch.cuda.empty_cache()
|
| 61 |
+
|
| 62 |
+
# ============ RIFE ============
|
| 63 |
+
get_timestamp_js = """
|
| 64 |
+
function() {
|
| 65 |
+
const video = document.querySelector('#generated-video video');
|
| 66 |
+
if (video) { return video.currentTime; }
|
| 67 |
+
return 0;
|
| 68 |
+
}
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def extract_frame(video_path, timestamp):
|
| 72 |
+
if not video_path:
|
| 73 |
+
return None
|
| 74 |
+
cap = cv2.VideoCapture(video_path)
|
| 75 |
+
if not cap.isOpened():
|
| 76 |
+
return None
|
| 77 |
+
fps = cap.get(cv2.CAP_FPS)
|
| 78 |
+
target_frame_num = int(float(timestamp) * fps)
|
| 79 |
+
total_frames = int(cap.get(cv2.CAP_FRAME_COUNT))
|
| 80 |
+
if target_frame_num >= total_frames:
|
| 81 |
+
target_frame_num = total_frames - 1
|
| 82 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num)
|
| 83 |
+
ret, frame = cap.read()
|
| 84 |
+
cap.release()
|
| 85 |
+
if ret:
|
| 86 |
+
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
if not os.path.exists("RIFEv4.26_0921.zip"):
|
| 90 |
+
print("Downloading RIFE Model...")
|
| 91 |
+
subprocess.run(["wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip"], check=True)
|
| 92 |
+
subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True)
|
| 93 |
+
|
| 94 |
+
from train_log.RIFE_HDv3 import Model
|
| 95 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 96 |
+
rife_model = Model()
|
| 97 |
+
rife_model.load_model("train_log", -1)
|
| 98 |
+
rife_model.eval()
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def interpolate_bits(frames_np, multiplier=2, scale=1.0):
|
| 102 |
+
if isinstance(frames_np, list):
|
| 103 |
+
T = len(frames_np)
|
| 104 |
+
H, W, C = frames_np[0].shape
|
| 105 |
+
else:
|
| 106 |
+
T, H, W, C = frames_np.shape
|
| 107 |
+
if multiplier < 2:
|
| 108 |
+
return list(frames_np) if isinstance(frames_np, np.ndarray) else frames_np
|
| 109 |
+
n_interp = multiplier - 1
|
| 110 |
+
tmp = max(128, int(128 / scale))
|
| 111 |
+
ph = ((H - 1) // tmp + 1) * tmp
|
| 112 |
+
pw = ((W - 1) // tmp + 1) * tmp
|
| 113 |
+
padding = (0, pw - W, 0, ph - H)
|
| 114 |
+
def to_tensor(frame_np):
|
| 115 |
+
t = torch.from_numpy(frame_np).to(device)
|
| 116 |
+
t = t.permute(2, 0, 1).unsqueeze(0)
|
| 117 |
+
return F.pad(t, padding).half()
|
| 118 |
+
def from_tensor(tensor):
|
| 119 |
+
t = tensor[0, :, :H, :W]
|
| 120 |
+
return t.permute(1, 2, 0).float().cpu().numpy()
|
| 121 |
+
def make_inference(I0, I1, n):
|
| 122 |
+
if rife_model.version >= 3.9:
|
| 123 |
+
return [rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale) for i in range(n)]
|
| 124 |
+
else:
|
| 125 |
+
middle = rife_model.inference(I0, I1, scale)
|
| 126 |
+
if n == 1: return [middle]
|
| 127 |
+
first_half = make_inference(I0, middle, n//2)
|
| 128 |
+
second_half = make_inference(middle, I1, n//2)
|
| 129 |
+
return [*first_half, middle, *second_half] if n % 2 else [*first_half, *second_half]
|
| 130 |
+
output_frames = []
|
| 131 |
+
I1 = to_tensor(frames_np[0])
|
| 132 |
+
with tqdm(total=T-1, desc="Interpolating", unit="frame") as pbar:
|
| 133 |
+
for i in range(T - 1):
|
| 134 |
+
I0 = I1
|
| 135 |
+
output_frames.append(from_tensor(I0))
|
| 136 |
+
I1 = to_tensor(frames_np[i+1])
|
| 137 |
+
for mid in make_inference(I0, I1, n_interp):
|
| 138 |
+
output_frames.append(from_tensor(mid))
|
| 139 |
+
if (i + 1) % 50 == 0:
|
| 140 |
+
pbar.update(50)
|
| 141 |
+
pbar.update((T-1) % 50)
|
| 142 |
+
output_frames.append(from_tensor(I1))
|
| 143 |
+
del I0, I1
|
| 144 |
+
torch.cuda.empty_cache()
|
| 145 |
+
return output_frames
|
| 146 |
+
|
| 147 |
+
# ============ Config ============
|
| 148 |
+
FIXED_FPS = 16
|
| 149 |
+
MAX_FRAMES_MODEL = 241 # ~15s@16fps, requires more VRAM/time
|
| 150 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 151 |
+
|
| 152 |
+
SCHEDULER_MAP = {
|
| 153 |
+
"FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
|
| 154 |
+
"SASolver": SASolverScheduler,
|
| 155 |
+
"DEISMultistep": DEISMultistepScheduler,
|
| 156 |
+
"DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler,
|
| 157 |
+
"UniPCMultistep": UniPCMultistepScheduler,
|
| 158 |
+
"DPMSolverMultistep": DPMSolverMultistepScheduler,
|
| 159 |
+
"DPMSolverSinglestep": DPMSolverSinglestepScheduler,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
default_negative_prompt = (
|
| 163 |
+
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
|
| 164 |
+
"still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, incomplete, "
|
| 165 |
+
"extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, "
|
| 166 |
+
"malformed limbs, fused fingers, still frame, messy background, three legs, "
|
| 167 |
+
"many people in background, walking backwards, watermark, text, signature"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# ============ Load I2V Pipeline (Lightning, AoT compiled) ============
|
| 171 |
+
print("Loading I2V Pipeline (Lightning 14B)...")
|
| 172 |
+
i2v_pipe = WanImageToVideoPipeline.from_pretrained(
|
| 173 |
+
"TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING",
|
| 174 |
+
torch_dtype=torch.bfloat16,
|
| 175 |
+
).to('cuda')
|
| 176 |
+
i2v_original_scheduler = copy.deepcopy(i2v_pipe.scheduler)
|
| 177 |
+
|
| 178 |
+
quantize_(i2v_pipe.text_encoder, Int8WeightOnlyConfig())
|
| 179 |
+
major, minor = torch.cuda.get_device_capability()
|
| 180 |
+
supports_fp8 = (major > 8) or (major == 8 and minor >= 9)
|
| 181 |
+
if supports_fp8:
|
| 182 |
+
quantize_(i2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 183 |
+
quantize_(i2v_pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
|
| 184 |
+
aoti.aoti_blocks_load(i2v_pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
|
| 185 |
+
aoti.aoti_blocks_load(i2v_pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
|
| 186 |
+
else:
|
| 187 |
+
quantize_(i2v_pipe.transformer, Int8WeightOnlyConfig())
|
| 188 |
+
quantize_(i2v_pipe.transformer_2, Int8WeightOnlyConfig())
|
| 189 |
+
|
| 190 |
+
# ============ T2V Pipeline (on-demand, 14B + Wan22 Lightning LoRA) ============
|
| 191 |
+
# Use T2V-A14B + Wan22 Lightning LoRA (separate HIGH/LOW for dual transformer)
|
| 192 |
+
# Load on-demand with CPU offload to avoid OOM alongside I2V
|
| 193 |
+
T2V_MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
|
| 194 |
+
T2V_LORA_REPO = "Kijai/WanVideo_comfy"
|
| 195 |
+
T2V_LORA_HIGH = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors"
|
| 196 |
+
T2V_LORA_LOW = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors"
|
| 197 |
+
t2v_pipe = None
|
| 198 |
+
t2v_ready = False
|
| 199 |
+
|
| 200 |
+
def load_t2v_pipeline():
|
| 201 |
+
"""Load T2V 14B + Lightning LoRA on-demand with CPU offload."""
|
| 202 |
+
global t2v_pipe, t2v_ready
|
| 203 |
+
|
| 204 |
+
if t2v_pipe is not None and t2v_ready:
|
| 205 |
+
print("T2V pipeline reused from memory")
|
| 206 |
+
return t2v_pipe
|
| 207 |
+
|
| 208 |
+
print("Loading T2V Pipeline (14B + Lightning LoRA) first time...")
|
| 209 |
+
|
| 210 |
+
# Move I2V components to CPU to make room
|
| 211 |
+
i2v_pipe.to('cpu')
|
| 212 |
+
clear_vram()
|
| 213 |
+
|
| 214 |
+
t2v_vae = AutoencoderKLWan.from_pretrained(T2V_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 215 |
+
t2v_pipe = WanPipeline.from_pretrained(
|
| 216 |
+
T2V_MODEL_ID,
|
| 217 |
+
transformer=WanTransformer3DModel.from_pretrained(
|
| 218 |
+
'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
|
| 219 |
+
subfolder='transformer',
|
| 220 |
+
torch_dtype=torch.bfloat16,
|
| 221 |
+
),
|
| 222 |
+
transformer_2=WanTransformer3DModel.from_pretrained(
|
| 223 |
+
'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
|
| 224 |
+
subfolder='transformer_2',
|
| 225 |
+
torch_dtype=torch.bfloat16,
|
| 226 |
+
),
|
| 227 |
+
vae=t2v_vae,
|
| 228 |
+
torch_dtype=torch.bfloat16,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Load and fuse Lightning LoRAs (HIGH for transformer, LOW for transformer_2)
|
| 232 |
+
print("Fusing Lightning LoRA HIGH (transformer)...")
|
| 233 |
+
from safetensors.torch import load_file
|
| 234 |
+
from huggingface_hub import hf_hub_download
|
| 235 |
+
|
| 236 |
+
# Download LoRA files
|
| 237 |
+
high_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_HIGH)
|
| 238 |
+
low_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_LOW)
|
| 239 |
+
|
| 240 |
+
# Load HIGH LoRA into transformer
|
| 241 |
+
t2v_pipe.load_lora_weights(high_path, adapter_name="lightning_high")
|
| 242 |
+
t2v_pipe.set_adapters(["lightning_high"], adapter_weights=[1.0])
|
| 243 |
+
t2v_pipe.fuse_lora(adapter_names=["lightning_high"], lora_scale=1.0, components=["transformer"])
|
| 244 |
+
t2v_pipe.unload_lora_weights()
|
| 245 |
+
|
| 246 |
+
# Load LOW LoRA into transformer_2
|
| 247 |
+
print("Fusing Lightning LoRA LOW (transformer_2)...")
|
| 248 |
+
t2v_pipe.load_lora_weights(low_path, adapter_name="lightning_low", load_into_transformer_2=True)
|
| 249 |
+
t2v_pipe.set_adapters(["lightning_low"], adapter_weights=[1.0])
|
| 250 |
+
t2v_pipe.fuse_lora(adapter_names=["lightning_low"], lora_scale=1.0, components=["transformer_2"])
|
| 251 |
+
t2v_pipe.unload_lora_weights()
|
| 252 |
+
|
| 253 |
+
# Use model CPU offload — only one component on GPU at a time
|
| 254 |
+
t2v_pipe.enable_model_cpu_offload()
|
| 255 |
+
|
| 256 |
+
t2v_ready = True
|
| 257 |
+
print("T2V pipeline ready (14B + Lightning + CPU offload)")
|
| 258 |
+
return t2v_pipe
|
| 259 |
+
|
| 260 |
+
def unload_t2v_pipeline():
|
| 261 |
+
"""Restore I2V to GPU after T2V is done."""
|
| 262 |
+
clear_vram()
|
| 263 |
+
i2v_pipe.to('cuda')
|
| 264 |
+
print("I2V restored to GPU")
|
| 265 |
+
|
| 266 |
+
# Keep cache for on-demand T2V loading
|
| 267 |
+
|
| 268 |
+
# ============ SAM2 Video Segmentation ============
|
| 269 |
+
sam2_predictor = None
|
| 270 |
+
|
| 271 |
+
def get_sam2_predictor():
|
| 272 |
+
global sam2_predictor
|
| 273 |
+
if sam2_predictor is None:
|
| 274 |
+
print("Loading SAM2.1 hiera-large...")
|
| 275 |
+
sam2_predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
| 276 |
+
print("SAM2 loaded")
|
| 277 |
+
return sam2_predictor
|
| 278 |
+
|
| 279 |
+
def extract_first_frame_from_video(video_path):
|
| 280 |
+
"""Extract first frame from video as PIL Image."""
|
| 281 |
+
cap = cv2.VideoCapture(video_path)
|
| 282 |
+
ret, frame = cap.read()
|
| 283 |
+
cap.release()
|
| 284 |
+
if ret:
|
| 285 |
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 286 |
+
return None
|
| 287 |
+
|
| 288 |
+
def video_to_frames_dir(video_path, max_frames=None):
|
| 289 |
+
"""Extract video frames to a temp directory for SAM2."""
|
| 290 |
+
tmp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
|
| 291 |
+
cap = cv2.VideoCapture(video_path)
|
| 292 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 16
|
| 293 |
+
idx = 0
|
| 294 |
+
while True:
|
| 295 |
+
ret, frame = cap.read()
|
| 296 |
+
if not ret:
|
| 297 |
+
break
|
| 298 |
+
if max_frames and idx >= max_frames:
|
| 299 |
+
break
|
| 300 |
+
cv2.imwrite(os.path.join(tmp_dir, f"{idx:05d}.jpg"), frame)
|
| 301 |
+
idx += 1
|
| 302 |
+
cap.release()
|
| 303 |
+
print(f"Extracted {idx} frames to {tmp_dir} (fps={fps:.1f})")
|
| 304 |
+
return tmp_dir, idx, fps
|
| 305 |
+
|
| 306 |
+
@spaces.GPU(duration=120)
|
| 307 |
+
def generate_mask_video(video_path, points_json, num_frames_limit=None):
|
| 308 |
+
"""Generate mask video using SAM2 from user-clicked points."""
|
| 309 |
+
import json
|
| 310 |
+
|
| 311 |
+
if not video_path:
|
| 312 |
+
raise gr.Error("请先上传视频 / Upload a video first")
|
| 313 |
+
if not points_json or points_json.strip() == "[]":
|
| 314 |
+
raise gr.Error("请在视频第一帧上点击要编辑的区域 / Click on the area to edit")
|
| 315 |
+
|
| 316 |
+
points_data = json.loads(points_json)
|
| 317 |
+
if not points_data:
|
| 318 |
+
raise gr.Error("没有标记点 / No points marked")
|
| 319 |
+
|
| 320 |
+
# Extract frames
|
| 321 |
+
frames_dir, total_frames, fps = video_to_frames_dir(video_path, max_frames=num_frames_limit)
|
| 322 |
+
|
| 323 |
+
predictor = get_sam2_predictor()
|
| 324 |
+
|
| 325 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 326 |
+
state = predictor.init_state(video_path=frames_dir)
|
| 327 |
+
|
| 328 |
+
# Add points (all on frame 0)
|
| 329 |
+
pos_points = []
|
| 330 |
+
neg_points = []
|
| 331 |
+
for p in points_data:
|
| 332 |
+
if p.get("label", 1) == 1:
|
| 333 |
+
pos_points.append([p["x"], p["y"]])
|
| 334 |
+
else:
|
| 335 |
+
neg_points.append([p["x"], p["y"]])
|
| 336 |
+
|
| 337 |
+
all_points = pos_points + neg_points
|
| 338 |
+
all_labels = [1] * len(pos_points) + [0] * len(neg_points)
|
| 339 |
+
|
| 340 |
+
points_np = np.array(all_points, dtype=np.float32)
|
| 341 |
+
labels_np = np.array(all_labels, dtype=np.int32)
|
| 342 |
+
|
| 343 |
+
_, _, _ = predictor.add_new_points_or_box(
|
| 344 |
+
state,
|
| 345 |
+
frame_idx=0,
|
| 346 |
+
obj_id=1,
|
| 347 |
+
points=points_np,
|
| 348 |
+
labels=labels_np,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Propagate through video
|
| 352 |
+
all_masks = {}
|
| 353 |
+
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
|
| 354 |
+
# masks shape: (num_objects, 1, H, W)
|
| 355 |
+
mask = (masks[0, 0] > 0.0).cpu().numpy().astype(np.uint8) * 255
|
| 356 |
+
all_masks[frame_idx] = mask
|
| 357 |
+
|
| 358 |
+
# Build mask video
|
| 359 |
+
out_path = os.path.join(tempfile.mkdtemp(), "mask_video.mp4")
|
| 360 |
+
# Get frame size from first mask
|
| 361 |
+
first_mask = all_masks[0]
|
| 362 |
+
h, w = first_mask.shape
|
| 363 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 364 |
+
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False)
|
| 365 |
+
for i in range(total_frames):
|
| 366 |
+
if i in all_masks:
|
| 367 |
+
writer.write(all_masks[i])
|
| 368 |
+
elif all_masks:
|
| 369 |
+
# Use nearest available mask
|
| 370 |
+
nearest = min(all_masks.keys(), key=lambda k: abs(k - i))
|
| 371 |
+
writer.write(all_masks[nearest])
|
| 372 |
+
writer.release()
|
| 373 |
+
|
| 374 |
+
# Cleanup frames dir
|
| 375 |
+
shutil.rmtree(frames_dir, ignore_errors=True)
|
| 376 |
+
|
| 377 |
+
print(f"Mask video generated: {out_path} ({total_frames} frames, {w}x{h})")
|
| 378 |
+
return out_path
|
| 379 |
+
|
| 380 |
+
# ============ Step 2: GrowMask + ImageComposite (from sam2.1_optimized workflow) ============
|
| 381 |
+
def grow_mask_frame(mask_gray, expand_pixels=5, blur=True):
|
| 382 |
+
"""Expand mask by N pixels (matching ComfyUI GrowMask node).
|
| 383 |
+
mask_gray: numpy uint8 H×W (255=mask, 0=bg)
|
| 384 |
+
Returns: expanded mask as numpy uint8 H×W
|
| 385 |
+
"""
|
| 386 |
+
if expand_pixels <= 0:
|
| 387 |
+
return mask_gray
|
| 388 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_pixels*2+1, expand_pixels*2+1))
|
| 389 |
+
grown = cv2.dilate(mask_gray, kernel, iterations=1)
|
| 390 |
+
if blur:
|
| 391 |
+
grown = cv2.GaussianBlur(grown, (expand_pixels*2+1, expand_pixels*2+1), 0)
|
| 392 |
+
# Re-threshold to keep it binary-ish but with soft edges
|
| 393 |
+
_, grown = cv2.threshold(grown, 127, 255, cv2.THRESH_BINARY)
|
| 394 |
+
return grown
|
| 395 |
+
|
| 396 |
+
def grow_mask_video_file(mask_video_path, expand_pixels=5):
|
| 397 |
+
"""Apply GrowMask to every frame of a mask video. Returns new video path."""
|
| 398 |
+
if expand_pixels <= 0:
|
| 399 |
+
return mask_video_path
|
| 400 |
+
|
| 401 |
+
cap = cv2.VideoCapture(mask_video_path)
|
| 402 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 16
|
| 403 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 404 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 405 |
+
|
| 406 |
+
out_path = os.path.join(tempfile.mkdtemp(), "grown_mask.mp4")
|
| 407 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 408 |
+
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False)
|
| 409 |
+
|
| 410 |
+
count = 0
|
| 411 |
+
while True:
|
| 412 |
+
ret, frame = cap.read()
|
| 413 |
+
if not ret:
|
| 414 |
+
break
|
| 415 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame
|
| 416 |
+
grown = grow_mask_frame(gray, expand_pixels)
|
| 417 |
+
writer.write(grown)
|
| 418 |
+
count += 1
|
| 419 |
+
|
| 420 |
+
cap.release()
|
| 421 |
+
writer.release()
|
| 422 |
+
print(f"GrowMask applied: {count} frames, expand={expand_pixels}px → {out_path}")
|
| 423 |
+
return out_path
|
| 424 |
+
|
| 425 |
+
def composite_video_from_mask(source_video_path, mask_video_path):
|
| 426 |
+
"""ImageComposite: replace masked region with mask overlay (from sam2.1_optimized workflow).
|
| 427 |
+
Creates a composite video where:
|
| 428 |
+
- Masked regions (white in mask) show the mask as white overlay
|
| 429 |
+
- Unmasked regions show original video
|
| 430 |
+
This gives VACE the control_video input it needs.
|
| 431 |
+
Returns: composite video path
|
| 432 |
+
"""
|
| 433 |
+
src_cap = cv2.VideoCapture(source_video_path)
|
| 434 |
+
mask_cap = cv2.VideoCapture(mask_video_path)
|
| 435 |
+
|
| 436 |
+
fps = src_cap.get(cv2.CAP_PROP_FPS) or 16
|
| 437 |
+
w = int(src_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 438 |
+
h = int(src_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 439 |
+
|
| 440 |
+
out_path = os.path.join(tempfile.mkdtemp(), "composite.mp4")
|
| 441 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 442 |
+
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
| 443 |
+
|
| 444 |
+
count = 0
|
| 445 |
+
while True:
|
| 446 |
+
ret_s, src_frame = src_cap.read()
|
| 447 |
+
ret_m, mask_frame = mask_cap.read()
|
| 448 |
+
if not ret_s:
|
| 449 |
+
break
|
| 450 |
+
if not ret_m:
|
| 451 |
+
# If mask video is shorter, use last available or all-black
|
| 452 |
+
mask_gray = np.zeros((h, w), dtype=np.uint8)
|
| 453 |
+
else:
|
| 454 |
+
# Resize mask to match source if needed
|
| 455 |
+
if mask_frame.shape[:2] != (h, w):
|
| 456 |
+
mask_frame = cv2.resize(mask_frame, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 457 |
+
mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY) if len(mask_frame.shape) == 3 else mask_frame
|
| 458 |
+
|
| 459 |
+
# Composite: original where mask=0, white where mask=255
|
| 460 |
+
mask_bool = mask_gray > 127
|
| 461 |
+
composite = src_frame.copy()
|
| 462 |
+
composite[mask_bool] = 255 # White in masked region
|
| 463 |
+
|
| 464 |
+
writer.write(composite)
|
| 465 |
+
count += 1
|
| 466 |
+
|
| 467 |
+
src_cap.release()
|
| 468 |
+
mask_cap.release()
|
| 469 |
+
writer.release()
|
| 470 |
+
print(f"Composite video: {count} frames → {out_path}")
|
| 471 |
+
return out_path
|
| 472 |
+
|
| 473 |
+
# ============ V2V Pipeline (VACE 14B, on-demand) ============
|
| 474 |
+
VACE_MODEL_ID = "Wan-AI/Wan2.1-VACE-14B-diffusers"
|
| 475 |
+
v2v_pipe = None
|
| 476 |
+
v2v_ready = False
|
| 477 |
+
|
| 478 |
+
def load_v2v_pipeline():
|
| 479 |
+
"""Load VACE 14B pipeline on-demand for mask-based video editing."""
|
| 480 |
+
global v2v_pipe, v2v_ready
|
| 481 |
+
|
| 482 |
+
# Move I2V to CPU to free GPU
|
| 483 |
+
i2v_pipe.to('cpu')
|
| 484 |
+
clear_vram()
|
| 485 |
+
|
| 486 |
+
if v2v_pipe is not None and v2v_ready:
|
| 487 |
+
v2v_pipe.to('cuda')
|
| 488 |
+
print("VACE pipeline restored to GPU")
|
| 489 |
+
return v2v_pipe
|
| 490 |
+
|
| 491 |
+
print("Loading VACE 14B Pipeline first time (this downloads ~75GB)...")
|
| 492 |
+
|
| 493 |
+
v2v_vae = AutoencoderKLWan.from_pretrained(VACE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 494 |
+
v2v_pipe = WanVACEPipeline.from_pretrained(
|
| 495 |
+
VACE_MODEL_ID,
|
| 496 |
+
vae=v2v_vae,
|
| 497 |
+
torch_dtype=torch.bfloat16,
|
| 498 |
+
)
|
| 499 |
+
v2v_pipe.scheduler = UniPCMultistepScheduler.from_config(v2v_pipe.scheduler.config, flow_shift=5.0)
|
| 500 |
+
|
| 501 |
+
# Quantize to fit in A100 80GB
|
| 502 |
+
quantize_(v2v_pipe.text_encoder, Int8WeightOnlyConfig())
|
| 503 |
+
major, minor = torch.cuda.get_device_capability()
|
| 504 |
+
if (major > 8) or (major == 8 and minor >= 9):
|
| 505 |
+
quantize_(v2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 506 |
+
else:
|
| 507 |
+
quantize_(v2v_pipe.transformer, Int8WeightOnlyConfig())
|
| 508 |
+
|
| 509 |
+
v2v_pipe.to('cuda')
|
| 510 |
+
|
| 511 |
+
v2v_ready = True
|
| 512 |
+
print("VACE 14B pipeline ready (quantized, on GPU)")
|
| 513 |
+
return v2v_pipe
|
| 514 |
+
|
| 515 |
+
def unload_v2v_pipeline():
|
| 516 |
+
"""Move V2V to CPU and restore I2V to GPU."""
|
| 517 |
+
global v2v_pipe
|
| 518 |
+
if v2v_pipe is not None:
|
| 519 |
+
v2v_pipe.to('cpu')
|
| 520 |
+
clear_vram()
|
| 521 |
+
i2v_pipe.to('cuda')
|
| 522 |
+
print("VACE → CPU, I2V → GPU")
|
| 523 |
+
|
| 524 |
+
def load_video_frames_and_masks(video_path, mask_path, num_frames, target_h, target_w):
|
| 525 |
+
"""Load source video frames and mask video frames for VACE."""
|
| 526 |
+
# Load source video frames as PIL Images
|
| 527 |
+
src_frames = load_video(video_path)[:num_frames]
|
| 528 |
+
print(f"Loaded {len(src_frames)} source frames (original size: {src_frames[0].size if src_frames else 'N/A'})")
|
| 529 |
+
|
| 530 |
+
# Load mask video frames
|
| 531 |
+
mask_frames_raw = load_video(mask_path)[:num_frames]
|
| 532 |
+
|
| 533 |
+
# Convert mask to L mode (white=edit, black=keep) — don't resize, let pipeline handle it
|
| 534 |
+
masks = []
|
| 535 |
+
for mf in mask_frames_raw:
|
| 536 |
+
gray = mf.convert("L")
|
| 537 |
+
masks.append(gray)
|
| 538 |
+
print(f"Loaded {len(masks)} mask frames")
|
| 539 |
+
|
| 540 |
+
# Pad or trim to match
|
| 541 |
+
while len(masks) < len(src_frames):
|
| 542 |
+
masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0))
|
| 543 |
+
while len(src_frames) < len(masks):
|
| 544 |
+
src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128)))
|
| 545 |
+
|
| 546 |
+
frame_count = min(len(src_frames), len(masks))
|
| 547 |
+
src_frames = src_frames[:frame_count]
|
| 548 |
+
masks = masks[:frame_count]
|
| 549 |
+
|
| 550 |
+
return src_frames, masks
|
| 551 |
+
|
| 552 |
+
# ============ Utils ============
|
| 553 |
+
def resize_image(image, max_dim=832, min_dim=480, square_dim=640, multiple_of=16):
|
| 554 |
+
width, height = image.size
|
| 555 |
+
if width == height:
|
| 556 |
+
return image.resize((square_dim, square_dim), Image.LANCZOS)
|
| 557 |
+
aspect_ratio = width / height
|
| 558 |
+
max_ar = max_dim / min_dim
|
| 559 |
+
min_ar = min_dim / max_dim
|
| 560 |
+
if aspect_ratio > max_ar:
|
| 561 |
+
crop_width = int(round(height * max_ar))
|
| 562 |
+
left = (width - crop_width) // 2
|
| 563 |
+
image = image.crop((left, 0, left + crop_width, height))
|
| 564 |
+
target_w, target_h = max_dim, min_dim
|
| 565 |
+
elif aspect_ratio < min_ar:
|
| 566 |
+
crop_height = int(round(width / min_ar))
|
| 567 |
+
top = (height - crop_height) // 2
|
| 568 |
+
image = image.crop((0, top, width, top + crop_height))
|
| 569 |
+
target_w, target_h = min_dim, max_dim
|
| 570 |
+
else:
|
| 571 |
+
if width > height:
|
| 572 |
+
target_w = max_dim
|
| 573 |
+
target_h = int(round(target_w / aspect_ratio))
|
| 574 |
+
else:
|
| 575 |
+
target_h = max_dim
|
| 576 |
+
target_w = int(round(target_h * aspect_ratio))
|
| 577 |
+
final_w = max(min_dim, min(max_dim, round(target_w / multiple_of) * multiple_of))
|
| 578 |
+
final_h = max(min_dim, min(max_dim, round(target_h / multiple_of) * multiple_of))
|
| 579 |
+
return image.resize((final_w, final_h), Image.LANCZOS)
|
| 580 |
+
|
| 581 |
+
def resize_and_crop_to_match(target_image, reference_image):
|
| 582 |
+
ref_w, ref_h = reference_image.size
|
| 583 |
+
tgt_w, tgt_h = target_image.size
|
| 584 |
+
scale = max(ref_w / tgt_w, ref_h / tgt_h)
|
| 585 |
+
new_w, new_h = int(tgt_w * scale), int(tgt_h * scale)
|
| 586 |
+
resized = target_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 587 |
+
left, top = (new_w - ref_w) // 2, (new_h - ref_h) // 2
|
| 588 |
+
return resized.crop((left, top, left + ref_w, top + ref_h))
|
| 589 |
+
|
| 590 |
+
def get_num_frames(duration_seconds):
|
| 591 |
+
raw = int(round(duration_seconds * FIXED_FPS))
|
| 592 |
+
raw = ((raw - 1) // 4) * 4 + 1
|
| 593 |
+
return int(np.clip(raw, 9, MAX_FRAMES_MODEL))
|
| 594 |
+
|
| 595 |
+
def extract_video_path(input_video):
|
| 596 |
+
if input_video is None:
|
| 597 |
+
return None
|
| 598 |
+
if isinstance(input_video, str):
|
| 599 |
+
return input_video
|
| 600 |
+
if isinstance(input_video, dict):
|
| 601 |
+
# Gradio 5.x format: {'video': filepath, ...} or {'name': filepath, ...} or {'path': filepath}
|
| 602 |
+
return input_video.get("video", input_video.get("path", input_video.get("name", None)))
|
| 603 |
+
# Could be a Gradio VideoData object
|
| 604 |
+
if hasattr(input_video, 'video'):
|
| 605 |
+
return input_video.video
|
| 606 |
+
if hasattr(input_video, 'path'):
|
| 607 |
+
return input_video.path
|
| 608 |
+
if hasattr(input_video, 'name'):
|
| 609 |
+
return input_video.name
|
| 610 |
+
return str(input_video)
|
| 611 |
+
|
| 612 |
+
def extract_first_frame(video_input):
|
| 613 |
+
path = extract_video_path(video_input)
|
| 614 |
+
if not path or not os.path.exists(path):
|
| 615 |
+
return None
|
| 616 |
+
cap = cv2.VideoCapture(path)
|
| 617 |
+
ret, frame = cap.read()
|
| 618 |
+
cap.release()
|
| 619 |
+
if ret:
|
| 620 |
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 621 |
+
return None
|
| 622 |
+
|
| 623 |
+
# ============ Inference ============
|
| 624 |
+
@spaces.GPU(duration=1200)
|
| 625 |
+
def run_inference(
|
| 626 |
+
task_type, input_image, input_video, mask_video, prompt, negative_prompt,
|
| 627 |
+
duration_seconds, steps, guidance_scale, guidance_scale_2,
|
| 628 |
+
current_seed, scheduler_name, flow_shift, frame_multiplier,
|
| 629 |
+
quality, last_image_input, lora_groups,
|
| 630 |
+
reference_image=None, grow_pixels=5,
|
| 631 |
+
progress=gr.Progress(track_tqdm=True),
|
| 632 |
+
):
|
| 633 |
+
clear_vram()
|
| 634 |
+
num_frames = get_num_frames(duration_seconds)
|
| 635 |
+
task_id = str(uuid.uuid4())[:8]
|
| 636 |
+
print(f"Task: {task_id}, type={task_type}, duration={duration_seconds}s, frames={num_frames}")
|
| 637 |
+
start = time.time()
|
| 638 |
+
|
| 639 |
+
if "T2V" in task_type:
|
| 640 |
+
# ====== T2V: 14B + Lightning LoRA (4 steps, dual guidance) ======
|
| 641 |
+
t2v_steps = max(int(steps), 4)
|
| 642 |
+
print(f"T2V: steps={t2v_steps}, guidance={guidance_scale}/{guidance_scale_2}, frames={num_frames}")
|
| 643 |
+
|
| 644 |
+
pipe = load_t2v_pipeline()
|
| 645 |
+
result = pipe(
|
| 646 |
+
prompt=prompt,
|
| 647 |
+
negative_prompt=negative_prompt,
|
| 648 |
+
height=480,
|
| 649 |
+
width=832,
|
| 650 |
+
num_frames=num_frames,
|
| 651 |
+
guidance_scale=float(guidance_scale),
|
| 652 |
+
guidance_scale_2=float(guidance_scale_2),
|
| 653 |
+
num_inference_steps=t2v_steps,
|
| 654 |
+
generator=torch.Generator(device="cpu").manual_seed(int(current_seed)),
|
| 655 |
+
output_type="np",
|
| 656 |
+
)
|
| 657 |
+
unload_t2v_pipeline()
|
| 658 |
+
|
| 659 |
+
else:
|
| 660 |
+
# ====== I2V / V2V ======
|
| 661 |
+
if "V2V" in task_type:
|
| 662 |
+
# ====== V2V: 3-Step Pipeline (SAM2 mask → Composite → VACE) ======
|
| 663 |
+
print(f"V2V 3-Step Pipeline: input_video type={type(input_video)}, value={input_video}")
|
| 664 |
+
video_path = extract_video_path(input_video)
|
| 665 |
+
if not video_path or not os.path.exists(video_path):
|
| 666 |
+
raise gr.Error("Upload a source video for V2V / V2V请上传原视频")
|
| 667 |
+
|
| 668 |
+
# Get mask video path
|
| 669 |
+
mask_path = extract_video_path(mask_video)
|
| 670 |
+
if not mask_path or not os.path.exists(mask_path):
|
| 671 |
+
raise gr.Error("Upload a mask video for V2V / V2V请上传遮罩视频(黑白视频,白色=编辑区域)")
|
| 672 |
+
|
| 673 |
+
# Step 2a: GrowMask — expand mask boundaries (from vace_optimized workflow)
|
| 674 |
+
grown_mask_path = grow_mask_video_file(mask_path, expand_pixels=int(grow_pixels))
|
| 675 |
+
print(f"V2V: GrowMask applied ({grow_pixels}px)")
|
| 676 |
+
|
| 677 |
+
# Step 2b: Composite — original video with mask overlay (from sam2.1_optimized workflow)
|
| 678 |
+
composite_path = composite_video_from_mask(video_path, mask_path)
|
| 679 |
+
print(f"V2V: Composite video created")
|
| 680 |
+
|
| 681 |
+
# Step 3: VACE generation using composite as control_video + grown mask
|
| 682 |
+
target_h, target_w = 480, 832
|
| 683 |
+
|
| 684 |
+
# Load composite video as control frames for VACE
|
| 685 |
+
src_frames = load_video(composite_path)[:num_frames]
|
| 686 |
+
print(f"Loaded {len(src_frames)} composite frames")
|
| 687 |
+
|
| 688 |
+
# Load grown mask frames
|
| 689 |
+
mask_frames_raw = load_video(grown_mask_path)[:num_frames]
|
| 690 |
+
masks = [mf.convert("L") for mf in mask_frames_raw]
|
| 691 |
+
print(f"Loaded {len(masks)} grown mask frames")
|
| 692 |
+
|
| 693 |
+
# Pad or trim to match
|
| 694 |
+
while len(masks) < len(src_frames):
|
| 695 |
+
masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0))
|
| 696 |
+
while len(src_frames) < len(masks):
|
| 697 |
+
src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128)))
|
| 698 |
+
|
| 699 |
+
# Ensure num_frames satisfies (n-1) % 4 == 0 for VACE
|
| 700 |
+
n = len(src_frames)
|
| 701 |
+
n = (n - 1) // 4 * 4 + 1
|
| 702 |
+
n = max(n, 5)
|
| 703 |
+
src_frames = src_frames[:n]
|
| 704 |
+
masks = masks[:n]
|
| 705 |
+
|
| 706 |
+
# Load VACE pipeline
|
| 707 |
+
pipe = load_v2v_pipeline()
|
| 708 |
+
v2v_steps = max(int(steps), 20)
|
| 709 |
+
print(f"V2V VACE: steps={v2v_steps}, guidance={guidance_scale}, frames={len(src_frames)}, ref_image={'yes' if reference_image else 'no'}")
|
| 710 |
+
|
| 711 |
+
# Build VACE kwargs
|
| 712 |
+
vace_kwargs = dict(
|
| 713 |
+
prompt=prompt,
|
| 714 |
+
negative_prompt=negative_prompt,
|
| 715 |
+
video=src_frames,
|
| 716 |
+
mask=masks,
|
| 717 |
+
height=target_h,
|
| 718 |
+
width=target_w,
|
| 719 |
+
num_frames=len(src_frames),
|
| 720 |
+
guidance_scale=max(float(guidance_scale), 5.0),
|
| 721 |
+
num_inference_steps=v2v_steps,
|
| 722 |
+
generator=torch.Generator(device="cuda").manual_seed(int(current_seed)),
|
| 723 |
+
output_type="np",
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
result = pipe(**vace_kwargs)
|
| 727 |
+
unload_v2v_pipeline()
|
| 728 |
+
|
| 729 |
+
# Cleanup temp files
|
| 730 |
+
for p in [grown_mask_path, composite_path]:
|
| 731 |
+
try:
|
| 732 |
+
if p and os.path.exists(p):
|
| 733 |
+
os.remove(p)
|
| 734 |
+
except:
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
else:
|
| 738 |
+
# ====== I2V ======
|
| 739 |
+
if input_image is None:
|
| 740 |
+
raise gr.Error("Upload an image / 请上传图片")
|
| 741 |
+
|
| 742 |
+
scheduler_class = SCHEDULER_MAP.get(scheduler_name)
|
| 743 |
+
if scheduler_class and scheduler_class.__name__ != i2v_pipe.scheduler.config._class_name:
|
| 744 |
+
config = copy.deepcopy(i2v_original_scheduler.config)
|
| 745 |
+
if scheduler_class == FlowMatchEulerDiscreteScheduler:
|
| 746 |
+
config['shift'] = flow_shift
|
| 747 |
+
else:
|
| 748 |
+
config['flow_shift'] = flow_shift
|
| 749 |
+
i2v_pipe.scheduler = scheduler_class.from_config(config)
|
| 750 |
+
|
| 751 |
+
lora_loaded = False
|
| 752 |
+
if lora_groups:
|
| 753 |
+
try:
|
| 754 |
+
for idx, name in enumerate(lora_groups):
|
| 755 |
+
if name and name != "(None)":
|
| 756 |
+
lora_loader.load_lora_to_pipe(i2v_pipe, name, adapter_name=f"lora_{idx}")
|
| 757 |
+
lora_loaded = True
|
| 758 |
+
except Exception as e:
|
| 759 |
+
print(f"LoRA warning: {e}")
|
| 760 |
+
|
| 761 |
+
resized_image = resize_image(input_image)
|
| 762 |
+
processed_last = None
|
| 763 |
+
if last_image_input:
|
| 764 |
+
processed_last = resize_and_crop_to_match(last_image_input, resized_image)
|
| 765 |
+
|
| 766 |
+
print(f"I2V: size={resized_image.size}, steps={int(steps)}, guidance={guidance_scale}/{guidance_scale_2}")
|
| 767 |
+
|
| 768 |
+
result = i2v_pipe(
|
| 769 |
+
image=resized_image,
|
| 770 |
+
last_image=processed_last,
|
| 771 |
+
prompt=prompt,
|
| 772 |
+
negative_prompt=negative_prompt,
|
| 773 |
+
height=resized_image.height,
|
| 774 |
+
width=resized_image.width,
|
| 775 |
+
num_frames=num_frames,
|
| 776 |
+
guidance_scale=float(guidance_scale),
|
| 777 |
+
guidance_scale_2=float(guidance_scale_2),
|
| 778 |
+
num_inference_steps=int(steps),
|
| 779 |
+
generator=torch.Generator(device="cuda").manual_seed(int(current_seed)),
|
| 780 |
+
output_type="np",
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if lora_loaded:
|
| 784 |
+
lora_loader.unload_lora(i2v_pipe)
|
| 785 |
+
|
| 786 |
+
raw_frames = result.frames[0]
|
| 787 |
+
elapsed = time.time() - start
|
| 788 |
+
print(f"Generation took {elapsed:.1f}s ({len(raw_frames)} frames)")
|
| 789 |
+
|
| 790 |
+
frame_factor = frame_multiplier // FIXED_FPS
|
| 791 |
+
if frame_factor > 1:
|
| 792 |
+
rife_model.device()
|
| 793 |
+
rife_model.flownet = rife_model.flownet.half()
|
| 794 |
+
final_frames = interpolate_bits(raw_frames, multiplier=int(frame_factor))
|
| 795 |
+
else:
|
| 796 |
+
final_frames = list(raw_frames)
|
| 797 |
+
final_fps = FIXED_FPS * max(1, frame_factor)
|
| 798 |
+
|
| 799 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
| 800 |
+
video_path = tmpfile.name
|
| 801 |
+
export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
|
| 802 |
+
return video_path, task_id
|
| 803 |
+
|
| 804 |
+
# ============ Generate ============
|
| 805 |
+
def generate_video(
|
| 806 |
+
task_type, input_image, input_video, mask_video, prompt,
|
| 807 |
+
lora_groups, duration_seconds, frame_multiplier,
|
| 808 |
+
steps, guidance_scale, guidance_scale_2,
|
| 809 |
+
negative_prompt, quality, seed, randomize_seed,
|
| 810 |
+
scheduler, flow_shift, last_image, display_result,
|
| 811 |
+
reference_image, grow_pixels,
|
| 812 |
+
progress=gr.Progress(track_tqdm=True),
|
| 813 |
+
):
|
| 814 |
+
if not prompt or not prompt.strip():
|
| 815 |
+
raise gr.Error("Enter a prompt / 请输入提示词")
|
| 816 |
+
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 817 |
+
video_path, task_id = run_inference(
|
| 818 |
+
task_type, input_image, input_video, mask_video, prompt, negative_prompt,
|
| 819 |
+
duration_seconds, steps, guidance_scale, guidance_scale_2,
|
| 820 |
+
current_seed, scheduler, flow_shift, frame_multiplier,
|
| 821 |
+
quality, last_image, lora_groups,
|
| 822 |
+
reference_image=reference_image, grow_pixels=grow_pixels,
|
| 823 |
+
)
|
| 824 |
+
print(f"Done: {task_id}")
|
| 825 |
+
return (video_path if display_result else None), video_path, current_seed
|
| 826 |
+
|
| 827 |
+
# ============ UI ============
|
| 828 |
+
CSS = """
|
| 829 |
+
#hidden-timestamp { opacity: 0; height: 0; width: 0; margin: 0; padding: 0; overflow: hidden; position: absolute; }
|
| 830 |
+
"""
|
| 831 |
+
|
| 832 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo:
|
| 833 |
+
gr.Markdown("## WAN 2.2 Multi-Task Video Generation / 多任务视频生成")
|
| 834 |
+
gr.Markdown("#### I2V (Lightning 6-step) · T2V (Lightning 14B 4-step) · V2V (3-Step: SAM2→Composite→VACE)")
|
| 835 |
+
gr.Markdown("---")
|
| 836 |
+
|
| 837 |
+
task_type = gr.Radio(
|
| 838 |
+
choices=[
|
| 839 |
+
"I2V (图生视频 / Image-to-Video)",
|
| 840 |
+
"T2V (文生视频 / Text-to-Video)",
|
| 841 |
+
"V2V (视频生视频 / Video-to-Video)",
|
| 842 |
+
],
|
| 843 |
+
value="I2V (图生视频 / Image-to-Video)",
|
| 844 |
+
label="Task Type / 任务类型",
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
with gr.Row():
|
| 848 |
+
with gr.Column():
|
| 849 |
+
with gr.Group():
|
| 850 |
+
input_image = gr.Image(type="pil", label="Input Image / 输入图片 (I2V)", sources=["upload", "clipboard"])
|
| 851 |
+
with gr.Group():
|
| 852 |
+
input_video = gr.Video(label="Source Video / 原视频 (V2V)", sources=["upload"], visible=False, interactive=True)
|
| 853 |
+
with gr.Group():
|
| 854 |
+
mask_video = gr.Video(label="Mask Video / 遮罩视频 (V2V, 白色=编辑区域)", sources=["upload"], visible=False, interactive=True)
|
| 855 |
+
v2v_guide = gr.Markdown(
|
| 856 |
+
value="""### 📖 V2V 三步流水线 / 3-Step V2V Pipeline
|
| 857 |
+
|
| 858 |
+
**Step 1 — SAM2 分割**: 上传原视频 → 提取第一帧 → 点击标记区域 → 生成遮罩视频
|
| 859 |
+
**Step 2 — 自动合成**: 原视频 + 遮罩 → GrowMask扩展边界 + ImageComposite合成(自动完成)
|
| 860 |
+
**Step 3 — VACE 生成**: 合成视频 + 遮罩 + 参考图 + Prompt → 最终成品视频
|
| 861 |
+
|
| 862 |
+
💡 也可跳过 Step 1,直接上传自己的遮罩视频(白色=编辑区域)
|
| 863 |
+
""",
|
| 864 |
+
visible=False,
|
| 865 |
+
)
|
| 866 |
+
with gr.Group(visible=False) as v2v_mask_tools:
|
| 867 |
+
first_frame_display = gr.Image(label="第一帧预览 / First Frame (点击标记区域)", type="pil", interactive=False)
|
| 868 |
+
points_store = gr.State(value=[])
|
| 869 |
+
points_display = gr.Textbox(label="标记点 / Points", value="无标记 / No points", interactive=False)
|
| 870 |
+
with gr.Row():
|
| 871 |
+
point_mode = gr.Radio(choices=["include (编辑)", "exclude (排除)"], value="include (编辑)", label="点击模式")
|
| 872 |
+
with gr.Row():
|
| 873 |
+
extract_frame_btn = gr.Button("📷 提取第一帧 / Extract First Frame", variant="secondary")
|
| 874 |
+
gen_mask_btn = gr.Button("🎭 生成遮罩 / Generate Mask (SAM2)", variant="primary")
|
| 875 |
+
clear_points_btn = gr.Button("🗑️ 清除标记 / Clear Points")
|
| 876 |
+
with gr.Accordion("🖼️ V2V 高级选项 / V2V Advanced", open=True):
|
| 877 |
+
reference_image = gr.Image(type="pil", label="参考图 / Reference Image (控制编辑区域的目标外观)", sources=["upload", "clipboard"])
|
| 878 |
+
grow_pixels_sl = gr.Slider(minimum=0, maximum=30, step=1, value=5, label="GrowMask / 遮罩扩展 (像素)", info="扩展遮罩边界,让编辑区域过渡更自然")
|
| 879 |
+
|
| 880 |
+
prompt_input = gr.Textbox(
|
| 881 |
+
label="Prompt / 提示词", value="",
|
| 882 |
+
placeholder="Describe the video... / 描述你想生成的视频...", lines=3,
|
| 883 |
+
)
|
| 884 |
+
duration_slider = gr.Slider(
|
| 885 |
+
minimum=0.5, maximum=15, step=0.5, value=3,
|
| 886 |
+
label="Duration / 时长 (seconds/秒)",
|
| 887 |
+
info="Max ~15s (241 frames @16fps) / 最大约15秒",
|
| 888 |
+
)
|
| 889 |
+
frame_multi = gr.Dropdown(choices=[16, 32, 64], value=16, label="Output FPS / 输出帧率", info="RIFE interpolation / RIFE插帧")
|
| 890 |
+
|
| 891 |
+
with gr.Accordion("⚙️ Advanced Settings / 高级设置", open=False):
|
| 892 |
+
last_image = gr.Image(type="pil", label="Last Frame / 末帧 (Optional)", sources=["upload", "clipboard"])
|
| 893 |
+
negative_prompt_input = gr.Textbox(label="Negative Prompt / 负面提示词", value=default_negative_prompt, lines=3)
|
| 894 |
+
with gr.Row():
|
| 895 |
+
steps_slider = gr.Slider(minimum=1, maximum=50, step=1, value=6, label="Steps / 步数", info="I2V: 4-8 | T2V: 4-8 | V2V: 25-50")
|
| 896 |
+
quality_sl = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Quality / 质量")
|
| 897 |
+
with gr.Row():
|
| 898 |
+
guidance_h = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance High / 引导(高噪声)")
|
| 899 |
+
guidance_l = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Low / 引导(低噪声)")
|
| 900 |
+
with gr.Row():
|
| 901 |
+
scheduler_dd = gr.Dropdown(choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler / 调度器")
|
| 902 |
+
flow_shift_sl = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift / 流偏移")
|
| 903 |
+
with gr.Row():
|
| 904 |
+
seed_sl = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed / 种子")
|
| 905 |
+
random_seed_cb = gr.Checkbox(label="Random / 随机", value=True)
|
| 906 |
+
lora_dd = gr.Dropdown(choices=lora_loader.get_lora_choices(), label="LoRA (I2V only / 仅I2V)", multiselect=True, info="From WAN2.2_LoraSet_NSFW")
|
| 907 |
+
display_cb = gr.Checkbox(label="Display / 显示", value=True)
|
| 908 |
+
|
| 909 |
+
generate_btn = gr.Button("🎬 Generate / 生成视频", variant="primary", size="lg")
|
| 910 |
+
|
| 911 |
+
with gr.Column():
|
| 912 |
+
video_output = gr.Video(label="Generated Video / 生成的视频", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video")
|
| 913 |
+
with gr.Row():
|
| 914 |
+
grab_frame_btn = gr.Button("📸 Use Frame / 使用帧", variant="secondary")
|
| 915 |
+
timestamp_box = gr.Number(value=0, label="Timestamp", visible=False, elem_id="hidden-timestamp")
|
| 916 |
+
file_output = gr.File(label="Download / 下载")
|
| 917 |
+
|
| 918 |
+
def update_task_ui(task):
|
| 919 |
+
is_v2v = "V2V" in task
|
| 920 |
+
is_t2v = "T2V" in task
|
| 921 |
+
if is_t2v:
|
| 922 |
+
return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 923 |
+
gr.update(visible=False), gr.update(visible=False),
|
| 924 |
+
gr.update(value=4), gr.update(value=1.0), gr.update(value=1.0))
|
| 925 |
+
elif is_v2v:
|
| 926 |
+
return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
|
| 927 |
+
gr.update(visible=True), gr.update(visible=True),
|
| 928 |
+
gr.update(value=30), gr.update(value=5.0), gr.update(value=1.0))
|
| 929 |
+
else:
|
| 930 |
+
return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
|
| 931 |
+
gr.update(visible=False), gr.update(visible=False),
|
| 932 |
+
gr.update(value=6), gr.update(value=1.0), gr.update(value=1.0))
|
| 933 |
+
|
| 934 |
+
task_type.change(update_task_ui, inputs=[task_type], outputs=[input_image, input_video, mask_video, v2v_guide, v2v_mask_tools, steps_slider, guidance_h, guidance_l])
|
| 935 |
+
|
| 936 |
+
# V2V mask generation callbacks
|
| 937 |
+
def on_extract_first_frame(video):
|
| 938 |
+
vpath = extract_video_path(video)
|
| 939 |
+
if not vpath or not os.path.exists(vpath):
|
| 940 |
+
raise gr.Error("请先上传视频 / Upload video first")
|
| 941 |
+
frame = extract_first_frame_from_video(vpath)
|
| 942 |
+
if frame is None:
|
| 943 |
+
raise gr.Error("无法提取第一帧 / Failed to extract first frame")
|
| 944 |
+
return frame, [], "无标记 / No points"
|
| 945 |
+
|
| 946 |
+
def on_click_frame(img, points, mode, evt: gr.SelectData):
|
| 947 |
+
if img is None:
|
| 948 |
+
return img, points, "请先提取第一帧 / Extract first frame first"
|
| 949 |
+
x, y = evt.index
|
| 950 |
+
label = 1 if "include" in mode else 0
|
| 951 |
+
points.append({"x": x, "y": y, "label": label})
|
| 952 |
+
# Draw points on image
|
| 953 |
+
display_img = img.copy()
|
| 954 |
+
draw = __import__('PIL').ImageDraw.Draw(display_img)
|
| 955 |
+
for p in points:
|
| 956 |
+
color = (0, 255, 0) if p["label"] == 1 else (255, 0, 0)
|
| 957 |
+
r = 8
|
| 958 |
+
draw.ellipse([p["x"]-r, p["y"]-r, p["x"]+r, p["y"]+r], fill=color, outline="white", width=2)
|
| 959 |
+
info = f"{len([p for p in points if p['label']==1])} include, {len([p for p in points if p['label']==0])} exclude"
|
| 960 |
+
return display_img, points, info
|
| 961 |
+
|
| 962 |
+
def on_clear_points(original_video):
|
| 963 |
+
vpath = extract_video_path(original_video)
|
| 964 |
+
if vpath and os.path.exists(vpath):
|
| 965 |
+
frame = extract_first_frame_from_video(vpath)
|
| 966 |
+
return frame, [], "无标记 / No points"
|
| 967 |
+
return None, [], "无标记 / No points"
|
| 968 |
+
|
| 969 |
+
def on_generate_mask(video, points):
|
| 970 |
+
import json
|
| 971 |
+
vpath = extract_video_path(video)
|
| 972 |
+
if not vpath:
|
| 973 |
+
raise gr.Error("请先上传视频 / Upload video first")
|
| 974 |
+
if not points:
|
| 975 |
+
raise gr.Error("请先在第一帧上点击标记 / Click on first frame to mark areas")
|
| 976 |
+
mask_path = generate_mask_video(vpath, json.dumps(points))
|
| 977 |
+
return mask_path
|
| 978 |
+
|
| 979 |
+
extract_frame_btn.click(fn=on_extract_first_frame, inputs=[input_video], outputs=[first_frame_display, points_store, points_display])
|
| 980 |
+
first_frame_display.select(fn=on_click_frame, inputs=[first_frame_display, points_store, point_mode], outputs=[first_frame_display, points_store, points_display])
|
| 981 |
+
clear_points_btn.click(fn=on_clear_points, inputs=[input_video], outputs=[first_frame_display, points_store, points_display])
|
| 982 |
+
gen_mask_btn.click(fn=on_generate_mask, inputs=[input_video, points_store], outputs=[mask_video])
|
| 983 |
+
generate_btn.click(
|
| 984 |
+
fn=generate_video,
|
| 985 |
+
inputs=[task_type, input_image, input_video, mask_video, prompt_input, lora_dd, duration_slider, frame_multi,
|
| 986 |
+
steps_slider, guidance_h, guidance_l, negative_prompt_input, quality_sl, seed_sl, random_seed_cb,
|
| 987 |
+
scheduler_dd, flow_shift_sl, last_image, display_cb,
|
| 988 |
+
reference_image, grow_pixels_sl],
|
| 989 |
+
outputs=[video_output, file_output, seed_sl],
|
| 990 |
+
)
|
| 991 |
+
grab_frame_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
|
| 992 |
+
timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image])
|
| 993 |
+
|
| 994 |
+
if __name__ == "__main__":
|
| 995 |
+
demo.queue().launch(mcp_server=True, show_error=True)
|
kill_bill.jpeg
ADDED
|
Git LFS Details
|
lora_loader.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA Loader for WAN 2.2 - references files from lkzd7/WAN2.2_LoraSet_NSFW
|
| 3 |
+
"""
|
| 4 |
+
import urllib.parse
|
| 5 |
+
import re
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
|
| 8 |
+
LORA_REPO = "lkzd7/WAN2.2_LoraSet_NSFW"
|
| 9 |
+
HF_TOKEN = None
|
| 10 |
+
|
| 11 |
+
LORA_FILES = [
|
| 12 |
+
"Blink_Squatting_Cowgirl_Position_I2V_HIGH.safetensors",
|
| 13 |
+
"Blink_Squatting_Cowgirl_Position_I2V_LOW.safetensors",
|
| 14 |
+
"PENISLORA_22_i2v_HIGH_e320.safetensors",
|
| 15 |
+
"PENISLORA_22_i2v_LOW_e496.safetensors",
|
| 16 |
+
"Pornmaster_wan 2.2_14b_I2V_bukkake_v1.4_high_noise.safetensors",
|
| 17 |
+
"Pornmaster_wan 2.2_14b_I2V_bukkake_v1.4_low_noise.safetensors",
|
| 18 |
+
"W22_Multiscene_Photoshoot_Softcore_i2v_HN.safetensors",
|
| 19 |
+
"W22_Multiscene_Photoshoot_Softcore_i2v_LN.safetensors",
|
| 20 |
+
"WAN-2.2-I2V-Double-Blowjob-HIGH-v1.safetensors",
|
| 21 |
+
"WAN-2.2-I2V-Double-Blowjob-LOW-v1.safetensors",
|
| 22 |
+
"WAN-2.2-I2V-HandjobBlowjobCombo-HIGH-v1.safetensors",
|
| 23 |
+
"WAN-2.2-I2V-HandjobBlowjobCombo-LOW-v1.safetensors",
|
| 24 |
+
"WAN-2.2-I2V-SensualTeasingBlowjob-HIGH-v1.safetensors",
|
| 25 |
+
"WAN-2.2-I2V-SensualTeasingBlowjob-LOW-v1.safetensors",
|
| 26 |
+
"iGOON_Blink_Blowjob_I2V_HIGH.safetensors",
|
| 27 |
+
"iGOON_Blink_Blowjob_I2V_LOW.safetensors",
|
| 28 |
+
"iGoon - Blink_Front_Doggystyle_I2V_HIGH.safetensors",
|
| 29 |
+
"iGoon - Blink_Front_Doggystyle_I2V_LOW.safetensors",
|
| 30 |
+
"iGoon - Blink_Missionary_I2V_HIGH.safetensors",
|
| 31 |
+
"iGoon - Blink_Missionary_I2V_LOW v2.safetensors",
|
| 32 |
+
"iGoon - Blink_Missionary_I2V_LOW.safetensors",
|
| 33 |
+
"iGoon%20-%20Blink_Back_Doggystyle_HIGH.safetensors",
|
| 34 |
+
"iGoon%20-%20Blink_Back_Doggystyle_LOW.safetensors",
|
| 35 |
+
"iGoon%20-%20Blink_Facial_I2V_HIGH.safetensors",
|
| 36 |
+
"iGoon%20-%20Blink_Facial_I2V_LOW.safetensors",
|
| 37 |
+
"iGoon_Blink_Missionary_I2V_HIGH v2.safetensors",
|
| 38 |
+
"iGoon_Blink_Titjob_I2V_HIGH.safetensors",
|
| 39 |
+
"iGoon_Blink_Titjob_I2V_LOW.safetensors",
|
| 40 |
+
"lips-bj_high_noise.safetensors",
|
| 41 |
+
"lips-bj_low_noise.safetensors",
|
| 42 |
+
"mql_casting_sex_doggy_kneel_diagonally_behind_vagina_wan22_i2v_v1_high_noise.safetensors",
|
| 43 |
+
"mql_casting_sex_doggy_kneel_diagonally_behind_vagina_wan22_i2v_v1_low_noise.safetensors",
|
| 44 |
+
"mql_casting_sex_reverse_cowgirl_lie_front_vagina_wan22_i2v_v1_high_noise.safetensors",
|
| 45 |
+
"mql_casting_sex_reverse_cowgirl_lie_front_vagina_wan22_i2v_v1_low_noise.safetensors",
|
| 46 |
+
"mql_casting_sex_spoon_wan22_i2v_v1_high_noise.safetensors",
|
| 47 |
+
"mql_casting_sex_spoon_wan22_i2v_v1_low_noise.safetensors",
|
| 48 |
+
"mql_doggy_a_wan22_t2v_v1_high_noise .safetensors",
|
| 49 |
+
"mql_doggy_a_wan22_t2v_v1_low_noise.safetensors",
|
| 50 |
+
"mql_massage_tits_wan22_i2v_v1_high_noise.safetensors",
|
| 51 |
+
"mql_massage_tits_wan22_i2v_v1_low_noise.safetensors",
|
| 52 |
+
"mql_panties_aside_wan22_i2v_v1_high_noise.safetensors",
|
| 53 |
+
"mql_panties_aside_wan22_i2v_v1_low_noise.safetensors",
|
| 54 |
+
"mqlspn_a_wan22_t2v_v1_high_noise.safetensors",
|
| 55 |
+
"mqlspn_a_wan22_t2v_v1_low_noise.safetensors",
|
| 56 |
+
"sfbehind_v2.1_high_noise.safetensors",
|
| 57 |
+
"sfbehind_v2.1_low_noise.safetensors",
|
| 58 |
+
"sid3l3g_transition_v2.0_H.safetensors",
|
| 59 |
+
"sid3l3g_transition_v2.0_L.safetensors",
|
| 60 |
+
"wan2.2_i2v_high_ulitmate_pussy_asshole.safetensors",
|
| 61 |
+
"wan2.2_i2v_low_ulitmate_pussy_asshole.safetensors",
|
| 62 |
+
"wan22-mouthfull-140epoc-high-k3nk.safetensors",
|
| 63 |
+
"wan22-mouthfull-152epoc-low-k3nk.safetensors",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
LORA_PAIRS = {}
|
| 67 |
+
for f in LORA_FILES:
|
| 68 |
+
name = urllib.parse.unquote(f).replace(".safetensors", "")
|
| 69 |
+
is_high = bool(re.search(r'(high|HN|_H\b)', name, re.IGNORECASE))
|
| 70 |
+
is_low = bool(re.search(r'(low|LN|_L\b)', name, re.IGNORECASE))
|
| 71 |
+
group = re.sub(r'[\s_-]*(high|low|noise|HN|LN)([\s_-]*noise)?[\s_-]*(v?\d+(\.\d+)?)?\s*$', '', name, flags=re.IGNORECASE).strip()
|
| 72 |
+
group = re.sub(r'[\s_]+$', '', group)
|
| 73 |
+
if group not in LORA_PAIRS:
|
| 74 |
+
LORA_PAIRS[group] = {"HIGH": None, "LOW": None}
|
| 75 |
+
if is_high:
|
| 76 |
+
LORA_PAIRS[group]["HIGH"] = f
|
| 77 |
+
elif is_low:
|
| 78 |
+
LORA_PAIRS[group]["LOW"] = f
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_lora_choices():
|
| 82 |
+
choices = []
|
| 83 |
+
for group in sorted(LORA_PAIRS.keys()):
|
| 84 |
+
p = LORA_PAIRS[group]
|
| 85 |
+
if p["HIGH"] and p["LOW"]:
|
| 86 |
+
choices.append(group)
|
| 87 |
+
elif p["HIGH"]:
|
| 88 |
+
choices.append(f"{group} (HIGH only)")
|
| 89 |
+
elif p["LOW"]:
|
| 90 |
+
choices.append(f"{group} (LOW only)")
|
| 91 |
+
return choices
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def download_lora(group_name):
|
| 95 |
+
if not group_name:
|
| 96 |
+
return None, None
|
| 97 |
+
clean_name = re.sub(r'\s*\(HIGH only\)|\s*\(LOW only\)', '', group_name)
|
| 98 |
+
if clean_name not in LORA_PAIRS:
|
| 99 |
+
return None, None
|
| 100 |
+
pair = LORA_PAIRS[clean_name]
|
| 101 |
+
high_path, low_path = None, None
|
| 102 |
+
if pair["HIGH"]:
|
| 103 |
+
high_path = hf_hub_download(LORA_REPO, pair["HIGH"], token=HF_TOKEN)
|
| 104 |
+
if pair["LOW"]:
|
| 105 |
+
low_path = hf_hub_download(LORA_REPO, pair["LOW"], token=HF_TOKEN)
|
| 106 |
+
return high_path, low_path
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_lora_to_pipe(pipe, group_name, adapter_name="lora"):
|
| 110 |
+
high_path, low_path = download_lora(group_name)
|
| 111 |
+
if high_path and low_path:
|
| 112 |
+
pipe.load_lora_weights(high_path, adapter_name=f"{adapter_name}_high")
|
| 113 |
+
pipe.load_lora_weights(low_path, adapter_name=f"{adapter_name}_low")
|
| 114 |
+
print(f"Loaded LoRA pair: {group_name}")
|
| 115 |
+
return True
|
| 116 |
+
elif high_path:
|
| 117 |
+
pipe.load_lora_weights(high_path, adapter_name=adapter_name)
|
| 118 |
+
print(f"Loaded LoRA: {group_name}")
|
| 119 |
+
return True
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def unload_lora(pipe):
|
| 124 |
+
try:
|
| 125 |
+
pipe.unload_lora_weights()
|
| 126 |
+
except:
|
| 127 |
+
pass
|
model/loss.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision.models as models
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EPE(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(EPE, self).__init__()
|
| 13 |
+
|
| 14 |
+
def forward(self, flow, gt, loss_mask):
|
| 15 |
+
loss_map = (flow - gt.detach()) ** 2
|
| 16 |
+
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
|
| 17 |
+
return (loss_map * loss_mask)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Ternary(nn.Module):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super(Ternary, self).__init__()
|
| 23 |
+
patch_size = 7
|
| 24 |
+
out_channels = patch_size * patch_size
|
| 25 |
+
self.w = np.eye(out_channels).reshape(
|
| 26 |
+
(patch_size, patch_size, 1, out_channels))
|
| 27 |
+
self.w = np.transpose(self.w, (3, 2, 0, 1))
|
| 28 |
+
self.w = torch.tensor(self.w).float().to(device)
|
| 29 |
+
|
| 30 |
+
def transform(self, img):
|
| 31 |
+
patches = F.conv2d(img, self.w, padding=3, bias=None)
|
| 32 |
+
transf = patches - img
|
| 33 |
+
transf_norm = transf / torch.sqrt(0.81 + transf**2)
|
| 34 |
+
return transf_norm
|
| 35 |
+
|
| 36 |
+
def rgb2gray(self, rgb):
|
| 37 |
+
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
|
| 38 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
| 39 |
+
return gray
|
| 40 |
+
|
| 41 |
+
def hamming(self, t1, t2):
|
| 42 |
+
dist = (t1 - t2) ** 2
|
| 43 |
+
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
|
| 44 |
+
return dist_norm
|
| 45 |
+
|
| 46 |
+
def valid_mask(self, t, padding):
|
| 47 |
+
n, _, h, w = t.size()
|
| 48 |
+
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
|
| 49 |
+
mask = F.pad(inner, [padding] * 4)
|
| 50 |
+
return mask
|
| 51 |
+
|
| 52 |
+
def forward(self, img0, img1):
|
| 53 |
+
img0 = self.transform(self.rgb2gray(img0))
|
| 54 |
+
img1 = self.transform(self.rgb2gray(img1))
|
| 55 |
+
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SOBEL(nn.Module):
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super(SOBEL, self).__init__()
|
| 61 |
+
self.kernelX = torch.tensor([
|
| 62 |
+
[1, 0, -1],
|
| 63 |
+
[2, 0, -2],
|
| 64 |
+
[1, 0, -1],
|
| 65 |
+
]).float()
|
| 66 |
+
self.kernelY = self.kernelX.clone().T
|
| 67 |
+
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
|
| 68 |
+
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
|
| 69 |
+
|
| 70 |
+
def forward(self, pred, gt):
|
| 71 |
+
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
|
| 72 |
+
img_stack = torch.cat(
|
| 73 |
+
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
|
| 74 |
+
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
|
| 75 |
+
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
|
| 76 |
+
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
|
| 77 |
+
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
|
| 78 |
+
|
| 79 |
+
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
|
| 80 |
+
loss = (L1X+L1Y)
|
| 81 |
+
return loss
|
| 82 |
+
|
| 83 |
+
class MeanShift(nn.Conv2d):
|
| 84 |
+
def __init__(self, data_mean, data_std, data_range=1, norm=True):
|
| 85 |
+
c = len(data_mean)
|
| 86 |
+
super(MeanShift, self).__init__(c, c, kernel_size=1)
|
| 87 |
+
std = torch.Tensor(data_std)
|
| 88 |
+
self.weight.data = torch.eye(c).view(c, c, 1, 1)
|
| 89 |
+
if norm:
|
| 90 |
+
self.weight.data.div_(std.view(c, 1, 1, 1))
|
| 91 |
+
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
|
| 92 |
+
self.bias.data.div_(std)
|
| 93 |
+
else:
|
| 94 |
+
self.weight.data.mul_(std.view(c, 1, 1, 1))
|
| 95 |
+
self.bias.data = data_range * torch.Tensor(data_mean)
|
| 96 |
+
self.requires_grad = False
|
| 97 |
+
|
| 98 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
| 99 |
+
def __init__(self, rank=0):
|
| 100 |
+
super(VGGPerceptualLoss, self).__init__()
|
| 101 |
+
blocks = []
|
| 102 |
+
pretrained = True
|
| 103 |
+
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
|
| 104 |
+
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
|
| 105 |
+
for param in self.parameters():
|
| 106 |
+
param.requires_grad = False
|
| 107 |
+
|
| 108 |
+
def forward(self, X, Y, indices=None):
|
| 109 |
+
X = self.normalize(X)
|
| 110 |
+
Y = self.normalize(Y)
|
| 111 |
+
indices = [2, 7, 12, 21, 30]
|
| 112 |
+
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
|
| 113 |
+
k = 0
|
| 114 |
+
loss = 0
|
| 115 |
+
for i in range(indices[-1]):
|
| 116 |
+
X = self.vgg_pretrained_features[i](X)
|
| 117 |
+
Y = self.vgg_pretrained_features[i](Y)
|
| 118 |
+
if (i+1) in indices:
|
| 119 |
+
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
|
| 120 |
+
k += 1
|
| 121 |
+
return loss
|
| 122 |
+
|
| 123 |
+
if __name__ == '__main__':
|
| 124 |
+
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
| 125 |
+
img1 = torch.tensor(np.random.normal(
|
| 126 |
+
0, 1, (3, 3, 256, 256))).float().to(device)
|
| 127 |
+
ternary_loss = Ternary()
|
| 128 |
+
print(ternary_loss(img0, img1).shape)
|
model/pytorch_msssim/__init__.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from math import exp
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
def gaussian(window_size, sigma):
|
| 9 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
| 10 |
+
return gauss/gauss.sum()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_window(window_size, channel=1):
|
| 14 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 15 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
|
| 16 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
| 17 |
+
return window
|
| 18 |
+
|
| 19 |
+
def create_window_3d(window_size, channel=1):
|
| 20 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 21 |
+
_2D_window = _1D_window.mm(_1D_window.t())
|
| 22 |
+
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
| 23 |
+
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
| 24 |
+
return window
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
| 28 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
| 29 |
+
if val_range is None:
|
| 30 |
+
if torch.max(img1) > 128:
|
| 31 |
+
max_val = 255
|
| 32 |
+
else:
|
| 33 |
+
max_val = 1
|
| 34 |
+
|
| 35 |
+
if torch.min(img1) < -0.5:
|
| 36 |
+
min_val = -1
|
| 37 |
+
else:
|
| 38 |
+
min_val = 0
|
| 39 |
+
L = max_val - min_val
|
| 40 |
+
else:
|
| 41 |
+
L = val_range
|
| 42 |
+
|
| 43 |
+
padd = 0
|
| 44 |
+
(_, channel, height, width) = img1.size()
|
| 45 |
+
if window is None:
|
| 46 |
+
real_size = min(window_size, height, width)
|
| 47 |
+
window = create_window(real_size, channel=channel).to(img1.device).type_as(img1)
|
| 48 |
+
|
| 49 |
+
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
|
| 50 |
+
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
|
| 51 |
+
|
| 52 |
+
mu1_sq = mu1.pow(2)
|
| 53 |
+
mu2_sq = mu2.pow(2)
|
| 54 |
+
mu1_mu2 = mu1 * mu2
|
| 55 |
+
|
| 56 |
+
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
|
| 57 |
+
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
|
| 58 |
+
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
|
| 59 |
+
|
| 60 |
+
C1 = (0.01 * L) ** 2
|
| 61 |
+
C2 = (0.03 * L) ** 2
|
| 62 |
+
|
| 63 |
+
v1 = 2.0 * sigma12 + C2
|
| 64 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
| 65 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
| 66 |
+
|
| 67 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
| 68 |
+
|
| 69 |
+
if size_average:
|
| 70 |
+
ret = ssim_map.mean()
|
| 71 |
+
else:
|
| 72 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
| 73 |
+
|
| 74 |
+
if full:
|
| 75 |
+
return ret, cs
|
| 76 |
+
return ret
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
| 80 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
| 81 |
+
if val_range is None:
|
| 82 |
+
if torch.max(img1) > 128:
|
| 83 |
+
max_val = 255
|
| 84 |
+
else:
|
| 85 |
+
max_val = 1
|
| 86 |
+
|
| 87 |
+
if torch.min(img1) < -0.5:
|
| 88 |
+
min_val = -1
|
| 89 |
+
else:
|
| 90 |
+
min_val = 0
|
| 91 |
+
L = max_val - min_val
|
| 92 |
+
else:
|
| 93 |
+
L = val_range
|
| 94 |
+
|
| 95 |
+
padd = 0
|
| 96 |
+
(_, _, height, width) = img1.size()
|
| 97 |
+
if window is None:
|
| 98 |
+
real_size = min(window_size, height, width)
|
| 99 |
+
window = create_window_3d(real_size, channel=1).to(img1.device).type_as(img1)
|
| 100 |
+
# Channel is set to 1 since we consider color images as volumetric images
|
| 101 |
+
|
| 102 |
+
img1 = img1.unsqueeze(1)
|
| 103 |
+
img2 = img2.unsqueeze(1)
|
| 104 |
+
|
| 105 |
+
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
|
| 106 |
+
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
|
| 107 |
+
|
| 108 |
+
mu1_sq = mu1.pow(2)
|
| 109 |
+
mu2_sq = mu2.pow(2)
|
| 110 |
+
mu1_mu2 = mu1 * mu2
|
| 111 |
+
|
| 112 |
+
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
|
| 113 |
+
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
|
| 114 |
+
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
|
| 115 |
+
|
| 116 |
+
C1 = (0.01 * L) ** 2
|
| 117 |
+
C2 = (0.03 * L) ** 2
|
| 118 |
+
|
| 119 |
+
v1 = 2.0 * sigma12 + C2
|
| 120 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
| 121 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
| 122 |
+
|
| 123 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
| 124 |
+
|
| 125 |
+
if size_average:
|
| 126 |
+
ret = ssim_map.mean()
|
| 127 |
+
else:
|
| 128 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
| 129 |
+
|
| 130 |
+
if full:
|
| 131 |
+
return ret, cs
|
| 132 |
+
return ret
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
|
| 136 |
+
device = img1.device
|
| 137 |
+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device).type_as(img1)
|
| 138 |
+
levels = weights.size()[0]
|
| 139 |
+
mssim = []
|
| 140 |
+
mcs = []
|
| 141 |
+
for _ in range(levels):
|
| 142 |
+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
| 143 |
+
mssim.append(sim)
|
| 144 |
+
mcs.append(cs)
|
| 145 |
+
|
| 146 |
+
img1 = F.avg_pool2d(img1, (2, 2))
|
| 147 |
+
img2 = F.avg_pool2d(img2, (2, 2))
|
| 148 |
+
|
| 149 |
+
mssim = torch.stack(mssim)
|
| 150 |
+
mcs = torch.stack(mcs)
|
| 151 |
+
|
| 152 |
+
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
|
| 153 |
+
if normalize:
|
| 154 |
+
mssim = (mssim + 1) / 2
|
| 155 |
+
mcs = (mcs + 1) / 2
|
| 156 |
+
|
| 157 |
+
pow1 = mcs ** weights
|
| 158 |
+
pow2 = mssim ** weights
|
| 159 |
+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
| 160 |
+
output = torch.prod(pow1[:-1] * pow2[-1])
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Classes to re-use window
|
| 165 |
+
class SSIM(torch.nn.Module):
|
| 166 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
| 167 |
+
super(SSIM, self).__init__()
|
| 168 |
+
self.window_size = window_size
|
| 169 |
+
self.size_average = size_average
|
| 170 |
+
self.val_range = val_range
|
| 171 |
+
|
| 172 |
+
# Assume 3 channel for SSIM
|
| 173 |
+
self.channel = 3
|
| 174 |
+
self.window = create_window(window_size, channel=self.channel)
|
| 175 |
+
|
| 176 |
+
def forward(self, img1, img2):
|
| 177 |
+
(_, channel, _, _) = img1.size()
|
| 178 |
+
|
| 179 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
| 180 |
+
window = self.window
|
| 181 |
+
else:
|
| 182 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
| 183 |
+
self.window = window
|
| 184 |
+
self.channel = channel
|
| 185 |
+
|
| 186 |
+
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
| 187 |
+
dssim = (1 - _ssim) / 2
|
| 188 |
+
return dssim
|
| 189 |
+
|
| 190 |
+
class MSSSIM(torch.nn.Module):
|
| 191 |
+
def __init__(self, window_size=11, size_average=True, channel=3):
|
| 192 |
+
super(MSSSIM, self).__init__()
|
| 193 |
+
self.window_size = window_size
|
| 194 |
+
self.size_average = size_average
|
| 195 |
+
self.channel = channel
|
| 196 |
+
|
| 197 |
+
def forward(self, img1, img2):
|
| 198 |
+
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
|
model/warplayer.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 5 |
+
backwarp_tenGrid = {}
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def warp(tenInput, tenFlow):
|
| 9 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
| 10 |
+
if k not in backwarp_tenGrid:
|
| 11 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
|
| 12 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
| 13 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
|
| 14 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
| 15 |
+
backwarp_tenGrid[k] = torch.cat(
|
| 16 |
+
[tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
| 17 |
+
|
| 18 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
| 19 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
| 20 |
+
|
| 21 |
+
grid = backwarp_tenGrid[k].type_as(tenFlow)
|
| 22 |
+
|
| 23 |
+
g = (grid + tenFlow).permute(0, 2, 3, 1)
|
| 24 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
unzip
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/linoytsaban/diffusers.git@wan22-loras
|
| 2 |
+
|
| 3 |
+
transformers<5
|
| 4 |
+
accelerate
|
| 5 |
+
safetensors
|
| 6 |
+
sentencepiece
|
| 7 |
+
peft
|
| 8 |
+
ftfy
|
| 9 |
+
imageio
|
| 10 |
+
imageio-ffmpeg
|
| 11 |
+
opencv-python
|
| 12 |
+
torchao==0.11.0
|
| 13 |
+
sam2
|
| 14 |
+
|
| 15 |
+
numpy
|
| 16 |
+
torchvision
|
wan22_input_2.jpg
ADDED
|
Git LFS Details
|
wan_controlnet.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 8 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 9 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 11 |
+
from diffusers.models.transformers.transformer_wan import (
|
| 12 |
+
WanTimeTextImageEmbedding,
|
| 13 |
+
WanRotaryPosEmbed,
|
| 14 |
+
WanTransformerBlock
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def zero_module(module):
|
| 19 |
+
for p in module.parameters():
|
| 20 |
+
nn.init.zeros_(p)
|
| 21 |
+
return module
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
+
|
| 26 |
+
def zero_module(module):
|
| 27 |
+
for p in module.parameters():
|
| 28 |
+
nn.init.zeros_(p)
|
| 29 |
+
return module
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 33 |
+
r"""
|
| 34 |
+
A Controlnet Transformer model for video-like data used in the Wan model.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
| 38 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
| 39 |
+
num_attention_heads (`int`, defaults to `40`):
|
| 40 |
+
Fixed length for text embeddings.
|
| 41 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 42 |
+
The number of channels in each head.
|
| 43 |
+
vae_channels (`int`, defaults to `16`):
|
| 44 |
+
The number of channels in the vae input.
|
| 45 |
+
in_channels (`int`, defaults to `16`):
|
| 46 |
+
The number of channels in the controlnet input.
|
| 47 |
+
text_dim (`int`, defaults to `512`):
|
| 48 |
+
Input dimension for text embeddings.
|
| 49 |
+
freq_dim (`int`, defaults to `256`):
|
| 50 |
+
Dimension for sinusoidal time embeddings.
|
| 51 |
+
ffn_dim (`int`, defaults to `13824`):
|
| 52 |
+
Intermediate dimension in feed-forward network.
|
| 53 |
+
num_layers (`int`, defaults to `40`):
|
| 54 |
+
The number of layers of transformer blocks to use.
|
| 55 |
+
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
| 56 |
+
Window size for local attention (-1 indicates global attention).
|
| 57 |
+
cross_attn_norm (`bool`, defaults to `True`):
|
| 58 |
+
Enable cross-attention normalization.
|
| 59 |
+
qk_norm (`bool`, defaults to `True`):
|
| 60 |
+
Enable query/key normalization.
|
| 61 |
+
eps (`float`, defaults to `1e-6`):
|
| 62 |
+
Epsilon value for normalization layers.
|
| 63 |
+
add_img_emb (`bool`, defaults to `False`):
|
| 64 |
+
Whether to use img_emb.
|
| 65 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 66 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 67 |
+
downscale_coef (`int`, *optional*, defaults to `8`):
|
| 68 |
+
Coeficient for downscale controlnet input video.
|
| 69 |
+
out_proj_dim (`int`, *optional*, defaults to `128 * 12`):
|
| 70 |
+
Output projection dimention for last linear layers.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
_supports_gradient_checkpointing = True
|
| 74 |
+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
| 75 |
+
_no_split_modules = ["WanTransformerBlock"]
|
| 76 |
+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
| 77 |
+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
| 78 |
+
|
| 79 |
+
@register_to_config
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
patch_size: Tuple[int] = (1, 2, 2),
|
| 83 |
+
num_attention_heads: int = 40,
|
| 84 |
+
attention_head_dim: int = 128,
|
| 85 |
+
in_channels: int = 3,
|
| 86 |
+
vae_channels: int = 16,
|
| 87 |
+
text_dim: int = 4096,
|
| 88 |
+
freq_dim: int = 256,
|
| 89 |
+
ffn_dim: int = 13824,
|
| 90 |
+
num_layers: int = 20,
|
| 91 |
+
cross_attn_norm: bool = True,
|
| 92 |
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
| 93 |
+
eps: float = 1e-6,
|
| 94 |
+
image_dim: Optional[int] = None,
|
| 95 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 96 |
+
rope_max_seq_len: int = 1024,
|
| 97 |
+
downscale_coef: int = 8,
|
| 98 |
+
out_proj_dim: int = 128 * 12,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
start_channels = in_channels * (downscale_coef ** 2)
|
| 103 |
+
input_channels = [start_channels, start_channels // 2, start_channels // 4]
|
| 104 |
+
|
| 105 |
+
self.control_encoder = nn.ModuleList([
|
| 106 |
+
## Spatial compression with time awareness
|
| 107 |
+
nn.Sequential(
|
| 108 |
+
nn.Conv3d(
|
| 109 |
+
in_channels,
|
| 110 |
+
input_channels[0],
|
| 111 |
+
kernel_size=(3, downscale_coef + 1, downscale_coef + 1),
|
| 112 |
+
stride=(1, downscale_coef, downscale_coef),
|
| 113 |
+
padding=(1, downscale_coef // 2, downscale_coef // 2)
|
| 114 |
+
),
|
| 115 |
+
nn.GELU(approximate="tanh"),
|
| 116 |
+
nn.GroupNorm(2, input_channels[0]),
|
| 117 |
+
),
|
| 118 |
+
## Spatio-Temporal compression with spatial awareness
|
| 119 |
+
nn.Sequential(
|
| 120 |
+
nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1),
|
| 121 |
+
nn.GELU(approximate="tanh"),
|
| 122 |
+
nn.GroupNorm(2, input_channels[1]),
|
| 123 |
+
),
|
| 124 |
+
## Temporal compression with spatial awareness
|
| 125 |
+
nn.Sequential(
|
| 126 |
+
nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1),
|
| 127 |
+
nn.GELU(approximate="tanh"),
|
| 128 |
+
nn.GroupNorm(2, input_channels[2]),
|
| 129 |
+
)
|
| 130 |
+
])
|
| 131 |
+
|
| 132 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 133 |
+
|
| 134 |
+
# 1. Patch & position embedding
|
| 135 |
+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
| 136 |
+
self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size)
|
| 137 |
+
|
| 138 |
+
# 2. Condition embeddings
|
| 139 |
+
# image_embedding_dim=1280 for I2V model
|
| 140 |
+
self.condition_embedder = WanTimeTextImageEmbedding(
|
| 141 |
+
dim=inner_dim,
|
| 142 |
+
time_freq_dim=freq_dim,
|
| 143 |
+
time_proj_dim=inner_dim * 6,
|
| 144 |
+
text_embed_dim=text_dim,
|
| 145 |
+
image_embed_dim=image_dim,
|
| 146 |
+
)
|
| 147 |
+
# 3. Transformer blocks
|
| 148 |
+
self.blocks = nn.ModuleList(
|
| 149 |
+
[
|
| 150 |
+
WanTransformerBlock(
|
| 151 |
+
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
| 152 |
+
)
|
| 153 |
+
for _ in range(num_layers)
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# 4 Controlnet modules
|
| 158 |
+
self.controlnet_blocks = nn.ModuleList([])
|
| 159 |
+
|
| 160 |
+
for _ in range(len(self.blocks)):
|
| 161 |
+
controlnet_block = nn.Linear(inner_dim, out_proj_dim)
|
| 162 |
+
controlnet_block = zero_module(controlnet_block)
|
| 163 |
+
self.controlnet_blocks.append(controlnet_block)
|
| 164 |
+
|
| 165 |
+
self.gradient_checkpointing = False
|
| 166 |
+
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
hidden_states: torch.Tensor,
|
| 170 |
+
timestep: torch.LongTensor,
|
| 171 |
+
encoder_hidden_states: torch.Tensor,
|
| 172 |
+
controlnet_states: torch.Tensor,
|
| 173 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 174 |
+
return_dict: bool = True,
|
| 175 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 176 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 177 |
+
if attention_kwargs is not None:
|
| 178 |
+
attention_kwargs = attention_kwargs.copy()
|
| 179 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 180 |
+
else:
|
| 181 |
+
lora_scale = 1.0
|
| 182 |
+
|
| 183 |
+
if USE_PEFT_BACKEND:
|
| 184 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 185 |
+
scale_lora_layers(self, lora_scale)
|
| 186 |
+
else:
|
| 187 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 188 |
+
logger.warning(
|
| 189 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
rotary_emb = self.rope(hidden_states)
|
| 193 |
+
|
| 194 |
+
# 0. Controlnet encoder
|
| 195 |
+
for control_encoder_block in self.control_encoder:
|
| 196 |
+
controlnet_states = control_encoder_block(controlnet_states)
|
| 197 |
+
# print("+" * 50, hidden_states.shape, controlnet_states.shape)
|
| 198 |
+
hidden_states = torch.cat([hidden_states, controlnet_states], dim=1)
|
| 199 |
+
|
| 200 |
+
hidden_states = self.patch_embedding(hidden_states)
|
| 201 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 202 |
+
|
| 203 |
+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
| 204 |
+
if timestep.ndim == 2:
|
| 205 |
+
ts_seq_len = timestep.shape[1]
|
| 206 |
+
timestep = timestep.flatten() # batch_size * seq_len
|
| 207 |
+
else:
|
| 208 |
+
ts_seq_len = None
|
| 209 |
+
|
| 210 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
| 211 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
| 212 |
+
)
|
| 213 |
+
if ts_seq_len is not None:
|
| 214 |
+
# batch_size, seq_len, 6, inner_dim
|
| 215 |
+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
| 216 |
+
else:
|
| 217 |
+
# batch_size, 6, inner_dim
|
| 218 |
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
| 219 |
+
|
| 220 |
+
if encoder_hidden_states_image is not None:
|
| 221 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 222 |
+
|
| 223 |
+
# 4. Transformer blocks
|
| 224 |
+
controlnet_hidden_states = ()
|
| 225 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 226 |
+
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
| 227 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 228 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
| 229 |
+
)
|
| 230 |
+
controlnet_hidden_states += (controlnet_block(hidden_states),)
|
| 231 |
+
else:
|
| 232 |
+
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
| 233 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
| 234 |
+
controlnet_hidden_states += (controlnet_block(hidden_states),)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if USE_PEFT_BACKEND:
|
| 238 |
+
# remove `lora_scale` from each PEFT layer
|
| 239 |
+
unscale_lora_layers(self, lora_scale)
|
| 240 |
+
|
| 241 |
+
if not return_dict:
|
| 242 |
+
return (controlnet_hidden_states,)
|
| 243 |
+
|
| 244 |
+
return Transformer2DModelOutput(sample=controlnet_hidden_states)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
parameters = {
|
| 249 |
+
"added_kv_proj_dim": None,
|
| 250 |
+
"attention_head_dim": 128,
|
| 251 |
+
"cross_attn_norm": True,
|
| 252 |
+
"eps": 1e-06,
|
| 253 |
+
"ffn_dim": 8960,
|
| 254 |
+
"freq_dim": 256,
|
| 255 |
+
"image_dim": None,
|
| 256 |
+
"in_channels": 3,
|
| 257 |
+
"num_attention_heads": 12,
|
| 258 |
+
"num_layers": 2,
|
| 259 |
+
"patch_size": [1, 2, 2],
|
| 260 |
+
"qk_norm": "rms_norm_across_heads",
|
| 261 |
+
"rope_max_seq_len": 1024,
|
| 262 |
+
"text_dim": 4096,
|
| 263 |
+
"downscale_coef": 8,
|
| 264 |
+
"out_proj_dim": 12 * 128,
|
| 265 |
+
"vae_channels": 16
|
| 266 |
+
}
|
| 267 |
+
controlnet = WanControlnet(**parameters)
|
| 268 |
+
|
| 269 |
+
hidden_states = torch.rand(1, 16, 13, 60, 90)
|
| 270 |
+
timestep = torch.tensor([1000]).repeat(17550).unsqueeze(0) #torch.randint(low=0, high=1000, size=(1,), dtype=torch.long)
|
| 271 |
+
encoder_hidden_states = torch.rand(1, 512, 4096)
|
| 272 |
+
controlnet_states = torch.rand(1, 3, 49, 480, 720)
|
| 273 |
+
|
| 274 |
+
controlnet_hidden_states = controlnet(
|
| 275 |
+
hidden_states=hidden_states,
|
| 276 |
+
timestep=timestep,
|
| 277 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 278 |
+
controlnet_states=controlnet_states,
|
| 279 |
+
return_dict=False
|
| 280 |
+
)
|
| 281 |
+
print("Output states count", len(controlnet_hidden_states[0]))
|
| 282 |
+
for out_hidden_states in controlnet_hidden_states[0]:
|
| 283 |
+
print(out_hidden_states.shape)
|
| 284 |
+
|
wan_i2v_input.JPG
ADDED
|
|
Git LFS Details
|
wan_t2v_controlnet_pipeline.py
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
# #
|
| 3 |
+
# # Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# # you may not use this file except in compliance with the License.
|
| 5 |
+
# # You may obtain a copy of the License at
|
| 6 |
+
# #
|
| 7 |
+
# # http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
# #
|
| 9 |
+
# # Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# # distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# # See the License for the specific language governing permissions and
|
| 13 |
+
# # limitations under the License.
|
| 14 |
+
|
| 15 |
+
import html
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
| 18 |
+
|
| 19 |
+
import ftfy
|
| 20 |
+
import regex as re
|
| 21 |
+
import torch
|
| 22 |
+
import numpy as np
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from torchvision import transforms
|
| 25 |
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
| 26 |
+
|
| 27 |
+
from diffusers import WanTransformer3DModel
|
| 28 |
+
from diffusers.image_processor import PipelineImageInput
|
| 29 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 30 |
+
from diffusers.loaders import WanLoraLoaderMixin
|
| 31 |
+
from diffusers.models import AutoencoderKLWan
|
| 32 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 33 |
+
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
| 34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 35 |
+
from diffusers.video_processor import VideoProcessor
|
| 36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 37 |
+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
| 38 |
+
|
| 39 |
+
from wan_transformer import CustomWanTransformer3DModel
|
| 40 |
+
from wan_controlnet import WanControlnet
|
| 41 |
+
from wan_teacache import TeaCache
|
| 42 |
+
|
| 43 |
+
if is_torch_xla_available():
|
| 44 |
+
import torch_xla.core.xla_model as xm
|
| 45 |
+
|
| 46 |
+
XLA_AVAILABLE = True
|
| 47 |
+
else:
|
| 48 |
+
XLA_AVAILABLE = False
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def resize_for_crop(image, crop_h, crop_w):
|
| 54 |
+
img_h, img_w = image.shape[-2:]
|
| 55 |
+
if img_h >= crop_h and img_w >= crop_w:
|
| 56 |
+
coef = max(crop_h / img_h, crop_w / img_w)
|
| 57 |
+
elif img_h <= crop_h and img_w <= crop_w:
|
| 58 |
+
coef = max(crop_h / img_h, crop_w / img_w)
|
| 59 |
+
else:
|
| 60 |
+
coef = crop_h / img_h if crop_h > img_h else crop_w / img_w
|
| 61 |
+
out_h, out_w = int(img_h * coef), int(img_w * coef)
|
| 62 |
+
resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
|
| 63 |
+
return resized_image
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def prepare_frames(input_images, video_size, do_resize=True, do_crop=True):
|
| 67 |
+
input_images = np.stack([np.array(x) for x in input_images])
|
| 68 |
+
images_tensor = torch.from_numpy(input_images).permute(0, 3, 1, 2) / 127.5 - 1
|
| 69 |
+
if do_resize:
|
| 70 |
+
images_tensor = [resize_for_crop(x, crop_h=video_size[0], crop_w=video_size[1]) for x in images_tensor]
|
| 71 |
+
if do_crop:
|
| 72 |
+
images_tensor = [transforms.functional.center_crop(x, video_size) for x in images_tensor]
|
| 73 |
+
if isinstance(images_tensor, list):
|
| 74 |
+
images_tensor = torch.stack(images_tensor)
|
| 75 |
+
return images_tensor.unsqueeze(0)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def prepare_controlnet_frames(controlnet_frames, height, width, dtype, device):
|
| 79 |
+
prepared_frames = prepare_frames(controlnet_frames, (height, width))
|
| 80 |
+
controlnet_encoded_frames = prepared_frames.to(dtype=dtype, device=device)
|
| 81 |
+
return controlnet_encoded_frames.permute(0, 2, 1, 3, 4).contiguous()
|
| 82 |
+
|
| 83 |
+
def basic_clean(text):
|
| 84 |
+
text = ftfy.fix_text(text)
|
| 85 |
+
text = html.unescape(html.unescape(text))
|
| 86 |
+
return text.strip()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def whitespace_clean(text):
|
| 90 |
+
text = re.sub(r"\s+", " ", text)
|
| 91 |
+
text = text.strip()
|
| 92 |
+
return text
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def prompt_clean(text):
|
| 96 |
+
text = whitespace_clean(basic_clean(text))
|
| 97 |
+
return text
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 101 |
+
def retrieve_timesteps(
|
| 102 |
+
scheduler,
|
| 103 |
+
num_inference_steps: Optional[int] = None,
|
| 104 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 105 |
+
timesteps: Optional[List[int]] = None,
|
| 106 |
+
sigmas: Optional[List[float]] = None,
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
r"""
|
| 110 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 111 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
scheduler (`SchedulerMixin`):
|
| 115 |
+
The scheduler to get timesteps from.
|
| 116 |
+
num_inference_steps (`int`):
|
| 117 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 118 |
+
must be `None`.
|
| 119 |
+
device (`str` or `torch.device`, *optional*):
|
| 120 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 121 |
+
timesteps (`List[int]`, *optional*):
|
| 122 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 123 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 124 |
+
sigmas (`List[float]`, *optional*):
|
| 125 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 126 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 130 |
+
second element is the number of inference steps.
|
| 131 |
+
"""
|
| 132 |
+
if timesteps is not None and sigmas is not None:
|
| 133 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 134 |
+
if timesteps is not None:
|
| 135 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 136 |
+
if not accepts_timesteps:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 139 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 140 |
+
)
|
| 141 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 142 |
+
timesteps = scheduler.timesteps
|
| 143 |
+
num_inference_steps = len(timesteps)
|
| 144 |
+
elif sigmas is not None:
|
| 145 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 146 |
+
if not accept_sigmas:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 149 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 150 |
+
)
|
| 151 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 152 |
+
timesteps = scheduler.timesteps
|
| 153 |
+
num_inference_steps = len(timesteps)
|
| 154 |
+
else:
|
| 155 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 156 |
+
timesteps = scheduler.timesteps
|
| 157 |
+
return timesteps, num_inference_steps
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 161 |
+
def retrieve_latents(
|
| 162 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 163 |
+
):
|
| 164 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 165 |
+
return encoder_output.latent_dist.sample(generator)
|
| 166 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 167 |
+
return encoder_output.latent_dist.mode()
|
| 168 |
+
elif hasattr(encoder_output, "latents"):
|
| 169 |
+
return encoder_output.latents
|
| 170 |
+
else:
|
| 171 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class WanTextToVideoControlnetPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
| 175 |
+
r"""
|
| 176 |
+
Pipeline for text-to-video generation using Wan.
|
| 177 |
+
|
| 178 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 179 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
tokenizer ([`T5Tokenizer`]):
|
| 183 |
+
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
| 184 |
+
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 185 |
+
text_encoder ([`T5EncoderModel`]):
|
| 186 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 187 |
+
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 188 |
+
transformer ([`WanTransformer3DModel`]):
|
| 189 |
+
Conditional Transformer to denoise the input latents.
|
| 190 |
+
scheduler ([`UniPCMultistepScheduler`]):
|
| 191 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 192 |
+
vae ([`AutoencoderKLWan`]):
|
| 193 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae->controlnet"
|
| 197 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 198 |
+
_optional_components = ["transformer_2"]
|
| 199 |
+
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
tokenizer: AutoTokenizer,
|
| 203 |
+
text_encoder: UMT5EncoderModel,
|
| 204 |
+
transformer: CustomWanTransformer3DModel,
|
| 205 |
+
vae: AutoencoderKLWan,
|
| 206 |
+
controlnet: WanControlnet,
|
| 207 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 208 |
+
transformer_2: WanTransformer3DModel = None,
|
| 209 |
+
boundary_ratio: Optional[float] = None,
|
| 210 |
+
expand_timesteps: bool = False,
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
|
| 214 |
+
self.register_modules(
|
| 215 |
+
vae=vae,
|
| 216 |
+
text_encoder=text_encoder,
|
| 217 |
+
tokenizer=tokenizer,
|
| 218 |
+
transformer=transformer,
|
| 219 |
+
controlnet=controlnet,
|
| 220 |
+
scheduler=scheduler,
|
| 221 |
+
transformer_2=transformer_2,
|
| 222 |
+
)
|
| 223 |
+
self.register_to_config(boundary_ratio=boundary_ratio)
|
| 224 |
+
self.register_to_config(expand_timesteps=expand_timesteps)
|
| 225 |
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
| 226 |
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
| 227 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 228 |
+
|
| 229 |
+
def _get_t5_prompt_embeds(
|
| 230 |
+
self,
|
| 231 |
+
prompt: Union[str, List[str]] = None,
|
| 232 |
+
num_videos_per_prompt: int = 1,
|
| 233 |
+
max_sequence_length: int = 226,
|
| 234 |
+
device: Optional[torch.device] = None,
|
| 235 |
+
dtype: Optional[torch.dtype] = None,
|
| 236 |
+
):
|
| 237 |
+
device = device or self._execution_device
|
| 238 |
+
dtype = dtype or self.text_encoder.dtype
|
| 239 |
+
|
| 240 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 241 |
+
prompt = [prompt_clean(u) for u in prompt]
|
| 242 |
+
batch_size = len(prompt)
|
| 243 |
+
|
| 244 |
+
text_inputs = self.tokenizer(
|
| 245 |
+
prompt,
|
| 246 |
+
padding="max_length",
|
| 247 |
+
max_length=max_sequence_length,
|
| 248 |
+
truncation=True,
|
| 249 |
+
add_special_tokens=True,
|
| 250 |
+
return_attention_mask=True,
|
| 251 |
+
return_tensors="pt",
|
| 252 |
+
)
|
| 253 |
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
| 254 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 255 |
+
|
| 256 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
| 257 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 258 |
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 259 |
+
prompt_embeds = torch.stack(
|
| 260 |
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 264 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 265 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 266 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 267 |
+
|
| 268 |
+
return prompt_embeds
|
| 269 |
+
|
| 270 |
+
def encode_prompt(
|
| 271 |
+
self,
|
| 272 |
+
prompt: Union[str, List[str]],
|
| 273 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 274 |
+
do_classifier_free_guidance: bool = True,
|
| 275 |
+
num_videos_per_prompt: int = 1,
|
| 276 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 277 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 278 |
+
max_sequence_length: int = 226,
|
| 279 |
+
device: Optional[torch.device] = None,
|
| 280 |
+
dtype: Optional[torch.dtype] = None,
|
| 281 |
+
):
|
| 282 |
+
r"""
|
| 283 |
+
Encodes the prompt into text encoder hidden states.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 287 |
+
prompt to be encoded
|
| 288 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 289 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 290 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 291 |
+
less than `1`).
|
| 292 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 293 |
+
Whether to use classifier free guidance or not.
|
| 294 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 295 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 296 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 297 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 298 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 299 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 300 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 301 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 302 |
+
argument.
|
| 303 |
+
device: (`torch.device`, *optional*):
|
| 304 |
+
torch device
|
| 305 |
+
dtype: (`torch.dtype`, *optional*):
|
| 306 |
+
torch dtype
|
| 307 |
+
"""
|
| 308 |
+
device = device or self._execution_device
|
| 309 |
+
|
| 310 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 311 |
+
if prompt is not None:
|
| 312 |
+
batch_size = len(prompt)
|
| 313 |
+
else:
|
| 314 |
+
batch_size = prompt_embeds.shape[0]
|
| 315 |
+
|
| 316 |
+
if prompt_embeds is None:
|
| 317 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 318 |
+
prompt=prompt,
|
| 319 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 320 |
+
max_sequence_length=max_sequence_length,
|
| 321 |
+
device=device,
|
| 322 |
+
dtype=dtype,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 326 |
+
negative_prompt = negative_prompt or ""
|
| 327 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 328 |
+
|
| 329 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 330 |
+
raise TypeError(
|
| 331 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 332 |
+
f" {type(prompt)}."
|
| 333 |
+
)
|
| 334 |
+
elif batch_size != len(negative_prompt):
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 337 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 338 |
+
" the batch size of `prompt`."
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 342 |
+
prompt=negative_prompt,
|
| 343 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 344 |
+
max_sequence_length=max_sequence_length,
|
| 345 |
+
device=device,
|
| 346 |
+
dtype=dtype,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return prompt_embeds, negative_prompt_embeds
|
| 350 |
+
|
| 351 |
+
def check_inputs(
|
| 352 |
+
self,
|
| 353 |
+
prompt,
|
| 354 |
+
negative_prompt,
|
| 355 |
+
height,
|
| 356 |
+
width,
|
| 357 |
+
prompt_embeds=None,
|
| 358 |
+
negative_prompt_embeds=None,
|
| 359 |
+
callback_on_step_end_tensor_inputs=None,
|
| 360 |
+
guidance_scale_2=None,
|
| 361 |
+
):
|
| 362 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 363 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 364 |
+
|
| 365 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 366 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 367 |
+
):
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
if prompt is not None and prompt_embeds is not None:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 375 |
+
" only forward one of the two."
|
| 376 |
+
)
|
| 377 |
+
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
| 380 |
+
" only forward one of the two."
|
| 381 |
+
)
|
| 382 |
+
elif prompt is None and prompt_embeds is None:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 385 |
+
)
|
| 386 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 387 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 388 |
+
elif negative_prompt is not None and (
|
| 389 |
+
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
| 390 |
+
):
|
| 391 |
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 392 |
+
|
| 393 |
+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
| 394 |
+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
| 395 |
+
|
| 396 |
+
def prepare_latents(
|
| 397 |
+
self,
|
| 398 |
+
batch_size: int,
|
| 399 |
+
num_channels_latents: int = 16,
|
| 400 |
+
height: int = 480,
|
| 401 |
+
width: int = 832,
|
| 402 |
+
num_frames: int = 81,
|
| 403 |
+
dtype: Optional[torch.dtype] = None,
|
| 404 |
+
device: Optional[torch.device] = None,
|
| 405 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 406 |
+
latents: Optional[torch.Tensor] = None,
|
| 407 |
+
) -> torch.Tensor:
|
| 408 |
+
if latents is not None:
|
| 409 |
+
return latents.to(device=device, dtype=dtype)
|
| 410 |
+
|
| 411 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 412 |
+
shape = (
|
| 413 |
+
batch_size,
|
| 414 |
+
num_channels_latents,
|
| 415 |
+
num_latent_frames,
|
| 416 |
+
int(height) // self.vae_scale_factor_spatial,
|
| 417 |
+
int(width) // self.vae_scale_factor_spatial,
|
| 418 |
+
)
|
| 419 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 422 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 426 |
+
return latents
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def guidance_scale(self):
|
| 430 |
+
return self._guidance_scale
|
| 431 |
+
|
| 432 |
+
@property
|
| 433 |
+
def do_classifier_free_guidance(self):
|
| 434 |
+
return self._guidance_scale > 1.0
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def num_timesteps(self):
|
| 438 |
+
return self._num_timesteps
|
| 439 |
+
|
| 440 |
+
@property
|
| 441 |
+
def current_timestep(self):
|
| 442 |
+
return self._current_timestep
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def interrupt(self):
|
| 446 |
+
return self._interrupt
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def attention_kwargs(self):
|
| 450 |
+
return self._attention_kwargs
|
| 451 |
+
|
| 452 |
+
@torch.no_grad()
|
| 453 |
+
def __call__(
|
| 454 |
+
self,
|
| 455 |
+
controlnet_frames: List[Image.Image] = None,
|
| 456 |
+
prompt: Union[str, List[str]] = None,
|
| 457 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 458 |
+
height: int = 480,
|
| 459 |
+
width: int = 832,
|
| 460 |
+
num_frames: int = 81,
|
| 461 |
+
num_inference_steps: int = 50,
|
| 462 |
+
guidance_scale: float = 5.0,
|
| 463 |
+
guidance_scale_2: Optional[float] = None,
|
| 464 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 465 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 466 |
+
latents: Optional[torch.Tensor] = None,
|
| 467 |
+
controlnet_latents: Optional[torch.FloatTensor] = None,
|
| 468 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 469 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 470 |
+
output_type: Optional[str] = "np",
|
| 471 |
+
return_dict: bool = True,
|
| 472 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 473 |
+
callback_on_step_end: Optional[
|
| 474 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 475 |
+
] = None,
|
| 476 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 477 |
+
max_sequence_length: int = 512,
|
| 478 |
+
|
| 479 |
+
controlnet_weight: float = 1.0,
|
| 480 |
+
controlnet_guidance_start: float = 0.0,
|
| 481 |
+
controlnet_guidance_end: float = 1.0,
|
| 482 |
+
controlnet_stride: int = 3,
|
| 483 |
+
|
| 484 |
+
teacache_state: Optional[TeaCache]= None,
|
| 485 |
+
teacache_treshold: float = 0.0,
|
| 486 |
+
):
|
| 487 |
+
r"""
|
| 488 |
+
The call function to the pipeline for generation.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 492 |
+
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
| 493 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 494 |
+
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
| 495 |
+
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
| 496 |
+
height (`int`, defaults to `480`):
|
| 497 |
+
The height in pixels of the generated image.
|
| 498 |
+
width (`int`, defaults to `832`):
|
| 499 |
+
The width in pixels of the generated image.
|
| 500 |
+
num_frames (`int`, defaults to `81`):
|
| 501 |
+
The number of frames in the generated video.
|
| 502 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 503 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 504 |
+
expense of slower inference.
|
| 505 |
+
guidance_scale (`float`, defaults to `5.0`):
|
| 506 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 507 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 508 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 509 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 510 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 511 |
+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
| 512 |
+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
| 513 |
+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
| 514 |
+
and the pipeline's `boundary_ratio` are not None.
|
| 515 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 516 |
+
The number of images to generate per prompt.
|
| 517 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 518 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 519 |
+
generation deterministic.
|
| 520 |
+
latents (`torch.Tensor`, *optional*):
|
| 521 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 522 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 523 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 524 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 525 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 526 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 527 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 528 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 529 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 530 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
| 531 |
+
attention_kwargs (`dict`, *optional*):
|
| 532 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 533 |
+
`self.processor` in
|
| 534 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 535 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 536 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 537 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 538 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 539 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 540 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 541 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 542 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 543 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 544 |
+
max_sequence_length (`int`, defaults to `512`):
|
| 545 |
+
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
| 546 |
+
truncated. If the prompt is shorter, it will be padded to this length.
|
| 547 |
+
controlnet_weight (`float`, defaults to `0.8`):
|
| 548 |
+
Wigight for controlnet modules.
|
| 549 |
+
controlnet_guidance_start (`float`, defaults to `0.0`):
|
| 550 |
+
When start do control.
|
| 551 |
+
controlnet_guidance_end (`float`, defaults to `0.8`):
|
| 552 |
+
When finish do control.
|
| 553 |
+
controlnet_stride (`int`, defaults to `3`):
|
| 554 |
+
Stride for controlnet blocks.
|
| 555 |
+
Examples:
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
[`~WanPipelineOutput`] or `tuple`:
|
| 559 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
| 560 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
| 561 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 562 |
+
"""
|
| 563 |
+
self.teacache = teacache_state or None
|
| 564 |
+
if (self.teacache is None) and (teacache_treshold > 0.0):
|
| 565 |
+
self.teacache = TeaCache(
|
| 566 |
+
num_inference_steps=num_inference_steps,
|
| 567 |
+
model_name="DEFAULT",
|
| 568 |
+
treshold=teacache_treshold
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 572 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 573 |
+
|
| 574 |
+
# 1. Check inputs. Raise error if not correct
|
| 575 |
+
self.check_inputs(
|
| 576 |
+
prompt,
|
| 577 |
+
negative_prompt,
|
| 578 |
+
height,
|
| 579 |
+
width,
|
| 580 |
+
prompt_embeds,
|
| 581 |
+
negative_prompt_embeds,
|
| 582 |
+
callback_on_step_end_tensor_inputs,
|
| 583 |
+
guidance_scale_2,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if num_frames % self.vae_scale_factor_temporal != 1:
|
| 587 |
+
logger.warning(
|
| 588 |
+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
| 589 |
+
)
|
| 590 |
+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
| 591 |
+
num_frames = max(num_frames, 1)
|
| 592 |
+
|
| 593 |
+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
| 594 |
+
guidance_scale_2 = guidance_scale
|
| 595 |
+
|
| 596 |
+
self._guidance_scale = guidance_scale
|
| 597 |
+
self._guidance_scale_2 = guidance_scale_2
|
| 598 |
+
self._attention_kwargs = attention_kwargs
|
| 599 |
+
self._current_timestep = None
|
| 600 |
+
self._interrupt = False
|
| 601 |
+
|
| 602 |
+
device = self._execution_device
|
| 603 |
+
|
| 604 |
+
# 2. Define call parameters
|
| 605 |
+
if prompt is not None and isinstance(prompt, str):
|
| 606 |
+
batch_size = 1
|
| 607 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 608 |
+
batch_size = len(prompt)
|
| 609 |
+
else:
|
| 610 |
+
batch_size = prompt_embeds.shape[0]
|
| 611 |
+
|
| 612 |
+
# 3. Encode input prompt
|
| 613 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 614 |
+
prompt=prompt,
|
| 615 |
+
negative_prompt=negative_prompt,
|
| 616 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 617 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 618 |
+
prompt_embeds=prompt_embeds,
|
| 619 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 620 |
+
max_sequence_length=max_sequence_length,
|
| 621 |
+
device=device,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
transformer_dtype = self.transformer.dtype
|
| 625 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 626 |
+
if negative_prompt_embeds is not None:
|
| 627 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 628 |
+
|
| 629 |
+
# 4. Prepare timesteps
|
| 630 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 631 |
+
timesteps = self.scheduler.timesteps
|
| 632 |
+
|
| 633 |
+
# 5. Prepare latent variables
|
| 634 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 635 |
+
latents = self.prepare_latents(
|
| 636 |
+
batch_size * num_videos_per_prompt,
|
| 637 |
+
num_channels_latents,
|
| 638 |
+
height,
|
| 639 |
+
width,
|
| 640 |
+
num_frames,
|
| 641 |
+
torch.float32,
|
| 642 |
+
device,
|
| 643 |
+
generator,
|
| 644 |
+
latents,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
| 648 |
+
|
| 649 |
+
# 6. Encode controlnet frames
|
| 650 |
+
if (controlnet_latents is None) and (controlnet_frames is not None):
|
| 651 |
+
duplicate_frames_count = num_frames - len(controlnet_frames)
|
| 652 |
+
print(f'Using controlnet frames: {len(controlnet_frames)}. Extended frames count: {duplicate_frames_count}')
|
| 653 |
+
if duplicate_frames_count > 0:
|
| 654 |
+
# Simple duplicate first frame
|
| 655 |
+
# controlnet_frames = [controlnet_frames[0]] * duplicate_frames_count + controlnet_frames
|
| 656 |
+
# Or reversed duplicate frames ?
|
| 657 |
+
reversed_controlnet_frames = list(reversed(controlnet_frames))
|
| 658 |
+
controlnet_sum_frames = controlnet_frames + reversed_controlnet_frames
|
| 659 |
+
reversed_chunks_count = num_frames // len(controlnet_sum_frames)
|
| 660 |
+
controlnet_frames = [*controlnet_sum_frames]
|
| 661 |
+
for _ in range(reversed_chunks_count):
|
| 662 |
+
controlnet_frames += controlnet_sum_frames
|
| 663 |
+
|
| 664 |
+
# If controlnet frames count greater than num_frames parameter
|
| 665 |
+
controlnet_frames = controlnet_frames[:num_frames]
|
| 666 |
+
|
| 667 |
+
controlnet_latents = prepare_controlnet_frames(
|
| 668 |
+
controlnet_frames,
|
| 669 |
+
height,
|
| 670 |
+
width,
|
| 671 |
+
dtype=self.controlnet.dtype,
|
| 672 |
+
device=self.controlnet.device
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# 7. Denoising loop
|
| 676 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 677 |
+
self._num_timesteps = len(timesteps)
|
| 678 |
+
|
| 679 |
+
if self.config.boundary_ratio is not None:
|
| 680 |
+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
| 681 |
+
else:
|
| 682 |
+
boundary_timestep = None
|
| 683 |
+
|
| 684 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 685 |
+
for i, t in enumerate(timesteps):
|
| 686 |
+
if self.interrupt:
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
self._current_timestep = t
|
| 690 |
+
|
| 691 |
+
if boundary_timestep is None or t >= boundary_timestep:
|
| 692 |
+
# wan2.1 or high-noise stage in wan2.2
|
| 693 |
+
current_model = self.transformer
|
| 694 |
+
current_guidance_scale = guidance_scale
|
| 695 |
+
else:
|
| 696 |
+
# low-noise stage in wan2.2
|
| 697 |
+
current_model = self.transformer_2
|
| 698 |
+
current_guidance_scale = guidance_scale_2
|
| 699 |
+
|
| 700 |
+
latent_model_input = latents.to(transformer_dtype)
|
| 701 |
+
if self.config.expand_timesteps:
|
| 702 |
+
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
| 703 |
+
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
| 704 |
+
# batch_size, seq_len
|
| 705 |
+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
| 706 |
+
else:
|
| 707 |
+
timestep = t.expand(latents.shape[0])
|
| 708 |
+
|
| 709 |
+
controlnet_states = None
|
| 710 |
+
current_sampling_percent = i / len(timesteps)
|
| 711 |
+
if (controlnet_latents is not None) and (controlnet_guidance_start <= current_sampling_percent < controlnet_guidance_end):
|
| 712 |
+
controlnet_states = self.controlnet(
|
| 713 |
+
hidden_states=latent_model_input,
|
| 714 |
+
timestep=timestep,
|
| 715 |
+
encoder_hidden_states=prompt_embeds,
|
| 716 |
+
attention_kwargs=attention_kwargs,
|
| 717 |
+
controlnet_states=controlnet_latents,
|
| 718 |
+
return_dict=False,
|
| 719 |
+
)[0]
|
| 720 |
+
if isinstance(controlnet_states, (tuple, list)):
|
| 721 |
+
controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states]
|
| 722 |
+
else:
|
| 723 |
+
controlnet_states = controlnet_states.to(dtype=self.transformer.dtype)
|
| 724 |
+
|
| 725 |
+
with current_model.cache_context("cond"):
|
| 726 |
+
noise_pred = current_model(
|
| 727 |
+
hidden_states=latent_model_input,
|
| 728 |
+
timestep=timestep,
|
| 729 |
+
encoder_hidden_states=prompt_embeds,
|
| 730 |
+
controlnet_states=controlnet_states,
|
| 731 |
+
controlnet_weight=controlnet_weight,
|
| 732 |
+
controlnet_stride=controlnet_stride,
|
| 733 |
+
teacache=self.teacache,
|
| 734 |
+
attention_kwargs=attention_kwargs,
|
| 735 |
+
return_dict=False,
|
| 736 |
+
)[0]
|
| 737 |
+
|
| 738 |
+
if self.do_classifier_free_guidance:
|
| 739 |
+
with current_model.cache_context("uncond"):
|
| 740 |
+
noise_uncond = current_model(
|
| 741 |
+
hidden_states=latent_model_input,
|
| 742 |
+
timestep=timestep,
|
| 743 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 744 |
+
controlnet_states=controlnet_states,
|
| 745 |
+
controlnet_weight=controlnet_weight,
|
| 746 |
+
controlnet_stride=controlnet_stride,
|
| 747 |
+
teacache=self.teacache,
|
| 748 |
+
attention_kwargs=attention_kwargs,
|
| 749 |
+
return_dict=False,
|
| 750 |
+
)[0]
|
| 751 |
+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
| 752 |
+
|
| 753 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 754 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 755 |
+
|
| 756 |
+
if callback_on_step_end is not None:
|
| 757 |
+
callback_kwargs = {}
|
| 758 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 759 |
+
callback_kwargs[k] = locals()[k]
|
| 760 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 761 |
+
|
| 762 |
+
latents = callback_outputs.pop("latents", latents)
|
| 763 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 764 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 765 |
+
|
| 766 |
+
# call the callback, if provided
|
| 767 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 768 |
+
progress_bar.update()
|
| 769 |
+
|
| 770 |
+
if XLA_AVAILABLE:
|
| 771 |
+
xm.mark_step()
|
| 772 |
+
|
| 773 |
+
self._current_timestep = None
|
| 774 |
+
self.teacache = None
|
| 775 |
+
|
| 776 |
+
if not output_type == "latent":
|
| 777 |
+
latents = latents.to(self.vae.dtype)
|
| 778 |
+
latents_mean = (
|
| 779 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 780 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 781 |
+
.to(latents.device, latents.dtype)
|
| 782 |
+
)
|
| 783 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 784 |
+
latents.device, latents.dtype
|
| 785 |
+
)
|
| 786 |
+
latents = latents / latents_std + latents_mean
|
| 787 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 788 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 789 |
+
else:
|
| 790 |
+
video = latents
|
| 791 |
+
|
| 792 |
+
# Offload all models
|
| 793 |
+
self.maybe_free_model_hooks()
|
| 794 |
+
|
| 795 |
+
if not return_dict:
|
| 796 |
+
return (video,)
|
| 797 |
+
|
| 798 |
+
return WanPipelineOutput(frames=video)
|
wan_teacache.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
coefficients = {
|
| 6 |
+
"DEFAULT": [-1.12343328e+02, 1.50680483e+02, -5.15023303e+01, 6.24892431e+00, 6.85022158e-02],
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TeaCache:
|
| 11 |
+
def __init__(self, num_inference_steps, model_name, treshold=0.3, start_step_treshold=0.1, end_step_treshold=0.9):
|
| 12 |
+
self.input_bank = []
|
| 13 |
+
self.current_step = 0
|
| 14 |
+
self.accumulated_distance = 0.0
|
| 15 |
+
self.num_inference_steps = num_inference_steps * 2
|
| 16 |
+
self.start_step_teacache = int(num_inference_steps * start_step_treshold) * 2
|
| 17 |
+
self.end_step_teacache = int(num_inference_steps * end_step_treshold) * 2
|
| 18 |
+
self.treshold = treshold # [0.3, 0.5, 0.7, 0.9]
|
| 19 |
+
self.coefficients = coefficients[model_name]
|
| 20 |
+
self.step_name = "even"
|
| 21 |
+
self.init_memory()
|
| 22 |
+
|
| 23 |
+
def init_memory(self):
|
| 24 |
+
self.accumulated_distance = {
|
| 25 |
+
"even": 0.0,
|
| 26 |
+
"odd": 0.0,
|
| 27 |
+
}
|
| 28 |
+
self.flow_direction = {
|
| 29 |
+
"even": None,
|
| 30 |
+
"odd": None,
|
| 31 |
+
}
|
| 32 |
+
self.previous_modulated_input = {
|
| 33 |
+
"even": None,
|
| 34 |
+
"odd": None,
|
| 35 |
+
}
|
| 36 |
+
# print("TEACACHE MEMORY HAS BEEN CREATED")
|
| 37 |
+
|
| 38 |
+
def check_for_using_cached_value(self, modulated_input):
|
| 39 |
+
use_tea_cache = (self.treshold > 0.0) and (self.start_step_teacache <= self.current_step < self.end_step_teacache)
|
| 40 |
+
self.step_name = "even" if self.current_step % 2 == 0 else "odd"
|
| 41 |
+
|
| 42 |
+
use_cached_value = False
|
| 43 |
+
if use_tea_cache:
|
| 44 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 45 |
+
current_disntace = rescale_func(
|
| 46 |
+
self.calculate_distance(modulated_input, self.previous_modulated_input[self.step_name])
|
| 47 |
+
)
|
| 48 |
+
self.accumulated_distance[self.step_name] += current_disntace
|
| 49 |
+
|
| 50 |
+
if self.accumulated_distance[self.step_name] < self.treshold:
|
| 51 |
+
use_cached_value = True
|
| 52 |
+
else:
|
| 53 |
+
use_cached_value = False
|
| 54 |
+
self.accumulated_distance[self.step_name] = 0.0
|
| 55 |
+
|
| 56 |
+
if self.step_name == "even":
|
| 57 |
+
self.input_bank.append(modulated_input.cpu())
|
| 58 |
+
|
| 59 |
+
self.previous_modulated_input[self.step_name] = modulated_input.clone()
|
| 60 |
+
# if use_tea_cache:
|
| 61 |
+
# print(f"[ STEP:{self.current_step} | USE CACHED VALUE: {use_cached_value} | ACCUMULATED DISTANCE: {self.accumulated_distance} | CURRENT DISTANCE: {current_disntace} ]")
|
| 62 |
+
return use_cached_value
|
| 63 |
+
|
| 64 |
+
def use_cache(self, hidden_states):
|
| 65 |
+
return hidden_states + self.flow_direction[self.step_name].to(device=hidden_states.device)
|
| 66 |
+
|
| 67 |
+
def calculate_distance(self, previous_tensor, current_tensor):
|
| 68 |
+
relative_l1_distance = torch.abs(
|
| 69 |
+
previous_tensor - current_tensor
|
| 70 |
+
).mean() / torch.abs(previous_tensor).mean()
|
| 71 |
+
return relative_l1_distance.to(torch.float32).cpu().item()
|
| 72 |
+
|
| 73 |
+
def update(self, flow_direction):
|
| 74 |
+
self.flow_direction[self.step_name] = flow_direction
|
| 75 |
+
self.current_step += 1
|
| 76 |
+
if self.current_step == self.num_inference_steps:
|
| 77 |
+
self.current_step = 0
|
| 78 |
+
self.init_memory()
|
wan_transformer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import WanTransformer3DModel
|
| 5 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 6 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 7 |
+
from wan_teacache import TeaCache
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CustomWanTransformer3DModel(WanTransformer3DModel):
|
| 14 |
+
def forward(
|
| 15 |
+
self,
|
| 16 |
+
hidden_states: torch.Tensor,
|
| 17 |
+
timestep: torch.LongTensor,
|
| 18 |
+
encoder_hidden_states: torch.Tensor,
|
| 19 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 20 |
+
return_dict: bool = True,
|
| 21 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 22 |
+
|
| 23 |
+
controlnet_states: torch.Tensor = None,
|
| 24 |
+
controlnet_weight: Optional[float] = 1.0,
|
| 25 |
+
controlnet_stride: Optional[int] = 1,
|
| 26 |
+
teacache: Optional[TeaCache] = None,
|
| 27 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 28 |
+
if attention_kwargs is not None:
|
| 29 |
+
attention_kwargs = attention_kwargs.copy()
|
| 30 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 31 |
+
else:
|
| 32 |
+
lora_scale = 1.0
|
| 33 |
+
|
| 34 |
+
if USE_PEFT_BACKEND:
|
| 35 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 36 |
+
scale_lora_layers(self, lora_scale)
|
| 37 |
+
else:
|
| 38 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 39 |
+
logger.warning(
|
| 40 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 44 |
+
p_t, p_h, p_w = self.config.patch_size
|
| 45 |
+
post_patch_num_frames = num_frames // p_t
|
| 46 |
+
post_patch_height = height // p_h
|
| 47 |
+
post_patch_width = width // p_w
|
| 48 |
+
|
| 49 |
+
rotary_emb = self.rope(hidden_states)
|
| 50 |
+
|
| 51 |
+
hidden_states = self.patch_embedding(hidden_states)
|
| 52 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 53 |
+
|
| 54 |
+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
| 55 |
+
if timestep.ndim == 2:
|
| 56 |
+
ts_seq_len = timestep.shape[1]
|
| 57 |
+
timestep = timestep.flatten() # batch_size * seq_len
|
| 58 |
+
else:
|
| 59 |
+
ts_seq_len = None
|
| 60 |
+
|
| 61 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
| 62 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
| 63 |
+
)
|
| 64 |
+
if ts_seq_len is not None:
|
| 65 |
+
# batch_size, seq_len, 6, inner_dim
|
| 66 |
+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
| 67 |
+
else:
|
| 68 |
+
# batch_size, 6, inner_dim
|
| 69 |
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
| 70 |
+
|
| 71 |
+
if encoder_hidden_states_image is not None:
|
| 72 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 73 |
+
|
| 74 |
+
use_cached_value = False
|
| 75 |
+
original_hidden_states = None
|
| 76 |
+
if (teacache is not None) and (teacache.treshold > 0.0):
|
| 77 |
+
original_hidden_states = hidden_states.clone()
|
| 78 |
+
use_cached_value = teacache.check_for_using_cached_value(temb)
|
| 79 |
+
|
| 80 |
+
if use_cached_value:
|
| 81 |
+
hidden_states = teacache.use_cache(hidden_states)
|
| 82 |
+
else:
|
| 83 |
+
# 4. Transformer blocks
|
| 84 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 85 |
+
for i, block in enumerate(self.blocks):
|
| 86 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 87 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
|
| 91 |
+
hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
|
| 92 |
+
else:
|
| 93 |
+
for i, block in enumerate(self.blocks):
|
| 94 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
| 95 |
+
|
| 96 |
+
if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
|
| 97 |
+
hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
|
| 98 |
+
|
| 99 |
+
if (teacache is not None) and (teacache.treshold > 0.0):
|
| 100 |
+
teacache.update(hidden_states - original_hidden_states)
|
| 101 |
+
|
| 102 |
+
# 5. Output norm, projection & unpatchify
|
| 103 |
+
if temb.ndim == 3:
|
| 104 |
+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
| 105 |
+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
| 106 |
+
shift = shift.squeeze(2)
|
| 107 |
+
scale = scale.squeeze(2)
|
| 108 |
+
else:
|
| 109 |
+
# batch_size, inner_dim
|
| 110 |
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
| 111 |
+
|
| 112 |
+
# Move the shift and scale tensors to the same device as hidden_states.
|
| 113 |
+
# When using multi-GPU inference via accelerate these will be on the
|
| 114 |
+
# first device rather than the last device, which hidden_states ends up
|
| 115 |
+
# on.
|
| 116 |
+
shift = shift.to(hidden_states.device)
|
| 117 |
+
scale = scale.to(hidden_states.device)
|
| 118 |
+
|
| 119 |
+
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
| 120 |
+
hidden_states = self.proj_out(hidden_states)
|
| 121 |
+
|
| 122 |
+
hidden_states = hidden_states.reshape(
|
| 123 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
| 124 |
+
)
|
| 125 |
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
| 126 |
+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 127 |
+
|
| 128 |
+
if USE_PEFT_BACKEND:
|
| 129 |
+
# remove `lora_scale` from each PEFT layer
|
| 130 |
+
unscale_lora_layers(self, lora_scale)
|
| 131 |
+
|
| 132 |
+
if not return_dict:
|
| 133 |
+
return (output,)
|
| 134 |
+
|
| 135 |
+
return Transformer2DModelOutput(sample=output)
|
workflows/sam2.1_optimized.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
workflows/sam_optimized.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
workflows/vace_optimized.json
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"id": "960108a5-bf9d-497f-a6e5-4c5c3e41c056",
|
| 3 |
+
"revision": 0,
|
| 4 |
+
"last_node_id": 37,
|
| 5 |
+
"last_link_id": 93,
|
| 6 |
+
"nodes": [
|
| 7 |
+
{
|
| 8 |
+
"id": 11,
|
| 9 |
+
"type": "ModelSamplingSD3",
|
| 10 |
+
"pos": [
|
| 11 |
+
442.7779541015625,
|
| 12 |
+
942.9921264648438
|
| 13 |
+
],
|
| 14 |
+
"size": [
|
| 15 |
+
210,
|
| 16 |
+
58
|
| 17 |
+
],
|
| 18 |
+
"flags": {
|
| 19 |
+
"collapsed": false
|
| 20 |
+
},
|
| 21 |
+
"order": 9,
|
| 22 |
+
"mode": 0,
|
| 23 |
+
"inputs": [
|
| 24 |
+
{
|
| 25 |
+
"name": "model",
|
| 26 |
+
"type": "MODEL",
|
| 27 |
+
"link": 91
|
| 28 |
+
}
|
| 29 |
+
],
|
| 30 |
+
"outputs": [
|
| 31 |
+
{
|
| 32 |
+
"name": "MODEL",
|
| 33 |
+
"type": "MODEL",
|
| 34 |
+
"links": [
|
| 35 |
+
58
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"properties": {
|
| 40 |
+
"Node name for S&R": "ModelSamplingSD3"
|
| 41 |
+
},
|
| 42 |
+
"widgets_values": [
|
| 43 |
+
2.0000000000000004
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"id": 32,
|
| 48 |
+
"type": "VHS_LoadVideo",
|
| 49 |
+
"pos": [
|
| 50 |
+
120.05851745605469,
|
| 51 |
+
397.98248291015625
|
| 52 |
+
],
|
| 53 |
+
"size": [
|
| 54 |
+
253.279296875,
|
| 55 |
+
310
|
| 56 |
+
],
|
| 57 |
+
"flags": {},
|
| 58 |
+
"order": 6,
|
| 59 |
+
"mode": 0,
|
| 60 |
+
"inputs": [
|
| 61 |
+
{
|
| 62 |
+
"name": "meta_batch",
|
| 63 |
+
"shape": 7,
|
| 64 |
+
"type": "VHS_BatchManager",
|
| 65 |
+
"link": null
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "vae",
|
| 69 |
+
"shape": 7,
|
| 70 |
+
"type": "VAE",
|
| 71 |
+
"link": null
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"name": "frame_load_cap",
|
| 75 |
+
"type": "INT",
|
| 76 |
+
"widget": {
|
| 77 |
+
"name": "frame_load_cap"
|
| 78 |
+
},
|
| 79 |
+
"link": 76
|
| 80 |
+
}
|
| 81 |
+
],
|
| 82 |
+
"outputs": [
|
| 83 |
+
{
|
| 84 |
+
"name": "IMAGE",
|
| 85 |
+
"type": "IMAGE",
|
| 86 |
+
"links": [
|
| 87 |
+
86
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"name": "frame_count",
|
| 92 |
+
"type": "INT",
|
| 93 |
+
"links": [
|
| 94 |
+
78
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"name": "audio",
|
| 99 |
+
"type": "AUDIO",
|
| 100 |
+
"links": null
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"name": "video_info",
|
| 104 |
+
"type": "VHS_VIDEOINFO",
|
| 105 |
+
"links": null
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"title": "上传遮罩合成视频",
|
| 109 |
+
"properties": {
|
| 110 |
+
"Node name for S&R": "VHS_LoadVideo"
|
| 111 |
+
},
|
| 112 |
+
"widgets_values": {
|
| 113 |
+
"video": "sam2.1_00182.mp4",
|
| 114 |
+
"force_rate": 16,
|
| 115 |
+
"custom_width": 0,
|
| 116 |
+
"custom_height": 0,
|
| 117 |
+
"frame_load_cap": 0,
|
| 118 |
+
"skip_first_frames": 0,
|
| 119 |
+
"select_every_nth": 1,
|
| 120 |
+
"format": "Wan",
|
| 121 |
+
"choose video to upload": "image",
|
| 122 |
+
"videopreview": {
|
| 123 |
+
"hidden": false,
|
| 124 |
+
"paused": false,
|
| 125 |
+
"params": {
|
| 126 |
+
"filename": "sam2.1_00182.mp4",
|
| 127 |
+
"type": "input",
|
| 128 |
+
"format": "video/mp4",
|
| 129 |
+
"force_rate": 16,
|
| 130 |
+
"custom_width": 0,
|
| 131 |
+
"custom_height": 0,
|
| 132 |
+
"frame_load_cap": 0,
|
| 133 |
+
"skip_first_frames": 0,
|
| 134 |
+
"select_every_nth": 1
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"id": 33,
|
| 141 |
+
"type": "VHS_LoadVideo",
|
| 142 |
+
"pos": [
|
| 143 |
+
112.58995056152344,
|
| 144 |
+
753.9783325195312
|
| 145 |
+
],
|
| 146 |
+
"size": [
|
| 147 |
+
253.279296875,
|
| 148 |
+
310
|
| 149 |
+
],
|
| 150 |
+
"flags": {},
|
| 151 |
+
"order": 0,
|
| 152 |
+
"mode": 0,
|
| 153 |
+
"inputs": [
|
| 154 |
+
{
|
| 155 |
+
"name": "meta_batch",
|
| 156 |
+
"shape": 7,
|
| 157 |
+
"type": "VHS_BatchManager",
|
| 158 |
+
"link": null
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"name": "vae",
|
| 162 |
+
"shape": 7,
|
| 163 |
+
"type": "VAE",
|
| 164 |
+
"link": null
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"outputs": [
|
| 168 |
+
{
|
| 169 |
+
"name": "IMAGE",
|
| 170 |
+
"type": "IMAGE",
|
| 171 |
+
"links": [
|
| 172 |
+
85
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"name": "frame_count",
|
| 177 |
+
"type": "INT",
|
| 178 |
+
"links": [
|
| 179 |
+
76
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"name": "audio",
|
| 184 |
+
"type": "AUDIO",
|
| 185 |
+
"links": null
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"name": "video_info",
|
| 189 |
+
"type": "VHS_VIDEOINFO",
|
| 190 |
+
"links": null
|
| 191 |
+
}
|
| 192 |
+
],
|
| 193 |
+
"title": "上传遮罩视频(黑白那个)",
|
| 194 |
+
"properties": {
|
| 195 |
+
"Node name for S&R": "VHS_LoadVideo"
|
| 196 |
+
},
|
| 197 |
+
"widgets_values": {
|
| 198 |
+
"video": "sam2.1_00181.mp4",
|
| 199 |
+
"force_rate": 0,
|
| 200 |
+
"custom_width": 0,
|
| 201 |
+
"custom_height": 0,
|
| 202 |
+
"frame_load_cap": 0,
|
| 203 |
+
"skip_first_frames": 0,
|
| 204 |
+
"select_every_nth": 1,
|
| 205 |
+
"format": "Wan",
|
| 206 |
+
"choose video to upload": "image",
|
| 207 |
+
"videopreview": {
|
| 208 |
+
"hidden": false,
|
| 209 |
+
"paused": false,
|
| 210 |
+
"params": {
|
| 211 |
+
"filename": "sam2.1_00181.mp4",
|
| 212 |
+
"type": "input",
|
| 213 |
+
"format": "video/mp4",
|
| 214 |
+
"force_rate": 0,
|
| 215 |
+
"custom_width": 0,
|
| 216 |
+
"custom_height": 0,
|
| 217 |
+
"frame_load_cap": 0,
|
| 218 |
+
"skip_first_frames": 0,
|
| 219 |
+
"select_every_nth": 1
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"id": 35,
|
| 226 |
+
"type": "GrowMask",
|
| 227 |
+
"pos": [
|
| 228 |
+
722.2931518554688,
|
| 229 |
+
1093.416015625
|
| 230 |
+
],
|
| 231 |
+
"size": [
|
| 232 |
+
270,
|
| 233 |
+
82
|
| 234 |
+
],
|
| 235 |
+
"flags": {},
|
| 236 |
+
"order": 10,
|
| 237 |
+
"mode": 0,
|
| 238 |
+
"inputs": [
|
| 239 |
+
{
|
| 240 |
+
"name": "mask",
|
| 241 |
+
"type": "MASK",
|
| 242 |
+
"link": 79
|
| 243 |
+
}
|
| 244 |
+
],
|
| 245 |
+
"outputs": [
|
| 246 |
+
{
|
| 247 |
+
"name": "MASK",
|
| 248 |
+
"type": "MASK",
|
| 249 |
+
"links": [
|
| 250 |
+
80
|
| 251 |
+
]
|
| 252 |
+
}
|
| 253 |
+
],
|
| 254 |
+
"properties": {
|
| 255 |
+
"Node name for S&R": "GrowMask"
|
| 256 |
+
},
|
| 257 |
+
"widgets_values": [
|
| 258 |
+
5,
|
| 259 |
+
true
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"id": 6,
|
| 264 |
+
"type": "CLIPLoader",
|
| 265 |
+
"pos": [
|
| 266 |
+
111.71733093261719,
|
| 267 |
+
1112.0469970703125
|
| 268 |
+
],
|
| 269 |
+
"size": [
|
| 270 |
+
210,
|
| 271 |
+
106
|
| 272 |
+
],
|
| 273 |
+
"flags": {},
|
| 274 |
+
"order": 1,
|
| 275 |
+
"mode": 0,
|
| 276 |
+
"inputs": [],
|
| 277 |
+
"outputs": [
|
| 278 |
+
{
|
| 279 |
+
"name": "CLIP",
|
| 280 |
+
"type": "CLIP",
|
| 281 |
+
"slot_index": 0,
|
| 282 |
+
"links": [
|
| 283 |
+
92,
|
| 284 |
+
93
|
| 285 |
+
]
|
| 286 |
+
}
|
| 287 |
+
],
|
| 288 |
+
"properties": {
|
| 289 |
+
"Node name for S&R": "CLIPLoader"
|
| 290 |
+
},
|
| 291 |
+
"widgets_values": [
|
| 292 |
+
"umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
| 293 |
+
"wan",
|
| 294 |
+
"cpu"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"id": 8,
|
| 299 |
+
"type": "UNETLoader",
|
| 300 |
+
"pos": [
|
| 301 |
+
153.8439178466797,
|
| 302 |
+
269.8687438964844
|
| 303 |
+
],
|
| 304 |
+
"size": [
|
| 305 |
+
210,
|
| 306 |
+
82
|
| 307 |
+
],
|
| 308 |
+
"flags": {},
|
| 309 |
+
"order": 2,
|
| 310 |
+
"mode": 0,
|
| 311 |
+
"inputs": [],
|
| 312 |
+
"outputs": [
|
| 313 |
+
{
|
| 314 |
+
"name": "MODEL",
|
| 315 |
+
"type": "MODEL",
|
| 316 |
+
"slot_index": 0,
|
| 317 |
+
"links": [
|
| 318 |
+
91
|
| 319 |
+
]
|
| 320 |
+
}
|
| 321 |
+
],
|
| 322 |
+
"properties": {
|
| 323 |
+
"Node name for S&R": "UNETLoader"
|
| 324 |
+
},
|
| 325 |
+
"widgets_values": [
|
| 326 |
+
"Wan2.1_T2V_14B_FusionX_VACE-FP8_e4m3fn.safetensors",
|
| 327 |
+
"default"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"id": 15,
|
| 332 |
+
"type": "CLIPTextEncode",
|
| 333 |
+
"pos": [
|
| 334 |
+
451.8983154296875,
|
| 335 |
+
311.4078674316406
|
| 336 |
+
],
|
| 337 |
+
"size": [
|
| 338 |
+
494.83367919921875,
|
| 339 |
+
119.57742309570312
|
| 340 |
+
],
|
| 341 |
+
"flags": {},
|
| 342 |
+
"order": 8,
|
| 343 |
+
"mode": 0,
|
| 344 |
+
"inputs": [
|
| 345 |
+
{
|
| 346 |
+
"name": "clip",
|
| 347 |
+
"type": "CLIP",
|
| 348 |
+
"link": 93
|
| 349 |
+
}
|
| 350 |
+
],
|
| 351 |
+
"outputs": [
|
| 352 |
+
{
|
| 353 |
+
"name": "CONDITIONING",
|
| 354 |
+
"type": "CONDITIONING",
|
| 355 |
+
"slot_index": 0,
|
| 356 |
+
"links": [
|
| 357 |
+
18
|
| 358 |
+
]
|
| 359 |
+
}
|
| 360 |
+
],
|
| 361 |
+
"title": "CLIP Text Encode (Positive Prompt)",
|
| 362 |
+
"properties": {
|
| 363 |
+
"Node name for S&R": "CLIPTextEncode"
|
| 364 |
+
},
|
| 365 |
+
"widgets_values": [
|
| 366 |
+
"赤身裸体的亚洲女子,主观视角,她抖动着胸部,乳房剧烈晃动。bustygrid. a completely naked asian woman with pale skin and huge breasts. she has straight dark hair. she is completely naked, and wearing pointed stilleto heels. bare legs, bare calf, sky-high heeled pumps. remove all clothes. "
|
| 367 |
+
],
|
| 368 |
+
"color": "#232",
|
| 369 |
+
"bgcolor": "#353"
|
| 370 |
+
},
|
| 371 |
+
{
|
| 372 |
+
"id": 2,
|
| 373 |
+
"type": "CLIPTextEncode",
|
| 374 |
+
"pos": [
|
| 375 |
+
453.97589111328125,
|
| 376 |
+
487.16363525390625
|
| 377 |
+
],
|
| 378 |
+
"size": [
|
| 379 |
+
486.9105529785156,
|
| 380 |
+
107.89899444580078
|
| 381 |
+
],
|
| 382 |
+
"flags": {
|
| 383 |
+
"collapsed": false
|
| 384 |
+
},
|
| 385 |
+
"order": 7,
|
| 386 |
+
"mode": 0,
|
| 387 |
+
"inputs": [
|
| 388 |
+
{
|
| 389 |
+
"name": "clip",
|
| 390 |
+
"type": "CLIP",
|
| 391 |
+
"link": 92
|
| 392 |
+
}
|
| 393 |
+
],
|
| 394 |
+
"outputs": [
|
| 395 |
+
{
|
| 396 |
+
"name": "CONDITIONING",
|
| 397 |
+
"type": "CONDITIONING",
|
| 398 |
+
"slot_index": 0,
|
| 399 |
+
"links": [
|
| 400 |
+
19
|
| 401 |
+
]
|
| 402 |
+
}
|
| 403 |
+
],
|
| 404 |
+
"title": "CLIP Text Encode (Negative Prompt)",
|
| 405 |
+
"properties": {
|
| 406 |
+
"Node name for S&R": "CLIPTextEncode"
|
| 407 |
+
},
|
| 408 |
+
"widgets_values": [
|
| 409 |
+
"白种人,黑种人,阴部遮挡,内裤,六根手指,低像素,模糊,像素点,多余的手臂,肢体扭曲,手指模糊,脸部改变,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 410 |
+
],
|
| 411 |
+
"color": "#322",
|
| 412 |
+
"bgcolor": "#533"
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"id": 3,
|
| 416 |
+
"type": "VAELoader",
|
| 417 |
+
"pos": [
|
| 418 |
+
433.6892395019531,
|
| 419 |
+
643.1557006835938
|
| 420 |
+
],
|
| 421 |
+
"size": [
|
| 422 |
+
210,
|
| 423 |
+
58
|
| 424 |
+
],
|
| 425 |
+
"flags": {
|
| 426 |
+
"collapsed": false
|
| 427 |
+
},
|
| 428 |
+
"order": 3,
|
| 429 |
+
"mode": 0,
|
| 430 |
+
"inputs": [],
|
| 431 |
+
"outputs": [
|
| 432 |
+
{
|
| 433 |
+
"name": "VAE",
|
| 434 |
+
"type": "VAE",
|
| 435 |
+
"links": [
|
| 436 |
+
16,
|
| 437 |
+
20
|
| 438 |
+
]
|
| 439 |
+
}
|
| 440 |
+
],
|
| 441 |
+
"properties": {
|
| 442 |
+
"Node name for S&R": "VAELoader"
|
| 443 |
+
},
|
| 444 |
+
"widgets_values": [
|
| 445 |
+
"Wan2.1_VAE.safetensors"
|
| 446 |
+
]
|
| 447 |
+
},
|
| 448 |
+
{
|
| 449 |
+
"id": 17,
|
| 450 |
+
"type": "WanVaceToVideo",
|
| 451 |
+
"pos": [
|
| 452 |
+
706.262939453125,
|
| 453 |
+
658.4074096679688
|
| 454 |
+
],
|
| 455 |
+
"size": [
|
| 456 |
+
224.32986450195312,
|
| 457 |
+
254
|
| 458 |
+
],
|
| 459 |
+
"flags": {},
|
| 460 |
+
"order": 11,
|
| 461 |
+
"mode": 0,
|
| 462 |
+
"inputs": [
|
| 463 |
+
{
|
| 464 |
+
"name": "positive",
|
| 465 |
+
"type": "CONDITIONING",
|
| 466 |
+
"link": 18
|
| 467 |
+
},
|
| 468 |
+
{
|
| 469 |
+
"name": "negative",
|
| 470 |
+
"type": "CONDITIONING",
|
| 471 |
+
"link": 19
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
"name": "vae",
|
| 475 |
+
"type": "VAE",
|
| 476 |
+
"link": 20
|
| 477 |
+
},
|
| 478 |
+
{
|
| 479 |
+
"name": "control_video",
|
| 480 |
+
"shape": 7,
|
| 481 |
+
"type": "IMAGE",
|
| 482 |
+
"link": 86
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"name": "control_masks",
|
| 486 |
+
"shape": 7,
|
| 487 |
+
"type": "MASK",
|
| 488 |
+
"link": 80
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"name": "reference_image",
|
| 492 |
+
"shape": 7,
|
| 493 |
+
"type": "IMAGE",
|
| 494 |
+
"link": 22
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
"name": "length",
|
| 498 |
+
"type": "INT",
|
| 499 |
+
"widget": {
|
| 500 |
+
"name": "length"
|
| 501 |
+
},
|
| 502 |
+
"link": 78
|
| 503 |
+
}
|
| 504 |
+
],
|
| 505 |
+
"outputs": [
|
| 506 |
+
{
|
| 507 |
+
"name": "positive",
|
| 508 |
+
"type": "CONDITIONING",
|
| 509 |
+
"links": [
|
| 510 |
+
12
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"name": "negative",
|
| 515 |
+
"type": "CONDITIONING",
|
| 516 |
+
"links": [
|
| 517 |
+
13
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"name": "latent",
|
| 522 |
+
"type": "LATENT",
|
| 523 |
+
"links": [
|
| 524 |
+
14
|
| 525 |
+
]
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"name": "trim_latent",
|
| 529 |
+
"type": "INT",
|
| 530 |
+
"links": [
|
| 531 |
+
10
|
| 532 |
+
]
|
| 533 |
+
}
|
| 534 |
+
],
|
| 535 |
+
"properties": {
|
| 536 |
+
"Node name for S&R": "WanVaceToVideo"
|
| 537 |
+
},
|
| 538 |
+
"widgets_values": [
|
| 539 |
+
480,
|
| 540 |
+
320,
|
| 541 |
+
49,
|
| 542 |
+
1,
|
| 543 |
+
1.0000000000000002
|
| 544 |
+
]
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"id": 12,
|
| 548 |
+
"type": "TrimVideoLatent",
|
| 549 |
+
"pos": [
|
| 550 |
+
746.625,
|
| 551 |
+
985.3895874023438
|
| 552 |
+
],
|
| 553 |
+
"size": [
|
| 554 |
+
226.2460174560547,
|
| 555 |
+
58
|
| 556 |
+
],
|
| 557 |
+
"flags": {
|
| 558 |
+
"collapsed": false
|
| 559 |
+
},
|
| 560 |
+
"order": 13,
|
| 561 |
+
"mode": 0,
|
| 562 |
+
"inputs": [
|
| 563 |
+
{
|
| 564 |
+
"name": "samples",
|
| 565 |
+
"type": "LATENT",
|
| 566 |
+
"link": 9
|
| 567 |
+
},
|
| 568 |
+
{
|
| 569 |
+
"name": "trim_amount",
|
| 570 |
+
"type": "INT",
|
| 571 |
+
"widget": {
|
| 572 |
+
"name": "trim_amount"
|
| 573 |
+
},
|
| 574 |
+
"link": 10
|
| 575 |
+
}
|
| 576 |
+
],
|
| 577 |
+
"outputs": [
|
| 578 |
+
{
|
| 579 |
+
"name": "LATENT",
|
| 580 |
+
"type": "LATENT",
|
| 581 |
+
"links": [
|
| 582 |
+
15
|
| 583 |
+
]
|
| 584 |
+
}
|
| 585 |
+
],
|
| 586 |
+
"properties": {
|
| 587 |
+
"Node name for S&R": "TrimVideoLatent"
|
| 588 |
+
},
|
| 589 |
+
"widgets_values": [
|
| 590 |
+
0
|
| 591 |
+
]
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
"id": 13,
|
| 595 |
+
"type": "KSampler",
|
| 596 |
+
"pos": [
|
| 597 |
+
985.894775390625,
|
| 598 |
+
349.17340087890625
|
| 599 |
+
],
|
| 600 |
+
"size": [
|
| 601 |
+
210,
|
| 602 |
+
605.3333129882812
|
| 603 |
+
],
|
| 604 |
+
"flags": {},
|
| 605 |
+
"order": 12,
|
| 606 |
+
"mode": 0,
|
| 607 |
+
"inputs": [
|
| 608 |
+
{
|
| 609 |
+
"name": "model",
|
| 610 |
+
"type": "MODEL",
|
| 611 |
+
"link": 58
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"name": "positive",
|
| 615 |
+
"type": "CONDITIONING",
|
| 616 |
+
"link": 12
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"name": "negative",
|
| 620 |
+
"type": "CONDITIONING",
|
| 621 |
+
"link": 13
|
| 622 |
+
},
|
| 623 |
+
{
|
| 624 |
+
"name": "latent_image",
|
| 625 |
+
"type": "LATENT",
|
| 626 |
+
"link": 14
|
| 627 |
+
}
|
| 628 |
+
],
|
| 629 |
+
"outputs": [
|
| 630 |
+
{
|
| 631 |
+
"name": "LATENT",
|
| 632 |
+
"type": "LATENT",
|
| 633 |
+
"slot_index": 0,
|
| 634 |
+
"links": [
|
| 635 |
+
9
|
| 636 |
+
]
|
| 637 |
+
}
|
| 638 |
+
],
|
| 639 |
+
"properties": {
|
| 640 |
+
"Node name for S&R": "KSampler"
|
| 641 |
+
},
|
| 642 |
+
"widgets_values": [
|
| 643 |
+
109768395777514,
|
| 644 |
+
"randomize",
|
| 645 |
+
10,
|
| 646 |
+
1,
|
| 647 |
+
"uni_pc_bh2",
|
| 648 |
+
"simple",
|
| 649 |
+
1
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"id": 14,
|
| 654 |
+
"type": "VAEDecode",
|
| 655 |
+
"pos": [
|
| 656 |
+
973.5802612304688,
|
| 657 |
+
1001.729736328125
|
| 658 |
+
],
|
| 659 |
+
"size": [
|
| 660 |
+
208.16270446777344,
|
| 661 |
+
46
|
| 662 |
+
],
|
| 663 |
+
"flags": {
|
| 664 |
+
"collapsed": false
|
| 665 |
+
},
|
| 666 |
+
"order": 14,
|
| 667 |
+
"mode": 0,
|
| 668 |
+
"inputs": [
|
| 669 |
+
{
|
| 670 |
+
"name": "samples",
|
| 671 |
+
"type": "LATENT",
|
| 672 |
+
"link": 15
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
"name": "vae",
|
| 676 |
+
"type": "VAE",
|
| 677 |
+
"link": 16
|
| 678 |
+
}
|
| 679 |
+
],
|
| 680 |
+
"outputs": [
|
| 681 |
+
{
|
| 682 |
+
"name": "IMAGE",
|
| 683 |
+
"type": "IMAGE",
|
| 684 |
+
"slot_index": 0,
|
| 685 |
+
"links": [
|
| 686 |
+
3
|
| 687 |
+
]
|
| 688 |
+
}
|
| 689 |
+
],
|
| 690 |
+
"properties": {
|
| 691 |
+
"Node name for S&R": "VAEDecode"
|
| 692 |
+
},
|
| 693 |
+
"widgets_values": []
|
| 694 |
+
},
|
| 695 |
+
{
|
| 696 |
+
"id": 4,
|
| 697 |
+
"type": "VHS_VideoCombine",
|
| 698 |
+
"pos": [
|
| 699 |
+
1219.9688720703125,
|
| 700 |
+
358.5111389160156
|
| 701 |
+
],
|
| 702 |
+
"size": [
|
| 703 |
+
239.620361328125,
|
| 704 |
+
310
|
| 705 |
+
],
|
| 706 |
+
"flags": {},
|
| 707 |
+
"order": 15,
|
| 708 |
+
"mode": 0,
|
| 709 |
+
"inputs": [
|
| 710 |
+
{
|
| 711 |
+
"name": "images",
|
| 712 |
+
"type": "IMAGE",
|
| 713 |
+
"link": 3
|
| 714 |
+
},
|
| 715 |
+
{
|
| 716 |
+
"name": "audio",
|
| 717 |
+
"shape": 7,
|
| 718 |
+
"type": "AUDIO",
|
| 719 |
+
"link": null
|
| 720 |
+
},
|
| 721 |
+
{
|
| 722 |
+
"name": "meta_batch",
|
| 723 |
+
"shape": 7,
|
| 724 |
+
"type": "VHS_BatchManager",
|
| 725 |
+
"link": null
|
| 726 |
+
},
|
| 727 |
+
{
|
| 728 |
+
"name": "vae",
|
| 729 |
+
"shape": 7,
|
| 730 |
+
"type": "VAE",
|
| 731 |
+
"link": null
|
| 732 |
+
}
|
| 733 |
+
],
|
| 734 |
+
"outputs": [
|
| 735 |
+
{
|
| 736 |
+
"name": "Filenames",
|
| 737 |
+
"type": "VHS_FILENAMES",
|
| 738 |
+
"links": null
|
| 739 |
+
}
|
| 740 |
+
],
|
| 741 |
+
"properties": {
|
| 742 |
+
"Node name for S&R": "VHS_VideoCombine"
|
| 743 |
+
},
|
| 744 |
+
"widgets_values": {
|
| 745 |
+
"frame_rate": 16,
|
| 746 |
+
"loop_count": 0,
|
| 747 |
+
"filename_prefix": "wan2.1",
|
| 748 |
+
"format": "video/h265-mp4",
|
| 749 |
+
"pix_fmt": "yuv420p10le",
|
| 750 |
+
"crf": 5,
|
| 751 |
+
"save_metadata": true,
|
| 752 |
+
"pingpong": false,
|
| 753 |
+
"save_output": true,
|
| 754 |
+
"videopreview": {
|
| 755 |
+
"hidden": false,
|
| 756 |
+
"paused": false,
|
| 757 |
+
"params": {
|
| 758 |
+
"filename": "wan2.1_00518.mp4",
|
| 759 |
+
"subfolder": "",
|
| 760 |
+
"type": "output",
|
| 761 |
+
"format": "video/h265-mp4",
|
| 762 |
+
"frame_rate": 16,
|
| 763 |
+
"workflow": "wan2.1_00518.png",
|
| 764 |
+
"fullpath": "E:\\comfyui3\\ComfyUI\\output\\wan2.1_00518.mp4"
|
| 765 |
+
}
|
| 766 |
+
}
|
| 767 |
+
}
|
| 768 |
+
},
|
| 769 |
+
{
|
| 770 |
+
"id": 25,
|
| 771 |
+
"type": "ImageToMask",
|
| 772 |
+
"pos": [
|
| 773 |
+
403.78155517578125,
|
| 774 |
+
1100.6531982421875
|
| 775 |
+
],
|
| 776 |
+
"size": [
|
| 777 |
+
270,
|
| 778 |
+
58
|
| 779 |
+
],
|
| 780 |
+
"flags": {},
|
| 781 |
+
"order": 5,
|
| 782 |
+
"mode": 0,
|
| 783 |
+
"inputs": [
|
| 784 |
+
{
|
| 785 |
+
"name": "image",
|
| 786 |
+
"type": "IMAGE",
|
| 787 |
+
"link": 85
|
| 788 |
+
}
|
| 789 |
+
],
|
| 790 |
+
"outputs": [
|
| 791 |
+
{
|
| 792 |
+
"name": "MASK",
|
| 793 |
+
"type": "MASK",
|
| 794 |
+
"links": [
|
| 795 |
+
79
|
| 796 |
+
]
|
| 797 |
+
}
|
| 798 |
+
],
|
| 799 |
+
"properties": {
|
| 800 |
+
"Node name for S&R": "ImageToMask"
|
| 801 |
+
},
|
| 802 |
+
"widgets_values": [
|
| 803 |
+
"red"
|
| 804 |
+
]
|
| 805 |
+
},
|
| 806 |
+
{
|
| 807 |
+
"id": 5,
|
| 808 |
+
"type": "LoadImage",
|
| 809 |
+
"pos": [
|
| 810 |
+
-272.46954345703125,
|
| 811 |
+
357.37689208984375
|
| 812 |
+
],
|
| 813 |
+
"size": [
|
| 814 |
+
335.15673828125,
|
| 815 |
+
709.6021728515625
|
| 816 |
+
],
|
| 817 |
+
"flags": {},
|
| 818 |
+
"order": 4,
|
| 819 |
+
"mode": 0,
|
| 820 |
+
"inputs": [],
|
| 821 |
+
"outputs": [
|
| 822 |
+
{
|
| 823 |
+
"name": "IMAGE",
|
| 824 |
+
"type": "IMAGE",
|
| 825 |
+
"links": [
|
| 826 |
+
22
|
| 827 |
+
]
|
| 828 |
+
},
|
| 829 |
+
{
|
| 830 |
+
"name": "MASK",
|
| 831 |
+
"type": "MASK",
|
| 832 |
+
"links": null
|
| 833 |
+
}
|
| 834 |
+
],
|
| 835 |
+
"properties": {
|
| 836 |
+
"Node name for S&R": "LoadImage"
|
| 837 |
+
},
|
| 838 |
+
"widgets_values": [
|
| 839 |
+
"ComfUI_287879_.png",
|
| 840 |
+
"image"
|
| 841 |
+
]
|
| 842 |
+
}
|
| 843 |
+
],
|
| 844 |
+
"links": [
|
| 845 |
+
[
|
| 846 |
+
3,
|
| 847 |
+
14,
|
| 848 |
+
0,
|
| 849 |
+
4,
|
| 850 |
+
0,
|
| 851 |
+
"IMAGE"
|
| 852 |
+
],
|
| 853 |
+
[
|
| 854 |
+
9,
|
| 855 |
+
13,
|
| 856 |
+
0,
|
| 857 |
+
12,
|
| 858 |
+
0,
|
| 859 |
+
"LATENT"
|
| 860 |
+
],
|
| 861 |
+
[
|
| 862 |
+
10,
|
| 863 |
+
17,
|
| 864 |
+
3,
|
| 865 |
+
12,
|
| 866 |
+
1,
|
| 867 |
+
"INT"
|
| 868 |
+
],
|
| 869 |
+
[
|
| 870 |
+
12,
|
| 871 |
+
17,
|
| 872 |
+
0,
|
| 873 |
+
13,
|
| 874 |
+
1,
|
| 875 |
+
"CONDITIONING"
|
| 876 |
+
],
|
| 877 |
+
[
|
| 878 |
+
13,
|
| 879 |
+
17,
|
| 880 |
+
1,
|
| 881 |
+
13,
|
| 882 |
+
2,
|
| 883 |
+
"CONDITIONING"
|
| 884 |
+
],
|
| 885 |
+
[
|
| 886 |
+
14,
|
| 887 |
+
17,
|
| 888 |
+
2,
|
| 889 |
+
13,
|
| 890 |
+
3,
|
| 891 |
+
"LATENT"
|
| 892 |
+
],
|
| 893 |
+
[
|
| 894 |
+
15,
|
| 895 |
+
12,
|
| 896 |
+
0,
|
| 897 |
+
14,
|
| 898 |
+
0,
|
| 899 |
+
"LATENT"
|
| 900 |
+
],
|
| 901 |
+
[
|
| 902 |
+
16,
|
| 903 |
+
3,
|
| 904 |
+
0,
|
| 905 |
+
14,
|
| 906 |
+
1,
|
| 907 |
+
"VAE"
|
| 908 |
+
],
|
| 909 |
+
[
|
| 910 |
+
18,
|
| 911 |
+
15,
|
| 912 |
+
0,
|
| 913 |
+
17,
|
| 914 |
+
0,
|
| 915 |
+
"CONDITIONING"
|
| 916 |
+
],
|
| 917 |
+
[
|
| 918 |
+
19,
|
| 919 |
+
2,
|
| 920 |
+
0,
|
| 921 |
+
17,
|
| 922 |
+
1,
|
| 923 |
+
"CONDITIONING"
|
| 924 |
+
],
|
| 925 |
+
[
|
| 926 |
+
20,
|
| 927 |
+
3,
|
| 928 |
+
0,
|
| 929 |
+
17,
|
| 930 |
+
2,
|
| 931 |
+
"VAE"
|
| 932 |
+
],
|
| 933 |
+
[
|
| 934 |
+
22,
|
| 935 |
+
5,
|
| 936 |
+
0,
|
| 937 |
+
17,
|
| 938 |
+
5,
|
| 939 |
+
"IMAGE"
|
| 940 |
+
],
|
| 941 |
+
[
|
| 942 |
+
58,
|
| 943 |
+
11,
|
| 944 |
+
0,
|
| 945 |
+
13,
|
| 946 |
+
0,
|
| 947 |
+
"MODEL"
|
| 948 |
+
],
|
| 949 |
+
[
|
| 950 |
+
76,
|
| 951 |
+
33,
|
| 952 |
+
1,
|
| 953 |
+
32,
|
| 954 |
+
2,
|
| 955 |
+
"INT"
|
| 956 |
+
],
|
| 957 |
+
[
|
| 958 |
+
78,
|
| 959 |
+
32,
|
| 960 |
+
1,
|
| 961 |
+
17,
|
| 962 |
+
6,
|
| 963 |
+
"INT"
|
| 964 |
+
],
|
| 965 |
+
[
|
| 966 |
+
79,
|
| 967 |
+
25,
|
| 968 |
+
0,
|
| 969 |
+
35,
|
| 970 |
+
0,
|
| 971 |
+
"MASK"
|
| 972 |
+
],
|
| 973 |
+
[
|
| 974 |
+
80,
|
| 975 |
+
35,
|
| 976 |
+
0,
|
| 977 |
+
17,
|
| 978 |
+
4,
|
| 979 |
+
"MASK"
|
| 980 |
+
],
|
| 981 |
+
[
|
| 982 |
+
85,
|
| 983 |
+
33,
|
| 984 |
+
0,
|
| 985 |
+
25,
|
| 986 |
+
0,
|
| 987 |
+
"IMAGE"
|
| 988 |
+
],
|
| 989 |
+
[
|
| 990 |
+
86,
|
| 991 |
+
32,
|
| 992 |
+
0,
|
| 993 |
+
17,
|
| 994 |
+
3,
|
| 995 |
+
"IMAGE"
|
| 996 |
+
],
|
| 997 |
+
[
|
| 998 |
+
91,
|
| 999 |
+
8,
|
| 1000 |
+
0,
|
| 1001 |
+
11,
|
| 1002 |
+
0,
|
| 1003 |
+
"MODEL"
|
| 1004 |
+
],
|
| 1005 |
+
[
|
| 1006 |
+
92,
|
| 1007 |
+
6,
|
| 1008 |
+
0,
|
| 1009 |
+
2,
|
| 1010 |
+
0,
|
| 1011 |
+
"CLIP"
|
| 1012 |
+
],
|
| 1013 |
+
[
|
| 1014 |
+
93,
|
| 1015 |
+
6,
|
| 1016 |
+
0,
|
| 1017 |
+
15,
|
| 1018 |
+
0,
|
| 1019 |
+
"CLIP"
|
| 1020 |
+
]
|
| 1021 |
+
],
|
| 1022 |
+
"groups": [],
|
| 1023 |
+
"config": {},
|
| 1024 |
+
"extra": {
|
| 1025 |
+
"ds": {
|
| 1026 |
+
"scale": 1.0152559799477145,
|
| 1027 |
+
"offset": [
|
| 1028 |
+
564.1931902142793,
|
| 1029 |
+
-170.45932466624348
|
| 1030 |
+
]
|
| 1031 |
+
},
|
| 1032 |
+
"frontendVersion": "1.25.11",
|
| 1033 |
+
"node_versions": {
|
| 1034 |
+
"comfy-core": "0.3.56",
|
| 1035 |
+
"ComfyUI-VideoHelperSuite": "972c87da577b47211c4e9aeed30dc38c7bae607f"
|
| 1036 |
+
},
|
| 1037 |
+
"VHS_latentpreview": true,
|
| 1038 |
+
"VHS_latentpreviewrate": 0,
|
| 1039 |
+
"VHS_MetadataImage": true,
|
| 1040 |
+
"VHS_KeepIntermediate": true
|
| 1041 |
+
},
|
| 1042 |
+
"version": 0.4
|
| 1043 |
+
}
|