vinesnt commited on
Commit
405de0f
·
verified ·
1 Parent(s): 078b7c3

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ kill_bill.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ wan22_input_2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ wan_i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,36 @@
1
  ---
2
- title: '1111'
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.9.0
 
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
- short_description: 哈哈
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WAN 2.2 3-Step V2V Pipeline
3
+ emoji: 🎬
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
+ python_version: "3.10"
9
  app_file: app.py
10
  pinned: false
11
+ short_description: I2V + T2V + 3-Step V2V (SAM2 → Composite → VACE)
12
+ models:
13
+ - facebook/sam2.1-hiera-large
14
+ - google/umt5-xxl
15
+ - Kijai/WanVideo_comfy
16
+ - linoyts/Wan2.2-T2V-A14B-Diffusers-BF16
17
+ - lkzd7/WAN2.2_LoraSet_NSFW
18
+ - r3gm/RIFE
19
+ - TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING
20
+ - Wan-AI/Wan2.1-VACE-14B-diffusers
21
+ - Wan-AI/Wan2.2-T2V-A14B-Diffusers
22
+ - zerogpu-aoti/Wan2
23
  ---
24
 
25
+ # WAN 2.2 Multi-Task Video Generation
26
+
27
+ ## Features
28
+ - **I2V**: Image-to-Video (Lightning 14B, 6-step, FP8+AoT)
29
+ - **T2V**: Text-to-Video (Lightning 14B, 4-step, Lightning LoRA)
30
+ - **V2V**: 3-Step Video-to-Video Pipeline
31
+ 1. **SAM2 Segmentation**: Click points on first frame → auto-track through video → mask video
32
+ 2. **Composite + GrowMask**: Original + mask → expanded mask + composite video (automatic)
33
+ 3. **VACE Generation**: Composite + grown mask + reference image + prompt → final video
34
+
35
+ ## V2V Workflow
36
+ Based on ComfyUI workflows: `sam_optimized`, `sam2.1_optimized`, `vace_optimized`
aoti.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import cast
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
9
+ from spaces.zero.torch.aoti import ZeroGPUWeights
10
+ from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
11
+
12
+
13
+ def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
14
+ clone = object.__new__(module.__class__)
15
+ clone.__dict__ = module.__dict__.copy()
16
+ clone._parameters = module._parameters.copy()
17
+ clone._buffers = module._buffers.copy()
18
+ clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
19
+ return clone
20
+
21
+
22
+ def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
23
+ repeated_blocks = cast(list[str], module._repeated_blocks)
24
+ aoti_files = {name: hf_hub_download(
25
+ repo_id=repo_id,
26
+ filename='package.pt2',
27
+ subfolder=name if variant is None else f'{name}.{variant}',
28
+ ) for name in repeated_blocks}
29
+ for block_name, aoti_file in aoti_files.items():
30
+ for block in module.modules():
31
+ if block.__class__.__name__ == block_name:
32
+ block_ = _shallow_clone_module(block)
33
+ unwrap_tensor_subclass_parameters(block_)
34
+ weights = ZeroGPUWeights(block_.state_dict())
35
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)
app.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WAN 2.2 Multi-Task Video Generation - 3-Step V2V Pipeline
3
+ I2V: Lightning 14B (6 steps, FP8+AoT)
4
+ T2V: Lightning 14B (4 steps, Lightning LoRA + FP8)
5
+ V2V: 3-Step Pipeline (SAM2 → Composite → VACE)
6
+ Step 1: SAM2 video segmentation (click points → mask video)
7
+ Step 2: ImageComposite (original + mask → composite video)
8
+ Step 3: VACE generation (composite + grown mask + ref image + prompt → final)
9
+ LoRA: from lkzd7/WAN2.2_LoraSet_NSFW (I2V only)
10
+ """
11
+ import os
12
+
13
+ import spaces
14
+ import shutil
15
+ import subprocess
16
+ import copy
17
+ import random
18
+ import tempfile
19
+ import warnings
20
+ import time
21
+ import gc
22
+ import uuid
23
+ from tqdm import tqdm
24
+
25
+ import cv2
26
+ import numpy as np
27
+ import torch
28
+ from torch.nn import functional as F
29
+ from PIL import Image, ImageFilter
30
+
31
+ import gradio as gr
32
+ from diffusers import (
33
+ AutoencoderKLWan,
34
+ FlowMatchEulerDiscreteScheduler,
35
+ WanPipeline,
36
+ SASolverScheduler,
37
+ DEISMultistepScheduler,
38
+ DPMSolverMultistepInverseScheduler,
39
+ UniPCMultistepScheduler,
40
+ DPMSolverMultistepScheduler,
41
+ DPMSolverSinglestepScheduler,
42
+ )
43
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
44
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
45
+ from diffusers.pipelines.wan.pipeline_wan_vace import WanVACEPipeline
46
+ from diffusers.utils.export_utils import export_to_video
47
+ from diffusers.utils import load_video
48
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
49
+ import aoti
50
+ import lora_loader
51
+
52
+ # SAM2 for video mask generation
53
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
54
+
55
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
56
+ warnings.filterwarnings("ignore")
57
+
58
+ def clear_vram():
59
+ gc.collect()
60
+ torch.cuda.empty_cache()
61
+
62
+ # ============ RIFE ============
63
+ get_timestamp_js = """
64
+ function() {
65
+ const video = document.querySelector('#generated-video video');
66
+ if (video) { return video.currentTime; }
67
+ return 0;
68
+ }
69
+ """
70
+
71
+ def extract_frame(video_path, timestamp):
72
+ if not video_path:
73
+ return None
74
+ cap = cv2.VideoCapture(video_path)
75
+ if not cap.isOpened():
76
+ return None
77
+ fps = cap.get(cv2.CAP_FPS)
78
+ target_frame_num = int(float(timestamp) * fps)
79
+ total_frames = int(cap.get(cv2.CAP_FRAME_COUNT))
80
+ if target_frame_num >= total_frames:
81
+ target_frame_num = total_frames - 1
82
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num)
83
+ ret, frame = cap.read()
84
+ cap.release()
85
+ if ret:
86
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
+ return None
88
+
89
+ if not os.path.exists("RIFEv4.26_0921.zip"):
90
+ print("Downloading RIFE Model...")
91
+ subprocess.run(["wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip"], check=True)
92
+ subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True)
93
+
94
+ from train_log.RIFE_HDv3 import Model
95
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+ rife_model = Model()
97
+ rife_model.load_model("train_log", -1)
98
+ rife_model.eval()
99
+
100
+ @torch.no_grad()
101
+ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
102
+ if isinstance(frames_np, list):
103
+ T = len(frames_np)
104
+ H, W, C = frames_np[0].shape
105
+ else:
106
+ T, H, W, C = frames_np.shape
107
+ if multiplier < 2:
108
+ return list(frames_np) if isinstance(frames_np, np.ndarray) else frames_np
109
+ n_interp = multiplier - 1
110
+ tmp = max(128, int(128 / scale))
111
+ ph = ((H - 1) // tmp + 1) * tmp
112
+ pw = ((W - 1) // tmp + 1) * tmp
113
+ padding = (0, pw - W, 0, ph - H)
114
+ def to_tensor(frame_np):
115
+ t = torch.from_numpy(frame_np).to(device)
116
+ t = t.permute(2, 0, 1).unsqueeze(0)
117
+ return F.pad(t, padding).half()
118
+ def from_tensor(tensor):
119
+ t = tensor[0, :, :H, :W]
120
+ return t.permute(1, 2, 0).float().cpu().numpy()
121
+ def make_inference(I0, I1, n):
122
+ if rife_model.version >= 3.9:
123
+ return [rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale) for i in range(n)]
124
+ else:
125
+ middle = rife_model.inference(I0, I1, scale)
126
+ if n == 1: return [middle]
127
+ first_half = make_inference(I0, middle, n//2)
128
+ second_half = make_inference(middle, I1, n//2)
129
+ return [*first_half, middle, *second_half] if n % 2 else [*first_half, *second_half]
130
+ output_frames = []
131
+ I1 = to_tensor(frames_np[0])
132
+ with tqdm(total=T-1, desc="Interpolating", unit="frame") as pbar:
133
+ for i in range(T - 1):
134
+ I0 = I1
135
+ output_frames.append(from_tensor(I0))
136
+ I1 = to_tensor(frames_np[i+1])
137
+ for mid in make_inference(I0, I1, n_interp):
138
+ output_frames.append(from_tensor(mid))
139
+ if (i + 1) % 50 == 0:
140
+ pbar.update(50)
141
+ pbar.update((T-1) % 50)
142
+ output_frames.append(from_tensor(I1))
143
+ del I0, I1
144
+ torch.cuda.empty_cache()
145
+ return output_frames
146
+
147
+ # ============ Config ============
148
+ FIXED_FPS = 16
149
+ MAX_FRAMES_MODEL = 241 # ~15s@16fps, requires more VRAM/time
150
+ MAX_SEED = np.iinfo(np.int32).max
151
+
152
+ SCHEDULER_MAP = {
153
+ "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
154
+ "SASolver": SASolverScheduler,
155
+ "DEISMultistep": DEISMultistepScheduler,
156
+ "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler,
157
+ "UniPCMultistep": UniPCMultistepScheduler,
158
+ "DPMSolverMultistep": DPMSolverMultistepScheduler,
159
+ "DPMSolverSinglestep": DPMSolverSinglestepScheduler,
160
+ }
161
+
162
+ default_negative_prompt = (
163
+ "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
164
+ "still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, incomplete, "
165
+ "extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, "
166
+ "malformed limbs, fused fingers, still frame, messy background, three legs, "
167
+ "many people in background, walking backwards, watermark, text, signature"
168
+ )
169
+
170
+ # ============ Load I2V Pipeline (Lightning, AoT compiled) ============
171
+ print("Loading I2V Pipeline (Lightning 14B)...")
172
+ i2v_pipe = WanImageToVideoPipeline.from_pretrained(
173
+ "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING",
174
+ torch_dtype=torch.bfloat16,
175
+ ).to('cuda')
176
+ i2v_original_scheduler = copy.deepcopy(i2v_pipe.scheduler)
177
+
178
+ quantize_(i2v_pipe.text_encoder, Int8WeightOnlyConfig())
179
+ major, minor = torch.cuda.get_device_capability()
180
+ supports_fp8 = (major > 8) or (major == 8 and minor >= 9)
181
+ if supports_fp8:
182
+ quantize_(i2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
183
+ quantize_(i2v_pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
184
+ aoti.aoti_blocks_load(i2v_pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
185
+ aoti.aoti_blocks_load(i2v_pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
186
+ else:
187
+ quantize_(i2v_pipe.transformer, Int8WeightOnlyConfig())
188
+ quantize_(i2v_pipe.transformer_2, Int8WeightOnlyConfig())
189
+
190
+ # ============ T2V Pipeline (on-demand, 14B + Wan22 Lightning LoRA) ============
191
+ # Use T2V-A14B + Wan22 Lightning LoRA (separate HIGH/LOW for dual transformer)
192
+ # Load on-demand with CPU offload to avoid OOM alongside I2V
193
+ T2V_MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
194
+ T2V_LORA_REPO = "Kijai/WanVideo_comfy"
195
+ T2V_LORA_HIGH = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors"
196
+ T2V_LORA_LOW = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors"
197
+ t2v_pipe = None
198
+ t2v_ready = False
199
+
200
+ def load_t2v_pipeline():
201
+ """Load T2V 14B + Lightning LoRA on-demand with CPU offload."""
202
+ global t2v_pipe, t2v_ready
203
+
204
+ if t2v_pipe is not None and t2v_ready:
205
+ print("T2V pipeline reused from memory")
206
+ return t2v_pipe
207
+
208
+ print("Loading T2V Pipeline (14B + Lightning LoRA) first time...")
209
+
210
+ # Move I2V components to CPU to make room
211
+ i2v_pipe.to('cpu')
212
+ clear_vram()
213
+
214
+ t2v_vae = AutoencoderKLWan.from_pretrained(T2V_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
215
+ t2v_pipe = WanPipeline.from_pretrained(
216
+ T2V_MODEL_ID,
217
+ transformer=WanTransformer3DModel.from_pretrained(
218
+ 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
219
+ subfolder='transformer',
220
+ torch_dtype=torch.bfloat16,
221
+ ),
222
+ transformer_2=WanTransformer3DModel.from_pretrained(
223
+ 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
224
+ subfolder='transformer_2',
225
+ torch_dtype=torch.bfloat16,
226
+ ),
227
+ vae=t2v_vae,
228
+ torch_dtype=torch.bfloat16,
229
+ )
230
+
231
+ # Load and fuse Lightning LoRAs (HIGH for transformer, LOW for transformer_2)
232
+ print("Fusing Lightning LoRA HIGH (transformer)...")
233
+ from safetensors.torch import load_file
234
+ from huggingface_hub import hf_hub_download
235
+
236
+ # Download LoRA files
237
+ high_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_HIGH)
238
+ low_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_LOW)
239
+
240
+ # Load HIGH LoRA into transformer
241
+ t2v_pipe.load_lora_weights(high_path, adapter_name="lightning_high")
242
+ t2v_pipe.set_adapters(["lightning_high"], adapter_weights=[1.0])
243
+ t2v_pipe.fuse_lora(adapter_names=["lightning_high"], lora_scale=1.0, components=["transformer"])
244
+ t2v_pipe.unload_lora_weights()
245
+
246
+ # Load LOW LoRA into transformer_2
247
+ print("Fusing Lightning LoRA LOW (transformer_2)...")
248
+ t2v_pipe.load_lora_weights(low_path, adapter_name="lightning_low", load_into_transformer_2=True)
249
+ t2v_pipe.set_adapters(["lightning_low"], adapter_weights=[1.0])
250
+ t2v_pipe.fuse_lora(adapter_names=["lightning_low"], lora_scale=1.0, components=["transformer_2"])
251
+ t2v_pipe.unload_lora_weights()
252
+
253
+ # Use model CPU offload — only one component on GPU at a time
254
+ t2v_pipe.enable_model_cpu_offload()
255
+
256
+ t2v_ready = True
257
+ print("T2V pipeline ready (14B + Lightning + CPU offload)")
258
+ return t2v_pipe
259
+
260
+ def unload_t2v_pipeline():
261
+ """Restore I2V to GPU after T2V is done."""
262
+ clear_vram()
263
+ i2v_pipe.to('cuda')
264
+ print("I2V restored to GPU")
265
+
266
+ # Keep cache for on-demand T2V loading
267
+
268
+ # ============ SAM2 Video Segmentation ============
269
+ sam2_predictor = None
270
+
271
+ def get_sam2_predictor():
272
+ global sam2_predictor
273
+ if sam2_predictor is None:
274
+ print("Loading SAM2.1 hiera-large...")
275
+ sam2_predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2.1-hiera-large")
276
+ print("SAM2 loaded")
277
+ return sam2_predictor
278
+
279
+ def extract_first_frame_from_video(video_path):
280
+ """Extract first frame from video as PIL Image."""
281
+ cap = cv2.VideoCapture(video_path)
282
+ ret, frame = cap.read()
283
+ cap.release()
284
+ if ret:
285
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
286
+ return None
287
+
288
+ def video_to_frames_dir(video_path, max_frames=None):
289
+ """Extract video frames to a temp directory for SAM2."""
290
+ tmp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
291
+ cap = cv2.VideoCapture(video_path)
292
+ fps = cap.get(cv2.CAP_PROP_FPS) or 16
293
+ idx = 0
294
+ while True:
295
+ ret, frame = cap.read()
296
+ if not ret:
297
+ break
298
+ if max_frames and idx >= max_frames:
299
+ break
300
+ cv2.imwrite(os.path.join(tmp_dir, f"{idx:05d}.jpg"), frame)
301
+ idx += 1
302
+ cap.release()
303
+ print(f"Extracted {idx} frames to {tmp_dir} (fps={fps:.1f})")
304
+ return tmp_dir, idx, fps
305
+
306
+ @spaces.GPU(duration=120)
307
+ def generate_mask_video(video_path, points_json, num_frames_limit=None):
308
+ """Generate mask video using SAM2 from user-clicked points."""
309
+ import json
310
+
311
+ if not video_path:
312
+ raise gr.Error("请先上传视频 / Upload a video first")
313
+ if not points_json or points_json.strip() == "[]":
314
+ raise gr.Error("请在视频第一帧上点击要编辑的区域 / Click on the area to edit")
315
+
316
+ points_data = json.loads(points_json)
317
+ if not points_data:
318
+ raise gr.Error("没有标记点 / No points marked")
319
+
320
+ # Extract frames
321
+ frames_dir, total_frames, fps = video_to_frames_dir(video_path, max_frames=num_frames_limit)
322
+
323
+ predictor = get_sam2_predictor()
324
+
325
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
326
+ state = predictor.init_state(video_path=frames_dir)
327
+
328
+ # Add points (all on frame 0)
329
+ pos_points = []
330
+ neg_points = []
331
+ for p in points_data:
332
+ if p.get("label", 1) == 1:
333
+ pos_points.append([p["x"], p["y"]])
334
+ else:
335
+ neg_points.append([p["x"], p["y"]])
336
+
337
+ all_points = pos_points + neg_points
338
+ all_labels = [1] * len(pos_points) + [0] * len(neg_points)
339
+
340
+ points_np = np.array(all_points, dtype=np.float32)
341
+ labels_np = np.array(all_labels, dtype=np.int32)
342
+
343
+ _, _, _ = predictor.add_new_points_or_box(
344
+ state,
345
+ frame_idx=0,
346
+ obj_id=1,
347
+ points=points_np,
348
+ labels=labels_np,
349
+ )
350
+
351
+ # Propagate through video
352
+ all_masks = {}
353
+ for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
354
+ # masks shape: (num_objects, 1, H, W)
355
+ mask = (masks[0, 0] > 0.0).cpu().numpy().astype(np.uint8) * 255
356
+ all_masks[frame_idx] = mask
357
+
358
+ # Build mask video
359
+ out_path = os.path.join(tempfile.mkdtemp(), "mask_video.mp4")
360
+ # Get frame size from first mask
361
+ first_mask = all_masks[0]
362
+ h, w = first_mask.shape
363
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
364
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False)
365
+ for i in range(total_frames):
366
+ if i in all_masks:
367
+ writer.write(all_masks[i])
368
+ elif all_masks:
369
+ # Use nearest available mask
370
+ nearest = min(all_masks.keys(), key=lambda k: abs(k - i))
371
+ writer.write(all_masks[nearest])
372
+ writer.release()
373
+
374
+ # Cleanup frames dir
375
+ shutil.rmtree(frames_dir, ignore_errors=True)
376
+
377
+ print(f"Mask video generated: {out_path} ({total_frames} frames, {w}x{h})")
378
+ return out_path
379
+
380
+ # ============ Step 2: GrowMask + ImageComposite (from sam2.1_optimized workflow) ============
381
+ def grow_mask_frame(mask_gray, expand_pixels=5, blur=True):
382
+ """Expand mask by N pixels (matching ComfyUI GrowMask node).
383
+ mask_gray: numpy uint8 H×W (255=mask, 0=bg)
384
+ Returns: expanded mask as numpy uint8 H×W
385
+ """
386
+ if expand_pixels <= 0:
387
+ return mask_gray
388
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_pixels*2+1, expand_pixels*2+1))
389
+ grown = cv2.dilate(mask_gray, kernel, iterations=1)
390
+ if blur:
391
+ grown = cv2.GaussianBlur(grown, (expand_pixels*2+1, expand_pixels*2+1), 0)
392
+ # Re-threshold to keep it binary-ish but with soft edges
393
+ _, grown = cv2.threshold(grown, 127, 255, cv2.THRESH_BINARY)
394
+ return grown
395
+
396
+ def grow_mask_video_file(mask_video_path, expand_pixels=5):
397
+ """Apply GrowMask to every frame of a mask video. Returns new video path."""
398
+ if expand_pixels <= 0:
399
+ return mask_video_path
400
+
401
+ cap = cv2.VideoCapture(mask_video_path)
402
+ fps = cap.get(cv2.CAP_PROP_FPS) or 16
403
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
404
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
405
+
406
+ out_path = os.path.join(tempfile.mkdtemp(), "grown_mask.mp4")
407
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
408
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False)
409
+
410
+ count = 0
411
+ while True:
412
+ ret, frame = cap.read()
413
+ if not ret:
414
+ break
415
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame
416
+ grown = grow_mask_frame(gray, expand_pixels)
417
+ writer.write(grown)
418
+ count += 1
419
+
420
+ cap.release()
421
+ writer.release()
422
+ print(f"GrowMask applied: {count} frames, expand={expand_pixels}px → {out_path}")
423
+ return out_path
424
+
425
+ def composite_video_from_mask(source_video_path, mask_video_path):
426
+ """ImageComposite: replace masked region with mask overlay (from sam2.1_optimized workflow).
427
+ Creates a composite video where:
428
+ - Masked regions (white in mask) show the mask as white overlay
429
+ - Unmasked regions show original video
430
+ This gives VACE the control_video input it needs.
431
+ Returns: composite video path
432
+ """
433
+ src_cap = cv2.VideoCapture(source_video_path)
434
+ mask_cap = cv2.VideoCapture(mask_video_path)
435
+
436
+ fps = src_cap.get(cv2.CAP_PROP_FPS) or 16
437
+ w = int(src_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
438
+ h = int(src_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
439
+
440
+ out_path = os.path.join(tempfile.mkdtemp(), "composite.mp4")
441
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
442
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
443
+
444
+ count = 0
445
+ while True:
446
+ ret_s, src_frame = src_cap.read()
447
+ ret_m, mask_frame = mask_cap.read()
448
+ if not ret_s:
449
+ break
450
+ if not ret_m:
451
+ # If mask video is shorter, use last available or all-black
452
+ mask_gray = np.zeros((h, w), dtype=np.uint8)
453
+ else:
454
+ # Resize mask to match source if needed
455
+ if mask_frame.shape[:2] != (h, w):
456
+ mask_frame = cv2.resize(mask_frame, (w, h), interpolation=cv2.INTER_NEAREST)
457
+ mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY) if len(mask_frame.shape) == 3 else mask_frame
458
+
459
+ # Composite: original where mask=0, white where mask=255
460
+ mask_bool = mask_gray > 127
461
+ composite = src_frame.copy()
462
+ composite[mask_bool] = 255 # White in masked region
463
+
464
+ writer.write(composite)
465
+ count += 1
466
+
467
+ src_cap.release()
468
+ mask_cap.release()
469
+ writer.release()
470
+ print(f"Composite video: {count} frames → {out_path}")
471
+ return out_path
472
+
473
+ # ============ V2V Pipeline (VACE 14B, on-demand) ============
474
+ VACE_MODEL_ID = "Wan-AI/Wan2.1-VACE-14B-diffusers"
475
+ v2v_pipe = None
476
+ v2v_ready = False
477
+
478
+ def load_v2v_pipeline():
479
+ """Load VACE 14B pipeline on-demand for mask-based video editing."""
480
+ global v2v_pipe, v2v_ready
481
+
482
+ # Move I2V to CPU to free GPU
483
+ i2v_pipe.to('cpu')
484
+ clear_vram()
485
+
486
+ if v2v_pipe is not None and v2v_ready:
487
+ v2v_pipe.to('cuda')
488
+ print("VACE pipeline restored to GPU")
489
+ return v2v_pipe
490
+
491
+ print("Loading VACE 14B Pipeline first time (this downloads ~75GB)...")
492
+
493
+ v2v_vae = AutoencoderKLWan.from_pretrained(VACE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
494
+ v2v_pipe = WanVACEPipeline.from_pretrained(
495
+ VACE_MODEL_ID,
496
+ vae=v2v_vae,
497
+ torch_dtype=torch.bfloat16,
498
+ )
499
+ v2v_pipe.scheduler = UniPCMultistepScheduler.from_config(v2v_pipe.scheduler.config, flow_shift=5.0)
500
+
501
+ # Quantize to fit in A100 80GB
502
+ quantize_(v2v_pipe.text_encoder, Int8WeightOnlyConfig())
503
+ major, minor = torch.cuda.get_device_capability()
504
+ if (major > 8) or (major == 8 and minor >= 9):
505
+ quantize_(v2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
506
+ else:
507
+ quantize_(v2v_pipe.transformer, Int8WeightOnlyConfig())
508
+
509
+ v2v_pipe.to('cuda')
510
+
511
+ v2v_ready = True
512
+ print("VACE 14B pipeline ready (quantized, on GPU)")
513
+ return v2v_pipe
514
+
515
+ def unload_v2v_pipeline():
516
+ """Move V2V to CPU and restore I2V to GPU."""
517
+ global v2v_pipe
518
+ if v2v_pipe is not None:
519
+ v2v_pipe.to('cpu')
520
+ clear_vram()
521
+ i2v_pipe.to('cuda')
522
+ print("VACE → CPU, I2V → GPU")
523
+
524
+ def load_video_frames_and_masks(video_path, mask_path, num_frames, target_h, target_w):
525
+ """Load source video frames and mask video frames for VACE."""
526
+ # Load source video frames as PIL Images
527
+ src_frames = load_video(video_path)[:num_frames]
528
+ print(f"Loaded {len(src_frames)} source frames (original size: {src_frames[0].size if src_frames else 'N/A'})")
529
+
530
+ # Load mask video frames
531
+ mask_frames_raw = load_video(mask_path)[:num_frames]
532
+
533
+ # Convert mask to L mode (white=edit, black=keep) — don't resize, let pipeline handle it
534
+ masks = []
535
+ for mf in mask_frames_raw:
536
+ gray = mf.convert("L")
537
+ masks.append(gray)
538
+ print(f"Loaded {len(masks)} mask frames")
539
+
540
+ # Pad or trim to match
541
+ while len(masks) < len(src_frames):
542
+ masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0))
543
+ while len(src_frames) < len(masks):
544
+ src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128)))
545
+
546
+ frame_count = min(len(src_frames), len(masks))
547
+ src_frames = src_frames[:frame_count]
548
+ masks = masks[:frame_count]
549
+
550
+ return src_frames, masks
551
+
552
+ # ============ Utils ============
553
+ def resize_image(image, max_dim=832, min_dim=480, square_dim=640, multiple_of=16):
554
+ width, height = image.size
555
+ if width == height:
556
+ return image.resize((square_dim, square_dim), Image.LANCZOS)
557
+ aspect_ratio = width / height
558
+ max_ar = max_dim / min_dim
559
+ min_ar = min_dim / max_dim
560
+ if aspect_ratio > max_ar:
561
+ crop_width = int(round(height * max_ar))
562
+ left = (width - crop_width) // 2
563
+ image = image.crop((left, 0, left + crop_width, height))
564
+ target_w, target_h = max_dim, min_dim
565
+ elif aspect_ratio < min_ar:
566
+ crop_height = int(round(width / min_ar))
567
+ top = (height - crop_height) // 2
568
+ image = image.crop((0, top, width, top + crop_height))
569
+ target_w, target_h = min_dim, max_dim
570
+ else:
571
+ if width > height:
572
+ target_w = max_dim
573
+ target_h = int(round(target_w / aspect_ratio))
574
+ else:
575
+ target_h = max_dim
576
+ target_w = int(round(target_h * aspect_ratio))
577
+ final_w = max(min_dim, min(max_dim, round(target_w / multiple_of) * multiple_of))
578
+ final_h = max(min_dim, min(max_dim, round(target_h / multiple_of) * multiple_of))
579
+ return image.resize((final_w, final_h), Image.LANCZOS)
580
+
581
+ def resize_and_crop_to_match(target_image, reference_image):
582
+ ref_w, ref_h = reference_image.size
583
+ tgt_w, tgt_h = target_image.size
584
+ scale = max(ref_w / tgt_w, ref_h / tgt_h)
585
+ new_w, new_h = int(tgt_w * scale), int(tgt_h * scale)
586
+ resized = target_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
587
+ left, top = (new_w - ref_w) // 2, (new_h - ref_h) // 2
588
+ return resized.crop((left, top, left + ref_w, top + ref_h))
589
+
590
+ def get_num_frames(duration_seconds):
591
+ raw = int(round(duration_seconds * FIXED_FPS))
592
+ raw = ((raw - 1) // 4) * 4 + 1
593
+ return int(np.clip(raw, 9, MAX_FRAMES_MODEL))
594
+
595
+ def extract_video_path(input_video):
596
+ if input_video is None:
597
+ return None
598
+ if isinstance(input_video, str):
599
+ return input_video
600
+ if isinstance(input_video, dict):
601
+ # Gradio 5.x format: {'video': filepath, ...} or {'name': filepath, ...} or {'path': filepath}
602
+ return input_video.get("video", input_video.get("path", input_video.get("name", None)))
603
+ # Could be a Gradio VideoData object
604
+ if hasattr(input_video, 'video'):
605
+ return input_video.video
606
+ if hasattr(input_video, 'path'):
607
+ return input_video.path
608
+ if hasattr(input_video, 'name'):
609
+ return input_video.name
610
+ return str(input_video)
611
+
612
+ def extract_first_frame(video_input):
613
+ path = extract_video_path(video_input)
614
+ if not path or not os.path.exists(path):
615
+ return None
616
+ cap = cv2.VideoCapture(path)
617
+ ret, frame = cap.read()
618
+ cap.release()
619
+ if ret:
620
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
621
+ return None
622
+
623
+ # ============ Inference ============
624
+ @spaces.GPU(duration=1200)
625
+ def run_inference(
626
+ task_type, input_image, input_video, mask_video, prompt, negative_prompt,
627
+ duration_seconds, steps, guidance_scale, guidance_scale_2,
628
+ current_seed, scheduler_name, flow_shift, frame_multiplier,
629
+ quality, last_image_input, lora_groups,
630
+ reference_image=None, grow_pixels=5,
631
+ progress=gr.Progress(track_tqdm=True),
632
+ ):
633
+ clear_vram()
634
+ num_frames = get_num_frames(duration_seconds)
635
+ task_id = str(uuid.uuid4())[:8]
636
+ print(f"Task: {task_id}, type={task_type}, duration={duration_seconds}s, frames={num_frames}")
637
+ start = time.time()
638
+
639
+ if "T2V" in task_type:
640
+ # ====== T2V: 14B + Lightning LoRA (4 steps, dual guidance) ======
641
+ t2v_steps = max(int(steps), 4)
642
+ print(f"T2V: steps={t2v_steps}, guidance={guidance_scale}/{guidance_scale_2}, frames={num_frames}")
643
+
644
+ pipe = load_t2v_pipeline()
645
+ result = pipe(
646
+ prompt=prompt,
647
+ negative_prompt=negative_prompt,
648
+ height=480,
649
+ width=832,
650
+ num_frames=num_frames,
651
+ guidance_scale=float(guidance_scale),
652
+ guidance_scale_2=float(guidance_scale_2),
653
+ num_inference_steps=t2v_steps,
654
+ generator=torch.Generator(device="cpu").manual_seed(int(current_seed)),
655
+ output_type="np",
656
+ )
657
+ unload_t2v_pipeline()
658
+
659
+ else:
660
+ # ====== I2V / V2V ======
661
+ if "V2V" in task_type:
662
+ # ====== V2V: 3-Step Pipeline (SAM2 mask → Composite → VACE) ======
663
+ print(f"V2V 3-Step Pipeline: input_video type={type(input_video)}, value={input_video}")
664
+ video_path = extract_video_path(input_video)
665
+ if not video_path or not os.path.exists(video_path):
666
+ raise gr.Error("Upload a source video for V2V / V2V请上传原视频")
667
+
668
+ # Get mask video path
669
+ mask_path = extract_video_path(mask_video)
670
+ if not mask_path or not os.path.exists(mask_path):
671
+ raise gr.Error("Upload a mask video for V2V / V2V请上传遮罩视频(黑白视频,白色=编辑区域)")
672
+
673
+ # Step 2a: GrowMask — expand mask boundaries (from vace_optimized workflow)
674
+ grown_mask_path = grow_mask_video_file(mask_path, expand_pixels=int(grow_pixels))
675
+ print(f"V2V: GrowMask applied ({grow_pixels}px)")
676
+
677
+ # Step 2b: Composite — original video with mask overlay (from sam2.1_optimized workflow)
678
+ composite_path = composite_video_from_mask(video_path, mask_path)
679
+ print(f"V2V: Composite video created")
680
+
681
+ # Step 3: VACE generation using composite as control_video + grown mask
682
+ target_h, target_w = 480, 832
683
+
684
+ # Load composite video as control frames for VACE
685
+ src_frames = load_video(composite_path)[:num_frames]
686
+ print(f"Loaded {len(src_frames)} composite frames")
687
+
688
+ # Load grown mask frames
689
+ mask_frames_raw = load_video(grown_mask_path)[:num_frames]
690
+ masks = [mf.convert("L") for mf in mask_frames_raw]
691
+ print(f"Loaded {len(masks)} grown mask frames")
692
+
693
+ # Pad or trim to match
694
+ while len(masks) < len(src_frames):
695
+ masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0))
696
+ while len(src_frames) < len(masks):
697
+ src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128)))
698
+
699
+ # Ensure num_frames satisfies (n-1) % 4 == 0 for VACE
700
+ n = len(src_frames)
701
+ n = (n - 1) // 4 * 4 + 1
702
+ n = max(n, 5)
703
+ src_frames = src_frames[:n]
704
+ masks = masks[:n]
705
+
706
+ # Load VACE pipeline
707
+ pipe = load_v2v_pipeline()
708
+ v2v_steps = max(int(steps), 20)
709
+ print(f"V2V VACE: steps={v2v_steps}, guidance={guidance_scale}, frames={len(src_frames)}, ref_image={'yes' if reference_image else 'no'}")
710
+
711
+ # Build VACE kwargs
712
+ vace_kwargs = dict(
713
+ prompt=prompt,
714
+ negative_prompt=negative_prompt,
715
+ video=src_frames,
716
+ mask=masks,
717
+ height=target_h,
718
+ width=target_w,
719
+ num_frames=len(src_frames),
720
+ guidance_scale=max(float(guidance_scale), 5.0),
721
+ num_inference_steps=v2v_steps,
722
+ generator=torch.Generator(device="cuda").manual_seed(int(current_seed)),
723
+ output_type="np",
724
+ )
725
+
726
+ result = pipe(**vace_kwargs)
727
+ unload_v2v_pipeline()
728
+
729
+ # Cleanup temp files
730
+ for p in [grown_mask_path, composite_path]:
731
+ try:
732
+ if p and os.path.exists(p):
733
+ os.remove(p)
734
+ except:
735
+ pass
736
+
737
+ else:
738
+ # ====== I2V ======
739
+ if input_image is None:
740
+ raise gr.Error("Upload an image / 请上传图片")
741
+
742
+ scheduler_class = SCHEDULER_MAP.get(scheduler_name)
743
+ if scheduler_class and scheduler_class.__name__ != i2v_pipe.scheduler.config._class_name:
744
+ config = copy.deepcopy(i2v_original_scheduler.config)
745
+ if scheduler_class == FlowMatchEulerDiscreteScheduler:
746
+ config['shift'] = flow_shift
747
+ else:
748
+ config['flow_shift'] = flow_shift
749
+ i2v_pipe.scheduler = scheduler_class.from_config(config)
750
+
751
+ lora_loaded = False
752
+ if lora_groups:
753
+ try:
754
+ for idx, name in enumerate(lora_groups):
755
+ if name and name != "(None)":
756
+ lora_loader.load_lora_to_pipe(i2v_pipe, name, adapter_name=f"lora_{idx}")
757
+ lora_loaded = True
758
+ except Exception as e:
759
+ print(f"LoRA warning: {e}")
760
+
761
+ resized_image = resize_image(input_image)
762
+ processed_last = None
763
+ if last_image_input:
764
+ processed_last = resize_and_crop_to_match(last_image_input, resized_image)
765
+
766
+ print(f"I2V: size={resized_image.size}, steps={int(steps)}, guidance={guidance_scale}/{guidance_scale_2}")
767
+
768
+ result = i2v_pipe(
769
+ image=resized_image,
770
+ last_image=processed_last,
771
+ prompt=prompt,
772
+ negative_prompt=negative_prompt,
773
+ height=resized_image.height,
774
+ width=resized_image.width,
775
+ num_frames=num_frames,
776
+ guidance_scale=float(guidance_scale),
777
+ guidance_scale_2=float(guidance_scale_2),
778
+ num_inference_steps=int(steps),
779
+ generator=torch.Generator(device="cuda").manual_seed(int(current_seed)),
780
+ output_type="np",
781
+ )
782
+
783
+ if lora_loaded:
784
+ lora_loader.unload_lora(i2v_pipe)
785
+
786
+ raw_frames = result.frames[0]
787
+ elapsed = time.time() - start
788
+ print(f"Generation took {elapsed:.1f}s ({len(raw_frames)} frames)")
789
+
790
+ frame_factor = frame_multiplier // FIXED_FPS
791
+ if frame_factor > 1:
792
+ rife_model.device()
793
+ rife_model.flownet = rife_model.flownet.half()
794
+ final_frames = interpolate_bits(raw_frames, multiplier=int(frame_factor))
795
+ else:
796
+ final_frames = list(raw_frames)
797
+ final_fps = FIXED_FPS * max(1, frame_factor)
798
+
799
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
800
+ video_path = tmpfile.name
801
+ export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
802
+ return video_path, task_id
803
+
804
+ # ============ Generate ============
805
+ def generate_video(
806
+ task_type, input_image, input_video, mask_video, prompt,
807
+ lora_groups, duration_seconds, frame_multiplier,
808
+ steps, guidance_scale, guidance_scale_2,
809
+ negative_prompt, quality, seed, randomize_seed,
810
+ scheduler, flow_shift, last_image, display_result,
811
+ reference_image, grow_pixels,
812
+ progress=gr.Progress(track_tqdm=True),
813
+ ):
814
+ if not prompt or not prompt.strip():
815
+ raise gr.Error("Enter a prompt / 请输入提示词")
816
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
817
+ video_path, task_id = run_inference(
818
+ task_type, input_image, input_video, mask_video, prompt, negative_prompt,
819
+ duration_seconds, steps, guidance_scale, guidance_scale_2,
820
+ current_seed, scheduler, flow_shift, frame_multiplier,
821
+ quality, last_image, lora_groups,
822
+ reference_image=reference_image, grow_pixels=grow_pixels,
823
+ )
824
+ print(f"Done: {task_id}")
825
+ return (video_path if display_result else None), video_path, current_seed
826
+
827
+ # ============ UI ============
828
+ CSS = """
829
+ #hidden-timestamp { opacity: 0; height: 0; width: 0; margin: 0; padding: 0; overflow: hidden; position: absolute; }
830
+ """
831
+
832
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo:
833
+ gr.Markdown("## WAN 2.2 Multi-Task Video Generation / 多任务视频生成")
834
+ gr.Markdown("#### I2V (Lightning 6-step) · T2V (Lightning 14B 4-step) · V2V (3-Step: SAM2→Composite→VACE)")
835
+ gr.Markdown("---")
836
+
837
+ task_type = gr.Radio(
838
+ choices=[
839
+ "I2V (图生视频 / Image-to-Video)",
840
+ "T2V (文生视频 / Text-to-Video)",
841
+ "V2V (视频生视频 / Video-to-Video)",
842
+ ],
843
+ value="I2V (图生视频 / Image-to-Video)",
844
+ label="Task Type / 任务类型",
845
+ )
846
+
847
+ with gr.Row():
848
+ with gr.Column():
849
+ with gr.Group():
850
+ input_image = gr.Image(type="pil", label="Input Image / 输入图片 (I2V)", sources=["upload", "clipboard"])
851
+ with gr.Group():
852
+ input_video = gr.Video(label="Source Video / 原视频 (V2V)", sources=["upload"], visible=False, interactive=True)
853
+ with gr.Group():
854
+ mask_video = gr.Video(label="Mask Video / 遮罩视频 (V2V, 白色=编辑区域)", sources=["upload"], visible=False, interactive=True)
855
+ v2v_guide = gr.Markdown(
856
+ value="""### 📖 V2V 三步流水线 / 3-Step V2V Pipeline
857
+
858
+ **Step 1 — SAM2 分割**: 上传原视频 → 提取第一帧 → 点击标记区域 → 生成遮罩视频
859
+ **Step 2 — 自动合成**: 原视频 + 遮罩 → GrowMask扩展边界 + ImageComposite合成(自动完成)
860
+ **Step 3 — VACE 生成**: 合成视频 + 遮罩 + 参考图 + Prompt → 最终成品视频
861
+
862
+ 💡 也可跳过 Step 1,直接上传自己的遮罩视频(白色=编辑区域)
863
+ """,
864
+ visible=False,
865
+ )
866
+ with gr.Group(visible=False) as v2v_mask_tools:
867
+ first_frame_display = gr.Image(label="第一帧预览 / First Frame (点击标记区域)", type="pil", interactive=False)
868
+ points_store = gr.State(value=[])
869
+ points_display = gr.Textbox(label="标记点 / Points", value="无标记 / No points", interactive=False)
870
+ with gr.Row():
871
+ point_mode = gr.Radio(choices=["include (编辑)", "exclude (排除)"], value="include (编辑)", label="点击模式")
872
+ with gr.Row():
873
+ extract_frame_btn = gr.Button("📷 提取第一帧 / Extract First Frame", variant="secondary")
874
+ gen_mask_btn = gr.Button("🎭 生成遮罩 / Generate Mask (SAM2)", variant="primary")
875
+ clear_points_btn = gr.Button("🗑️ 清除标记 / Clear Points")
876
+ with gr.Accordion("🖼️ V2V 高级选项 / V2V Advanced", open=True):
877
+ reference_image = gr.Image(type="pil", label="参考图 / Reference Image (控制编辑区域的目标外观)", sources=["upload", "clipboard"])
878
+ grow_pixels_sl = gr.Slider(minimum=0, maximum=30, step=1, value=5, label="GrowMask / 遮罩扩展 (像素)", info="扩展遮罩边界,让编辑区域过渡更自然")
879
+
880
+ prompt_input = gr.Textbox(
881
+ label="Prompt / 提示词", value="",
882
+ placeholder="Describe the video... / 描述你想生成的视频...", lines=3,
883
+ )
884
+ duration_slider = gr.Slider(
885
+ minimum=0.5, maximum=15, step=0.5, value=3,
886
+ label="Duration / 时长 (seconds/秒)",
887
+ info="Max ~15s (241 frames @16fps) / 最大约15秒",
888
+ )
889
+ frame_multi = gr.Dropdown(choices=[16, 32, 64], value=16, label="Output FPS / 输出帧率", info="RIFE interpolation / RIFE插帧")
890
+
891
+ with gr.Accordion("⚙️ Advanced Settings / 高级设置", open=False):
892
+ last_image = gr.Image(type="pil", label="Last Frame / 末帧 (Optional)", sources=["upload", "clipboard"])
893
+ negative_prompt_input = gr.Textbox(label="Negative Prompt / 负面提示词", value=default_negative_prompt, lines=3)
894
+ with gr.Row():
895
+ steps_slider = gr.Slider(minimum=1, maximum=50, step=1, value=6, label="Steps / 步数", info="I2V: 4-8 | T2V: 4-8 | V2V: 25-50")
896
+ quality_sl = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Quality / 质量")
897
+ with gr.Row():
898
+ guidance_h = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance High / 引导(高噪声)")
899
+ guidance_l = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Low / 引导(低噪声)")
900
+ with gr.Row():
901
+ scheduler_dd = gr.Dropdown(choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler / 调度器")
902
+ flow_shift_sl = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift / 流偏移")
903
+ with gr.Row():
904
+ seed_sl = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed / 种子")
905
+ random_seed_cb = gr.Checkbox(label="Random / 随机", value=True)
906
+ lora_dd = gr.Dropdown(choices=lora_loader.get_lora_choices(), label="LoRA (I2V only / 仅I2V)", multiselect=True, info="From WAN2.2_LoraSet_NSFW")
907
+ display_cb = gr.Checkbox(label="Display / 显示", value=True)
908
+
909
+ generate_btn = gr.Button("🎬 Generate / 生成视频", variant="primary", size="lg")
910
+
911
+ with gr.Column():
912
+ video_output = gr.Video(label="Generated Video / 生成的视频", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video")
913
+ with gr.Row():
914
+ grab_frame_btn = gr.Button("📸 Use Frame / 使用帧", variant="secondary")
915
+ timestamp_box = gr.Number(value=0, label="Timestamp", visible=False, elem_id="hidden-timestamp")
916
+ file_output = gr.File(label="Download / 下载")
917
+
918
+ def update_task_ui(task):
919
+ is_v2v = "V2V" in task
920
+ is_t2v = "T2V" in task
921
+ if is_t2v:
922
+ return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
923
+ gr.update(visible=False), gr.update(visible=False),
924
+ gr.update(value=4), gr.update(value=1.0), gr.update(value=1.0))
925
+ elif is_v2v:
926
+ return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True),
927
+ gr.update(visible=True), gr.update(visible=True),
928
+ gr.update(value=30), gr.update(value=5.0), gr.update(value=1.0))
929
+ else:
930
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
931
+ gr.update(visible=False), gr.update(visible=False),
932
+ gr.update(value=6), gr.update(value=1.0), gr.update(value=1.0))
933
+
934
+ task_type.change(update_task_ui, inputs=[task_type], outputs=[input_image, input_video, mask_video, v2v_guide, v2v_mask_tools, steps_slider, guidance_h, guidance_l])
935
+
936
+ # V2V mask generation callbacks
937
+ def on_extract_first_frame(video):
938
+ vpath = extract_video_path(video)
939
+ if not vpath or not os.path.exists(vpath):
940
+ raise gr.Error("请先上传视频 / Upload video first")
941
+ frame = extract_first_frame_from_video(vpath)
942
+ if frame is None:
943
+ raise gr.Error("无法提取第一帧 / Failed to extract first frame")
944
+ return frame, [], "无标记 / No points"
945
+
946
+ def on_click_frame(img, points, mode, evt: gr.SelectData):
947
+ if img is None:
948
+ return img, points, "请先提取第一帧 / Extract first frame first"
949
+ x, y = evt.index
950
+ label = 1 if "include" in mode else 0
951
+ points.append({"x": x, "y": y, "label": label})
952
+ # Draw points on image
953
+ display_img = img.copy()
954
+ draw = __import__('PIL').ImageDraw.Draw(display_img)
955
+ for p in points:
956
+ color = (0, 255, 0) if p["label"] == 1 else (255, 0, 0)
957
+ r = 8
958
+ draw.ellipse([p["x"]-r, p["y"]-r, p["x"]+r, p["y"]+r], fill=color, outline="white", width=2)
959
+ info = f"{len([p for p in points if p['label']==1])} include, {len([p for p in points if p['label']==0])} exclude"
960
+ return display_img, points, info
961
+
962
+ def on_clear_points(original_video):
963
+ vpath = extract_video_path(original_video)
964
+ if vpath and os.path.exists(vpath):
965
+ frame = extract_first_frame_from_video(vpath)
966
+ return frame, [], "无标记 / No points"
967
+ return None, [], "无标记 / No points"
968
+
969
+ def on_generate_mask(video, points):
970
+ import json
971
+ vpath = extract_video_path(video)
972
+ if not vpath:
973
+ raise gr.Error("请先上传视频 / Upload video first")
974
+ if not points:
975
+ raise gr.Error("请先在第一帧上点击标记 / Click on first frame to mark areas")
976
+ mask_path = generate_mask_video(vpath, json.dumps(points))
977
+ return mask_path
978
+
979
+ extract_frame_btn.click(fn=on_extract_first_frame, inputs=[input_video], outputs=[first_frame_display, points_store, points_display])
980
+ first_frame_display.select(fn=on_click_frame, inputs=[first_frame_display, points_store, point_mode], outputs=[first_frame_display, points_store, points_display])
981
+ clear_points_btn.click(fn=on_clear_points, inputs=[input_video], outputs=[first_frame_display, points_store, points_display])
982
+ gen_mask_btn.click(fn=on_generate_mask, inputs=[input_video, points_store], outputs=[mask_video])
983
+ generate_btn.click(
984
+ fn=generate_video,
985
+ inputs=[task_type, input_image, input_video, mask_video, prompt_input, lora_dd, duration_slider, frame_multi,
986
+ steps_slider, guidance_h, guidance_l, negative_prompt_input, quality_sl, seed_sl, random_seed_cb,
987
+ scheduler_dd, flow_shift_sl, last_image, display_cb,
988
+ reference_image, grow_pixels_sl],
989
+ outputs=[video_output, file_output, seed_sl],
990
+ )
991
+ grab_frame_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
992
+ timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image])
993
+
994
+ if __name__ == "__main__":
995
+ demo.queue().launch(mcp_server=True, show_error=True)
kill_bill.jpeg ADDED

Git LFS Details

  • SHA256: d1db15fcc022a6c639d14d4b246c40729af2873ca81d4acf7b48d36d62b8d864
  • Pointer size: 131 Bytes
  • Size of remote file: 240 kB
lora_loader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA Loader for WAN 2.2 - references files from lkzd7/WAN2.2_LoraSet_NSFW
3
+ """
4
+ import urllib.parse
5
+ import re
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ LORA_REPO = "lkzd7/WAN2.2_LoraSet_NSFW"
9
+ HF_TOKEN = None
10
+
11
+ LORA_FILES = [
12
+ "Blink_Squatting_Cowgirl_Position_I2V_HIGH.safetensors",
13
+ "Blink_Squatting_Cowgirl_Position_I2V_LOW.safetensors",
14
+ "PENISLORA_22_i2v_HIGH_e320.safetensors",
15
+ "PENISLORA_22_i2v_LOW_e496.safetensors",
16
+ "Pornmaster_wan 2.2_14b_I2V_bukkake_v1.4_high_noise.safetensors",
17
+ "Pornmaster_wan 2.2_14b_I2V_bukkake_v1.4_low_noise.safetensors",
18
+ "W22_Multiscene_Photoshoot_Softcore_i2v_HN.safetensors",
19
+ "W22_Multiscene_Photoshoot_Softcore_i2v_LN.safetensors",
20
+ "WAN-2.2-I2V-Double-Blowjob-HIGH-v1.safetensors",
21
+ "WAN-2.2-I2V-Double-Blowjob-LOW-v1.safetensors",
22
+ "WAN-2.2-I2V-HandjobBlowjobCombo-HIGH-v1.safetensors",
23
+ "WAN-2.2-I2V-HandjobBlowjobCombo-LOW-v1.safetensors",
24
+ "WAN-2.2-I2V-SensualTeasingBlowjob-HIGH-v1.safetensors",
25
+ "WAN-2.2-I2V-SensualTeasingBlowjob-LOW-v1.safetensors",
26
+ "iGOON_Blink_Blowjob_I2V_HIGH.safetensors",
27
+ "iGOON_Blink_Blowjob_I2V_LOW.safetensors",
28
+ "iGoon - Blink_Front_Doggystyle_I2V_HIGH.safetensors",
29
+ "iGoon - Blink_Front_Doggystyle_I2V_LOW.safetensors",
30
+ "iGoon - Blink_Missionary_I2V_HIGH.safetensors",
31
+ "iGoon - Blink_Missionary_I2V_LOW v2.safetensors",
32
+ "iGoon - Blink_Missionary_I2V_LOW.safetensors",
33
+ "iGoon%20-%20Blink_Back_Doggystyle_HIGH.safetensors",
34
+ "iGoon%20-%20Blink_Back_Doggystyle_LOW.safetensors",
35
+ "iGoon%20-%20Blink_Facial_I2V_HIGH.safetensors",
36
+ "iGoon%20-%20Blink_Facial_I2V_LOW.safetensors",
37
+ "iGoon_Blink_Missionary_I2V_HIGH v2.safetensors",
38
+ "iGoon_Blink_Titjob_I2V_HIGH.safetensors",
39
+ "iGoon_Blink_Titjob_I2V_LOW.safetensors",
40
+ "lips-bj_high_noise.safetensors",
41
+ "lips-bj_low_noise.safetensors",
42
+ "mql_casting_sex_doggy_kneel_diagonally_behind_vagina_wan22_i2v_v1_high_noise.safetensors",
43
+ "mql_casting_sex_doggy_kneel_diagonally_behind_vagina_wan22_i2v_v1_low_noise.safetensors",
44
+ "mql_casting_sex_reverse_cowgirl_lie_front_vagina_wan22_i2v_v1_high_noise.safetensors",
45
+ "mql_casting_sex_reverse_cowgirl_lie_front_vagina_wan22_i2v_v1_low_noise.safetensors",
46
+ "mql_casting_sex_spoon_wan22_i2v_v1_high_noise.safetensors",
47
+ "mql_casting_sex_spoon_wan22_i2v_v1_low_noise.safetensors",
48
+ "mql_doggy_a_wan22_t2v_v1_high_noise .safetensors",
49
+ "mql_doggy_a_wan22_t2v_v1_low_noise.safetensors",
50
+ "mql_massage_tits_wan22_i2v_v1_high_noise.safetensors",
51
+ "mql_massage_tits_wan22_i2v_v1_low_noise.safetensors",
52
+ "mql_panties_aside_wan22_i2v_v1_high_noise.safetensors",
53
+ "mql_panties_aside_wan22_i2v_v1_low_noise.safetensors",
54
+ "mqlspn_a_wan22_t2v_v1_high_noise.safetensors",
55
+ "mqlspn_a_wan22_t2v_v1_low_noise.safetensors",
56
+ "sfbehind_v2.1_high_noise.safetensors",
57
+ "sfbehind_v2.1_low_noise.safetensors",
58
+ "sid3l3g_transition_v2.0_H.safetensors",
59
+ "sid3l3g_transition_v2.0_L.safetensors",
60
+ "wan2.2_i2v_high_ulitmate_pussy_asshole.safetensors",
61
+ "wan2.2_i2v_low_ulitmate_pussy_asshole.safetensors",
62
+ "wan22-mouthfull-140epoc-high-k3nk.safetensors",
63
+ "wan22-mouthfull-152epoc-low-k3nk.safetensors",
64
+ ]
65
+
66
+ LORA_PAIRS = {}
67
+ for f in LORA_FILES:
68
+ name = urllib.parse.unquote(f).replace(".safetensors", "")
69
+ is_high = bool(re.search(r'(high|HN|_H\b)', name, re.IGNORECASE))
70
+ is_low = bool(re.search(r'(low|LN|_L\b)', name, re.IGNORECASE))
71
+ group = re.sub(r'[\s_-]*(high|low|noise|HN|LN)([\s_-]*noise)?[\s_-]*(v?\d+(\.\d+)?)?\s*$', '', name, flags=re.IGNORECASE).strip()
72
+ group = re.sub(r'[\s_]+$', '', group)
73
+ if group not in LORA_PAIRS:
74
+ LORA_PAIRS[group] = {"HIGH": None, "LOW": None}
75
+ if is_high:
76
+ LORA_PAIRS[group]["HIGH"] = f
77
+ elif is_low:
78
+ LORA_PAIRS[group]["LOW"] = f
79
+
80
+
81
+ def get_lora_choices():
82
+ choices = []
83
+ for group in sorted(LORA_PAIRS.keys()):
84
+ p = LORA_PAIRS[group]
85
+ if p["HIGH"] and p["LOW"]:
86
+ choices.append(group)
87
+ elif p["HIGH"]:
88
+ choices.append(f"{group} (HIGH only)")
89
+ elif p["LOW"]:
90
+ choices.append(f"{group} (LOW only)")
91
+ return choices
92
+
93
+
94
+ def download_lora(group_name):
95
+ if not group_name:
96
+ return None, None
97
+ clean_name = re.sub(r'\s*\(HIGH only\)|\s*\(LOW only\)', '', group_name)
98
+ if clean_name not in LORA_PAIRS:
99
+ return None, None
100
+ pair = LORA_PAIRS[clean_name]
101
+ high_path, low_path = None, None
102
+ if pair["HIGH"]:
103
+ high_path = hf_hub_download(LORA_REPO, pair["HIGH"], token=HF_TOKEN)
104
+ if pair["LOW"]:
105
+ low_path = hf_hub_download(LORA_REPO, pair["LOW"], token=HF_TOKEN)
106
+ return high_path, low_path
107
+
108
+
109
+ def load_lora_to_pipe(pipe, group_name, adapter_name="lora"):
110
+ high_path, low_path = download_lora(group_name)
111
+ if high_path and low_path:
112
+ pipe.load_lora_weights(high_path, adapter_name=f"{adapter_name}_high")
113
+ pipe.load_lora_weights(low_path, adapter_name=f"{adapter_name}_low")
114
+ print(f"Loaded LoRA pair: {group_name}")
115
+ return True
116
+ elif high_path:
117
+ pipe.load_lora_weights(high_path, adapter_name=adapter_name)
118
+ print(f"Loaded LoRA: {group_name}")
119
+ return True
120
+ return False
121
+
122
+
123
+ def unload_lora(pipe):
124
+ try:
125
+ pipe.unload_lora_weights()
126
+ except:
127
+ pass
model/loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EPE(nn.Module):
11
+ def __init__(self):
12
+ super(EPE, self).__init__()
13
+
14
+ def forward(self, flow, gt, loss_mask):
15
+ loss_map = (flow - gt.detach()) ** 2
16
+ loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
17
+ return (loss_map * loss_mask)
18
+
19
+
20
+ class Ternary(nn.Module):
21
+ def __init__(self):
22
+ super(Ternary, self).__init__()
23
+ patch_size = 7
24
+ out_channels = patch_size * patch_size
25
+ self.w = np.eye(out_channels).reshape(
26
+ (patch_size, patch_size, 1, out_channels))
27
+ self.w = np.transpose(self.w, (3, 2, 0, 1))
28
+ self.w = torch.tensor(self.w).float().to(device)
29
+
30
+ def transform(self, img):
31
+ patches = F.conv2d(img, self.w, padding=3, bias=None)
32
+ transf = patches - img
33
+ transf_norm = transf / torch.sqrt(0.81 + transf**2)
34
+ return transf_norm
35
+
36
+ def rgb2gray(self, rgb):
37
+ r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39
+ return gray
40
+
41
+ def hamming(self, t1, t2):
42
+ dist = (t1 - t2) ** 2
43
+ dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44
+ return dist_norm
45
+
46
+ def valid_mask(self, t, padding):
47
+ n, _, h, w = t.size()
48
+ inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49
+ mask = F.pad(inner, [padding] * 4)
50
+ return mask
51
+
52
+ def forward(self, img0, img1):
53
+ img0 = self.transform(self.rgb2gray(img0))
54
+ img1 = self.transform(self.rgb2gray(img1))
55
+ return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
+
57
+
58
+ class SOBEL(nn.Module):
59
+ def __init__(self):
60
+ super(SOBEL, self).__init__()
61
+ self.kernelX = torch.tensor([
62
+ [1, 0, -1],
63
+ [2, 0, -2],
64
+ [1, 0, -1],
65
+ ]).float()
66
+ self.kernelY = self.kernelX.clone().T
67
+ self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68
+ self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
+
70
+ def forward(self, pred, gt):
71
+ N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72
+ img_stack = torch.cat(
73
+ [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74
+ sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75
+ sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76
+ pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77
+ pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
+
79
+ L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80
+ loss = (L1X+L1Y)
81
+ return loss
82
+
83
+ class MeanShift(nn.Conv2d):
84
+ def __init__(self, data_mean, data_std, data_range=1, norm=True):
85
+ c = len(data_mean)
86
+ super(MeanShift, self).__init__(c, c, kernel_size=1)
87
+ std = torch.Tensor(data_std)
88
+ self.weight.data = torch.eye(c).view(c, c, 1, 1)
89
+ if norm:
90
+ self.weight.data.div_(std.view(c, 1, 1, 1))
91
+ self.bias.data = -1 * data_range * torch.Tensor(data_mean)
92
+ self.bias.data.div_(std)
93
+ else:
94
+ self.weight.data.mul_(std.view(c, 1, 1, 1))
95
+ self.bias.data = data_range * torch.Tensor(data_mean)
96
+ self.requires_grad = False
97
+
98
+ class VGGPerceptualLoss(torch.nn.Module):
99
+ def __init__(self, rank=0):
100
+ super(VGGPerceptualLoss, self).__init__()
101
+ blocks = []
102
+ pretrained = True
103
+ self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
104
+ self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
105
+ for param in self.parameters():
106
+ param.requires_grad = False
107
+
108
+ def forward(self, X, Y, indices=None):
109
+ X = self.normalize(X)
110
+ Y = self.normalize(Y)
111
+ indices = [2, 7, 12, 21, 30]
112
+ weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
113
+ k = 0
114
+ loss = 0
115
+ for i in range(indices[-1]):
116
+ X = self.vgg_pretrained_features[i](X)
117
+ Y = self.vgg_pretrained_features[i](Y)
118
+ if (i+1) in indices:
119
+ loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
120
+ k += 1
121
+ return loss
122
+
123
+ if __name__ == '__main__':
124
+ img0 = torch.zeros(3, 3, 256, 256).float().to(device)
125
+ img1 = torch.tensor(np.random.normal(
126
+ 0, 1, (3, 3, 256, 256))).float().to(device)
127
+ ternary_loss = Ternary()
128
+ print(ternary_loss(img0, img1).shape)
model/pytorch_msssim/__init__.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device).type_as(img1)
48
+
49
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
50
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
51
+
52
+ mu1_sq = mu1.pow(2)
53
+ mu2_sq = mu2.pow(2)
54
+ mu1_mu2 = mu1 * mu2
55
+
56
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
57
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
58
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
59
+
60
+ C1 = (0.01 * L) ** 2
61
+ C2 = (0.03 * L) ** 2
62
+
63
+ v1 = 2.0 * sigma12 + C2
64
+ v2 = sigma1_sq + sigma2_sq + C2
65
+ cs = torch.mean(v1 / v2) # contrast sensitivity
66
+
67
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
68
+
69
+ if size_average:
70
+ ret = ssim_map.mean()
71
+ else:
72
+ ret = ssim_map.mean(1).mean(1).mean(1)
73
+
74
+ if full:
75
+ return ret, cs
76
+ return ret
77
+
78
+
79
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
80
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
81
+ if val_range is None:
82
+ if torch.max(img1) > 128:
83
+ max_val = 255
84
+ else:
85
+ max_val = 1
86
+
87
+ if torch.min(img1) < -0.5:
88
+ min_val = -1
89
+ else:
90
+ min_val = 0
91
+ L = max_val - min_val
92
+ else:
93
+ L = val_range
94
+
95
+ padd = 0
96
+ (_, _, height, width) = img1.size()
97
+ if window is None:
98
+ real_size = min(window_size, height, width)
99
+ window = create_window_3d(real_size, channel=1).to(img1.device).type_as(img1)
100
+ # Channel is set to 1 since we consider color images as volumetric images
101
+
102
+ img1 = img1.unsqueeze(1)
103
+ img2 = img2.unsqueeze(1)
104
+
105
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
106
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
107
+
108
+ mu1_sq = mu1.pow(2)
109
+ mu2_sq = mu2.pow(2)
110
+ mu1_mu2 = mu1 * mu2
111
+
112
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
113
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
114
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
115
+
116
+ C1 = (0.01 * L) ** 2
117
+ C2 = (0.03 * L) ** 2
118
+
119
+ v1 = 2.0 * sigma12 + C2
120
+ v2 = sigma1_sq + sigma2_sq + C2
121
+ cs = torch.mean(v1 / v2) # contrast sensitivity
122
+
123
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
124
+
125
+ if size_average:
126
+ ret = ssim_map.mean()
127
+ else:
128
+ ret = ssim_map.mean(1).mean(1).mean(1)
129
+
130
+ if full:
131
+ return ret, cs
132
+ return ret
133
+
134
+
135
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
136
+ device = img1.device
137
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device).type_as(img1)
138
+ levels = weights.size()[0]
139
+ mssim = []
140
+ mcs = []
141
+ for _ in range(levels):
142
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
143
+ mssim.append(sim)
144
+ mcs.append(cs)
145
+
146
+ img1 = F.avg_pool2d(img1, (2, 2))
147
+ img2 = F.avg_pool2d(img2, (2, 2))
148
+
149
+ mssim = torch.stack(mssim)
150
+ mcs = torch.stack(mcs)
151
+
152
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
153
+ if normalize:
154
+ mssim = (mssim + 1) / 2
155
+ mcs = (mcs + 1) / 2
156
+
157
+ pow1 = mcs ** weights
158
+ pow2 = mssim ** weights
159
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
160
+ output = torch.prod(pow1[:-1] * pow2[-1])
161
+ return output
162
+
163
+
164
+ # Classes to re-use window
165
+ class SSIM(torch.nn.Module):
166
+ def __init__(self, window_size=11, size_average=True, val_range=None):
167
+ super(SSIM, self).__init__()
168
+ self.window_size = window_size
169
+ self.size_average = size_average
170
+ self.val_range = val_range
171
+
172
+ # Assume 3 channel for SSIM
173
+ self.channel = 3
174
+ self.window = create_window(window_size, channel=self.channel)
175
+
176
+ def forward(self, img1, img2):
177
+ (_, channel, _, _) = img1.size()
178
+
179
+ if channel == self.channel and self.window.dtype == img1.dtype:
180
+ window = self.window
181
+ else:
182
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
183
+ self.window = window
184
+ self.channel = channel
185
+
186
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
187
+ dssim = (1 - _ssim) / 2
188
+ return dssim
189
+
190
+ class MSSSIM(torch.nn.Module):
191
+ def __init__(self, window_size=11, size_average=True, channel=3):
192
+ super(MSSSIM, self).__init__()
193
+ self.window_size = window_size
194
+ self.size_average = size_average
195
+ self.channel = channel
196
+
197
+ def forward(self, img1, img2):
198
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
model/warplayer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(tenFlow.device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ grid = backwarp_tenGrid[k].type_as(tenFlow)
22
+
23
+ g = (grid + tenFlow).permute(0, 2, 3, 1)
24
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ unzip
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/linoytsaban/diffusers.git@wan22-loras
2
+
3
+ transformers<5
4
+ accelerate
5
+ safetensors
6
+ sentencepiece
7
+ peft
8
+ ftfy
9
+ imageio
10
+ imageio-ffmpeg
11
+ opencv-python
12
+ torchao==0.11.0
13
+ sam2
14
+
15
+ numpy
16
+ torchvision
wan22_input_2.jpg ADDED

Git LFS Details

  • SHA256: e5f312a03278dc2009fc02e61b1cd3f743ee1abd12ae184deb6ea504f8676a8a
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
wan_controlnet.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
8
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
9
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.transformers.transformer_wan import (
12
+ WanTimeTextImageEmbedding,
13
+ WanRotaryPosEmbed,
14
+ WanTransformerBlock
15
+ )
16
+
17
+
18
+ def zero_module(module):
19
+ for p in module.parameters():
20
+ nn.init.zeros_(p)
21
+ return module
22
+
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ def zero_module(module):
27
+ for p in module.parameters():
28
+ nn.init.zeros_(p)
29
+ return module
30
+
31
+
32
+ class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
33
+ r"""
34
+ A Controlnet Transformer model for video-like data used in the Wan model.
35
+
36
+ Args:
37
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
38
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
39
+ num_attention_heads (`int`, defaults to `40`):
40
+ Fixed length for text embeddings.
41
+ attention_head_dim (`int`, defaults to `128`):
42
+ The number of channels in each head.
43
+ vae_channels (`int`, defaults to `16`):
44
+ The number of channels in the vae input.
45
+ in_channels (`int`, defaults to `16`):
46
+ The number of channels in the controlnet input.
47
+ text_dim (`int`, defaults to `512`):
48
+ Input dimension for text embeddings.
49
+ freq_dim (`int`, defaults to `256`):
50
+ Dimension for sinusoidal time embeddings.
51
+ ffn_dim (`int`, defaults to `13824`):
52
+ Intermediate dimension in feed-forward network.
53
+ num_layers (`int`, defaults to `40`):
54
+ The number of layers of transformer blocks to use.
55
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
56
+ Window size for local attention (-1 indicates global attention).
57
+ cross_attn_norm (`bool`, defaults to `True`):
58
+ Enable cross-attention normalization.
59
+ qk_norm (`bool`, defaults to `True`):
60
+ Enable query/key normalization.
61
+ eps (`float`, defaults to `1e-6`):
62
+ Epsilon value for normalization layers.
63
+ add_img_emb (`bool`, defaults to `False`):
64
+ Whether to use img_emb.
65
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
66
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
67
+ downscale_coef (`int`, *optional*, defaults to `8`):
68
+ Coeficient for downscale controlnet input video.
69
+ out_proj_dim (`int`, *optional*, defaults to `128 * 12`):
70
+ Output projection dimention for last linear layers.
71
+ """
72
+
73
+ _supports_gradient_checkpointing = True
74
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
75
+ _no_split_modules = ["WanTransformerBlock"]
76
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
77
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ patch_size: Tuple[int] = (1, 2, 2),
83
+ num_attention_heads: int = 40,
84
+ attention_head_dim: int = 128,
85
+ in_channels: int = 3,
86
+ vae_channels: int = 16,
87
+ text_dim: int = 4096,
88
+ freq_dim: int = 256,
89
+ ffn_dim: int = 13824,
90
+ num_layers: int = 20,
91
+ cross_attn_norm: bool = True,
92
+ qk_norm: Optional[str] = "rms_norm_across_heads",
93
+ eps: float = 1e-6,
94
+ image_dim: Optional[int] = None,
95
+ added_kv_proj_dim: Optional[int] = None,
96
+ rope_max_seq_len: int = 1024,
97
+ downscale_coef: int = 8,
98
+ out_proj_dim: int = 128 * 12,
99
+ ) -> None:
100
+ super().__init__()
101
+
102
+ start_channels = in_channels * (downscale_coef ** 2)
103
+ input_channels = [start_channels, start_channels // 2, start_channels // 4]
104
+
105
+ self.control_encoder = nn.ModuleList([
106
+ ## Spatial compression with time awareness
107
+ nn.Sequential(
108
+ nn.Conv3d(
109
+ in_channels,
110
+ input_channels[0],
111
+ kernel_size=(3, downscale_coef + 1, downscale_coef + 1),
112
+ stride=(1, downscale_coef, downscale_coef),
113
+ padding=(1, downscale_coef // 2, downscale_coef // 2)
114
+ ),
115
+ nn.GELU(approximate="tanh"),
116
+ nn.GroupNorm(2, input_channels[0]),
117
+ ),
118
+ ## Spatio-Temporal compression with spatial awareness
119
+ nn.Sequential(
120
+ nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1),
121
+ nn.GELU(approximate="tanh"),
122
+ nn.GroupNorm(2, input_channels[1]),
123
+ ),
124
+ ## Temporal compression with spatial awareness
125
+ nn.Sequential(
126
+ nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1),
127
+ nn.GELU(approximate="tanh"),
128
+ nn.GroupNorm(2, input_channels[2]),
129
+ )
130
+ ])
131
+
132
+ inner_dim = num_attention_heads * attention_head_dim
133
+
134
+ # 1. Patch & position embedding
135
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
136
+ self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size)
137
+
138
+ # 2. Condition embeddings
139
+ # image_embedding_dim=1280 for I2V model
140
+ self.condition_embedder = WanTimeTextImageEmbedding(
141
+ dim=inner_dim,
142
+ time_freq_dim=freq_dim,
143
+ time_proj_dim=inner_dim * 6,
144
+ text_embed_dim=text_dim,
145
+ image_embed_dim=image_dim,
146
+ )
147
+ # 3. Transformer blocks
148
+ self.blocks = nn.ModuleList(
149
+ [
150
+ WanTransformerBlock(
151
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
152
+ )
153
+ for _ in range(num_layers)
154
+ ]
155
+ )
156
+
157
+ # 4 Controlnet modules
158
+ self.controlnet_blocks = nn.ModuleList([])
159
+
160
+ for _ in range(len(self.blocks)):
161
+ controlnet_block = nn.Linear(inner_dim, out_proj_dim)
162
+ controlnet_block = zero_module(controlnet_block)
163
+ self.controlnet_blocks.append(controlnet_block)
164
+
165
+ self.gradient_checkpointing = False
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ timestep: torch.LongTensor,
171
+ encoder_hidden_states: torch.Tensor,
172
+ controlnet_states: torch.Tensor,
173
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
174
+ return_dict: bool = True,
175
+ attention_kwargs: Optional[Dict[str, Any]] = None,
176
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
177
+ if attention_kwargs is not None:
178
+ attention_kwargs = attention_kwargs.copy()
179
+ lora_scale = attention_kwargs.pop("scale", 1.0)
180
+ else:
181
+ lora_scale = 1.0
182
+
183
+ if USE_PEFT_BACKEND:
184
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
185
+ scale_lora_layers(self, lora_scale)
186
+ else:
187
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
188
+ logger.warning(
189
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
190
+ )
191
+
192
+ rotary_emb = self.rope(hidden_states)
193
+
194
+ # 0. Controlnet encoder
195
+ for control_encoder_block in self.control_encoder:
196
+ controlnet_states = control_encoder_block(controlnet_states)
197
+ # print("+" * 50, hidden_states.shape, controlnet_states.shape)
198
+ hidden_states = torch.cat([hidden_states, controlnet_states], dim=1)
199
+
200
+ hidden_states = self.patch_embedding(hidden_states)
201
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
202
+
203
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
204
+ if timestep.ndim == 2:
205
+ ts_seq_len = timestep.shape[1]
206
+ timestep = timestep.flatten() # batch_size * seq_len
207
+ else:
208
+ ts_seq_len = None
209
+
210
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
211
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
212
+ )
213
+ if ts_seq_len is not None:
214
+ # batch_size, seq_len, 6, inner_dim
215
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
216
+ else:
217
+ # batch_size, 6, inner_dim
218
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
219
+
220
+ if encoder_hidden_states_image is not None:
221
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
222
+
223
+ # 4. Transformer blocks
224
+ controlnet_hidden_states = ()
225
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
226
+ for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
227
+ hidden_states = self._gradient_checkpointing_func(
228
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
229
+ )
230
+ controlnet_hidden_states += (controlnet_block(hidden_states),)
231
+ else:
232
+ for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
233
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
234
+ controlnet_hidden_states += (controlnet_block(hidden_states),)
235
+
236
+
237
+ if USE_PEFT_BACKEND:
238
+ # remove `lora_scale` from each PEFT layer
239
+ unscale_lora_layers(self, lora_scale)
240
+
241
+ if not return_dict:
242
+ return (controlnet_hidden_states,)
243
+
244
+ return Transformer2DModelOutput(sample=controlnet_hidden_states)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ parameters = {
249
+ "added_kv_proj_dim": None,
250
+ "attention_head_dim": 128,
251
+ "cross_attn_norm": True,
252
+ "eps": 1e-06,
253
+ "ffn_dim": 8960,
254
+ "freq_dim": 256,
255
+ "image_dim": None,
256
+ "in_channels": 3,
257
+ "num_attention_heads": 12,
258
+ "num_layers": 2,
259
+ "patch_size": [1, 2, 2],
260
+ "qk_norm": "rms_norm_across_heads",
261
+ "rope_max_seq_len": 1024,
262
+ "text_dim": 4096,
263
+ "downscale_coef": 8,
264
+ "out_proj_dim": 12 * 128,
265
+ "vae_channels": 16
266
+ }
267
+ controlnet = WanControlnet(**parameters)
268
+
269
+ hidden_states = torch.rand(1, 16, 13, 60, 90)
270
+ timestep = torch.tensor([1000]).repeat(17550).unsqueeze(0) #torch.randint(low=0, high=1000, size=(1,), dtype=torch.long)
271
+ encoder_hidden_states = torch.rand(1, 512, 4096)
272
+ controlnet_states = torch.rand(1, 3, 49, 480, 720)
273
+
274
+ controlnet_hidden_states = controlnet(
275
+ hidden_states=hidden_states,
276
+ timestep=timestep,
277
+ encoder_hidden_states=encoder_hidden_states,
278
+ controlnet_states=controlnet_states,
279
+ return_dict=False
280
+ )
281
+ print("Output states count", len(controlnet_hidden_states[0]))
282
+ for out_hidden_states in controlnet_hidden_states[0]:
283
+ print(out_hidden_states.shape)
284
+
wan_i2v_input.JPG ADDED

