File size: 18,239 Bytes
c8a8fcf
bff8240
c8a8fcf
41e72bb
 
 
 
c73cdb0
c8a8fcf
c73cdb0
 
 
 
 
41e72bb
 
 
b4a27cf
05ad421
41e72bb
05ad421
 
 
 
2630bd2
41e72bb
2630bd2
 
 
 
05ad421
 
b4a27cf
44d04eb
2ff48aa
 
b4a27cf
 
 
 
 
 
c8a8fcf
41e72bb
c73cdb0
 
41e72bb
 
 
c73cdb0
41e72bb
c73cdb0
 
41e72bb
c73cdb0
41e72bb
c73cdb0
 
 
 
41e72bb
 
 
c73cdb0
 
41e72bb
 
c73cdb0
 
41e72bb
c73cdb0
41e72bb
c73cdb0
 
 
41e72bb
 
 
 
c73cdb0
 
 
41e72bb
c73cdb0
41e72bb
211837b
c73cdb0
41e72bb
c73cdb0
41e72bb
 
 
 
 
 
 
 
 
 
 
c73cdb0
41e72bb
 
 
c73cdb0
41e72bb
 
 
 
 
 
 
 
 
5a3e453
41e72bb
c73cdb0
41e72bb
 
 
 
 
c73cdb0
41e72bb
 
 
c73cdb0
41e72bb
 
 
 
 
 
c73cdb0
41e72bb
c73cdb0
 
 
 
 
 
 
 
 
41e72bb
 
 
7b7a87f
5a3e453
41e72bb
 
7b7a87f
 
35ce3f1
41e72bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c73cdb0
c8a8fcf
9e195fc
 
 
 
 
06ce900
9e195fc
 
c73cdb0
41e72bb
 
c73cdb0
9e195fc
41e72bb
 
 
c73cdb0
 
41e72bb
 
c73cdb0
941a8cc
c73cdb0
 
 
 
 
 
 
 
 
41e72bb
 
 
 
 
 
 
 
 
 
 
 
06ce900
 
41e72bb
 
 
 
 
9e195fc
 
41e72bb
 
9e195fc
 
 
 
 
941a8cc
41e72bb
 
c8a8fcf
9e195fc
 
06ce900
9e195fc
 
 
 
 
 
 
c73cdb0
 
41e72bb
 
 
 
 
9e195fc
 
 
 
41e72bb
 
 
9e195fc
 
41e72bb
9e195fc
 
 
 
41e72bb
 
9e195fc
c73cdb0
 
9e195fc
 
c8a8fcf
9e195fc
 
 
41e72bb
9e195fc
 
 
 
41e72bb
9e195fc
 
41e72bb
 
 
 
 
 
 
 
 
06ce900
 
9e195fc
32ba0e7
9e195fc
 
 
 
e23cd28
06ce900
9e195fc
 
406dc20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e195fc
 
406dc20
9e195fc
 
 
 
 
 
 
 
 
406dc20
44d04eb
406dc20
9e195fc
 
 
 
 
 
c8a8fcf
 
c73cdb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import os
import shutil
import sys
import subprocess
import asyncio
import uuid
import random
import tempfile
from typing import Sequence, Mapping, Any, Union

import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import spaces

# --- 1. Model Download and Setup ---

def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
    """Downloads a file from Hugging Face Hub and symlinks it to a local directory."""
    downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
    os.makedirs(local_dir, exist_ok=True)
    base_filename = os.path.basename(filename)
    target_path = os.path.join(local_dir, base_filename)
    
    # Remove existing symlink or file to avoid errors
    if os.path.exists(target_path) or os.path.islink(target_path):
        os.remove(target_path)
    
    os.symlink(downloaded_path, target_path)
    return target_path

print("Downloading models from Hugging Face Hub...")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
print("Downloads complete.")


# --- 2. ComfyUI Backend Initialization ---

def find_path(name: str, path: str = None) -> str:
    """Recursively finds a directory with a given name."""
    if path is None: path = os.getcwd()
    if name in os.listdir(path): return os.path.join(path, name)
    parent_directory = os.path.dirname(path)
    return find_path(name, parent_directory) if parent_directory != path else None

def add_comfyui_directory_to_sys_path() -> None:
    """Adds the ComfyUI directory to sys.path for imports."""
    comfyui_path = find_path("ComfyUI")
    if comfyui_path and os.path.isdir(comfyui_path):
        sys.path.append(comfyui_path)
        print(f"'{comfyui_path}' added to sys.path")

