wannsfw / app.py
devindevine's picture
Upload 6 files
a482b15 verified
# IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
import spaces
import os
import numpy as np
from PIL import Image
import gradio as gr
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
# ────────────────────────────────────────────────
# Model + LoRA configuration
# ────────────────────────────────────────────────
MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
AVAILABLE_LORAS = [
{
"name": "Lightning (Fast 4-step)",
"repo_id": "lightx2v/Wan2.2-Distill-Loras",
"filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",
"default_strength": 1.0,
},
{
"name": "General NSFW",
"repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA",
"filename": "pytorch_lora_weights.safetensors",
"default_strength": 0.8,
},
# Add more LoRAs here β€” they will be pre-loaded automatically
]
# Global pipeline + pre-loaded adapter info
pipe = None
lora_adapters = {} # name β†’ {"adapter_name": str, "strength": float}
def initialize_pipeline():
global pipe, lora_adapters
if pipe is not None:
return pipe
print("Loading Wan2.2-TI2V-5B base model...")
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=torch.float32
)
pipe = WanPipeline.from_pretrained(
MODEL_ID,
vae=vae,
torch_dtype=dtype
)
pipe.to(device)
print("Base model loaded.")
print("Pre-loading LoRAs...")
for lora in AVAILABLE_LORAS:
name = lora["name"]
try:
print(f" β†’ {name}")
pipe.load_lora_weights(
lora["repo_id"],
weight_name=lora["filename"],
adapter_name=name,
)
lora_adapters[name] = {
"adapter_name": name,
"strength": lora["default_strength"]
}
except Exception as e:
print(f" Failed to load {name}: {e}")
if lora_adapters:
pipe.fuse_lora()
print("All LoRAs fused.")
print("Pipeline fully initialized.")
return pipe
@spaces.GPU(duration=180)
def generate_video(
prompt: str,
image: Image.Image = None,
width: int = 1280,
height: int = 704,
num_frames: int = 73,
num_inference_steps: int = 35,
guidance_scale: float = 5.0,
seed: int = -1,
enabled_loras: list = None,
lora_strength_multiplier: float = 1.0,
progress=gr.Progress()
):
try:
pipeline = initialize_pipeline()
active_adapters = []
active_strengths = []
enabled = enabled_loras or []
for lora_name in enabled:
if lora_name in lora_adapters:
strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier
active_adapters.append(lora_name)
active_strengths.append(strength)
if active_adapters:
print(f"Activating LoRAs: {active_adapters} with strengths {active_strengths}")
pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths)
else:
print("No LoRAs enabled β†’ disabling LoRA")
try:
pipeline.disable_lora()
except Exception:
pass
if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8:
num_inference_steps = 4
print("Lightning LoRA β†’ reduced to 4 steps")
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=device).manual_seed(seed)
gen_params = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
}
if image is not None:
gen_params["image"] = image
print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}")
progress(0, desc="Starting generation...")
output = pipeline(**gen_params).frames[0]
output_path = "output.mp4"
export_to_video(output, output_path, fps=24)
status = f"Done! Seed: {seed}"
if active_adapters:
status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x"
return output_path, status
except Exception as e:
msg = f"Error: {str(e)}"
print(msg)
return None, msg
# ────────────────────────────────────────────────
# Gradio UI
# ────────────────────────────────────────────────
with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo:
gr.Markdown("""
# Wan2.2-TI2V-5B Video Generation
**Text-to-Video & Image-to-Video** with optimized LoRA loading.
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt", lines=3,
value="Two anthropomorphic cats in comfy boxing gear fight on stage"
)
image_input = gr.Image(label="Input Image (optional for I2V)", type="pil", sources=["upload"])
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width_input = gr.Slider(512, 1920, step=64, value=1280, label="Width")
height_input = gr.Slider(512, 1080, step=64, value=704, label="Height")
num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames")
num_steps_input = gr.Slider(4, 60, step=1, value=4, label="Inference Steps",
info="Lightning LoRA β†’ try 4–8 steps")
guidance_scale_input = gr.Slider(1.0, 15.0, 1.0, value=5.0, label="Guidance Scale")
seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
with gr.Accordion("LoRA Controls", open=True):
lora_checkbox = gr.CheckboxGroup(
choices=[l["name"] for l in AVAILABLE_LORAS],
label="Enable LoRAs",
value=[]
)
lora_strength = gr.Slider(0.1, 1.5, step=0.05, value=1.0,
label="Global Strength Multiplier")
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="Generated Video", autoplay=True)
status_output = gr.Textbox(label="Status", lines=3)
# Examples with LoRA usage
gr.Examples(
examples=[
["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0],
["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0],
["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9],
],
inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
outputs=[video_output, status_output],
fn=generate_video,
cache_examples=False,
)
generate_btn.click(
generate_video,
inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
outputs=[video_output, status_output]
)
gr.Markdown("""
## Performance Notes
- LoRAs are **pre-loaded once** β†’ first generation may take ~10–30s longer, later ones are fast.
- Lightning LoRA: use **4–8 steps** β†’ generation can finish in <60s.
- Add new LoRAs by appending to `AVAILABLE_LORAS` β€” they auto-load at startup.
""")
if __name__ == "__main__":
demo.queue(max_size=20).launch()