File size: 8,650 Bytes
d16eb70
 
 
 
 
 
 
 
 
 
 
 
dcd7d2e
 
 
 
d16eb70
 
 
 
dcd7d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d16eb70
dcd7d2e
d16eb70
 
dcd7d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a482b15
dcd7d2e
 
 
 
 
 
 
 
 
 
 
 
 
d16eb70
 
dcd7d2e
d16eb70
 
 
 
 
 
 
 
dcd7d2e
 
d5ed7d9
 
d16eb70
 
decedbf
33e1890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da6bf2a
dcd7d2e
 
 
 
d16eb70
 
 
 
 
 
 
 
 
 
 
 
 
a482b15
d16eb70
a482b15
d16eb70
dcd7d2e
d5ed7d9
 
a482b15
d16eb70
 
 
 
dcd7d2e
 
 
d16eb70
dcd7d2e
d16eb70
dcd7d2e
 
 
 
d16eb70
dcd7d2e
 
 
d16eb70
dcd7d2e
 
 
a482b15
dcd7d2e
d16eb70
 
 
 
dcd7d2e
d16eb70
 
a482b15
d16eb70
 
 
dcd7d2e
 
 
7994d21
dcd7d2e
7994d21
dcd7d2e
 
 
 
 
 
 
d16eb70
a482b15
dcd7d2e
d16eb70
 
 
 
dcd7d2e
 
d16eb70
dcd7d2e
d16eb70
 
dcd7d2e
 
 
d16eb70
dcd7d2e
 
d16eb70
 
 
 
 
 
dcd7d2e
 
 
d16eb70
 
 
dcd7d2e
 
 
 
 
 
d16eb70
 
dcd7d2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
import spaces

import os
import numpy as np
from PIL import Image
import gradio as gr

import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video

# ────────────────────────────────────────────────
#  Model + LoRA configuration
# ────────────────────────────────────────────────

MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

AVAILABLE_LORAS = [
    {
        "name": "Lightning (Fast 4-step)",
        "repo_id": "lightx2v/Wan2.2-Distill-Loras",
        "filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",
        "default_strength": 1.0,
    },
    {
        "name": "General NSFW",
        "repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA",
        "filename": "pytorch_lora_weights.safetensors",
        "default_strength": 0.8,
    },
    # Add more LoRAs here β€” they will be pre-loaded automatically
]

# Global pipeline + pre-loaded adapter info
pipe = None
lora_adapters = {}          # name β†’ {"adapter_name": str, "strength": float}

def initialize_pipeline():
    global pipe, lora_adapters

    if pipe is not None:
        return pipe

    print("Loading Wan2.2-TI2V-5B base model...")
    vae = AutoencoderKLWan.from_pretrained(
        MODEL_ID,
        subfolder="vae",
        torch_dtype=torch.float32
    )
    pipe = WanPipeline.from_pretrained(
        MODEL_ID,
        vae=vae,
        torch_dtype=dtype
    )
    pipe.to(device)
    print("Base model loaded.")

    print("Pre-loading LoRAs...")
    for lora in AVAILABLE_LORAS:
        name = lora["name"]
        try:
            print(f"  β†’ {name}")
            pipe.load_lora_weights(
                lora["repo_id"],
                weight_name=lora["filename"],
                adapter_name=name,
            )
            lora_adapters[name] = {
                "adapter_name": name,
                "strength": lora["default_strength"]
            }
        except Exception as e:
            print(f"    Failed to load {name}: {e}")

    if lora_adapters:
        pipe.fuse_lora()
        print("All LoRAs fused.")

    print("Pipeline fully initialized.")
    return pipe