def add_extra_model_paths() -> None:
    """Initializes ComfyUI's folder_paths with custom paths."""
    from main import apply_custom_paths
    apply_custom_paths()

def import_custom_nodes() -> None:
    """Initializes all ComfyUI custom nodes."""
    import nodes
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True))

print("Setting up ComfyUI paths and nodes...")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
import_custom_nodes()
print("ComfyUI setup complete.")


# --- 3. Global Model & Node Loading and Patching ---

from nodes import NODE_CLASS_MAPPINGS
import folder_paths
from comfy import model_management

# Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use.
# model_management.vram_state = model_management.VRAMState.HIGH_VRAM

MODELS_AND_NODES = {}

def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
    """Helper to safely access outputs from ComfyUI nodes, which are often tuples."""
    try:
        return obj[index]
    except (KeyError, TypeError):
        # Fallback for custom nodes that might return a dictionary with a 'result' key
        if isinstance(obj, Mapping) and "result" in obj:
            return obj["result"][index]
        raise

print("Loading models and instantiating nodes into memory. This may take a few minutes...")

# Instantiate Node Classes that will be used for loading and patching
cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()

# Load base models into CPU RAM initially
MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan")
unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")

# Chain all patching operations together for the final models
print("Applying all patches to models...")

# --- Low Noise Model Chain ---
model_low_with_lora = loraloadermodelonly.load_lora_model_only(
    lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
    strength_model=0.8, model=get_value_at_index(unet_low_noise, 0))
model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0))
MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))

# --- High Noise Model Chain ---
model_high_with_lora = loraloadermodelonly.load_lora_model_only(
    lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
    strength_model=0.8, model=get_value_at_index(unet_high_noise, 0))
model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0))
MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))

# Instantiate all other node classes ONCE and store them
MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()

# Move all final, fully-patched models to the GPU
print("Moving final models to GPU...")
model_loaders_final = [
    MODELS_AND_NODES["clip"],
#    MODELS_AND_NODES["vae"],
    MODELS_AND_NODES["model_low_noise"],
    MODELS_AND_NODES["model_high_noise"],
    MODELS_AND_NODES["clip_vision"],
]
model_management.load_models_gpu([
    loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final
], force_patch_weights=True) # force_patch_weights permanently merges the LoRA

print("All models loaded, patched, and on GPU. Gradio app is ready.")


# --- 4. Application Logic and Gradio Interface ---

def calculate_video_dimensions(width, height, max_size=832, min_size=480):
    """Calculates video dimensions, ensuring they are multiples of 16."""
    if width == height:
        return min_size, min_size
    aspect_ratio = width / height
    if width > height:
        video_width = max_size
        video_height = int(max_size / aspect_ratio)
    else:
        video_height = max_size
        video_width = int(max_size * aspect_ratio)
    video_width = max(16, round(video_width / 16) * 16)
    video_height = max(16, round(video_height / 16) * 16)
    return video_width, video_height

def resize_and_crop_to_match(target_image, reference_image):
    """Resizes and center-crops the target image to match the reference image's dimensions."""
    ref_width, ref_height = reference_image.size
    target_width, target_height = target_image.size
    scale = max(ref_width / target_width, ref_height / target_height)
    new_width, new_height = int(target_width * scale), int(target_height * scale)
    resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
    return resized.crop((left, top, left + ref_width, top + ref_height))

