Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import subprocess | |
| import time | |
| import random | |
| import asyncio | |
| import threading | |
| import io | |
| import shutil | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| # --- Configuration & Paths --- | |
| ROOT_DIR = os.path.abspath(os.getcwd()) | |
| COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI") | |
| sys.path.append(COMFYUI_DIR) | |
| MODELS_DIR = os.path.join(COMFYUI_DIR, "models") | |
| UNET_DIR = os.path.join(MODELS_DIR, "unet") | |
| CLIP_DIR = os.path.join(MODELS_DIR, "clip") | |
| VAE_DIR = os.path.join(MODELS_DIR, "vae") | |
| LORA_DIR = os.path.join(MODELS_DIR, "loras", "FusionX") | |
| CUSTOM_NODES_DIR = os.path.join(COMFYUI_DIR, "custom_nodes") | |
| GGUF_NODE_DIR = os.path.join(CUSTOM_NODES_DIR, "ComfyUI-GGUF") | |
| # --- Model URLs --- | |
| URL_UNET = "https://huggingface.co/QuantStack/Wan2.2-T2V-A14B-GGUF/resolve/main/LowNoise/Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf" | |
| FILENAME_UNET = "Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf" | |
| URL_CLIP = "https://huggingface.co/city96/umt5-xxl-encoder-gguf/resolve/main/umt5-xxl-encoder-Q3_K_S.gguf" | |
| FILENAME_CLIP = "umt5-xxl-encoder-Q3_K_S.gguf" | |
| URL_VAE = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors" | |
| FILENAME_VAE = "wan_2.1_vae.safetensors" | |
| URL_LORA = "https://huggingface.co/vrgamedevgirl84/Wan14BT2VFusioniX/resolve/main/FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" | |
| FILENAME_LORA = "Wan2.1_T2V_14B_FusionX_LoRA.safetensors" | |
| # --- Setup Functions --- | |
| def run_command(command, desc=None): | |
| if desc: | |
| print(f"β {desc}...") | |
| try: | |
| subprocess.run(command, check=True, shell=True) | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Error during {desc}: {e}") | |
| raise | |
| def robust_download(url, dest_dir, filename): | |
| dest_path = os.path.join(dest_dir, filename) | |
| if os.path.exists(dest_path): | |
| print(f"β {filename} already exists.") | |
| return | |
| print(f"β¬οΈ Downloading {filename}...") | |
| # Method 1: Try aria2c (fastest) | |
| try: | |
| # Check if aria2c is installed | |
| subprocess.run(["aria2c", "--version"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| run_command(f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {url} -d {dest_dir} -o {filename}", f"Downloading {filename} (aria2c)") | |
| return | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| print("β οΈ aria2c not found or failed, falling back to huggingface_hub...") | |
| # Method 2: huggingface_hub (reliable) | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| # Parse Repo ID and Filename from URL | |
| # URL format: https://huggingface.co/USER/REPO/resolve/main/PATH/TO/FILE | |
| parts = url.replace("https://huggingface.co/", "").split("/resolve/main/") | |
| if len(parts) == 2: | |
| repo_id = parts[0] | |
| subfolder = os.path.dirname(parts[1]) if "/" in parts[1] else None | |
| remote_filename = os.path.basename(parts[1]) | |
| # Download | |
| print(f"β³ Downloading via HF Hub: {repo_id}/{remote_filename}") | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=parts[1], # Pass full path as filename argument usually handles directory structure? | |
| # Actually hf_hub_download 'filename' argument is relative path in repo. | |
| local_dir=dest_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| # Rename if necessary/ensure it matches what we expect | |
| # hf_hub_download with local_dir preserves structure usually. | |
| # check where it landed. | |
| expected_path = os.path.join(dest_dir, remote_filename) | |
| # If subfolders are involved, it might be deep. | |
| # Simpler: just move it if name doesn't match | |
| # Re-verification handled by existence check on next run | |
| return | |
| except Exception as e: | |
| print(f"β Fallback download failed: {e}") | |
| # Method 3: Requests (slowest fallback) | |
| import requests | |
| print(f"β οΈ Trying simple requests download...") | |
| with requests.get(url, stream=True) as r: | |
| r.raise_for_status() | |
| with open(dest_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return | |
| def setup_environment(): | |
| print("π Starting Setup Environment...") | |
| # 1. Clone ComfyUI if not exists | |
| if not os.path.exists(COMFYUI_DIR): | |
| run_command(f"git clone https://github.com/comfyanonymous/ComfyUI {COMFYUI_DIR}", "Cloning ComfyUI") | |
| else: | |
| print(f"β ComfyUI found at {COMFYUI_DIR}") | |
| # 2. Clone Custom Node (ComfyUI-GGUF) | |
| if not os.path.exists(GGUF_NODE_DIR): | |
| run_command(f"git clone https://github.com/city96/ComfyUI-GGUF {GGUF_NODE_DIR}", "Cloning ComfyUI-GGUF") | |
| else: | |
| print(f"β ComfyUI-GGUF found at {GGUF_NODE_DIR}") | |
| # 3. Create Directories | |
| for d in [UNET_DIR, CLIP_DIR, VAE_DIR, LORA_DIR]: | |
| os.makedirs(d, exist_ok=True) | |
| # 4. Download Models | |
| download_list = [ | |
| (URL_UNET, UNET_DIR, FILENAME_UNET), | |
| (URL_CLIP, CLIP_DIR, FILENAME_CLIP), | |
| (URL_VAE, VAE_DIR, FILENAME_VAE), | |
| (URL_LORA, LORA_DIR, FILENAME_LORA) | |
| ] | |
| for url, dest_dir, filename in download_list: | |
| robust_download(url, dest_dir, filename) | |
| print("π Environment Setup Complete!") | |
| # Run setup immediately | |
| setup_environment() | |
| # --- ComfyUI Imports --- | |
| # Configure Execution Arguments for ComfyUI | |
| # Aggressively force CPU if CUDA is not available or if we want to ensure no crashes on CPU Spaces | |
| try: | |
| if not torch.cuda.is_available(): | |
| print("β οΈ CUDA not available, forcing CPU mode for ComfyUI...") | |
| # 1. Force environment variable | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # 2. Inject --cpu argument | |
| if "--cpu" not in sys.argv: | |
| sys.argv.append("--cpu") | |
| # 3. Monkeypatch torch.cuda to ensure ComfyUI doesn't try to initialize CUDA | |
| # This is necessary because some ComfyUI versions checks might be aggressive | |
| torch.cuda.is_available = lambda: False | |
| torch.cuda.device_count = lambda: 0 | |
| torch.cuda.current_device = lambda: None | |
| print("β Applied CPU enforcement patches.") | |
| except Exception as e: | |
| print(f"β οΈ Error applying CPU patches: {e}") | |
| # These must happen AFTER setup because ComfyUI folder might not exist before | |
| try: | |
| import nodes | |
| import comfy.samplers | |
| from nodes import NODE_CLASS_MAPPINGS, KSamplerAdvanced, VAEDecode, CLIPTextEncode, EmptyLatentImage, VAELoader, LoraLoaderModelOnly | |
| from comfy_extras.nodes_model_advanced import ModelSamplingSD3 | |
| except ImportError as e: | |
| print("β οΈ Error importing ComfyUI nodes (expected during first build if imports happen too early):", e) | |
| # This might happen if sys.path.append didn't catch up or folder structured differently | |
| # But usually works if we just ran setup. | |
| # --- Global Models --- | |
| class ModelContainer: | |
| def __init__(self): | |
| self.unet = None | |
| self.clip = None | |
| self.vae = None | |
| self.lora = None | |
| self.loaded = False | |
| model_container = ModelContainer() | |
| def load_models(): | |
| if model_container.loaded: | |
| return | |
| print("β³ Loading Models into Memory...") | |
| try: | |
| # Initialize Node Classes | |
| UnetLoaderGGUF = NODE_CLASS_MAPPINGS["UnetLoaderGGUF"]() | |
| CLIPLoaderGGUF = NODE_CLASS_MAPPINGS["CLIPLoaderGGUF"]() | |
| vae_loader = VAELoader() | |
| lora_loader = LoraLoaderModelOnly() | |
| # Load Models | |
| # NOTE: Paths in ComfyUI loaders are relative to the 'models' directory usually, | |
| # but UnetLoaderGGUF might expect just the filename if it scans the directory. | |
| # We need to make sure ComfyUI "knows" about these paths. | |
| # By default ComfyUI scans 'models/unet', 'models/clip' etc. | |
| # We also need to load custom nodes explicitly sometimes | |
| # In headless, we might need to trigger the registration of custom nodes | |
| from nodes import init_custom_nodes | |
| init_custom_nodes() | |
| # Load Unet | |
| # Scan dir to ensure we find it | |
| model_container.unet = UnetLoaderGGUF.load_unet(FILENAME_UNET)[0] | |
| # Load CLIP | |
| model_container.clip = CLIPLoaderGGUF.load_clip(FILENAME_CLIP, "wan")[0] | |
| # Load VAE | |
| model_container.vae = vae_loader.load_vae(FILENAME_VAE)[0] | |
| # Load LoRA (Applying to Model only as per notebook logic) | |
| # Note: notebook logic: lora_loader.load_lora_model_only(unet_model, "FusionX/Wan2.1_T2V_14B_FusionX_LoRA.safetensors", 1.0)[0] | |
| # ComfyUI LoRA loader usually expects relative path from models/loras | |
| lora_rel_path = f"FusionX/{FILENAME_LORA}" | |
| model_container.lora = lora_loader.load_lora_model_only(model_container.unet, lora_rel_path, 1.0)[0] | |
| model_container.loaded = True | |
| print("β All Models Loaded Successfully!") | |
| except Exception as e: | |
| print(f"β Error Loading Models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # --- Generation Function --- | |
| def generate(prompt, negative_prompt, width, height, steps, cfg, sampler_name, scheduler_name, seed): | |
| if not model_container.loaded: | |
| load_models() | |
| if seed == -1: | |
| seed = random.randint(0, 2**64 - 1) | |
| print(f"π¨ Generating: {width}x{height}, Steps: {steps}, CFG: {cfg}, Seed: {seed}") | |
| try: | |
| # Instantiate Nodes for this run | |
| clip_text_encode = CLIPTextEncode() | |
| empty_latent_image = EmptyLatentImage() | |
| k_sampler_advanced = KSamplerAdvanced() | |
| vae_decode = VAEDecode() | |
| model_sampler_patcher = ModelSamplingSD3() | |
| with torch.inference_mode(): | |
| # Encode Prompts | |
| positive_cond = clip_text_encode.encode(model_container.clip, prompt)[0] | |
| negative_cond = clip_text_encode.encode(model_container.clip, negative_prompt)[0] | |
| # Patch Model | |
| # Note: Notebook uses 'lora_model' passed to patcher. | |
| # In our container, 'lora' IS the model with lora applied (returned from load_lora_model_only) | |
| # wait, load_lora_model_only returns (MODEL, CLIP). | |
| # Let's double check the notebook. | |
| # Notebook: lora_model = lora_loader.load_lora_model_only(unet_model, ...)[0] -> This is the unet with lora. | |
| # Then: model_with_sampler = model_sampler_patcher.patch(lora_model, 1.0)[0] | |
| model_with_sampler = model_sampler_patcher.patch(model_container.lora, 1.0)[0] | |
| # Empty Latent | |
| latent_image = empty_latent_image.generate(width, height, 1)[0] | |
| # Sample | |
| samples = k_sampler_advanced.sample( | |
| model=model_with_sampler, | |
| add_noise="enable", | |
| noise_seed=int(seed), | |
| steps=int(steps), | |
| cfg=float(cfg), | |
| sampler_name=sampler_name, | |
| scheduler=scheduler_name, | |
| positive=positive_cond, | |
| negative=negative_cond, | |
| latent_image=latent_image, | |
| start_at_step=0, | |
| end_at_step=9999, | |
| return_with_leftover_noise="disable" | |
| )[0] | |
| # Decode | |
| decoded = vae_decode.decode(model_container.vae, samples)[0] | |
| # Convert to PIL | |
| image_np = decoded.cpu().numpy() | |
| image_np_uint8 = (image_np.clip(0, 1) * 255).astype(np.uint8) | |
| final_image = Image.fromarray(image_np_uint8[0]) | |
| return final_image, f"Seed: {seed}" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Generation Failed: {str(e)}") | |
| # --- Interface Options --- | |
| SAMPLERS = [ | |
| "euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral", | |
| "lcm", "dpmpp_2s_ancestral", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_3m_sde", | |
| "ddim", "uni_pc", "uni_pc_bh2" | |
| ] | |
| SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="Wan2.1 T2I GGUF", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¨ Wan2.1 Text-to-Image (GGUF)") | |
| gr.Markdown("Generating high-quality images using Wan2.1 14B (Quantized) via ComfyUI backend.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Positive Prompt", placeholder="A cinematic photo of...", lines=3) | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality, static, frame, text, watermark, nsfw", lines=2) | |
| with gr.Accordion("Advanced Settings", open=True): | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=832) | |
| height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1216) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=20) | |
| cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.5, value=7.5) | |
| with gr.Row(): | |
| sampler = gr.Dropdown(label="Sampler", choices=SAMPLERS, value="dpmpp_2m") | |
| scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value="karras") | |
| seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random") | |
| generate_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| output_seed = gr.Label(label="Seed Information") | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=[prompt, negative_prompt, width, height, steps, cfg, sampler, scheduler, seed], | |
| outputs=[output_image, output_seed] | |
| ) | |
| # Pre-load models on app startup if desired, or wait for first request | |
| # threading.Thread(target=load_models).start() | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |