wan / app.py
alphabagibagi's picture
Update app.py
8f579e1 verified
import os
import sys
import subprocess
import time
import random
import asyncio
import threading
import io
import shutil
import numpy as np
from PIL import Image
import gradio as gr
import torch
# --- Configuration & Paths ---
ROOT_DIR = os.path.abspath(os.getcwd())
COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI")
sys.path.append(COMFYUI_DIR)
MODELS_DIR = os.path.join(COMFYUI_DIR, "models")
UNET_DIR = os.path.join(MODELS_DIR, "unet")
CLIP_DIR = os.path.join(MODELS_DIR, "clip")
VAE_DIR = os.path.join(MODELS_DIR, "vae")
LORA_DIR = os.path.join(MODELS_DIR, "loras", "FusionX")
CUSTOM_NODES_DIR = os.path.join(COMFYUI_DIR, "custom_nodes")
GGUF_NODE_DIR = os.path.join(CUSTOM_NODES_DIR, "ComfyUI-GGUF")
# --- Model URLs ---
URL_UNET = "https://huggingface.co/QuantStack/Wan2.2-T2V-A14B-GGUF/resolve/main/LowNoise/Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf"
FILENAME_UNET = "Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf"
URL_CLIP = "https://huggingface.co/city96/umt5-xxl-encoder-gguf/resolve/main/umt5-xxl-encoder-Q3_K_S.gguf"
FILENAME_CLIP = "umt5-xxl-encoder-Q3_K_S.gguf"
URL_VAE = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors"
FILENAME_VAE = "wan_2.1_vae.safetensors"
URL_LORA = "https://huggingface.co/vrgamedevgirl84/Wan14BT2VFusioniX/resolve/main/FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
FILENAME_LORA = "Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
# --- Setup Functions ---
def run_command(command, desc=None):
if desc:
print(f"➜ {desc}...")
try:
subprocess.run(command, check=True, shell=True)
except subprocess.CalledProcessError as e:
print(f"❌ Error during {desc}: {e}")
raise
def robust_download(url, dest_dir, filename):
dest_path = os.path.join(dest_dir, filename)
if os.path.exists(dest_path):
print(f"βœ… {filename} already exists.")
return
print(f"⬇️ Downloading {filename}...")
# Method 1: Try aria2c (fastest)
try:
# Check if aria2c is installed
subprocess.run(["aria2c", "--version"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
run_command(f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {url} -d {dest_dir} -o {filename}", f"Downloading {filename} (aria2c)")
return
except (subprocess.CalledProcessError, FileNotFoundError):
print("⚠️ aria2c not found or failed, falling back to huggingface_hub...")
# Method 2: huggingface_hub (reliable)
try:
from huggingface_hub import hf_hub_download
# Parse Repo ID and Filename from URL
# URL format: https://huggingface.co/USER/REPO/resolve/main/PATH/TO/FILE
parts = url.replace("https://huggingface.co/", "").split("/resolve/main/")
if len(parts) == 2:
repo_id = parts[0]
subfolder = os.path.dirname(parts[1]) if "/" in parts[1] else None
remote_filename = os.path.basename(parts[1])
# Download
print(f"⏳ Downloading via HF Hub: {repo_id}/{remote_filename}")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=parts[1], # Pass full path as filename argument usually handles directory structure?
# Actually hf_hub_download 'filename' argument is relative path in repo.
local_dir=dest_dir,
local_dir_use_symlinks=False
)
# Rename if necessary/ensure it matches what we expect
# hf_hub_download with local_dir preserves structure usually.
# check where it landed.
expected_path = os.path.join(dest_dir, remote_filename)
# If subfolders are involved, it might be deep.
# Simpler: just move it if name doesn't match
# Re-verification handled by existence check on next run
return
except Exception as e:
print(f"❌ Fallback download failed: {e}")
# Method 3: Requests (slowest fallback)
import requests
print(f"⚠️ Trying simple requests download...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(dest_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return
def setup_environment():
print("πŸš€ Starting Setup Environment...")
# 1. Clone ComfyUI if not exists
if not os.path.exists(COMFYUI_DIR):
run_command(f"git clone https://github.com/comfyanonymous/ComfyUI {COMFYUI_DIR}", "Cloning ComfyUI")
else:
print(f"βœ… ComfyUI found at {COMFYUI_DIR}")
# 2. Clone Custom Node (ComfyUI-GGUF)
if not os.path.exists(GGUF_NODE_DIR):
run_command(f"git clone https://github.com/city96/ComfyUI-GGUF {GGUF_NODE_DIR}", "Cloning ComfyUI-GGUF")
else:
print(f"βœ… ComfyUI-GGUF found at {GGUF_NODE_DIR}")
# 3. Create Directories
for d in [UNET_DIR, CLIP_DIR, VAE_DIR, LORA_DIR]:
os.makedirs(d, exist_ok=True)
# 4. Download Models
download_list = [
(URL_UNET, UNET_DIR, FILENAME_UNET),
(URL_CLIP, CLIP_DIR, FILENAME_CLIP),
(URL_VAE, VAE_DIR, FILENAME_VAE),
(URL_LORA, LORA_DIR, FILENAME_LORA)
]
for url, dest_dir, filename in download_list:
robust_download(url, dest_dir, filename)
print("πŸŽ‰ Environment Setup Complete!")
# Run setup immediately
setup_environment()
# --- ComfyUI Imports ---
# Configure Execution Arguments for ComfyUI
# Aggressively force CPU if CUDA is not available or if we want to ensure no crashes on CPU Spaces
try:
if not torch.cuda.is_available():
print("⚠️ CUDA not available, forcing CPU mode for ComfyUI...")
# 1. Force environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# 2. Inject --cpu argument
if "--cpu" not in sys.argv:
sys.argv.append("--cpu")
# 3. Monkeypatch torch.cuda to ensure ComfyUI doesn't try to initialize CUDA
# This is necessary because some ComfyUI versions checks might be aggressive
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
torch.cuda.current_device = lambda: None
print("βœ… Applied CPU enforcement patches.")
except Exception as e:
print(f"⚠️ Error applying CPU patches: {e}")
# These must happen AFTER setup because ComfyUI folder might not exist before
try:
import nodes
import comfy.samplers
from nodes import NODE_CLASS_MAPPINGS, KSamplerAdvanced, VAEDecode, CLIPTextEncode, EmptyLatentImage, VAELoader, LoraLoaderModelOnly
from comfy_extras.nodes_model_advanced import ModelSamplingSD3
except ImportError as e:
print("⚠️ Error importing ComfyUI nodes (expected during first build if imports happen too early):", e)
# This might happen if sys.path.append didn't catch up or folder structured differently
# But usually works if we just ran setup.
# --- Global Models ---
class ModelContainer:
def __init__(self):
self.unet = None
self.clip = None
self.vae = None
self.lora = None
self.loaded = False
model_container = ModelContainer()
def load_models():
if model_container.loaded:
return
print("⏳ Loading Models into Memory...")
try:
# Initialize Node Classes
UnetLoaderGGUF = NODE_CLASS_MAPPINGS["UnetLoaderGGUF"]()
CLIPLoaderGGUF = NODE_CLASS_MAPPINGS["CLIPLoaderGGUF"]()
vae_loader = VAELoader()
lora_loader = LoraLoaderModelOnly()
# Load Models
# NOTE: Paths in ComfyUI loaders are relative to the 'models' directory usually,
# but UnetLoaderGGUF might expect just the filename if it scans the directory.
# We need to make sure ComfyUI "knows" about these paths.
# By default ComfyUI scans 'models/unet', 'models/clip' etc.
# We also need to load custom nodes explicitly sometimes
# In headless, we might need to trigger the registration of custom nodes
from nodes import init_custom_nodes
init_custom_nodes()
# Load Unet
# Scan dir to ensure we find it
model_container.unet = UnetLoaderGGUF.load_unet(FILENAME_UNET)[0]
# Load CLIP
model_container.clip = CLIPLoaderGGUF.load_clip(FILENAME_CLIP, "wan")[0]
# Load VAE
model_container.vae = vae_loader.load_vae(FILENAME_VAE)[0]
# Load LoRA (Applying to Model only as per notebook logic)
# Note: notebook logic: lora_loader.load_lora_model_only(unet_model, "FusionX/Wan2.1_T2V_14B_FusionX_LoRA.safetensors", 1.0)[0]
# ComfyUI LoRA loader usually expects relative path from models/loras
lora_rel_path = f"FusionX/{FILENAME_LORA}"
model_container.lora = lora_loader.load_lora_model_only(model_container.unet, lora_rel_path, 1.0)[0]
model_container.loaded = True
print("βœ… All Models Loaded Successfully!")
except Exception as e:
print(f"❌ Error Loading Models: {e}")
import traceback
traceback.print_exc()
# --- Generation Function ---
def generate(prompt, negative_prompt, width, height, steps, cfg, sampler_name, scheduler_name, seed):
if not model_container.loaded:
load_models()
if seed == -1:
seed = random.randint(0, 2**64 - 1)
print(f"🎨 Generating: {width}x{height}, Steps: {steps}, CFG: {cfg}, Seed: {seed}")
try:
# Instantiate Nodes for this run
clip_text_encode = CLIPTextEncode()
empty_latent_image = EmptyLatentImage()
k_sampler_advanced = KSamplerAdvanced()
vae_decode = VAEDecode()
model_sampler_patcher = ModelSamplingSD3()
with torch.inference_mode():
# Encode Prompts
positive_cond = clip_text_encode.encode(model_container.clip, prompt)[0]
negative_cond = clip_text_encode.encode(model_container.clip, negative_prompt)[0]
# Patch Model
# Note: Notebook uses 'lora_model' passed to patcher.
# In our container, 'lora' IS the model with lora applied (returned from load_lora_model_only)
# wait, load_lora_model_only returns (MODEL, CLIP).
# Let's double check the notebook.
# Notebook: lora_model = lora_loader.load_lora_model_only(unet_model, ...)[0] -> This is the unet with lora.
# Then: model_with_sampler = model_sampler_patcher.patch(lora_model, 1.0)[0]
model_with_sampler = model_sampler_patcher.patch(model_container.lora, 1.0)[0]
# Empty Latent
latent_image = empty_latent_image.generate(width, height, 1)[0]
# Sample
samples = k_sampler_advanced.sample(
model=model_with_sampler,
add_noise="enable",
noise_seed=int(seed),
steps=int(steps),
cfg=float(cfg),
sampler_name=sampler_name,
scheduler=scheduler_name,
positive=positive_cond,
negative=negative_cond,
latent_image=latent_image,
start_at_step=0,
end_at_step=9999,
return_with_leftover_noise="disable"
)[0]
# Decode
decoded = vae_decode.decode(model_container.vae, samples)[0]
# Convert to PIL
image_np = decoded.cpu().numpy()
image_np_uint8 = (image_np.clip(0, 1) * 255).astype(np.uint8)
final_image = Image.fromarray(image_np_uint8[0])
return final_image, f"Seed: {seed}"
except Exception as e:
import traceback
traceback.print_exc()
raise gr.Error(f"Generation Failed: {str(e)}")
# --- Interface Options ---
SAMPLERS = [
"euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral",
"lcm", "dpmpp_2s_ancestral", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_3m_sde",
"ddim", "uni_pc", "uni_pc_bh2"
]
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
# --- Gradio UI ---
with gr.Blocks(title="Wan2.1 T2I GGUF", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Wan2.1 Text-to-Image (GGUF)")
gr.Markdown("Generating high-quality images using Wan2.1 14B (Quantized) via ComfyUI backend.")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Positive Prompt", placeholder="A cinematic photo of...", lines=3)
negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality, static, frame, text, watermark, nsfw", lines=2)
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=832)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1216)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=20)
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.5, value=7.5)
with gr.Row():
sampler = gr.Dropdown(label="Sampler", choices=SAMPLERS, value="dpmpp_2m")
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value="karras")
seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random")
generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", type="pil")
output_seed = gr.Label(label="Seed Information")
generate_btn.click(
fn=generate,
inputs=[prompt, negative_prompt, width, height, steps, cfg, sampler, scheduler, seed],
outputs=[output_image, output_seed]
)
# Pre-load models on app startup if desired, or wait for first request
# threading.Thread(target=load_models).start()
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)