alphabagibagi commited on
Commit
29a9839
Β·
verified Β·
1 Parent(s): fdfc588

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import time
5
+ import random
6
+ import asyncio
7
+ import threading
8
+ import io
9
+ import shutil
10
+ import numpy as np
11
+ from PIL import Image
12
+ import gradio as gr
13
+ import torch
14
+
15
+ # --- Configuration & Paths ---
16
+ ROOT_DIR = os.path.abspath(os.getcwd())
17
+ COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI")
18
+ sys.path.append(COMFYUI_DIR)
19
+
20
+ MODELS_DIR = os.path.join(COMFYUI_DIR, "models")
21
+ UNET_DIR = os.path.join(MODELS_DIR, "unet")
22
+ CLIP_DIR = os.path.join(MODELS_DIR, "clip")
23
+ VAE_DIR = os.path.join(MODELS_DIR, "vae")
24
+ LORA_DIR = os.path.join(MODELS_DIR, "loras", "FusionX")
25
+ CUSTOM_NODES_DIR = os.path.join(COMFYUI_DIR, "custom_nodes")
26
+ GGUF_NODE_DIR = os.path.join(CUSTOM_NODES_DIR, "ComfyUI-GGUF")
27
+
28
+ # --- Model URLs ---
29
+ URL_UNET = "https://huggingface.co/QuantStack/Wan2.2-T2V-A14B-GGUF/resolve/main/LowNoise/Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf"
30
+ FILENAME_UNET = "Wan2.2-T2V-A14B-LowNoise-Q3_K_S.gguf"
31
+
32
+ URL_CLIP = "https://huggingface.co/city96/umt5-xxl-encoder-gguf/resolve/main/umt5-xxl-encoder-Q3_K_S.gguf"
33
+ FILENAME_CLIP = "umt5-xxl-encoder-Q3_K_S.gguf"
34
+
35
+ URL_VAE = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors"
36
+ FILENAME_VAE = "wan_2.1_vae.safetensors"
37
+
38
+ URL_LORA = "https://huggingface.co/vrgamedevgirl84/Wan14BT2VFusioniX/resolve/main/FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
39
+ FILENAME_LORA = "Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
40
+
41
+ # --- Setup Functions ---
42
+ def run_command(command, desc=None):
43
+ if desc:
44
+ print(f"➜ {desc}...")
45
+ try:
46
+ subprocess.run(command, check=True, shell=True)
47
+ except subprocess.CalledProcessError as e:
48
+ print(f"❌ Error during {desc}: {e}")
49
+ raise
50
+
51
+ def setup_environment():
52
+ print("πŸš€ Starting Setup Environment...")
53
+
54
+ # 1. Clone ComfyUI if not exists
55
+ if not os.path.exists(COMFYUI_DIR):
56
+ run_command(f"git clone https://github.com/comfyanonymous/ComfyUI {COMFYUI_DIR}", "Cloning ComfyUI")
57
+ else:
58
+ print(f"βœ… ComfyUI found at {COMFYUI_DIR}")
59
+
60
+ # 2. Clone Custom Node (ComfyUI-GGUF)
61
+ if not os.path.exists(GGUF_NODE_DIR):
62
+ run_command(f"git clone https://github.com/city96/ComfyUI-GGUF {GGUF_NODE_DIR}", "Cloning ComfyUI-GGUF")
63
+ else:
64
+ print(f"βœ… ComfyUI-GGUF found at {GGUF_NODE_DIR}")
65
+
66
+ # 3. Create Directories
67
+ for d in [UNET_DIR, CLIP_DIR, VAE_DIR, LORA_DIR]:
68
+ os.makedirs(d, exist_ok=True)
69
+
70
+ # 4. Download Models
71
+ download_list = [
72
+ (URL_UNET, UNET_DIR, FILENAME_UNET),
73
+ (URL_CLIP, CLIP_DIR, FILENAME_CLIP),
74
+ (URL_VAE, VAE_DIR, FILENAME_VAE),
75
+ (URL_LORA, LORA_DIR, FILENAME_LORA)
76
+ ]
77
+
78
+ for url, dest_dir, filename in download_list:
79
+ dest_path = os.path.join(dest_dir, filename)
80
+ if not os.path.exists(dest_path):
81
+ print(f"⬇️ Downloading {filename}...")
82
+ # Use aria2c if available, else wget/curl, or fallback to python
83
+ # Since we installed aria2 in Dockerfile, try that first
84
+ try:
85
+ run_command(f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {url} -d {dest_dir} -o {filename}", f"Downloading {filename}")
86
+ except:
87
+ print("⚠️ Aria2 failed, falling back to basic download methods could be added here if needed.")
88
+ # Basic fallback using huggingface_hub or wget could go here
89
+ from huggingface_hub import hf_hub_download
90
+ # This is a bit complex since URLs are direct, but for now assuming aria2/git works or manual download
91
+ else:
92
+ print(f"βœ… {filename} already exists.")
93
+
94
+ print("πŸŽ‰ Environment Setup Complete!")
95
+
96
+ # Run setup immediately
97
+ setup_environment()
98
+
99
+ # --- ComfyUI Imports ---
100
+ # These must happen AFTER setup because ComfyUI folder might not exist before
101
+ try:
102
+ import nodes
103
+ import comfy.samplers
104
+ from nodes import NODE_CLASS_MAPPINGS, KSamplerAdvanced, VAEDecode, CLIPTextEncode, EmptyLatentImage, VAELoader, LoraLoaderModelOnly
105
+ from comfy_extras.nodes_model_advanced import ModelSamplingSD3
106
+ except ImportError as e:
107
+ print("⚠️ Error importing ComfyUI nodes (expected during first build if imports happen too early):", e)
108
+ # This might happen if sys.path.append didn't catch up or folder structured differently
109
+ # But usually works if we just ran setup.
110
+
111
+ # --- Global Models ---
112
+ class ModelContainer:
113
+ def __init__(self):
114
+ self.unet = None
115
+ self.clip = None
116
+ self.vae = None
117
+ self.lora = None
118
+ self.loaded = False
119
+
120
+ model_container = ModelContainer()
121
+
122
+ def load_models():
123
+ if model_container.loaded:
124
+ return
125
+
126
+ print("⏳ Loading Models into Memory...")
127
+ try:
128
+ # Initialize Node Classes
129
+ UnetLoaderGGUF = NODE_CLASS_MAPPINGS["UnetLoaderGGUF"]()
130
+ CLIPLoaderGGUF = NODE_CLASS_MAPPINGS["CLIPLoaderGGUF"]()
131
+ vae_loader = VAELoader()
132
+ lora_loader = LoraLoaderModelOnly()
133
+
134
+ # Load Models
135
+ # NOTE: Paths in ComfyUI loaders are relative to the 'models' directory usually,
136
+ # but UnetLoaderGGUF might expect just the filename if it scans the directory.
137
+ # We need to make sure ComfyUI "knows" about these paths.
138
+ # By default ComfyUI scans 'models/unet', 'models/clip' etc.
139
+
140
+ # We also need to load custom nodes explicitly sometimes
141
+ # In headless, we might need to trigger the registration of custom nodes
142
+ from nodes import init_custom_nodes
143
+ init_custom_nodes()
144
+
145
+ # Load Unet
146
+ # Scan dir to ensure we find it
147
+ model_container.unet = UnetLoaderGGUF.load_unet(FILENAME_UNET)[0]
148
+
149
+ # Load CLIP
150
+ model_container.clip = CLIPLoaderGGUF.load_clip(FILENAME_CLIP, "wan")[0]
151
+
152
+ # Load VAE
153
+ model_container.vae = vae_loader.load_vae(FILENAME_VAE)[0]
154
+
155
+ # Load LoRA (Applying to Model only as per notebook logic)
156
+ # Note: notebook logic: lora_loader.load_lora_model_only(unet_model, "FusionX/Wan2.1_T2V_14B_FusionX_LoRA.safetensors", 1.0)[0]
157
+ # ComfyUI LoRA loader usually expects relative path from models/loras
158
+ lora_rel_path = f"FusionX/{FILENAME_LORA}"
159
+ model_container.lora = lora_loader.load_lora_model_only(model_container.unet, lora_rel_path, 1.0)[0]
160
+
161
+ model_container.loaded = True
162
+ print("βœ… All Models Loaded Successfully!")
163
+
164
+ except Exception as e:
165
+ print(f"❌ Error Loading Models: {e}")
166
+ import traceback
167
+ traceback.print_exc()
168
+
169
+ # --- Generation Function ---
170
+ def generate(prompt, negative_prompt, width, height, steps, cfg, sampler_name, scheduler_name, seed):
171
+ if not model_container.loaded:
172
+ load_models()
173
+
174
+ if seed == -1:
175
+ seed = random.randint(0, 2**64 - 1)
176
+
177
+ print(f"🎨 Generating: {width}x{height}, Steps: {steps}, CFG: {cfg}, Seed: {seed}")
178
+
179
+ try:
180
+ # Instantiate Nodes for this run
181
+ clip_text_encode = CLIPTextEncode()
182
+ empty_latent_image = EmptyLatentImage()
183
+ k_sampler_advanced = KSamplerAdvanced()
184
+ vae_decode = VAEDecode()
185
+ model_sampler_patcher = ModelSamplingSD3()
186
+
187
+ with torch.inference_mode():
188
+ # Encode Prompts
189
+ positive_cond = clip_text_encode.encode(model_container.clip, prompt)[0]
190
+ negative_cond = clip_text_encode.encode(model_container.clip, negative_prompt)[0]
191
+
192
+ # Patch Model
193
+ # Note: Notebook uses 'lora_model' passed to patcher.
194
+ # In our container, 'lora' IS the model with lora applied (returned from load_lora_model_only)
195
+ # wait, load_lora_model_only returns (MODEL, CLIP).
196
+ # Let's double check the notebook.
197
+ # Notebook: lora_model = lora_loader.load_lora_model_only(unet_model, ...)[0] -> This is the unet with lora.
198
+ # Then: model_with_sampler = model_sampler_patcher.patch(lora_model, 1.0)[0]
199
+ model_with_sampler = model_sampler_patcher.patch(model_container.lora, 1.0)[0]
200
+
201
+ # Empty Latent
202
+ latent_image = empty_latent_image.generate(width, height, 1)[0]
203
+
204
+ # Sample
205
+ samples = k_sampler_advanced.sample(
206
+ model=model_with_sampler,
207
+ add_noise="enable",
208
+ noise_seed=int(seed),
209
+ steps=int(steps),
210
+ cfg=float(cfg),
211
+ sampler_name=sampler_name,
212
+ scheduler=scheduler_name,
213
+ positive=positive_cond,
214
+ negative=negative_cond,
215
+ latent_image=latent_image,
216
+ start_at_step=0,
217
+ end_at_step=9999,
218
+ return_with_leftover_noise="disable"
219
+ )[0]
220
+
221
+ # Decode
222
+ decoded = vae_decode.decode(model_container.vae, samples)[0]
223
+
224
+ # Convert to PIL
225
+ image_np = decoded.cpu().numpy()
226
+ image_np_uint8 = (image_np.clip(0, 1) * 255).astype(np.uint8)
227
+ final_image = Image.fromarray(image_np_uint8[0])
228
+
229
+ return final_image, f"Seed: {seed}"
230
+
231
+ except Exception as e:
232
+ import traceback
233
+ traceback.print_exc()
234
+ raise gr.Error(f"Generation Failed: {str(e)}")
235
+
236
+ # --- Interface Options ---
237
+ SAMPLERS = [
238
+ "euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral",
239
+ "lcm", "dpmpp_2s_ancestral", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_3m_sde",
240
+ "ddim", "uni_pc", "uni_pc_bh2"
241
+ ]
242
+ SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
243
+
244
+ # --- Gradio UI ---
245
+ with gr.Blocks(title="Wan2.1 T2I GGUF", theme=gr.themes.Soft()) as demo:
246
+ gr.Markdown("# 🎨 Wan2.1 Text-to-Image (GGUF)")
247
+ gr.Markdown("Generating high-quality images using Wan2.1 14B (Quantized) via ComfyUI backend.")
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=1):
251
+ prompt = gr.Textbox(label="Positive Prompt", placeholder="A cinematic photo of...", lines=3)
252
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality, static, frame, text, watermark, nsfw", lines=2)
253
+
254
+ with gr.Accordion("Advanced Settings", open=True):
255
+ with gr.Row():
256
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=832)
257
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1216)
258
+
259
+ with gr.Row():
260
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=20)
261
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.5, value=7.5)
262
+
263
+ with gr.Row():
264
+ sampler = gr.Dropdown(label="Sampler", choices=SAMPLERS, value="dpmpp_2m")
265
+ scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value="karras")
266
+
267
+ seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random")
268
+
269
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
270
+
271
+ with gr.Column(scale=1):
272
+ output_image = gr.Image(label="Generated Image", type="pil")
273
+ output_seed = gr.Label(label="Seed Information")
274
+
275
+ generate_btn.click(
276
+ fn=generate,
277
+ inputs=[prompt, negative_prompt, width, height, steps, cfg, sampler, scheduler, seed],
278
+ outputs=[output_image, output_seed]
279
+ )
280
+
281
+ # Pre-load models on app startup if desired, or wait for first request
282
+ # threading.Thread(target=load_models).start()
283
+
284
+ if __name__ == "__main__":
285
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)