# IMPORTANT: spaces must be imported first to avoid CUDA initialization issues import spaces import os import numpy as np from PIL import Image import gradio as gr import torch from diffusers import WanPipeline, AutoencoderKLWan from diffusers.utils import export_to_video # ──────────────────────────────────────────────── # Model + LoRA configuration # ──────────────────────────────────────────────── MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" AVAILABLE_LORAS = [ { "name": "Lightning (Fast 4-step)", "repo_id": "lightx2v/Wan2.2-Distill-Loras", "filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "default_strength": 1.0, }, { "name": "General NSFW", "repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA", "filename": "pytorch_lora_weights.safetensors", "default_strength": 0.8, }, # Add more LoRAs here — they will be pre-loaded automatically ] # Global pipeline + pre-loaded adapter info pipe = None lora_adapters = {} # name → {"adapter_name": str, "strength": float} def initialize_pipeline(): global pipe, lora_adapters if pipe is not None: return pipe print("Loading Wan2.2-TI2V-5B base model...") vae = AutoencoderKLWan.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=torch.float32 ) pipe = WanPipeline.from_pretrained( MODEL_ID, vae=vae, torch_dtype=dtype ) pipe.to(device) print("Base model loaded.") print("Pre-loading LoRAs...") for lora in AVAILABLE_LORAS: name = lora["name"] try: print(f" → {name}") pipe.load_lora_weights( lora["repo_id"], weight_name=lora["filename"], adapter_name=name, ) lora_adapters[name] = { "adapter_name": name, "strength": lora["default_strength"] } except Exception as e: print(f" Failed to load {name}: {e}") if lora_adapters: pipe.fuse_lora() print("All LoRAs fused.") print("Pipeline fully initialized.") return pipe @spaces.GPU(duration=180) def generate_video( prompt: str, image: Image.Image = None, width: int = 1280, height: int = 704, num_frames: int = 73, num_inference_steps: int = 35, guidance_scale: float = 5.0, seed: int = -1, enabled_loras: list = None, lora_strength_multiplier: float = 1.0, progress=gr.Progress() ): try: pipeline = initialize_pipeline() active_adapters = [] active_strengths = [] enabled = enabled_loras or [] for lora_name in enabled: if lora_name in lora_adapters: strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier active_adapters.append(lora_name) active_strengths.append(strength) if active_adapters: print(f"Activating LoRAs: {active_adapters} with strengths {active_strengths}") pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths) else: print("No LoRAs enabled → disabling LoRA") try: pipeline.disable_lora() except Exception: pass if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8: num_inference_steps = 4 print("Lightning LoRA → reduced to 4 steps") if seed == -1: seed = torch.randint(0, 2**32 - 1, (1,)).item() generator = torch.Generator(device=device).manual_seed(seed) gen_params = { "prompt": prompt, "height": height, "width": width, "num_frames": num_frames, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": generator, } if image is not None: gen_params["image"] = image print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}") progress(0, desc="Starting generation...") output = pipeline(**gen_params).frames[0] output_path = "output.mp4" export_to_video(output, output_path, fps=24) status = f"Done! Seed: {seed}" if active_adapters: status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x" return output_path, status except Exception as e: msg = f"Error: {str(e)}" print(msg) return None, msg # ──────────────────────────────────────────────── # Gradio UI # ──────────────────────────────────────────────── with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo: gr.Markdown(""" # Wan2.2-TI2V-5B Video Generation **Text-to-Video & Image-to-Video** with optimized LoRA loading. """) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", lines=3, value="Two anthropomorphic cats in comfy boxing gear fight on stage" ) image_input = gr.Image(label="Input Image (optional for I2V)", type="pil", sources=["upload"]) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width_input = gr.Slider(512, 1920, step=64, value=1280, label="Width") height_input = gr.Slider(512, 1080, step=64, value=704, label="Height") num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames") num_steps_input = gr.Slider(4, 60, step=1, value=4, label="Inference Steps", info="Lightning LoRA → try 4–8 steps") guidance_scale_input = gr.Slider(1.0, 15.0, 1.0, value=5.0, label="Guidance Scale") seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0) with gr.Accordion("LoRA Controls", open=True): lora_checkbox = gr.CheckboxGroup( choices=[l["name"] for l in AVAILABLE_LORAS], label="Enable LoRAs", value=[] ) lora_strength = gr.Slider(0.1, 1.5, step=0.05, value=1.0, label="Global Strength Multiplier") generate_btn = gr.Button("Generate Video", variant="primary", size="lg") with gr.Column(): video_output = gr.Video(label="Generated Video", autoplay=True) status_output = gr.Textbox(label="Status", lines=3) # Examples with LoRA usage gr.Examples( examples=[ ["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0], ["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0], ["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9], ], inputs=[prompt_input, image_input, width_input, height_input, num_frames_input, num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength], outputs=[video_output, status_output], fn=generate_video, cache_examples=False, ) generate_btn.click( generate_video, inputs=[prompt_input, image_input, width_input, height_input, num_frames_input, num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength], outputs=[video_output, status_output] ) gr.Markdown(""" ## Performance Notes - LoRAs are **pre-loaded once** → first generation may take ~10–30s longer, later ones are fast. - Lightning LoRA: use **4–8 steps** → generation can finish in <60s. - Add new LoRAs by appending to `AVAILABLE_LORAS` — they auto-load at startup. """) if __name__ == "__main__": demo.queue(max_size=20).launch()