@spaces.GPU(duration=120)
def generate_video(
    start_image_pil,
    end_image_pil,
    prompt,
    negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
    duration=33,
    progress=gr.Progress(track_tqdm=True)
):
    """
    Generates a video by interpolating between a start and end image, guided by a text prompt.
    This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes.
    """
    FPS = 16

    # --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances ---
    # These are not re-instantiated; we are just getting references to the global objects.
    clip = MODELS_AND_NODES["clip"]
    vae = MODELS_AND_NODES["vae"]
    model_low_final = MODELS_AND_NODES["model_low_noise"]
    model_high_final = MODELS_AND_NODES["model_high_noise"]
    clip_vision = MODELS_AND_NODES["clip_vision"]
    
    cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
    loadimage = MODELS_AND_NODES["LoadImage"]
    clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
    wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
    ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
    vaedecode = MODELS_AND_NODES["VAEDecode"]
    createvideo = MODELS_AND_NODES["CreateVideo"]
    savevideo = MODELS_AND_NODES["SaveVideo"]

    # --- 2. Image Preprocessing for the Current Run ---
    print("Preprocessing images with Pillow...")
    processed_start_image = start_image_pil.copy()
    processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
    video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height)

    # Save processed images to temporary files for the LoadImage node
    temp_dir = "input" # ComfyUI's default input directory
    os.makedirs(temp_dir, exist_ok=True)
    
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file:
        processed_start_image.save(start_file.name)
        processed_end_image.save(end_file.name)
        start_image_path = os.path.basename(start_file.name)
        end_image_path = os.path.basename(end_file.name)
    print(f"Images resized to {video_width}x{video_height} and saved temporarily.")
    
    # --- 3. Execute the ComfyUI Workflow in Inference Mode ---
    with torch.inference_mode():
        progress(0.1, desc="Encoding text and images...")
        
        # Encode prompts and vision models
        positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
        negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))

        start_image_loaded = loadimage.load_image(image=start_image_path)
        end_image_loaded = loadimage.load_image(image=end_image_path)
        
        clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0))
        clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0))

        progress(0.2, desc="Preparing initial latents...")
        initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
            width=video_width, height=video_height, length=duration, batch_size=1,
            positive=get_value_at_index(positive_conditioning, 0),
            negative=get_value_at_index(negative_conditioning, 0),
            vae=get_value_at_index(vae, 0),
            clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
            clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
            start_image=get_value_at_index(start_image_loaded, 0),
            end_image=get_value_at_index(end_image_loaded, 0),
        )

        ksampler_positive = get_value_at_index(initial_latents, 0)
        ksampler_negative = get_value_at_index(initial_latents, 1)
        ksampler_latent = get_value_at_index(initial_latents, 2)
        
        progress(0.5, desc="Denoising (Step 1/2)...")
        latent_step1 = ksampleradvanced.sample(
            add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
            sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
            return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
            positive=ksampler_positive,
            negative=ksampler_negative,
            latent_image=ksampler_latent,
        )
        
        progress(0.7, desc="Denoising (Step 2/2)...")
        latent_step2 = ksampleradvanced.sample(
            add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
            sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
            return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
            positive=ksampler_positive,
            negative=ksampler_negative,
            latent_image=get_value_at_index(latent_step1, 0),
        )

        progress(0.8, desc="Decoding VAE...")
        decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))

        progress(0.9, desc="Creating and saving video...")
        video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
        
        # Save the video to ComfyUI's default output directory
        save_result = savevideo.save_video(
            filename_prefix="GradioVideo", format="mp4", codec="h264",
            video=get_value_at_index(video_data, 0),
        )
        
        progress(1.0, desc="Done!")

        # --- 4. Cleanup and Return ---
        try:
            os.remove(start_file.name)
            os.remove(end_file.name)
        except Exception as e:
            print(f"Error cleaning up temporary files: {e}")

        # Gradio video component expects a filepath relative to the root of the app
        return f"output/{save_result['ui']['images'][0]['filename']}"


css = '''
.fillable{max-width: 1100px !important}
.dark .progress-text {color: white}
'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
    gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
    gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU")
    
    with gr.Row():
        with gr.Column():
            with gr.Group():
                with gr.Row():
                    start_image = gr.Image(type="pil", label="Start Frame")
                    end_image = gr.Image(type="pil", label="End Frame")
                
                prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
    
                with gr.Accordion("Advanced Settings", open=False, visible=False):
                    duration = gr.Radio(
                        [("Short (2s)", 33), ("Mid (4s)", 66)],
                        value=33, 
                        label="Video Duration",
                        visible=False
                    )
                    negative_prompt = gr.Textbox(
                        label="Negative Prompt", 
                        value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
                        visible=False
                    )
                
                generate_button = gr.Button("Generate Video", variant="primary")
        
        with gr.Column():
            output_video = gr.Video(label="Generated Video", autoplay=True)

    generate_button.click(
        fn=generate_video,
        inputs=[start_image, end_image, prompt, negative_prompt, duration],
        outputs=output_video
    )

    gr.Examples(
        examples=[
            ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
            ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"],
            ["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"],
        ],
        inputs=[start_image, end_image, prompt],
        outputs=output_video,
        fn=generate_video,
        cache_examples="lazy",
    )

if __name__ == "__main__":
    app.launch(share=True)