Spaces:
Paused
Paused
File size: 8,650 Bytes
d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e a482b15 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d5ed7d9 d16eb70 decedbf 33e1890 da6bf2a dcd7d2e d16eb70 a482b15 d16eb70 a482b15 d16eb70 dcd7d2e d5ed7d9 a482b15 d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e a482b15 dcd7d2e d16eb70 dcd7d2e d16eb70 a482b15 d16eb70 dcd7d2e 7994d21 dcd7d2e 7994d21 dcd7d2e d16eb70 a482b15 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e d16eb70 dcd7d2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | # IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
import spaces
import os
import numpy as np
from PIL import Image
import gradio as gr
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# Model + LoRA configuration
# ββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
AVAILABLE_LORAS = [
{
"name": "Lightning (Fast 4-step)",
"repo_id": "lightx2v/Wan2.2-Distill-Loras",
"filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",
"default_strength": 1.0,
},
{
"name": "General NSFW",
"repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA",
"filename": "pytorch_lora_weights.safetensors",
"default_strength": 0.8,
},
# Add more LoRAs here β they will be pre-loaded automatically
]
# Global pipeline + pre-loaded adapter info
pipe = None
lora_adapters = {} # name β {"adapter_name": str, "strength": float}
def initialize_pipeline():
global pipe, lora_adapters
if pipe is not None:
return pipe
print("Loading Wan2.2-TI2V-5B base model...")
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=torch.float32
)
pipe = WanPipeline.from_pretrained(
MODEL_ID,
vae=vae,
torch_dtype=dtype
)
pipe.to(device)
print("Base model loaded.")
print("Pre-loading LoRAs...")
for lora in AVAILABLE_LORAS:
name = lora["name"]
try:
print(f" β {name}")
pipe.load_lora_weights(
lora["repo_id"],
weight_name=lora["filename"],
adapter_name=name,
)
lora_adapters[name] = {
"adapter_name": name,
"strength": lora["default_strength"]
}
except Exception as e:
print(f" Failed to load {name}: {e}")
if lora_adapters:
pipe.fuse_lora()
print("All LoRAs fused.")
print("Pipeline fully initialized.")
return pipe
@spaces.GPU(duration=180)
def generate_video(
prompt: str,
image: Image.Image = None,
width: int = 1280,
height: int = 704,
num_frames: int = 73,
num_inference_steps: int = 35,
guidance_scale: float = 5.0,
seed: int = -1,
enabled_loras: list = None,
lora_strength_multiplier: float = 1.0,
progress=gr.Progress()
):
try:
pipeline = initialize_pipeline()
active_adapters = []
active_strengths = []
enabled = enabled_loras or []
for lora_name in enabled:
if lora_name in lora_adapters:
strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier
active_adapters.append(lora_name)
active_strengths.append(strength)
if active_adapters:
print(f"Activating LoRAs: {active_adapters} with strengths {active_strengths}")
pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths)
else:
print("No LoRAs enabled β disabling LoRA")
try:
pipeline.disable_lora()
except Exception:
pass
if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8:
num_inference_steps = 4
print("Lightning LoRA β reduced to 4 steps")
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=device).manual_seed(seed)
gen_params = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
}
if image is not None:
gen_params["image"] = image
print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}")
progress(0, desc="Starting generation...")
output = pipeline(**gen_params).frames[0]
output_path = "output.mp4"
export_to_video(output, output_path, fps=24)
status = f"Done! Seed: {seed}"
if active_adapters:
status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x"
return output_path, status
except Exception as e:
msg = f"Error: {str(e)}"
print(msg)
return None, msg
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# Gradio UI
# ββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo:
gr.Markdown("""
# Wan2.2-TI2V-5B Video Generation
**Text-to-Video & Image-to-Video** with optimized LoRA loading.
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt", lines=3,
value="Two anthropomorphic cats in comfy boxing gear fight on stage"
)
image_input = gr.Image(label="Input Image (optional for I2V)", type="pil", sources=["upload"])
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width_input = gr.Slider(512, 1920, step=64, value=1280, label="Width")
height_input = gr.Slider(512, 1080, step=64, value=704, label="Height")
num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames")
num_steps_input = gr.Slider(4, 60, step=1, value=4, label="Inference Steps",
info="Lightning LoRA β try 4β8 steps")
guidance_scale_input = gr.Slider(1.0, 15.0, 1.0, value=5.0, label="Guidance Scale")
seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
with gr.Accordion("LoRA Controls", open=True):
lora_checkbox = gr.CheckboxGroup(
choices=[l["name"] for l in AVAILABLE_LORAS],
label="Enable LoRAs",
value=[]
)
lora_strength = gr.Slider(0.1, 1.5, step=0.05, value=1.0,
label="Global Strength Multiplier")
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="Generated Video", autoplay=True)
status_output = gr.Textbox(label="Status", lines=3)
# Examples with LoRA usage
gr.Examples(
examples=[
["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0],
["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0],
["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9],
],
inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
outputs=[video_output, status_output],
fn=generate_video,
cache_examples=False,
)
generate_btn.click(
generate_video,
inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
outputs=[video_output, status_output]
)
gr.Markdown("""
## Performance Notes
- LoRAs are **pre-loaded once** β first generation may take ~10β30s longer, later ones are fast.
- Lightning LoRA: use **4β8 steps** β generation can finish in <60s.
- Add new LoRAs by appending to `AVAILABLE_LORAS` β they auto-load at startup.
""")
if __name__ == "__main__":
demo.queue(max_size=20).launch() |