Git LFS Details

  • SHA256: 077e3d965090c9028c69c00931675f42e1acc815c6eb450ab291b3b72d211a8e
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
wan_t2v_controlnet_pipeline.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ # #
3
+ # # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # # you may not use this file except in compliance with the License.
5
+ # # You may obtain a copy of the License at
6
+ # #
7
+ # # http://www.apache.org/licenses/LICENSE-2.0
8
+ # #
9
+ # # Unless required by applicable law or agreed to in writing, software
10
+ # # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # # See the License for the specific language governing permissions and
13
+ # # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
18
+
19
+ import ftfy
20
+ import regex as re
21
+ import torch
22
+ import numpy as np
23
+ from PIL import Image
24
+ from torchvision import transforms
25
+ from transformers import AutoTokenizer, UMT5EncoderModel
26
+
27
+ from diffusers import WanTransformer3DModel
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
30
+ from diffusers.loaders import WanLoraLoaderMixin
31
+ from diffusers.models import AutoencoderKLWan
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from diffusers.video_processor import VideoProcessor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
38
+
39
+ from wan_transformer import CustomWanTransformer3DModel
40
+ from wan_controlnet import WanControlnet
41
+ from wan_teacache import TeaCache
42
+
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ def resize_for_crop(image, crop_h, crop_w):
54
+ img_h, img_w = image.shape[-2:]
55
+ if img_h >= crop_h and img_w >= crop_w:
56
+ coef = max(crop_h / img_h, crop_w / img_w)
57
+ elif img_h <= crop_h and img_w <= crop_w:
58
+ coef = max(crop_h / img_h, crop_w / img_w)
59
+ else:
60
+ coef = crop_h / img_h if crop_h > img_h else crop_w / img_w
61
+ out_h, out_w = int(img_h * coef), int(img_w * coef)
62
+ resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
63
+ return resized_image
64
+
65
+
66
+ def prepare_frames(input_images, video_size, do_resize=True, do_crop=True):
67
+ input_images = np.stack([np.array(x) for x in input_images])
68
+ images_tensor = torch.from_numpy(input_images).permute(0, 3, 1, 2) / 127.5 - 1
69
+ if do_resize:
70
+ images_tensor = [resize_for_crop(x, crop_h=video_size[0], crop_w=video_size[1]) for x in images_tensor]
71
+ if do_crop:
72
+ images_tensor = [transforms.functional.center_crop(x, video_size) for x in images_tensor]
73
+ if isinstance(images_tensor, list):
74
+ images_tensor = torch.stack(images_tensor)
75
+ return images_tensor.unsqueeze(0)
76
+
77
+
78
+ def prepare_controlnet_frames(controlnet_frames, height, width, dtype, device):
79
+ prepared_frames = prepare_frames(controlnet_frames, (height, width))
80
+ controlnet_encoded_frames = prepared_frames.to(dtype=dtype, device=device)
81
+ return controlnet_encoded_frames.permute(0, 2, 1, 3, 4).contiguous()
82
+
83
+ def basic_clean(text):
84
+ text = ftfy.fix_text(text)
85
+ text = html.unescape(html.unescape(text))
86
+ return text.strip()
87
+
88
+
89
+ def whitespace_clean(text):
90
+ text = re.sub(r"\s+", " ", text)
91
+ text = text.strip()
92
+ return text
93
+
94
+
95
+ def prompt_clean(text):
96
+ text = whitespace_clean(basic_clean(text))
97
+ return text
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
+ def retrieve_timesteps(
102
+ scheduler,
103
+ num_inference_steps: Optional[int] = None,
104
+ device: Optional[Union[str, torch.device]] = None,
105
+ timesteps: Optional[List[int]] = None,
106
+ sigmas: Optional[List[float]] = None,
107
+ **kwargs,
108
+ ):
109
+ r"""
110
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
111
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
112
+
113
+ Args:
114
+ scheduler (`SchedulerMixin`):
115
+ The scheduler to get timesteps from.
116
+ num_inference_steps (`int`):
117
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
118
+ must be `None`.
119
+ device (`str` or `torch.device`, *optional*):
120
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
121
+ timesteps (`List[int]`, *optional*):
122
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
123
+ `num_inference_steps` and `sigmas` must be `None`.
124
+ sigmas (`List[float]`, *optional*):
125
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
126
+ `num_inference_steps` and `timesteps` must be `None`.
127
+
128
+ Returns:
129
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
+ second element is the number of inference steps.
131
+ """
132
+ if timesteps is not None and sigmas is not None:
133
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
134
+ if timesteps is not None:
135
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accepts_timesteps:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" timestep schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ elif sigmas is not None:
145
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
146
+ if not accept_sigmas:
147
+ raise ValueError(
148
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
149
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
150
+ )
151
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
152
+ timesteps = scheduler.timesteps
153
+ num_inference_steps = len(timesteps)
154
+ else:
155
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ return timesteps, num_inference_steps
158
+
159
+
160
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
161
+ def retrieve_latents(
162
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
163
+ ):
164
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
165
+ return encoder_output.latent_dist.sample(generator)
166
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
167
+ return encoder_output.latent_dist.mode()
168
+ elif hasattr(encoder_output, "latents"):
169
+ return encoder_output.latents
170
+ else:
171
+ raise AttributeError("Could not access latents of provided encoder_output")
172
+
173
+
174
+ class WanTextToVideoControlnetPipeline(DiffusionPipeline, WanLoraLoaderMixin):
175
+ r"""
176
+ Pipeline for text-to-video generation using Wan.
177
+
178
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
179
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
180
+
181
+ Args:
182
+ tokenizer ([`T5Tokenizer`]):
183
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
184
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
185
+ text_encoder ([`T5EncoderModel`]):
186
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
187
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
188
+ transformer ([`WanTransformer3DModel`]):
189
+ Conditional Transformer to denoise the input latents.
190
+ scheduler ([`UniPCMultistepScheduler`]):
191
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
192
+ vae ([`AutoencoderKLWan`]):
193
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
194
+ """
195
+
196
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae->controlnet"
197
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
198
+ _optional_components = ["transformer_2"]
199
+
200
+ def __init__(
201
+ self,
202
+ tokenizer: AutoTokenizer,
203
+ text_encoder: UMT5EncoderModel,
204
+ transformer: CustomWanTransformer3DModel,
205
+ vae: AutoencoderKLWan,
206
+ controlnet: WanControlnet,
207
+ scheduler: FlowMatchEulerDiscreteScheduler,
208
+ transformer_2: WanTransformer3DModel = None,
209
+ boundary_ratio: Optional[float] = None,
210
+ expand_timesteps: bool = False,
211
+ ):
212
+ super().__init__()
213
+
214
+ self.register_modules(
215
+ vae=vae,
216
+ text_encoder=text_encoder,
217
+ tokenizer=tokenizer,
218
+ transformer=transformer,
219
+ controlnet=controlnet,
220
+ scheduler=scheduler,
221
+ transformer_2=transformer_2,
222
+ )
223
+ self.register_to_config(boundary_ratio=boundary_ratio)
224
+ self.register_to_config(expand_timesteps=expand_timesteps)
225
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
226
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
227
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
228
+
229
+ def _get_t5_prompt_embeds(
230
+ self,
231
+ prompt: Union[str, List[str]] = None,
232
+ num_videos_per_prompt: int = 1,
233
+ max_sequence_length: int = 226,
234
+ device: Optional[torch.device] = None,
235
+ dtype: Optional[torch.dtype] = None,
236
+ ):
237
+ device = device or self._execution_device
238
+ dtype = dtype or self.text_encoder.dtype
239
+
240
+ prompt = [prompt] if isinstance(prompt, str) else prompt
241
+ prompt = [prompt_clean(u) for u in prompt]
242
+ batch_size = len(prompt)
243
+
244
+ text_inputs = self.tokenizer(
245
+ prompt,
246
+ padding="max_length",
247
+ max_length=max_sequence_length,
248
+ truncation=True,
249
+ add_special_tokens=True,
250
+ return_attention_mask=True,
251
+ return_tensors="pt",
252
+ )
253
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
254
+ seq_lens = mask.gt(0).sum(dim=1).long()
255
+
256
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
257
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
258
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
259
+ prompt_embeds = torch.stack(
260
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
261
+ )
262
+
263
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
264
+ _, seq_len, _ = prompt_embeds.shape
265
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
267
+
268
+ return prompt_embeds
269
+
270
+ def encode_prompt(
271
+ self,
272
+ prompt: Union[str, List[str]],
273
+ negative_prompt: Optional[Union[str, List[str]]] = None,
274
+ do_classifier_free_guidance: bool = True,
275
+ num_videos_per_prompt: int = 1,
276
+ prompt_embeds: Optional[torch.Tensor] = None,
277
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
278
+ max_sequence_length: int = 226,
279
+ device: Optional[torch.device] = None,
280
+ dtype: Optional[torch.dtype] = None,
281
+ ):
282
+ r"""
283
+ Encodes the prompt into text encoder hidden states.
284
+
285
+ Args:
286
+ prompt (`str` or `List[str]`, *optional*):
287
+ prompt to be encoded
288
+ negative_prompt (`str` or `List[str]`, *optional*):
289
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
290
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
291
+ less than `1`).
292
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
293
+ Whether to use classifier free guidance or not.
294
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
295
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
296
+ prompt_embeds (`torch.Tensor`, *optional*):
297
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
298
+ provided, text embeddings will be generated from `prompt` input argument.
299
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
300
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
301
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302
+ argument.
303
+ device: (`torch.device`, *optional*):
304
+ torch device
305
+ dtype: (`torch.dtype`, *optional*):
306
+ torch dtype
307
+ """
308
+ device = device or self._execution_device
309
+
310
+ prompt = [prompt] if isinstance(prompt, str) else prompt
311
+ if prompt is not None:
312
+ batch_size = len(prompt)
313
+ else:
314
+ batch_size = prompt_embeds.shape[0]
315
+
316
+ if prompt_embeds is None:
317
+ prompt_embeds = self._get_t5_prompt_embeds(
318
+ prompt=prompt,
319
+ num_videos_per_prompt=num_videos_per_prompt,
320
+ max_sequence_length=max_sequence_length,
321
+ device=device,
322
+ dtype=dtype,
323
+ )
324
+
325
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
326
+ negative_prompt = negative_prompt or ""
327
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
328
+
329
+ if prompt is not None and type(prompt) is not type(negative_prompt):
330
+ raise TypeError(
331
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
332
+ f" {type(prompt)}."
333
+ )
334
+ elif batch_size != len(negative_prompt):
335
+ raise ValueError(
336
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
337
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
338
+ " the batch size of `prompt`."
339
+ )
340
+
341
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
342
+ prompt=negative_prompt,
343
+ num_videos_per_prompt=num_videos_per_prompt,
344
+ max_sequence_length=max_sequence_length,
345
+ device=device,
346
+ dtype=dtype,
347
+ )
348
+
349
+ return prompt_embeds, negative_prompt_embeds
350
+
351
+ def check_inputs(
352
+ self,
353
+ prompt,
354
+ negative_prompt,
355
+ height,
356
+ width,
357
+ prompt_embeds=None,
358
+ negative_prompt_embeds=None,
359
+ callback_on_step_end_tensor_inputs=None,
360
+ guidance_scale_2=None,
361
+ ):
362
+ if height % 16 != 0 or width % 16 != 0:
363
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
364
+
365
+ if callback_on_step_end_tensor_inputs is not None and not all(
366
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
367
+ ):
368
+ raise ValueError(
369
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
370
+ )
371
+
372
+ if prompt is not None and prompt_embeds is not None:
373
+ raise ValueError(
374
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
375
+ " only forward one of the two."
376
+ )
377
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
378
+ raise ValueError(
379
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
380
+ " only forward one of the two."
381
+ )
382
+ elif prompt is None and prompt_embeds is None:
383
+ raise ValueError(
384
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
385
+ )
386
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
387
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
388
+ elif negative_prompt is not None and (
389
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
390
+ ):
391
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
392
+
393
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
394
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
395
+
396
+ def prepare_latents(
397
+ self,
398
+ batch_size: int,
399
+ num_channels_latents: int = 16,
400
+ height: int = 480,
401
+ width: int = 832,
402
+ num_frames: int = 81,
403
+ dtype: Optional[torch.dtype] = None,
404
+ device: Optional[torch.device] = None,
405
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
406
+ latents: Optional[torch.Tensor] = None,
407
+ ) -> torch.Tensor:
408
+ if latents is not None:
409
+ return latents.to(device=device, dtype=dtype)
410
+
411
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
412
+ shape = (
413
+ batch_size,
414
+ num_channels_latents,
415
+ num_latent_frames,
416
+ int(height) // self.vae_scale_factor_spatial,
417
+ int(width) // self.vae_scale_factor_spatial,
418
+ )
419
+ if isinstance(generator, list) and len(generator) != batch_size:
420
+ raise ValueError(
421
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
422
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
423
+ )
424
+
425
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
426
+ return latents
427
+
428
+ @property
429
+ def guidance_scale(self):
430
+ return self._guidance_scale
431
+
432
+ @property
433
+ def do_classifier_free_guidance(self):
434
+ return self._guidance_scale > 1.0
435
+
436
+ @property
437
+ def num_timesteps(self):
438
+ return self._num_timesteps
439
+
440
+ @property
441
+ def current_timestep(self):
442
+ return self._current_timestep
443
+
444
+ @property
445
+ def interrupt(self):
446
+ return self._interrupt
447
+
448
+ @property
449
+ def attention_kwargs(self):
450
+ return self._attention_kwargs
451
+
452
+ @torch.no_grad()
453
+ def __call__(
454
+ self,
455
+ controlnet_frames: List[Image.Image] = None,
456
+ prompt: Union[str, List[str]] = None,
457
+ negative_prompt: Union[str, List[str]] = None,
458
+ height: int = 480,
459
+ width: int = 832,
460
+ num_frames: int = 81,
461
+ num_inference_steps: int = 50,
462
+ guidance_scale: float = 5.0,
463
+ guidance_scale_2: Optional[float] = None,
464
+ num_videos_per_prompt: Optional[int] = 1,
465
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
466
+ latents: Optional[torch.Tensor] = None,
467
+ controlnet_latents: Optional[torch.FloatTensor] = None,
468
+ prompt_embeds: Optional[torch.Tensor] = None,
469
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
470
+ output_type: Optional[str] = "np",
471
+ return_dict: bool = True,
472
+ attention_kwargs: Optional[Dict[str, Any]] = None,
473
+ callback_on_step_end: Optional[
474
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
475
+ ] = None,
476
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
477
+ max_sequence_length: int = 512,
478
+
479
+ controlnet_weight: float = 1.0,
480
+ controlnet_guidance_start: float = 0.0,
481
+ controlnet_guidance_end: float = 1.0,
482
+ controlnet_stride: int = 3,
483
+
484
+ teacache_state: Optional[TeaCache]= None,
485
+ teacache_treshold: float = 0.0,
486
+ ):
487
+ r"""
488
+ The call function to the pipeline for generation.
489
+
490
+ Args:
491
+ prompt (`str` or `List[str]`, *optional*):
492
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
493
+ negative_prompt (`str` or `List[str]`, *optional*):
494
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
495
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
496
+ height (`int`, defaults to `480`):
497
+ The height in pixels of the generated image.
498
+ width (`int`, defaults to `832`):
499
+ The width in pixels of the generated image.
500
+ num_frames (`int`, defaults to `81`):
501
+ The number of frames in the generated video.
502
+ num_inference_steps (`int`, defaults to `50`):
503
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
504
+ expense of slower inference.
505
+ guidance_scale (`float`, defaults to `5.0`):
506
+ Guidance scale as defined in [Classifier-Free Diffusion
507
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
508
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
509
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
510
+ the text `prompt`, usually at the expense of lower image quality.
511
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
512
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
513
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
514
+ and the pipeline's `boundary_ratio` are not None.
515
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
516
+ The number of images to generate per prompt.
517
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
518
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
519
+ generation deterministic.
520
+ latents (`torch.Tensor`, *optional*):
521
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
522
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
523
+ tensor is generated by sampling using the supplied random `generator`.
524
+ prompt_embeds (`torch.Tensor`, *optional*):
525
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
526
+ provided, text embeddings are generated from the `prompt` input argument.
527
+ output_type (`str`, *optional*, defaults to `"np"`):
528
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
529
+ return_dict (`bool`, *optional*, defaults to `True`):
530
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
531
+ attention_kwargs (`dict`, *optional*):
532
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
533
+ `self.processor` in
534
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
535
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
536
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
537
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
538
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
539
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
540
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
541
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
542
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
543
+ `._callback_tensor_inputs` attribute of your pipeline class.
544
+ max_sequence_length (`int`, defaults to `512`):
545
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
546
+ truncated. If the prompt is shorter, it will be padded to this length.
547
+ controlnet_weight (`float`, defaults to `0.8`):
548
+ Wigight for controlnet modules.
549
+ controlnet_guidance_start (`float`, defaults to `0.0`):
550
+ When start do control.
551
+ controlnet_guidance_end (`float`, defaults to `0.8`):
552
+ When finish do control.
553
+ controlnet_stride (`int`, defaults to `3`):
554
+ Stride for controlnet blocks.
555
+ Examples:
556
+
557
+ Returns:
558
+ [`~WanPipelineOutput`] or `tuple`:
559
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
560
+ the first element is a list with the generated images and the second element is a list of `bool`s
561
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
562
+ """
563
+ self.teacache = teacache_state or None
564
+ if (self.teacache is None) and (teacache_treshold > 0.0):
565
+ self.teacache = TeaCache(
566
+ num_inference_steps=num_inference_steps,
567
+ model_name="DEFAULT",
568
+ treshold=teacache_treshold
569
+ )
570
+
571
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
572
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
573
+
574
+ # 1. Check inputs. Raise error if not correct
575
+ self.check_inputs(
576
+ prompt,
577
+ negative_prompt,
578
+ height,
579
+ width,
580
+ prompt_embeds,
581
+ negative_prompt_embeds,
582
+ callback_on_step_end_tensor_inputs,
583
+ guidance_scale_2,
584
+ )
585
+
586
+ if num_frames % self.vae_scale_factor_temporal != 1:
587
+ logger.warning(
588
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
589
+ )
590
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
591
+ num_frames = max(num_frames, 1)
592
+
593
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
594
+ guidance_scale_2 = guidance_scale
595
+
596
+ self._guidance_scale = guidance_scale
597
+ self._guidance_scale_2 = guidance_scale_2
598
+ self._attention_kwargs = attention_kwargs
599
+ self._current_timestep = None
600
+ self._interrupt = False
601
+
602
+ device = self._execution_device
603
+
604
+ # 2. Define call parameters
605
+ if prompt is not None and isinstance(prompt, str):
606
+ batch_size = 1
607
+ elif prompt is not None and isinstance(prompt, list):
608
+ batch_size = len(prompt)
609
+ else:
610
+ batch_size = prompt_embeds.shape[0]
611
+
612
+ # 3. Encode input prompt
613
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
614
+ prompt=prompt,
615
+ negative_prompt=negative_prompt,
616
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
617
+ num_videos_per_prompt=num_videos_per_prompt,
618
+ prompt_embeds=prompt_embeds,
619
+ negative_prompt_embeds=negative_prompt_embeds,
620
+ max_sequence_length=max_sequence_length,
621
+ device=device,
622
+ )
623
+
624
+ transformer_dtype = self.transformer.dtype
625
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
626
+ if negative_prompt_embeds is not None:
627
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
628
+
629
+ # 4. Prepare timesteps
630
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
631
+ timesteps = self.scheduler.timesteps
632
+
633
+ # 5. Prepare latent variables
634
+ num_channels_latents = self.transformer.config.in_channels
635
+ latents = self.prepare_latents(
636
+ batch_size * num_videos_per_prompt,
637
+ num_channels_latents,
638
+ height,
639
+ width,
640
+ num_frames,
641
+ torch.float32,
642
+ device,
643
+ generator,
644
+ latents,
645
+ )
646
+
647
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
648
+
649
+ # 6. Encode controlnet frames
650
+ if (controlnet_latents is None) and (controlnet_frames is not None):
651
+ duplicate_frames_count = num_frames - len(controlnet_frames)
652
+ print(f'Using controlnet frames: {len(controlnet_frames)}. Extended frames count: {duplicate_frames_count}')
653
+ if duplicate_frames_count > 0:
654
+ # Simple duplicate first frame
655
+ # controlnet_frames = [controlnet_frames[0]] * duplicate_frames_count + controlnet_frames
656
+ # Or reversed duplicate frames ?
657
+ reversed_controlnet_frames = list(reversed(controlnet_frames))
658
+ controlnet_sum_frames = controlnet_frames + reversed_controlnet_frames
659
+ reversed_chunks_count = num_frames // len(controlnet_sum_frames)
660
+ controlnet_frames = [*controlnet_sum_frames]
661
+ for _ in range(reversed_chunks_count):
662
+ controlnet_frames += controlnet_sum_frames
663
+
664
+ # If controlnet frames count greater than num_frames parameter
665
+ controlnet_frames = controlnet_frames[:num_frames]
666
+
667
+ controlnet_latents = prepare_controlnet_frames(
668
+ controlnet_frames,
669
+ height,
670
+ width,
671
+ dtype=self.controlnet.dtype,
672
+ device=self.controlnet.device
673
+ )
674
+
675
+ # 7. Denoising loop
676
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
677
+ self._num_timesteps = len(timesteps)
678
+
679
+ if self.config.boundary_ratio is not None:
680
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
681
+ else:
682
+ boundary_timestep = None
683
+
684
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
685
+ for i, t in enumerate(timesteps):
686
+ if self.interrupt:
687
+ continue
688
+
689
+ self._current_timestep = t
690
+
691
+ if boundary_timestep is None or t >= boundary_timestep:
692
+ # wan2.1 or high-noise stage in wan2.2
693
+ current_model = self.transformer
694
+ current_guidance_scale = guidance_scale
695
+ else:
696
+ # low-noise stage in wan2.2
697
+ current_model = self.transformer_2
698
+ current_guidance_scale = guidance_scale_2
699
+
700
+ latent_model_input = latents.to(transformer_dtype)
701
+ if self.config.expand_timesteps:
702
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
703
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
704
+ # batch_size, seq_len
705
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
706
+ else:
707
+ timestep = t.expand(latents.shape[0])
708
+
709
+ controlnet_states = None
710
+ current_sampling_percent = i / len(timesteps)
711
+ if (controlnet_latents is not None) and (controlnet_guidance_start <= current_sampling_percent < controlnet_guidance_end):
712
+ controlnet_states = self.controlnet(
713
+ hidden_states=latent_model_input,
714
+ timestep=timestep,
715
+ encoder_hidden_states=prompt_embeds,
716
+ attention_kwargs=attention_kwargs,
717
+ controlnet_states=controlnet_latents,
718
+ return_dict=False,
719
+ )[0]
720
+ if isinstance(controlnet_states, (tuple, list)):
721
+ controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states]
722
+ else:
723
+ controlnet_states = controlnet_states.to(dtype=self.transformer.dtype)
724
+
725
+ with current_model.cache_context("cond"):
726
+ noise_pred = current_model(
727
+ hidden_states=latent_model_input,
728
+ timestep=timestep,
729
+ encoder_hidden_states=prompt_embeds,
730
+ controlnet_states=controlnet_states,
731
+ controlnet_weight=controlnet_weight,
732
+ controlnet_stride=controlnet_stride,
733
+ teacache=self.teacache,
734
+ attention_kwargs=attention_kwargs,
735
+ return_dict=False,
736
+ )[0]
737
+
738
+ if self.do_classifier_free_guidance:
739
+ with current_model.cache_context("uncond"):
740
+ noise_uncond = current_model(
741
+ hidden_states=latent_model_input,
742
+ timestep=timestep,
743
+ encoder_hidden_states=negative_prompt_embeds,
744
+ controlnet_states=controlnet_states,
745
+ controlnet_weight=controlnet_weight,
746
+ controlnet_stride=controlnet_stride,
747
+ teacache=self.teacache,
748
+ attention_kwargs=attention_kwargs,
749
+ return_dict=False,
750
+ )[0]
751
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
752
+
753
+ # compute the previous noisy sample x_t -> x_t-1
754
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
755
+
756
+ if callback_on_step_end is not None:
757
+ callback_kwargs = {}
758
+ for k in callback_on_step_end_tensor_inputs:
759
+ callback_kwargs[k] = locals()[k]
760
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
761
+
762
+ latents = callback_outputs.pop("latents", latents)
763
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
764
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
765
+
766
+ # call the callback, if provided
767
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
768
+ progress_bar.update()
769
+
770
+ if XLA_AVAILABLE:
771
+ xm.mark_step()
772
+
773
+ self._current_timestep = None
774
+ self.teacache = None
775
+
776
+ if not output_type == "latent":
777
+ latents = latents.to(self.vae.dtype)
778
+ latents_mean = (
779
+ torch.tensor(self.vae.config.latents_mean)
780
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
781
+ .to(latents.device, latents.dtype)
782
+ )
783
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
784
+ latents.device, latents.dtype
785
+ )
786
+ latents = latents / latents_std + latents_mean
787
+ video = self.vae.decode(latents, return_dict=False)[0]
788
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
789
+ else:
790
+ video = latents
791
+
792
+ # Offload all models
793
+ self.maybe_free_model_hooks()
794
+
795
+ if not return_dict:
796
+ return (video,)
797
+
798
+ return WanPipelineOutput(frames=video)
wan_teacache.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ coefficients = {
6
+ "DEFAULT": [-1.12343328e+02, 1.50680483e+02, -5.15023303e+01, 6.24892431e+00, 6.85022158e-02],
7
+ }
8
+
9
+
10
+ class TeaCache:
11
+ def __init__(self, num_inference_steps, model_name, treshold=0.3, start_step_treshold=0.1, end_step_treshold=0.9):
12
+ self.input_bank = []
13
+ self.current_step = 0
14
+ self.accumulated_distance = 0.0
15
+ self.num_inference_steps = num_inference_steps * 2
16
+ self.start_step_teacache = int(num_inference_steps * start_step_treshold) * 2
17
+ self.end_step_teacache = int(num_inference_steps * end_step_treshold) * 2
18
+ self.treshold = treshold # [0.3, 0.5, 0.7, 0.9]
19
+ self.coefficients = coefficients[model_name]
20
+ self.step_name = "even"
21
+ self.init_memory()
22
+
23
+ def init_memory(self):
24
+ self.accumulated_distance = {
25
+ "even": 0.0,
26
+ "odd": 0.0,
27
+ }
28
+ self.flow_direction = {
29
+ "even": None,
30
+ "odd": None,
31
+ }
32
+ self.previous_modulated_input = {
33
+ "even": None,
34
+ "odd": None,
35
+ }
36
+ # print("TEACACHE MEMORY HAS BEEN CREATED")
37
+
38
+ def check_for_using_cached_value(self, modulated_input):
39
+ use_tea_cache = (self.treshold > 0.0) and (self.start_step_teacache <= self.current_step < self.end_step_teacache)
40
+ self.step_name = "even" if self.current_step % 2 == 0 else "odd"
41
+
42
+ use_cached_value = False
43
+ if use_tea_cache:
44
+ rescale_func = np.poly1d(self.coefficients)
45
+ current_disntace = rescale_func(
46
+ self.calculate_distance(modulated_input, self.previous_modulated_input[self.step_name])
47
+ )
48
+ self.accumulated_distance[self.step_name] += current_disntace
49
+
50
+ if self.accumulated_distance[self.step_name] < self.treshold:
51
+ use_cached_value = True
52
+ else:
53
+ use_cached_value = False
54
+ self.accumulated_distance[self.step_name] = 0.0
55
+
56
+ if self.step_name == "even":
57
+ self.input_bank.append(modulated_input.cpu())
58
+
59
+ self.previous_modulated_input[self.step_name] = modulated_input.clone()
60
+ # if use_tea_cache:
61
+ # print(f"[ STEP:{self.current_step} | USE CACHED VALUE: {use_cached_value} | ACCUMULATED DISTANCE: {self.accumulated_distance} | CURRENT DISTANCE: {current_disntace} ]")
62
+ return use_cached_value
63
+
64
+ def use_cache(self, hidden_states):
65
+ return hidden_states + self.flow_direction[self.step_name].to(device=hidden_states.device)
66
+
67
+ def calculate_distance(self, previous_tensor, current_tensor):
68
+ relative_l1_distance = torch.abs(
69
+ previous_tensor - current_tensor
70
+ ).mean() / torch.abs(previous_tensor).mean()
71
+ return relative_l1_distance.to(torch.float32).cpu().item()
72
+
73
+ def update(self, flow_direction):
74
+ self.flow_direction[self.step_name] = flow_direction
75
+ self.current_step += 1
76
+ if self.current_step == self.num_inference_steps:
77
+ self.current_step = 0
78
+ self.init_memory()
wan_transformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import WanTransformer3DModel
5
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
6
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
7
+ from wan_teacache import TeaCache
8
+
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+
13
+ class CustomWanTransformer3DModel(WanTransformer3DModel):
14
+ def forward(
15
+ self,
16
+ hidden_states: torch.Tensor,
17
+ timestep: torch.LongTensor,
18
+ encoder_hidden_states: torch.Tensor,
19
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
20
+ return_dict: bool = True,
21
+ attention_kwargs: Optional[Dict[str, Any]] = None,
22
+
23
+ controlnet_states: torch.Tensor = None,
24
+ controlnet_weight: Optional[float] = 1.0,
25
+ controlnet_stride: Optional[int] = 1,
26
+ teacache: Optional[TeaCache] = None,
27
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
28
+ if attention_kwargs is not None:
29
+ attention_kwargs = attention_kwargs.copy()
30
+ lora_scale = attention_kwargs.pop("scale", 1.0)
31
+ else:
32
+ lora_scale = 1.0
33
+
34
+ if USE_PEFT_BACKEND:
35
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
36
+ scale_lora_layers(self, lora_scale)
37
+ else:
38
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
39
+ logger.warning(
40
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
41
+ )
42
+
43
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
44
+ p_t, p_h, p_w = self.config.patch_size
45
+ post_patch_num_frames = num_frames // p_t
46
+ post_patch_height = height // p_h
47
+ post_patch_width = width // p_w
48
+
49
+ rotary_emb = self.rope(hidden_states)
50
+
51
+ hidden_states = self.patch_embedding(hidden_states)
52
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
53
+
54
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
55
+ if timestep.ndim == 2:
56
+ ts_seq_len = timestep.shape[1]
57
+ timestep = timestep.flatten() # batch_size * seq_len
58
+ else:
59
+ ts_seq_len = None
60
+
61
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
62
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
63
+ )
64
+ if ts_seq_len is not None:
65
+ # batch_size, seq_len, 6, inner_dim
66
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
67
+ else:
68
+ # batch_size, 6, inner_dim
69
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
70
+
71
+ if encoder_hidden_states_image is not None:
72
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
73
+
74
+ use_cached_value = False
75
+ original_hidden_states = None
76
+ if (teacache is not None) and (teacache.treshold > 0.0):
77
+ original_hidden_states = hidden_states.clone()
78
+ use_cached_value = teacache.check_for_using_cached_value(temb)
79
+
80
+ if use_cached_value:
81
+ hidden_states = teacache.use_cache(hidden_states)
82
+ else:
83
+ # 4. Transformer blocks
84
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
85
+ for i, block in enumerate(self.blocks):
86
+ hidden_states = self._gradient_checkpointing_func(
87
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
88
+ )
89
+
90
+ if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
91
+ hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
92
+ else:
93
+ for i, block in enumerate(self.blocks):
94
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
95
+
96
+ if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
97
+ hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
98
+
99
+ if (teacache is not None) and (teacache.treshold > 0.0):
100
+ teacache.update(hidden_states - original_hidden_states)
101
+
102
+ # 5. Output norm, projection & unpatchify
103
+ if temb.ndim == 3:
104
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
105
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
106
+ shift = shift.squeeze(2)
107
+ scale = scale.squeeze(2)
108
+ else:
109
+ # batch_size, inner_dim
110
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
111
+
112
+ # Move the shift and scale tensors to the same device as hidden_states.
113
+ # When using multi-GPU inference via accelerate these will be on the
114
+ # first device rather than the last device, which hidden_states ends up
115
+ # on.
116
+ shift = shift.to(hidden_states.device)
117
+ scale = scale.to(hidden_states.device)
118
+
119
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
120
+ hidden_states = self.proj_out(hidden_states)
121
+
122
+ hidden_states = hidden_states.reshape(
123
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
124
+ )
125
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
126
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
127
+
128
+ if USE_PEFT_BACKEND:
129
+ # remove `lora_scale` from each PEFT layer
130
+ unscale_lora_layers(self, lora_scale)
131
+
132
+ if not return_dict:
133
+ return (output,)
134
+
135
+ return Transformer2DModelOutput(sample=output)
workflows/sam2.1_optimized.json ADDED
The diff for this file is too large to render. See raw diff
 