@spaces.GPU(duration=180)
def generate_video(
    prompt: str,
    image: Image.Image = None,
    width: int = 1280,
    height: int = 704,
    num_frames: int = 73,
    num_inference_steps: int = 35,
    guidance_scale: float = 5.0,
    seed: int = -1,
    enabled_loras: list = None,
    lora_strength_multiplier: float = 1.0,
    progress=gr.Progress()
):
    try:
        pipeline = initialize_pipeline()

        active_adapters = []
        active_strengths = []
        enabled = enabled_loras or []

        for lora_name in enabled:
            if lora_name in lora_adapters:
                strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier
                active_adapters.append(lora_name)
                active_strengths.append(strength)

        if active_adapters:
            print(f"Activating LoRAs: {active_adapters} with strengths {active_strengths}")
            pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths)
        else:
            print("No LoRAs enabled β†’ disabling LoRA")
            try:
                pipeline.disable_lora()
            except Exception:
                pass
        
        if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8:
            num_inference_steps = 4
            print("Lightning LoRA β†’ reduced to 4 steps")

        if seed == -1:
            seed = torch.randint(0, 2**32 - 1, (1,)).item()
        generator = torch.Generator(device=device).manual_seed(seed)

        gen_params = {
            "prompt": prompt,
            "height": height,
            "width": width,
            "num_frames": num_frames,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "generator": generator,
        }
        
        if image is not None:
            gen_params["image"] = image

        print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}")
        progress(0, desc="Starting generation...")
        
        output = pipeline(**gen_params).frames[0] 

        output_path = "output.mp4"
        export_to_video(output, output_path, fps=24)

        status = f"Done! Seed: {seed}"
        if active_adapters:
            status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x"

        return output_path, status

    except Exception as e:
        msg = f"Error: {str(e)}"
        print(msg)
        return None, msg

# ────────────────────────────────────────────────
#               Gradio UI
# ────────────────────────────────────────────────

with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo:
    gr.Markdown("""
    # Wan2.2-TI2V-5B Video Generation  
    **Text-to-Video & Image-to-Video** with optimized LoRA loading.
    """)

    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt", lines=3,
                value="Two anthropomorphic cats in comfy boxing gear fight on stage"
            )
            image_input = gr.Image(label="Input Image (optional for I2V)", type="pil", sources=["upload"])

            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    width_input  = gr.Slider(512, 1920, step=64, value=1280, label="Width")
                    height_input = gr.Slider(512, 1080, step=64, value=704, label="Height")
                num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames")
                num_steps_input  = gr.Slider(4, 60, step=1, value=4, label="Inference Steps",
                                             info="Lightning LoRA β†’ try 4–8 steps")
                guidance_scale_input = gr.Slider(1.0, 15.0, 1.0, value=5.0, label="Guidance Scale")
                seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)

            with gr.Accordion("LoRA Controls", open=True):
                lora_checkbox = gr.CheckboxGroup(
                    choices=[l["name"] for l in AVAILABLE_LORAS],
                    label="Enable LoRAs",
                    value=[]
                )
                lora_strength = gr.Slider(0.1, 1.5, step=0.05, value=1.0,
                                          label="Global Strength Multiplier")

            generate_btn = gr.Button("Generate Video", variant="primary", size="lg")

        with gr.Column():
            video_output = gr.Video(label="Generated Video", autoplay=True)
            status_output = gr.Textbox(label="Status", lines=3)

    # Examples with LoRA usage
    gr.Examples(
        examples=[
            ["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0],
            ["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0],
            ["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9],
        ],
        inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
                num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
        outputs=[video_output, status_output],
        fn=generate_video,
        cache_examples=False,
    )

    generate_btn.click(
        generate_video,
        inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
                num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
        outputs=[video_output, status_output]
    )

    gr.Markdown("""
    ## Performance Notes
    - LoRAs are **pre-loaded once** β†’ first generation may take ~10–30s longer, later ones are fast.
    - Lightning LoRA: use **4–8 steps** β†’ generation can finish in <60s.
    - Add new LoRAs by appending to `AVAILABLE_LORAS` β€” they auto-load at startup.
    """)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()