Spaces:
Paused
Paused
| # IMPORTANT: spaces must be imported first to avoid CUDA initialization issues | |
| import spaces | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| from diffusers import WanPipeline, AutoencoderKLWan | |
| from diffusers.utils import export_to_video | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model + LoRA configuration | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| AVAILABLE_LORAS = [ | |
| { | |
| "name": "Lightning (Fast 4-step)", | |
| "repo_id": "lightx2v/Wan2.2-Distill-Loras", | |
| "filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", | |
| "default_strength": 1.0, | |
| }, | |
| { | |
| "name": "General NSFW", | |
| "repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA", | |
| "filename": "pytorch_lora_weights.safetensors", | |
| "default_strength": 0.8, | |
| }, | |
| # Add more LoRAs here β they will be pre-loaded automatically | |
| ] | |
| # Global pipeline + pre-loaded adapter info | |
| pipe = None | |
| lora_adapters = {} # name β {"adapter_name": str, "strength": float} | |
| def initialize_pipeline(): | |
| global pipe, lora_adapters | |
| if pipe is not None: | |
| return pipe | |
| print("Loading Wan2.2-TI2V-5B base model...") | |
| vae = AutoencoderKLWan.from_pretrained( | |
| MODEL_ID, | |
| subfolder="vae", | |
| torch_dtype=torch.float32 | |
| ) | |
| pipe = WanPipeline.from_pretrained( | |
| MODEL_ID, | |
| vae=vae, | |
| torch_dtype=dtype | |
| ) | |
| pipe.to(device) | |
| print("Base model loaded.") | |
| print("Pre-loading LoRAs...") | |
| for lora in AVAILABLE_LORAS: | |
| name = lora["name"] | |
| try: | |
| print(f" β {name}") | |
| pipe.load_lora_weights( | |
| lora["repo_id"], | |
| weight_name=lora["filename"], | |
| adapter_name=name, | |
| ) | |
| lora_adapters[name] = { | |
| "adapter_name": name, | |
| "strength": lora["default_strength"] | |
| } | |
| except Exception as e: | |
| print(f" Failed to load {name}: {e}") | |
| if lora_adapters: | |
| pipe.fuse_lora() | |
| print("All LoRAs fused.") | |
| print("Pipeline fully initialized.") | |
| return pipe | |
| def generate_video( | |
| prompt: str, | |
| image: Image.Image = None, | |
| width: int = 1280, | |
| height: int = 704, | |
| num_frames: int = 73, | |
| num_inference_steps: int = 35, | |
| guidance_scale: float = 5.0, | |
| seed: int = -1, | |
| enabled_loras: list = None, | |
| lora_strength_multiplier: float = 1.0, | |
| progress=gr.Progress() | |
| ): | |
| try: | |
| pipeline = initialize_pipeline() | |
| active_adapters = [] | |
| active_strengths = [] | |
| enabled = enabled_loras or [] | |
| for lora_name in enabled: | |
| if lora_name in lora_adapters: | |
| strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier | |
| active_adapters.append(lora_name) | |
| active_strengths.append(strength) | |
| if active_adapters: | |
| print(f"Activating LoRAs: {active_adapters} with strengths {active_strengths}") | |
| pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths) | |
| else: | |
| print("No LoRAs enabled β disabling LoRA") | |
| try: | |
| pipeline.disable_lora() | |
| except Exception: | |
| pass | |
| if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8: | |
| num_inference_steps = 4 | |
| print("Lightning LoRA β reduced to 4 steps") | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| gen_params = { | |
| "prompt": prompt, | |
| "height": height, | |
| "width": width, | |
| "num_frames": num_frames, | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": generator, | |
| } | |
| if image is not None: | |
| gen_params["image"] = image | |
| print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}") | |
| progress(0, desc="Starting generation...") | |
| output = pipeline(**gen_params).frames[0] | |
| output_path = "output.mp4" | |
| export_to_video(output, output_path, fps=24) | |
| status = f"Done! Seed: {seed}" | |
| if active_adapters: | |
| status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x" | |
| return output_path, status | |
| except Exception as e: | |
| msg = f"Error: {str(e)}" | |
| print(msg) | |
| return None, msg | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo: | |
| gr.Markdown(""" | |
| # Wan2.2-TI2V-5B Video Generation | |
| **Text-to-Video & Image-to-Video** with optimized LoRA loading. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", lines=3, | |
| value="Two anthropomorphic cats in comfy boxing gear fight on stage" | |
| ) | |
| image_input = gr.Image(label="Input Image (optional for I2V)", type="pil", sources=["upload"]) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| width_input = gr.Slider(512, 1920, step=64, value=1280, label="Width") | |
| height_input = gr.Slider(512, 1080, step=64, value=704, label="Height") | |
| num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames") | |
| num_steps_input = gr.Slider(4, 60, step=1, value=4, label="Inference Steps", | |
| info="Lightning LoRA β try 4β8 steps") | |
| guidance_scale_input = gr.Slider(1.0, 15.0, 1.0, value=5.0, label="Guidance Scale") | |
| seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0) | |
| with gr.Accordion("LoRA Controls", open=True): | |
| lora_checkbox = gr.CheckboxGroup( | |
| choices=[l["name"] for l in AVAILABLE_LORAS], | |
| label="Enable LoRAs", | |
| value=[] | |
| ) | |
| lora_strength = gr.Slider(0.1, 1.5, step=0.05, value=1.0, | |
| label="Global Strength Multiplier") | |
| generate_btn = gr.Button("Generate Video", variant="primary", size="lg") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video", autoplay=True) | |
| status_output = gr.Textbox(label="Status", lines=3) | |
| # Examples with LoRA usage | |
| gr.Examples( | |
| examples=[ | |
| ["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0], | |
| ["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0], | |
| ["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9], | |
| ], | |
| inputs=[prompt_input, image_input, width_input, height_input, num_frames_input, | |
| num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength], | |
| outputs=[video_output, status_output], | |
| fn=generate_video, | |
| cache_examples=False, | |
| ) | |
| generate_btn.click( | |
| generate_video, | |
| inputs=[prompt_input, image_input, width_input, height_input, num_frames_input, | |
| num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength], | |
| outputs=[video_output, status_output] | |
| ) | |
| gr.Markdown(""" | |
| ## Performance Notes | |
| - LoRAs are **pre-loaded once** β first generation may take ~10β30s longer, later ones are fast. | |
| - Lightning LoRA: use **4β8 steps** β generation can finish in <60s. | |
| - Add new LoRAs by appending to `AVAILABLE_LORAS` β they auto-load at startup. | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() |