workflows/sam_optimized.json ADDED
The diff for this file is too large to render. See raw diff
 
workflows/vace_optimized.json ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "id": "960108a5-bf9d-497f-a6e5-4c5c3e41c056",
3
+ "revision": 0,
4
+ "last_node_id": 37,
5
+ "last_link_id": 93,
6
+ "nodes": [
7
+ {
8
+ "id": 11,
9
+ "type": "ModelSamplingSD3",
10
+ "pos": [
11
+ 442.7779541015625,
12
+ 942.9921264648438
13
+ ],
14
+ "size": [
15
+ 210,
16
+ 58
17
+ ],
18
+ "flags": {
19
+ "collapsed": false
20
+ },
21
+ "order": 9,
22
+ "mode": 0,
23
+ "inputs": [
24
+ {
25
+ "name": "model",
26
+ "type": "MODEL",
27
+ "link": 91
28
+ }
29
+ ],
30
+ "outputs": [
31
+ {
32
+ "name": "MODEL",
33
+ "type": "MODEL",
34
+ "links": [
35
+ 58
36
+ ]
37
+ }
38
+ ],
39
+ "properties": {
40
+ "Node name for S&R": "ModelSamplingSD3"
41
+ },
42
+ "widgets_values": [
43
+ 2.0000000000000004
44
+ ]
45
+ },
46
+ {
47
+ "id": 32,
48
+ "type": "VHS_LoadVideo",
49
+ "pos": [
50
+ 120.05851745605469,
51
+ 397.98248291015625
52
+ ],
53
+ "size": [
54
+ 253.279296875,
55
+ 310
56
+ ],
57
+ "flags": {},
58
+ "order": 6,
59
+ "mode": 0,
60
+ "inputs": [
61
+ {
62
+ "name": "meta_batch",
63
+ "shape": 7,
64
+ "type": "VHS_BatchManager",
65
+ "link": null
66
+ },
67
+ {
68
+ "name": "vae",
69
+ "shape": 7,
70
+ "type": "VAE",
71
+ "link": null
72
+ },
73
+ {
74
+ "name": "frame_load_cap",
75
+ "type": "INT",
76
+ "widget": {
77
+ "name": "frame_load_cap"
78
+ },
79
+ "link": 76
80
+ }
81
+ ],
82
+ "outputs": [
83
+ {
84
+ "name": "IMAGE",
85
+ "type": "IMAGE",
86
+ "links": [
87
+ 86
88
+ ]
89
+ },
90
+ {
91
+ "name": "frame_count",
92
+ "type": "INT",
93
+ "links": [
94
+ 78
95
+ ]
96
+ },
97
+ {
98
+ "name": "audio",
99
+ "type": "AUDIO",
100
+ "links": null
101
+ },
102
+ {
103
+ "name": "video_info",
104
+ "type": "VHS_VIDEOINFO",
105
+ "links": null
106
+ }
107
+ ],
108
+ "title": "上传遮罩合成视频",
109
+ "properties": {
110
+ "Node name for S&R": "VHS_LoadVideo"
111
+ },
112
+ "widgets_values": {
113
+ "video": "sam2.1_00182.mp4",
114
+ "force_rate": 16,
115
+ "custom_width": 0,
116
+ "custom_height": 0,
117
+ "frame_load_cap": 0,
118
+ "skip_first_frames": 0,
119
+ "select_every_nth": 1,
120
+ "format": "Wan",
121
+ "choose video to upload": "image",
122
+ "videopreview": {
123
+ "hidden": false,
124
+ "paused": false,
125
+ "params": {
126
+ "filename": "sam2.1_00182.mp4",
127
+ "type": "input",
128
+ "format": "video/mp4",
129
+ "force_rate": 16,
130
+ "custom_width": 0,
131
+ "custom_height": 0,
132
+ "frame_load_cap": 0,
133
+ "skip_first_frames": 0,
134
+ "select_every_nth": 1
135
+ }
136
+ }
137
+ }
138
+ },
139
+ {
140
+ "id": 33,
141
+ "type": "VHS_LoadVideo",
142
+ "pos": [
143
+ 112.58995056152344,
144
+ 753.9783325195312
145
+ ],
146
+ "size": [
147
+ 253.279296875,
148
+ 310
149
+ ],
150
+ "flags": {},
151
+ "order": 0,
152
+ "mode": 0,
153
+ "inputs": [
154
+ {
155
+ "name": "meta_batch",
156
+ "shape": 7,
157
+ "type": "VHS_BatchManager",
158
+ "link": null
159
+ },
160
+ {
161
+ "name": "vae",
162
+ "shape": 7,
163
+ "type": "VAE",
164
+ "link": null
165
+ }
166
+ ],
167
+ "outputs": [
168
+ {
169
+ "name": "IMAGE",
170
+ "type": "IMAGE",
171
+ "links": [
172
+ 85
173
+ ]
174
+ },
175
+ {
176
+ "name": "frame_count",
177
+ "type": "INT",
178
+ "links": [
179
+ 76
180
+ ]
181
+ },
182
+ {
183
+ "name": "audio",
184
+ "type": "AUDIO",
185
+ "links": null
186
+ },
187
+ {
188
+ "name": "video_info",
189
+ "type": "VHS_VIDEOINFO",
190
+ "links": null
191
+ }
192
+ ],
193
+ "title": "上传遮罩视频(黑白那个)",
194
+ "properties": {
195
+ "Node name for S&R": "VHS_LoadVideo"
196
+ },
197
+ "widgets_values": {
198
+ "video": "sam2.1_00181.mp4",
199
+ "force_rate": 0,
200
+ "custom_width": 0,
201
+ "custom_height": 0,
202
+ "frame_load_cap": 0,
203
+ "skip_first_frames": 0,
204
+ "select_every_nth": 1,
205
+ "format": "Wan",
206
+ "choose video to upload": "image",
207
+ "videopreview": {
208
+ "hidden": false,
209
+ "paused": false,
210
+ "params": {
211
+ "filename": "sam2.1_00181.mp4",
212
+ "type": "input",
213
+ "format": "video/mp4",
214
+ "force_rate": 0,
215
+ "custom_width": 0,
216
+ "custom_height": 0,
217
+ "frame_load_cap": 0,
218
+ "skip_first_frames": 0,
219
+ "select_every_nth": 1
220
+ }
221
+ }
222
+ }
223
+ },
224
+ {
225
+ "id": 35,
226
+ "type": "GrowMask",
227
+ "pos": [
228
+ 722.2931518554688,
229
+ 1093.416015625
230
+ ],
231
+ "size": [
232
+ 270,
233
+ 82
234
+ ],
235
+ "flags": {},
236
+ "order": 10,
237
+ "mode": 0,
238
+ "inputs": [
239
+ {
240
+ "name": "mask",
241
+ "type": "MASK",
242
+ "link": 79
243
+ }
244
+ ],
245
+ "outputs": [
246
+ {
247
+ "name": "MASK",
248
+ "type": "MASK",
249
+ "links": [
250
+ 80
251
+ ]
252
+ }
253
+ ],
254
+ "properties": {
255
+ "Node name for S&R": "GrowMask"
256
+ },
257
+ "widgets_values": [
258
+ 5,
259
+ true
260
+ ]
261
+ },
262
+ {
263
+ "id": 6,
264
+ "type": "CLIPLoader",
265
+ "pos": [
266
+ 111.71733093261719,
267
+ 1112.0469970703125
268
+ ],
269
+ "size": [
270
+ 210,
271
+ 106
272
+ ],
273
+ "flags": {},
274
+ "order": 1,
275
+ "mode": 0,
276
+ "inputs": [],
277
+ "outputs": [
278
+ {
279
+ "name": "CLIP",
280
+ "type": "CLIP",
281
+ "slot_index": 0,
282
+ "links": [
283
+ 92,
284
+ 93
285
+ ]
286
+ }
287
+ ],
288
+ "properties": {
289
+ "Node name for S&R": "CLIPLoader"
290
+ },
291
+ "widgets_values": [
292
+ "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
293
+ "wan",
294
+ "cpu"
295
+ ]
296
+ },
297
+ {
298
+ "id": 8,
299
+ "type": "UNETLoader",
300
+ "pos": [
301
+ 153.8439178466797,
302
+ 269.8687438964844
303
+ ],
304
+ "size": [
305
+ 210,
306
+ 82
307
+ ],
308
+ "flags": {},
309
+ "order": 2,
310
+ "mode": 0,
311
+ "inputs": [],
312
+ "outputs": [
313
+ {
314
+ "name": "MODEL",
315
+ "type": "MODEL",
316
+ "slot_index": 0,
317
+ "links": [
318
+ 91
319
+ ]
320
+ }
321
+ ],
322
+ "properties": {
323
+ "Node name for S&R": "UNETLoader"
324
+ },
325
+ "widgets_values": [
326
+ "Wan2.1_T2V_14B_FusionX_VACE-FP8_e4m3fn.safetensors",
327
+ "default"
328
+ ]
329
+ },
330
+ {
331
+ "id": 15,
332
+ "type": "CLIPTextEncode",
333
+ "pos": [
334
+ 451.8983154296875,
335
+ 311.4078674316406
336
+ ],
337
+ "size": [
338
+ 494.83367919921875,
339
+ 119.57742309570312
340
+ ],
341
+ "flags": {},
342
+ "order": 8,
343
+ "mode": 0,
344
+ "inputs": [
345
+ {
346
+ "name": "clip",
347
+ "type": "CLIP",
348
+ "link": 93
349
+ }
350
+ ],
351
+ "outputs": [
352
+ {
353
+ "name": "CONDITIONING",
354
+ "type": "CONDITIONING",
355
+ "slot_index": 0,
356
+ "links": [
357
+ 18
358
+ ]
359
+ }
360
+ ],
361
+ "title": "CLIP Text Encode (Positive Prompt)",
362
+ "properties": {
363
+ "Node name for S&R": "CLIPTextEncode"
364
+ },
365
+ "widgets_values": [
366
+ "赤身裸体的亚洲女子,主观视角,她抖动着胸部,乳房剧烈晃动。bustygrid. a completely naked asian woman with pale skin and huge breasts. she has straight dark hair. she is completely naked, and wearing pointed stilleto heels. bare legs, bare calf, sky-high heeled pumps. remove all clothes. "
367
+ ],
368
+ "color": "#232",
369
+ "bgcolor": "#353"
370
+ },
371
+ {
372
+ "id": 2,
373
+ "type": "CLIPTextEncode",
374
+ "pos": [
375
+ 453.97589111328125,
376
+ 487.16363525390625
377
+ ],
378
+ "size": [
379
+ 486.9105529785156,
380
+ 107.89899444580078
381
+ ],
382
+ "flags": {
383
+ "collapsed": false
384
+ },
385
+ "order": 7,
386
+ "mode": 0,
387
+ "inputs": [
388
+ {
389
+ "name": "clip",
390
+ "type": "CLIP",
391
+ "link": 92
392
+ }
393
+ ],
394
+ "outputs": [
395
+ {
396
+ "name": "CONDITIONING",
397
+ "type": "CONDITIONING",
398
+ "slot_index": 0,
399
+ "links": [
400
+ 19
401
+ ]
402
+ }
403
+ ],
404
+ "title": "CLIP Text Encode (Negative Prompt)",
405
+ "properties": {
406
+ "Node name for S&R": "CLIPTextEncode"
407
+ },
408
+ "widgets_values": [
409
+ "白种人,黑种人,阴部遮挡,内裤,六根手指,低像素,模糊,像素点,多余的手臂,肢体扭曲,手指模糊,脸部改变,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
410
+ ],
411
+ "color": "#322",
412
+ "bgcolor": "#533"
413
+ },
414
+ {
415
+ "id": 3,
416
+ "type": "VAELoader",
417
+ "pos": [
418
+ 433.6892395019531,
419
+ 643.1557006835938
420
+ ],
421
+ "size": [
422
+ 210,
423
+ 58
424
+ ],
425
+ "flags": {
426
+ "collapsed": false
427
+ },
428
+ "order": 3,
429
+ "mode": 0,
430
+ "inputs": [],
431
+ "outputs": [
432
+ {
433
+ "name": "VAE",
434
+ "type": "VAE",
435
+ "links": [
436
+ 16,
437
+ 20
438
+ ]
439
+ }
440
+ ],
441
+ "properties": {
442
+ "Node name for S&R": "VAELoader"
443
+ },
444
+ "widgets_values": [
445
+ "Wan2.1_VAE.safetensors"
446
+ ]
447
+ },
448
+ {
449
+ "id": 17,
450
+ "type": "WanVaceToVideo",
451
+ "pos": [
452
+ 706.262939453125,
453
+ 658.4074096679688
454
+ ],
455
+ "size": [
456
+ 224.32986450195312,
457
+ 254
458
+ ],
459
+ "flags": {},
460
+ "order": 11,
461
+ "mode": 0,
462
+ "inputs": [
463
+ {
464
+ "name": "positive",
465
+ "type": "CONDITIONING",
466
+ "link": 18
467
+ },
468
+ {
469
+ "name": "negative",
470
+ "type": "CONDITIONING",
471
+ "link": 19
472
+ },
473
+ {
474
+ "name": "vae",
475
+ "type": "VAE",
476
+ "link": 20
477
+ },
478
+ {
479
+ "name": "control_video",
480
+ "shape": 7,
481
+ "type": "IMAGE",
482
+ "link": 86
483
+ },
484
+ {
485
+ "name": "control_masks",
486
+ "shape": 7,
487
+ "type": "MASK",
488
+ "link": 80
489
+ },
490
+ {
491
+ "name": "reference_image",
492
+ "shape": 7,
493
+ "type": "IMAGE",
494
+ "link": 22
495
+ },
496
+ {
497
+ "name": "length",
498
+ "type": "INT",
499
+ "widget": {
500
+ "name": "length"
501
+ },
502
+ "link": 78
503
+ }
504
+ ],
505
+ "outputs": [
506
+ {
507
+ "name": "positive",
508
+ "type": "CONDITIONING",
509
+ "links": [
510
+ 12
511
+ ]
512
+ },
513
+ {
514
+ "name": "negative",
515
+ "type": "CONDITIONING",
516
+ "links": [
517
+ 13
518
+ ]
519
+ },
520
+ {
521
+ "name": "latent",
522
+ "type": "LATENT",
523
+ "links": [
524
+ 14
525
+ ]
526
+ },
527
+ {
528
+ "name": "trim_latent",
529
+ "type": "INT",
530
+ "links": [
531
+ 10
532
+ ]
533
+ }
534
+ ],
535
+ "properties": {
536
+ "Node name for S&R": "WanVaceToVideo"
537
+ },
538
+ "widgets_values": [
539
+ 480,
540
+ 320,
541
+ 49,
542
+ 1,
543
+ 1.0000000000000002
544
+ ]
545
+ },
546
+ {
547
+ "id": 12,
548
+ "type": "TrimVideoLatent",
549
+ "pos": [
550
+ 746.625,
551
+ 985.3895874023438
552
+ ],
553
+ "size": [
554
+ 226.2460174560547,
555
+ 58
556
+ ],
557
+ "flags": {
558
+ "collapsed": false
559
+ },
560
+ "order": 13,
561
+ "mode": 0,
562
+ "inputs": [
563
+ {
564
+ "name": "samples",
565
+ "type": "LATENT",
566
+ "link": 9
567
+ },
568
+ {
569
+ "name": "trim_amount",
570
+ "type": "INT",
571
+ "widget": {
572
+ "name": "trim_amount"
573
+ },
574
+ "link": 10
575
+ }
576
+ ],
577
+ "outputs": [
578
+ {
579
+ "name": "LATENT",
580
+ "type": "LATENT",
581
+ "links": [
582
+ 15
583
+ ]
584
+ }
585
+ ],
586
+ "properties": {
587
+ "Node name for S&R": "TrimVideoLatent"
588
+ },
589
+ "widgets_values": [
590
+ 0
591
+ ]
592
+ },
593
+ {
594
+ "id": 13,
595
+ "type": "KSampler",
596
+ "pos": [
597
+ 985.894775390625,
598
+ 349.17340087890625
599
+ ],
600
+ "size": [
601
+ 210,
602
+ 605.3333129882812
603
+ ],
604
+ "flags": {},
605
+ "order": 12,
606
+ "mode": 0,
607
+ "inputs": [
608
+ {
609
+ "name": "model",
610
+ "type": "MODEL",
611
+ "link": 58
612
+ },
613
+ {
614
+ "name": "positive",
615
+ "type": "CONDITIONING",
616
+ "link": 12
617
+ },
618
+ {
619
+ "name": "negative",
620
+ "type": "CONDITIONING",
621
+ "link": 13
622
+ },
623
+ {
624
+ "name": "latent_image",
625
+ "type": "LATENT",
626
+ "link": 14
627
+ }
628
+ ],
629
+ "outputs": [
630
+ {
631
+ "name": "LATENT",
632
+ "type": "LATENT",
633
+ "slot_index": 0,
634
+ "links": [
635
+ 9
636
+ ]
637
+ }
638
+ ],
639
+ "properties": {
640
+ "Node name for S&R": "KSampler"
641
+ },
642
+ "widgets_values": [
643
+ 109768395777514,
644
+ "randomize",
645
+ 10,
646
+ 1,
647
+ "uni_pc_bh2",
648
+ "simple",
649
+ 1
650
+ ]
651
+ },
652
+ {
653
+ "id": 14,
654
+ "type": "VAEDecode",
655
+ "pos": [
656
+ 973.5802612304688,
657
+ 1001.729736328125
658
+ ],
659
+ "size": [
660
+ 208.16270446777344,
661
+ 46
662
+ ],
663
+ "flags": {
664
+ "collapsed": false
665
+ },
666
+ "order": 14,
667
+ "mode": 0,
668
+ "inputs": [
669
+ {
670
+ "name": "samples",
671
+ "type": "LATENT",
672
+ "link": 15
673
+ },
674
+ {
675
+ "name": "vae",
676
+ "type": "VAE",
677
+ "link": 16
678
+ }
679
+ ],
680
+ "outputs": [
681
+ {
682
+ "name": "IMAGE",
683
+ "type": "IMAGE",
684
+ "slot_index": 0,
685
+ "links": [
686
+ 3
687
+ ]
688
+ }
689
+ ],
690
+ "properties": {
691
+ "Node name for S&R": "VAEDecode"
692
+ },
693
+ "widgets_values": []
694
+ },
695
+ {
696
+ "id": 4,
697
+ "type": "VHS_VideoCombine",
698
+ "pos": [
699
+ 1219.9688720703125,
700
+ 358.5111389160156
701
+ ],
702
+ "size": [
703
+ 239.620361328125,
704
+ 310
705
+ ],
706
+ "flags": {},
707
+ "order": 15,
708
+ "mode": 0,
709
+ "inputs": [
710
+ {
711
+ "name": "images",
712
+ "type": "IMAGE",
713
+ "link": 3
714
+ },
715
+ {
716
+ "name": "audio",
717
+ "shape": 7,
718
+ "type": "AUDIO",
719
+ "link": null
720
+ },
721
+ {
722
+ "name": "meta_batch",
723
+ "shape": 7,
724
+ "type": "VHS_BatchManager",
725
+ "link": null
726
+ },
727
+ {
728
+ "name": "vae",
729
+ "shape": 7,
730
+ "type": "VAE",
731
+ "link": null
732
+ }
733
+ ],
734
+ "outputs": [
735
+ {
736
+ "name": "Filenames",
737
+ "type": "VHS_FILENAMES",
738
+ "links": null
739
+ }
740
+ ],
741
+ "properties": {
742
+ "Node name for S&R": "VHS_VideoCombine"
743
+ },
744
+ "widgets_values": {
745
+ "frame_rate": 16,
746
+ "loop_count": 0,
747
+ "filename_prefix": "wan2.1",
748
+ "format": "video/h265-mp4",
749
+ "pix_fmt": "yuv420p10le",
750
+ "crf": 5,
751
+ "save_metadata": true,
752
+ "pingpong": false,
753
+ "save_output": true,
754
+ "videopreview": {
755
+ "hidden": false,
756
+ "paused": false,
757
+ "params": {
758
+ "filename": "wan2.1_00518.mp4",
759
+ "subfolder": "",
760
+ "type": "output",
761
+ "format": "video/h265-mp4",
762
+ "frame_rate": 16,
763
+ "workflow": "wan2.1_00518.png",
764
+ "fullpath": "E:\\comfyui3\\ComfyUI\\output\\wan2.1_00518.mp4"
765
+ }
766
+ }
767
+ }
768
+ },
769
+ {
770
+ "id": 25,
771
+ "type": "ImageToMask",
772
+ "pos": [
773
+ 403.78155517578125,
774
+ 1100.6531982421875
775
+ ],
776
+ "size": [
777
+ 270,
778
+ 58
779
+ ],
780
+ "flags": {},
781
+ "order": 5,
782
+ "mode": 0,
783
+ "inputs": [
784
+ {
785
+ "name": "image",
786
+ "type": "IMAGE",
787
+ "link": 85
788
+ }
789
+ ],
790
+ "outputs": [
791
+ {
792
+ "name": "MASK",
793
+ "type": "MASK",
794
+ "links": [
795
+ 79
796
+ ]
797
+ }
798
+ ],
799
+ "properties": {
800
+ "Node name for S&R": "ImageToMask"
801
+ },
802
+ "widgets_values": [
803
+ "red"
804
+ ]
805
+ },
806
+ {
807
+ "id": 5,
808
+ "type": "LoadImage",
809
+ "pos": [
810
+ -272.46954345703125,
811
+ 357.37689208984375
812
+ ],
813
+ "size": [
814
+ 335.15673828125,
815
+ 709.6021728515625
816
+ ],
817
+ "flags": {},
818
+ "order": 4,
819
+ "mode": 0,
820
+ "inputs": [],
821
+ "outputs": [
822
+ {
823
+ "name": "IMAGE",
824
+ "type": "IMAGE",
825
+ "links": [
826
+ 22
827
+ ]
828
+ },
829
+ {
830
+ "name": "MASK",
831
+ "type": "MASK",
832
+ "links": null
833
+ }
834
+ ],
835
+ "properties": {
836
+ "Node name for S&R": "LoadImage"
837
+ },
838
+ "widgets_values": [
839
+ "ComfUI_287879_.png",
840
+ "image"
841
+ ]
842
+ }
843
+ ],
844
+ "links": [
845
+ [
846
+ 3,
847
+ 14,
848
+ 0,
849
+ 4,
850
+ 0,
851
+ "IMAGE"
852
+ ],
853
+ [
854
+ 9,
855
+ 13,
856
+ 0,
857
+ 12,
858
+ 0,
859
+ "LATENT"
860
+ ],
861
+ [
862
+ 10,
863
+ 17,
864
+ 3,
865
+ 12,
866
+ 1,
867
+ "INT"
868
+ ],
869
+ [
870
+ 12,
871
+ 17,
872
+ 0,
873
+ 13,
874
+ 1,
875
+ "CONDITIONING"
876
+ ],
877
+ [
878
+ 13,
879
+ 17,
880
+ 1,
881
+ 13,
882
+ 2,
883
+ "CONDITIONING"
884
+ ],
885
+ [
886
+ 14,
887
+ 17,
888
+ 2,
889
+ 13,
890
+ 3,
891
+ "LATENT"
892
+ ],
893
+ [
894
+ 15,
895
+ 12,
896
+ 0,
897
+ 14,
898
+ 0,
899
+ "LATENT"
900
+ ],
901
+ [
902
+ 16,
903
+ 3,
904
+ 0,
905
+ 14,
906
+ 1,
907
+ "VAE"
908
+ ],
909
+ [
910
+ 18,
911
+ 15,
912
+ 0,
913
+ 17,
914
+ 0,
915
+ "CONDITIONING"
916
+ ],
917
+ [
918
+ 19,
919
+ 2,
920
+ 0,
921
+ 17,
922
+ 1,
923
+ "CONDITIONING"
924
+ ],
925
+ [
926
+ 20,
927
+ 3,
928
+ 0,
929
+ 17,
930
+ 2,
931
+ "VAE"
932
+ ],
933
+ [
934
+ 22,
935
+ 5,
936
+ 0,
937
+ 17,
938
+ 5,
939
+ "IMAGE"
940
+ ],
941
+ [
942
+ 58,
943
+ 11,
944
+ 0,
945
+ 13,
946
+ 0,
947
+ "MODEL"
948
+ ],
949
+ [
950
+ 76,
951
+ 33,
952
+ 1,
953
+ 32,
954
+ 2,
955
+ "INT"
956
+ ],
957
+ [
958
+ 78,
959
+ 32,
960
+ 1,
961
+ 17,
962
+ 6,
963
+ "INT"
964
+ ],
965
+ [
966
+ 79,
967
+ 25,
968
+ 0,
969
+ 35,
970
+ 0,
971
+ "MASK"
972
+ ],
973
+ [
974
+ 80,
975
+ 35,
976
+ 0,
977
+ 17,
978
+ 4,
979
+ "MASK"
980
+ ],
981
+ [
982
+ 85,
983
+ 33,
984
+ 0,
985
+ 25,
986
+ 0,
987
+ "IMAGE"
988
+ ],
989
+ [
990
+ 86,
991
+ 32,
992
+ 0,
993
+ 17,
994
+ 3,
995
+ "IMAGE"
996
+ ],
997
+ [
998
+ 91,
999
+ 8,
1000
+ 0,
1001
+ 11,
1002
+ 0,
1003
+ "MODEL"
1004
+ ],
1005
+ [
1006
+ 92,
1007
+ 6,
1008
+ 0,
1009
+ 2,
1010
+ 0,
1011
+ "CLIP"
1012
+ ],
1013
+ [
1014
+ 93,
1015
+ 6,
1016
+ 0,
1017
+ 15,
1018
+ 0,
1019
+ "CLIP"
1020
+ ]
1021
+ ],
1022
+ "groups": [],
1023
+ "config": {},
1024
+ "extra": {
1025
+ "ds": {
1026
+ "scale": 1.0152559799477145,
1027
+ "offset": [
1028
+ 564.1931902142793,
1029
+ -170.45932466624348
1030
+ ]
1031
+ },
1032
+ "frontendVersion": "1.25.11",
1033
+ "node_versions": {
1034
+ "comfy-core": "0.3.56",
1035
+ "ComfyUI-VideoHelperSuite": "972c87da577b47211c4e9aeed30dc38c7bae607f"
1036
+ },
1037
+ "VHS_latentpreview": true,
1038
+ "VHS_latentpreviewrate": 0,
1039
+ "VHS_MetadataImage": true,
1040
+ "VHS_KeepIntermediate": true
1041
+ },
1042
+ "version": 0.4
1043
+ }