RioShiina commited on
Commit
c009d4f
·
verified ·
1 Parent(s): 090484b

Upload folder using huggingface_hub

Browse files
Files changed (46) hide show
  1. README.md +10 -13
  2. app.py +223 -0
  3. chain_injectors/__init__.py +0 -0
  4. chain_injectors/conditioning_injector.py +81 -0
  5. chain_injectors/reference_latent_injector.py +66 -0
  6. comfy_integration/__init__.py +0 -0
  7. comfy_integration/nodes.py +39 -0
  8. comfy_integration/setup.py +67 -0
  9. core/__init__.py +0 -0
  10. core/generation_logic.py +25 -0
  11. core/model_manager.py +168 -0
  12. core/pipelines/__init__.py +0 -0
  13. core/pipelines/base_pipeline.py +53 -0
  14. core/pipelines/controlnet_preprocessor.py +143 -0
  15. core/pipelines/sd_image_pipeline.py +389 -0
  16. core/pipelines/workflow_recipes/_partials/_base_sampler.yaml +23 -0
  17. core/pipelines/workflow_recipes/_partials/conditioning/flux2.yaml +56 -0
  18. core/pipelines/workflow_recipes/_partials/input/hires_fix.yaml +26 -0
  19. core/pipelines/workflow_recipes/_partials/input/img2img.yaml +19 -0
  20. core/pipelines/workflow_recipes/_partials/input/inpaint.yaml +25 -0
  21. core/pipelines/workflow_recipes/_partials/input/outpaint.yaml +38 -0
  22. core/pipelines/workflow_recipes/_partials/input/txt2img.yaml +8 -0
  23. core/pipelines/workflow_recipes/sd_unified_recipe.yaml +8 -0
  24. core/settings.py +125 -0
  25. core/shared_state.py +1 -0
  26. core/workflow_assembler.py +179 -0
  27. requirements.txt +58 -0
  28. scripts/__init__.py +0 -0
  29. scripts/build_sage_attention.py +99 -0
  30. ui/__init__.py +0 -0
  31. ui/events.py +500 -0
  32. ui/layout.py +92 -0
  33. ui/shared/hires_fix_ui.py +76 -0
  34. ui/shared/img2img_ui.py +59 -0
  35. ui/shared/inpaint_ui.py +83 -0
  36. ui/shared/outpaint_ui.py +70 -0
  37. ui/shared/txt2img_ui.py +40 -0
  38. ui/shared/ui_components.py +236 -0
  39. utils/__init__.py +0 -0
  40. utils/app_utils.py +462 -0
  41. yaml/constants.yaml +16 -0
  42. yaml/file_list.yaml +50 -0
  43. yaml/injectors.yaml +9 -0
  44. yaml/model_defaults.yaml +18 -0
  45. yaml/model_list.yaml +26 -0
  46. yaml/private_file_list.yaml +12 -0
README.md CHANGED
@@ -1,13 +1,10 @@
1
- ---
2
- title: ImageGen FLUX.2
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.5.0
8
- python_version: '3.12'
9
- app_file: app.py
10
- pinned: false
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: ImageGen - FLUX.2
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: "5.50.0"
8
+ app_file: app.py
9
+ short_description: Multi-task image generator with dynamic, chainable workflows
10
+ ---
 
 
 
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import sys
4
+ import requests
5
+ import site
6
+
7
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ if APP_DIR not in sys.path:
9
+ sys.path.insert(0, APP_DIR)
10
+ print(f"✅ Added project root '{APP_DIR}' to sys.path.")
11
+
12
+ SAGE_PATCH_APPLIED = False
13
+
14
+ def apply_sage_attention_patch():
15
+ global SAGE_PATCH_APPLIED
16
+ if SAGE_PATCH_APPLIED:
17
+ return "SageAttention patch already applied."
18
+
19
+ try:
20
+ from comfy import model_management
21
+ import sageattention
22
+
23
+ print("--- [Runtime Patch] sageattention package found. Applying patch... ---")
24
+ model_management.sage_attention_enabled = lambda: True
25
+ model_management.pytorch_attention_enabled = lambda: False
26
+
27
+ SAGE_PATCH_APPLIED = True
28
+ return "✅ Successfully enabled SageAttention."
29
+ except ImportError:
30
+ SAGE_PATCH_APPLIED = False
31
+ msg = "--- [Runtime Patch] ⚠️ sageattention package not found. Continuing with default attention. ---"
32
+ print(msg)
33
+ return msg
34
+ except Exception as e:
35
+ SAGE_PATCH_APPLIED = False
36
+ msg = f"--- [Runtime Patch] ❌ An error occurred while applying SageAttention patch: {e} ---"
37
+ print(msg)
38
+ return msg
39
+
40
+ @spaces.GPU
41
+ def dummy_gpu_for_startup():
42
+ print("--- [GPU Startup] Dummy function for startup check initiated. ---")
43
+ patch_result = apply_sage_attention_patch()
44
+ print(f"--- [GPU Startup] {patch_result} ---")
45
+ print("--- [GPU Startup] Startup check passed. ---")
46
+ return "Startup check passed."
47
+
48
+ def handle_private_downloads():
49
+ """
50
+ Checks for a private_file_list.yaml, downloads required models using HF_TOKEN,
51
+ and then clears the token from the environment.
52
+ """
53
+ import yaml
54
+ from huggingface_hub import hf_hub_download
55
+ from core.settings import (
56
+ DIFFUSION_MODELS_DIR, TEXT_ENCODERS_DIR, VAE_DIR, CHECKPOINT_DIR,
57
+ LORA_DIR, CONTROLNET_DIR, MODEL_PATCHES_DIR, EMBEDDING_DIR
58
+ )
59
+
60
+ print("--- [Startup] Checking for private models to download... ---")
61
+ private_list_path = os.path.join(APP_DIR, 'yaml', 'private_file_list.yaml')
62
+
63
+ if not os.path.exists(private_list_path):
64
+ print("--- [Startup] No private model list found. Skipping. ---")
65
+ if 'HF_TOKEN' in os.environ:
66
+ del os.environ['HF_TOKEN']
67
+ print("--- [Startup] Cleared HF_TOKEN environment variable as it is no longer needed. ---")
68
+ print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
69
+ return
70
+
71
+ try:
72
+ with open(private_list_path, 'r', encoding='utf-8') as f:
73
+ private_files_config = yaml.safe_load(f)
74
+
75
+ if not private_files_config or 'file' not in private_files_config:
76
+ print("--- [Startup] Private model list is empty or malformed. Skipping. ---")
77
+ return
78
+
79
+ category_to_dir_map = {
80
+ "diffusion_models": DIFFUSION_MODELS_DIR,
81
+ "text_encoders": TEXT_ENCODERS_DIR,
82
+ "vae": VAE_DIR,
83
+ "checkpoints": CHECKPOINT_DIR,
84
+ "loras": LORA_DIR,
85
+ "controlnet": CONTROLNET_DIR,
86
+ "model_patches": MODEL_PATCHES_DIR,
87
+ "embeddings": EMBEDDING_DIR,
88
+ }
89
+
90
+ files_to_download = []
91
+ for category, files in private_files_config.get('file', {}).items():
92
+ dest_dir = category_to_dir_map.get(category)
93
+ if not dest_dir:
94
+ print(f"--- [Startup] ⚠️ Unknown category '{category}' in private_file_list.yaml. Skipping. ---")
95
+ continue
96
+
97
+ if isinstance(files, list):
98
+ for file_info in files:
99
+ files_to_download.append((file_info, dest_dir))
100
+
101
+ if not files_to_download:
102
+ print("--- [Startup] No private models configured for download. ---")
103
+ return
104
+
105
+ print(f"--- [Startup] Found {len(files_to_download)} private model(s) to download. Using HF_TOKEN if available. ---")
106
+
107
+ for file_info, dest_dir in files_to_download:
108
+ filename = file_info.get("filename")
109
+ repo_id = file_info.get("repo_id")
110
+ repo_path = file_info.get("repository_file_path", filename)
111
+
112
+ if not all([filename, repo_id]):
113
+ print(f"--- [Startup] ⚠️ Skipping malformed entry in private_file_list.yaml: {file_info} ---")
114
+ continue
115
+
116
+ dest_path = os.path.join(dest_dir, filename)
117
+ if os.path.lexists(dest_path):
118
+ print(f"--- [Startup] ✅ Model '{filename}' already exists. Skipping download. ---")
119
+ continue
120
+
121
+ print(f"--- [Startup] ⏳ Downloading '{filename}' from repo '{repo_id}'... ---")
122
+ try:
123
+ cached_path = hf_hub_download(repo_id=repo_id, filename=repo_path)
124
+ os.makedirs(dest_dir, exist_ok=True)
125
+ os.symlink(cached_path, dest_path)
126
+ print(f"--- [Startup] ✅ Successfully downloaded and linked '{filename}'. ---")
127
+ except Exception as e:
128
+ print(f"--- [Startup] ❌ ERROR: Failed to download '{filename}': {e}")
129
+ print("--- [Startup] ❌ Please ensure your HF_TOKEN is set correctly and has access to the repository. ---")
130
+
131
+ finally:
132
+ if 'HF_TOKEN' in os.environ:
133
+ del os.environ['HF_TOKEN']
134
+ print("--- [Startup] ✅ Cleared HF_TOKEN environment variable. ---")
135
+ print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
136
+ else:
137
+ print("--- [Startup] Note: HF_TOKEN environment variable was not set. Private downloads may fail without it. ---")
138
+
139
+ def main():
140
+ from utils.app_utils import print_welcome_message
141
+ from scripts import build_sage_attention
142
+
143
+ print_welcome_message()
144
+
145
+ # Handle downloads that require authentication first.
146
+ handle_private_downloads()
147
+
148
+ print("--- [Setup] Attempting to build and install SageAttention... ---")
149
+ try:
150
+ build_sage_attention.install_sage_attention()
151
+ print("--- [Setup] ✅ SageAttention installation process finished. ---")
152
+ except Exception as e:
153
+ print(f"--- [Setup] ❌ SageAttention installation failed: {e}. Continuing with default attention. ---")
154
+
155
+
156
+ print("--- [Setup] Reloading site-packages to detect newly installed packages... ---")
157
+ try:
158
+ site.main()
159
+ print("--- [Setup] ✅ Site-packages reloaded. ---")
160
+ except Exception as e:
161
+ print(f"--- [Setup] ⚠️ Warning: Could not fully reload site-packages: {e} ---")
162
+
163
+ from comfy_integration import setup as setup_comfyui
164
+ from utils.app_utils import (
165
+ build_preprocessor_model_map,
166
+ build_preprocessor_parameter_map
167
+ )
168
+ from core import shared_state
169
+ from core.settings import ALL_MODEL_MAP, ALL_FILE_DOWNLOAD_MAP
170
+
171
+ def check_all_model_urls_on_startup():
172
+ print("--- [Setup] Checking all model URL validity (one-time check) ---")
173
+ for display_name, model_info in ALL_MODEL_MAP.items():
174
+ _, components, _, _ = model_info
175
+ if not components: continue
176
+
177
+ for filename in components.values():
178
+ download_info = ALL_FILE_DOWNLOAD_MAP.get(filename, {})
179
+ repo_id = download_info.get('repo_id')
180
+ if not repo_id: continue
181
+
182
+ repo_file_path = download_info.get('repository_file_path', filename)
183
+ url = f"https://huggingface.co/{repo_id}/resolve/main/{repo_file_path}"
184
+
185
+ try:
186
+ response = requests.head(url, timeout=5, allow_redirects=True)
187
+ if response.status_code >= 400:
188
+ print(f"❌ Invalid URL for '{display_name}' component '{filename}': {url} (Status: {response.status_code})")
189
+ shared_state.INVALID_MODEL_URLS[display_name] = True
190
+ break
191
+ except requests.RequestException as e:
192
+ print(f"❌ URL check failed for '{display_name}' component '{filename}': {e}")
193
+ shared_state.INVALID_MODEL_URLS[display_name] = True
194
+ break
195
+ print("--- [Setup] ✅ Finished checking model URLs. ---")
196
+
197
+ print("--- Starting Application Setup ---")
198
+
199
+ setup_comfyui.initialize_comfyui()
200
+
201
+ check_all_model_urls_on_startup()
202
+
203
+ print("--- Building ControlNet preprocessor maps ---")
204
+ from core.generation_logic import build_reverse_map
205
+ build_reverse_map()
206
+ build_preprocessor_model_map()
207
+ build_preprocessor_parameter_map()
208
+ print("--- ✅ ControlNet preprocessor setup complete. ---")
209
+
210
+ print("--- Environment configured. Proceeding with module imports. ---")
211
+ from ui.layout import build_ui
212
+ from ui.events import attach_event_handlers
213
+
214
+ print(f"✅ Working directory is stable: {os.getcwd()}")
215
+
216
+ demo = build_ui(attach_event_handlers)
217
+
218
+ print("--- Launching Gradio Interface ---")
219
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()
chain_injectors/__init__.py ADDED
File without changes
chain_injectors/conditioning_injector.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def inject(assembler, chain_definition, chain_items):
2
+ if not chain_items:
3
+ return
4
+
5
+ ksampler_name = chain_definition.get('ksampler_node', 'ksampler')
6
+
7
+ target_node_id = None
8
+ target_input_name = None
9
+
10
+ if ksampler_name in assembler.node_map:
11
+ ksampler_id = assembler.node_map[ksampler_name]
12
+ if 'positive' in assembler.workflow[ksampler_id]['inputs']:
13
+ target_node_id = ksampler_id
14
+ target_input_name = 'positive'
15
+ print(f"Conditioning injector targeting KSampler node '{ksampler_name}'.")
16
+ else:
17
+ print(f"Warning: KSampler node '{ksampler_name}' for Conditioning chain not found. Skipping.")
18
+ return
19
+
20
+ if not target_node_id:
21
+ print("Warning: Conditioning chain could not find a valid injection point (KSampler may be missing 'positive' input). Skipping.")
22
+ return
23
+
24
+ clip_source_str = chain_definition.get('clip_source')
25
+ if not clip_source_str:
26
+ print("Warning: 'clip_source' definition missing in the recipe for the Conditioning chain. Skipping.")
27
+ return
28
+ clip_node_name, clip_idx_str = clip_source_str.split(':')
29
+ if clip_node_name not in assembler.node_map:
30
+ print(f"Warning: CLIP source node '{clip_node_name}' for Conditioning chain not found. Skipping.")
31
+ return
32
+ clip_connection = [assembler.node_map[clip_node_name], int(clip_idx_str)]
33
+
34
+ original_positive_connection = assembler.workflow[target_node_id]['inputs'][target_input_name]
35
+
36
+ area_conditioning_outputs = []
37
+
38
+ for item_data in chain_items:
39
+ prompt = item_data.get('prompt', '')
40
+ if not prompt or not prompt.strip():
41
+ continue
42
+
43
+ text_encode_id = assembler._get_unique_id()
44
+ text_encode_node = assembler._get_node_template("CLIPTextEncode")
45
+ text_encode_node['inputs']['text'] = prompt
46
+ text_encode_node['inputs']['clip'] = clip_connection
47
+ assembler.workflow[text_encode_id] = text_encode_node
48
+
49
+ set_area_id = assembler._get_unique_id()
50
+ set_area_node = assembler._get_node_template("ConditioningSetArea")
51
+ set_area_node['inputs']['width'] = item_data.get('width', 1024)
52
+ set_area_node['inputs']['height'] = item_data.get('height', 1024)
53
+ set_area_node['inputs']['x'] = item_data.get('x', 0)
54
+ set_area_node['inputs']['y'] = item_data.get('y', 0)
55
+ set_area_node['inputs']['strength'] = item_data.get('strength', 1.0)
56
+ set_area_node['inputs']['conditioning'] = [text_encode_id, 0]
57
+ assembler.workflow[set_area_id] = set_area_node
58
+
59
+ area_conditioning_outputs.append([set_area_id, 0])
60
+
61
+ if not area_conditioning_outputs:
62
+ return
63
+
64
+ current_combined_conditioning = area_conditioning_outputs[0]
65
+ if len(area_conditioning_outputs) > 1:
66
+ for i in range(1, len(area_conditioning_outputs)):
67
+ combine_id = assembler._get_unique_id()
68
+ combine_node = assembler._get_node_template("ConditioningCombine")
69
+ combine_node['inputs']['conditioning_1'] = current_combined_conditioning
70
+ combine_node['inputs']['conditioning_2'] = area_conditioning_outputs[i]
71
+ assembler.workflow[combine_id] = combine_node
72
+ current_combined_conditioning = [combine_id, 0]
73
+
74
+ final_combine_id = assembler._get_unique_id()
75
+ final_combine_node = assembler._get_node_template("ConditioningCombine")
76
+ final_combine_node['inputs']['conditioning_1'] = original_positive_connection
77
+ final_combine_node['inputs']['conditioning_2'] = current_combined_conditioning
78
+ assembler.workflow[final_combine_id] = final_combine_node
79
+
80
+ assembler.workflow[target_node_id]['inputs'][target_input_name] = [final_combine_id, 0]
81
+ print(f"Conditioning injector applied. Redirected '{target_input_name}' input with {len(area_conditioning_outputs)} regional prompts.")
chain_injectors/reference_latent_injector.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def inject(assembler, chain_definition, chain_items):
2
+ if not chain_items:
3
+ return
4
+
5
+ ksampler_name = chain_definition.get('ksampler_node', 'ksampler')
6
+ flux_guidance_name = chain_definition.get('flux_guidance_node')
7
+ vae_node_name = chain_definition.get('vae_node', 'vae_loader')
8
+
9
+ if ksampler_name not in assembler.node_map:
10
+ print(f"Warning: [ReferenceLatent] KSampler node '{ksampler_name}' not found. Skipping.")
11
+ return
12
+ if vae_node_name not in assembler.node_map:
13
+ print(f"Warning: [ReferenceLatent] VAE loader node '{vae_node_name}' not found. Skipping.")
14
+ return
15
+
16
+ ksampler_id = assembler.node_map[ksampler_name]
17
+ vae_node_id = assembler.node_map[vae_node_name]
18
+
19
+ pos_target_node_id = None
20
+ pos_target_input_name = None
21
+ if flux_guidance_name and flux_guidance_name in assembler.node_map:
22
+ flux_guidance_id = assembler.node_map[flux_guidance_name]
23
+ if 'conditioning' in assembler.workflow[flux_guidance_id]['inputs']:
24
+ pos_target_node_id = flux_guidance_id
25
+ pos_target_input_name = 'conditioning'
26
+ print(f"ReferenceLatent injector targeting FluxGuidance node '{flux_guidance_name}'.")
27
+
28
+ if not pos_target_node_id:
29
+ if 'positive' in assembler.workflow[ksampler_id]['inputs']:
30
+ pos_target_node_id = ksampler_id
31
+ pos_target_input_name = 'positive'
32
+ print(f"ReferenceLatent injector targeting KSampler node '{ksampler_name}'.")
33
+ else:
34
+ print(f"Warning: [ReferenceLatent] Could not find a valid positive injection point. Skipping.")
35
+ return
36
+
37
+ current_pos_conditioning = assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name]
38
+
39
+ for i, img_filename in enumerate(chain_items):
40
+ if not img_filename or not isinstance(img_filename, str):
41
+ continue
42
+
43
+ load_id = assembler._get_unique_id()
44
+ load_node = assembler._get_node_template("LoadImage")
45
+ load_node['inputs']['image'] = img_filename
46
+ assembler.workflow[load_id] = load_node
47
+
48
+ vae_encode_id = assembler._get_unique_id()
49
+ vae_encode_node = assembler._get_node_template("VAEEncode")
50
+ vae_encode_node['inputs']['pixels'] = [load_id, 0]
51
+ vae_encode_node['inputs']['vae'] = [vae_node_id, 0]
52
+ assembler.workflow[vae_encode_id] = vae_encode_node
53
+
54
+ latent_conn = [vae_encode_id, 0]
55
+
56
+ ref_latent_id = assembler._get_unique_id()
57
+ ref_latent_node = assembler._get_node_template("ReferenceLatent")
58
+ ref_latent_node['inputs']['conditioning'] = current_pos_conditioning
59
+ ref_latent_node['inputs']['latent'] = latent_conn
60
+ assembler.workflow[ref_latent_id] = ref_latent_node
61
+
62
+ current_pos_conditioning = [ref_latent_id, 0]
63
+
64
+ assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name] = current_pos_conditioning
65
+
66
+ print(f"ReferenceLatent injector applied. Re-routed inputs through {len(chain_items)} reference image(s).")
comfy_integration/__init__.py ADDED
File without changes
comfy_integration/nodes.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import execution
3
+ import server
4
+ from nodes import (
5
+ init_extra_nodes, CheckpointLoaderSimple, EmptyLatentImage, KSampler,
6
+ VAEDecode, SaveImage, NODE_CLASS_MAPPINGS, LoadImage, VAEEncode,
7
+ VAEEncodeForInpaint, ImagePadForOutpaint, LatentUpscaleBy, RepeatLatentBatch
8
+ )
9
+
10
+
11
+ def import_custom_nodes() -> None:
12
+ loop = asyncio.new_event_loop()
13
+ asyncio.set_event_loop(loop)
14
+ server_instance = server.PromptServer(loop)
15
+ execution.PromptQueue(server_instance)
16
+
17
+ loop.run_until_complete(init_extra_nodes())
18
+
19
+ import_custom_nodes()
20
+
21
+ CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode']
22
+ CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL']
23
+ LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader']
24
+ CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer']
25
+
26
+ try:
27
+ KSamplerNode = NODE_CLASS_MAPPINGS['KSampler']
28
+ SAMPLER_CHOICES = KSamplerNode.INPUT_TYPES()["required"]["sampler_name"][0]
29
+ SCHEDULER_CHOICES = KSamplerNode.INPUT_TYPES()["required"]["scheduler"][0]
30
+ except Exception:
31
+ print("⚠️ Could not dynamically get sampler/scheduler choices, using fallback list.")
32
+ SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu']
33
+ SCHEDULER_CHOICES = ['normal', 'karras']
34
+
35
+ checkpointloadersimple = CheckpointLoaderSimple()
36
+ loraloader = LoraLoader()
37
+
38
+
39
+ print("✅ ComfyUI custom nodes and class mappings are ready.")
comfy_integration/setup.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+
5
+ from core.settings import *
6
+
7
+ def move_and_overwrite(src, dst):
8
+ if os.path.isdir(src):
9
+ if os.path.exists(dst):
10
+ shutil.rmtree(dst)
11
+ shutil.move(src, dst)
12
+ elif os.path.isfile(src):
13
+ if os.path.exists(dst):
14
+ os.remove(dst)
15
+ shutil.move(src, dst)
16
+
17
+ def initialize_comfyui():
18
+ APP_DIR = sys.path[0]
19
+ COMFYUI_TEMP_DIR = "ComfyUI_temp"
20
+
21
+ print("--- Cloning ComfyUI Repository ---")
22
+ if not os.path.exists(COMFYUI_TEMP_DIR):
23
+ os.system(f"git clone https://github.com/comfy-Org/ComfyUI {COMFYUI_TEMP_DIR}")
24
+ print("✅ ComfyUI repository cloned.")
25
+ else:
26
+ print("✅ ComfyUI repository already exists.")
27
+
28
+ print(f"--- Merging ComfyUI from '{COMFYUI_TEMP_DIR}' to '{APP_DIR}' ---")
29
+ for item in os.listdir(COMFYUI_TEMP_DIR):
30
+ src_path = os.path.join(COMFYUI_TEMP_DIR, item)
31
+ dst_path = os.path.join(APP_DIR, item)
32
+ if item == '.git':
33
+ continue
34
+ move_and_overwrite(src_path, dst_path)
35
+
36
+ try:
37
+ shutil.rmtree(COMFYUI_TEMP_DIR)
38
+ print("✅ ComfyUI merged and temporary directory removed.")
39
+ except OSError as e:
40
+ print(f"⚠️ Could not remove temporary directory '{COMFYUI_TEMP_DIR}': {e}")
41
+
42
+ print("--- Cloning third-party extensions for ComfyUI ---")
43
+ controlnet_aux_path = os.path.join(APP_DIR, "custom_nodes", "comfyui_controlnet_aux")
44
+ if not os.path.exists(controlnet_aux_path):
45
+ os.system(f"git clone https://github.com/Fannovel16/comfyui_controlnet_aux.git {controlnet_aux_path}")
46
+ print("✅ comfyui_controlnet_aux extension cloned.")
47
+ else:
48
+ print("✅ comfyui_controlnet_aux extension already exists.")
49
+
50
+
51
+ print(f"✅ Current working directory is: {os.getcwd()}")
52
+
53
+ import comfy.model_management
54
+ print("--- Environment Ready ---")
55
+
56
+ print("✅ ComfyUI initialized with default attention mechanism.")
57
+
58
+ os.makedirs(os.path.join(APP_DIR, CHECKPOINT_DIR), exist_ok=True)
59
+ os.makedirs(os.path.join(APP_DIR, LORA_DIR), exist_ok=True)
60
+ os.makedirs(os.path.join(APP_DIR, EMBEDDING_DIR), exist_ok=True)
61
+ os.makedirs(os.path.join(APP_DIR, CONTROLNET_DIR), exist_ok=True)
62
+ os.makedirs(os.path.join(APP_DIR, MODEL_PATCHES_DIR), exist_ok=True)
63
+ os.makedirs(os.path.join(APP_DIR, DIFFUSION_MODELS_DIR), exist_ok=True)
64
+ os.makedirs(os.path.join(APP_DIR, VAE_DIR), exist_ok=True)
65
+ os.makedirs(os.path.join(APP_DIR, TEXT_ENCODERS_DIR), exist_ok=True)
66
+ os.makedirs(os.path.join(APP_DIR, INPUT_DIR), exist_ok=True)
67
+ print("✅ All required model directories are present.")
core/__init__.py ADDED
File without changes
core/generation_logic.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import gradio as gr
3
+
4
+ from core.pipelines.controlnet_preprocessor import ControlNetPreprocessorPipeline
5
+ from core.pipelines.sd_image_pipeline import SdImagePipeline
6
+
7
+ controlnet_preprocessor_pipeline = ControlNetPreprocessorPipeline()
8
+ sd_image_pipeline = SdImagePipeline()
9
+
10
+
11
+ def build_reverse_map():
12
+ from nodes import NODE_DISPLAY_NAME_MAPPINGS
13
+ import core.pipelines.controlnet_preprocessor as cn_module
14
+
15
+ if cn_module.REVERSE_DISPLAY_NAME_MAP is None:
16
+ cn_module.REVERSE_DISPLAY_NAME_MAP = {v: k for k, v in NODE_DISPLAY_NAME_MAPPINGS.items()}
17
+ if "Semantic Segmentor (legacy, alias for UniFormer)" not in cn_module.REVERSE_DISPLAY_NAME_MAP:
18
+ cn_module.REVERSE_DISPLAY_NAME_MAP["Semantic Segmentor (legacy, alias for UniFormer)"] = "SemSegPreprocessor"
19
+
20
+
21
+ def run_cn_preprocessor_entry(*args, **kwargs):
22
+ return controlnet_preprocessor_pipeline.run(*args, **kwargs)
23
+
24
+ def generate_image_wrapper(ui_inputs: dict, progress=gr.Progress(track_tqdm=True)):
25
+ return sd_image_pipeline.run(ui_inputs=ui_inputs, progress=progress)
core/model_manager.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Dict, List, Any, Set
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from comfy import model_management
7
+
8
+ from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR
9
+ from comfy_integration.nodes import LoraLoader
10
+ from nodes import NODE_CLASS_MAPPINGS
11
+ from utils.app_utils import get_value_at_index, _ensure_model_downloaded
12
+
13
+
14
+ class ModelManager:
15
+ _instance = None
16
+
17
+ def __new__(cls, *args, **kwargs):
18
+ if not cls._instance:
19
+ cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
20
+ return cls._instance
21
+
22
+ def __init__(self):
23
+ if hasattr(self, 'initialized'):
24
+ return
25
+ self.loaded_models: Dict[str, Any] = {}
26
+ self.last_active_loras: List[Dict[str, Any]] = []
27
+ self.initialized = True
28
+ print("✅ ModelManager initialized.")
29
+
30
+ def get_loaded_model_names(self) -> Set[str]:
31
+ return set(self.loaded_models.keys())
32
+
33
+ def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
34
+ print(f"--- [ModelManager] Loading model combo: '{display_name}' ---")
35
+
36
+ if display_name not in ALL_MODEL_MAP:
37
+ raise ValueError(f"Model '{display_name}' not found in configuration.")
38
+
39
+ _, components, _, _ = ALL_MODEL_MAP[display_name]
40
+
41
+ unet_filename = components.get('unet')
42
+ clip_filename = components.get('clip')
43
+ vae_filename = components.get('vae')
44
+
45
+ if not all([unet_filename, clip_filename, vae_filename]):
46
+ raise ValueError(f"Model '{display_name}' is missing required components (unet, clip, or vae) in model_list.yaml.")
47
+
48
+ unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]()
49
+ clip_loader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
50
+ vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]()
51
+
52
+ print(" - Loading UNET...")
53
+ unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default")
54
+
55
+ print(" - Loading CLIP...")
56
+ clip_tuple = clip_loader.load_clip(clip_name=clip_filename, type="flux2", device="default")
57
+
58
+ print(" - Loading VAE...")
59
+ vae_tuple = vae_loader.load_vae(vae_name=vae_filename)
60
+
61
+ unet_object = get_value_at_index(unet_tuple, 0)
62
+ clip_object = get_value_at_index(clip_tuple, 0)
63
+
64
+ if active_loras:
65
+ print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---")
66
+ lora_loader = LoraLoader()
67
+ patched_unet, patched_clip = unet_object, clip_object
68
+
69
+ for lora_info in active_loras:
70
+ patched_unet, patched_clip = lora_loader.load_lora(
71
+ model=patched_unet,
72
+ clip=patched_clip,
73
+ lora_name=lora_info["lora_name"],
74
+ strength_model=lora_info["strength_model"],
75
+ strength_clip=lora_info["strength_clip"]
76
+ )
77
+
78
+ unet_object = patched_unet
79
+ clip_object = patched_clip
80
+ print(f"--- [ModelManager] ✅ All LoRAs merged into the model on CPU. ---")
81
+
82
+ loaded_combo = {
83
+ "unet": (unet_object,),
84
+ "clip": (clip_object,),
85
+ "vae": vae_tuple,
86
+ }
87
+
88
+ print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---")
89
+ return loaded_combo
90
+
91
+ def move_models_to_gpu(self, required_models: List[str]):
92
+ print(f"--- [ModelManager] Moving models to GPU: {required_models} ---")
93
+ models_to_load_gpu = []
94
+ for name in required_models:
95
+ if name in self.loaded_models:
96
+ model_combo = self.loaded_models[name]
97
+ models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0))
98
+
99
+ if models_to_load_gpu:
100
+ model_management.load_models_gpu(models_to_load_gpu)
101
+ print("--- [ModelManager] ✅ Models successfully moved to GPU. ---")
102
+ else:
103
+ print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---")
104
+
105
+ def ensure_models_downloaded(self, required_models: List[str], progress):
106
+ print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---")
107
+
108
+ files_to_download = set()
109
+ for display_name in required_models:
110
+ if display_name in ALL_MODEL_MAP:
111
+ _, components, _, _ = ALL_MODEL_MAP[display_name]
112
+ for component_file in components.values():
113
+ files_to_download.add(component_file)
114
+
115
+ files_to_download = list(files_to_download)
116
+ total_files = len(files_to_download)
117
+
118
+ for i, filename in enumerate(files_to_download):
119
+ if progress and hasattr(progress, '__call__'):
120
+ progress(i / total_files, desc=f"Checking file: {filename}")
121
+ try:
122
+ _ensure_model_downloaded(filename, progress)
123
+ except Exception as e:
124
+ raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}")
125
+
126
+ print(f"--- [ModelManager] ✅ All required models are present on disk. ---")
127
+
128
+ def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
129
+ required_set = set(required_models)
130
+ current_set = set(self.loaded_models.keys())
131
+
132
+ loras_changed = active_loras != self.last_active_loras
133
+
134
+ models_to_unload = current_set - required_set
135
+ if models_to_unload or loras_changed:
136
+ if models_to_unload:
137
+ print(f"--- [ModelManager] Models to unload: {models_to_unload} ---")
138
+ if loras_changed and not models_to_unload:
139
+ models_to_unload = current_set.intersection(required_set)
140
+ if active_loras:
141
+ print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---")
142
+ else:
143
+ print(f"--- [ModelManager] LoRAs removed. Reloading base model(s) to clear patches: {models_to_unload} ---")
144
+
145
+ model_management.unload_all_models()
146
+ self.loaded_models.clear()
147
+ gc.collect()
148
+ torch.cuda.empty_cache()
149
+ print("--- [ModelManager] All models unloaded to free RAM. ---")
150
+
151
+ models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set
152
+
153
+ if models_to_load:
154
+ print(f"--- [ModelManager] Models to load: {models_to_load} ---")
155
+ for i, display_name in enumerate(models_to_load):
156
+ progress(i / len(models_to_load), desc=f"Loading model: {display_name}")
157
+ try:
158
+ loaded_model_data = self._load_model_combo(display_name, active_loras, progress)
159
+ self.loaded_models[display_name] = loaded_model_data
160
+ except Exception as e:
161
+ raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}")
162
+ else:
163
+ print(f"--- [ModelManager] All required models are already loaded. ---")
164
+
165
+ self.last_active_loras = active_loras
166
+ return {name: self.loaded_models[name] for name in required_models}
167
+
168
+ model_manager = ModelManager()
core/pipelines/__init__.py ADDED
File without changes
core/pipelines/base_pipeline.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Any, Dict
3
+ import gradio as gr
4
+ import spaces
5
+ import tempfile
6
+ import imageio
7
+ import numpy as np
8
+
9
+ class BasePipeline(ABC):
10
+ def __init__(self):
11
+ from core.model_manager import model_manager
12
+ self.model_manager = model_manager
13
+
14
+ @abstractmethod
15
+ def get_required_models(self, **kwargs) -> List[str]:
16
+ pass
17
+
18
+ @abstractmethod
19
+ def run(self, *args, progress: gr.Progress, **kwargs) -> Any:
20
+ pass
21
+
22
+ def _ensure_models_downloaded(self, progress: gr.Progress, **kwargs):
23
+ """Ensures model files are downloaded before requesting GPU."""
24
+ required_models = self.get_required_models(**kwargs)
25
+ self.model_manager.ensure_models_downloaded(required_models, progress=progress)
26
+
27
+ def _execute_gpu_logic(self, gpu_function: callable, duration: int, default_duration: int, task_name: str, *args, **kwargs):
28
+ final_duration = default_duration
29
+ try:
30
+ if duration is not None and int(duration) > 0:
31
+ final_duration = int(duration)
32
+ except (ValueError, TypeError):
33
+ print(f"Invalid ZeroGPU duration input for {task_name}. Using default {default_duration}s.")
34
+ pass
35
+
36
+ print(f"Requesting ZeroGPU for {task_name} with duration: {final_duration} seconds.")
37
+ gpu_runner = spaces.GPU(duration=final_duration)(gpu_function)
38
+
39
+ return gpu_runner(*args, **kwargs)
40
+
41
+ def _encode_video_from_frames(self, frames_tensor_cpu: 'torch.Tensor', fps: int, progress: gr.Progress) -> str:
42
+ progress(0.9, desc="Encoding video on CPU...")
43
+ frames_np = (frames_tensor_cpu.numpy() * 255.0).astype(np.uint8)
44
+
45
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
46
+ video_path = temp_video_file.name
47
+ writer = imageio.get_writer(video_path, fps=fps, codec='libx264', quality=8)
48
+ for frame in frames_np:
49
+ writer.append_data(frame)
50
+ writer.close()
51
+
52
+ progress(1.0, desc="Done!")
53
+ return video_path
core/pipelines/controlnet_preprocessor.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import imageio
3
+ import tempfile
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import spaces
9
+
10
+ from .base_pipeline import BasePipeline
11
+ from comfy_integration.nodes import NODE_CLASS_MAPPINGS
12
+ from nodes import NODE_DISPLAY_NAME_MAPPINGS
13
+ from utils.app_utils import get_value_at_index
14
+
15
+ REVERSE_DISPLAY_NAME_MAP = None
16
+ CPU_ONLY_PREPROCESSORS = {
17
+ "Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)",
18
+ "Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines",
19
+ "Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile"
20
+ }
21
+
22
+ def run_node_by_function_name(node_instance: Any, **kwargs) -> Any:
23
+ node_class = type(node_instance)
24
+ function_name = getattr(node_class, 'FUNCTION', None)
25
+ if not function_name:
26
+ raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.")
27
+ execution_method = getattr(node_instance, function_name, None)
28
+ if not callable(execution_method):
29
+ raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.")
30
+ return execution_method(**kwargs)
31
+
32
+ class ControlNetPreprocessorPipeline(BasePipeline):
33
+ def get_required_models(self, **kwargs) -> List[str]:
34
+ return []
35
+
36
+ def _gpu_logic(
37
+ self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str,
38
+ params: Dict[str, Any], progress=gr.Progress(track_tqdm=True)
39
+ ) -> List[Image.Image]:
40
+ global REVERSE_DISPLAY_NAME_MAP
41
+ if REVERSE_DISPLAY_NAME_MAP is None:
42
+ raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.")
43
+
44
+ class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name)
45
+ if not class_name or class_name not in NODE_CLASS_MAPPINGS:
46
+ raise ValueError(f"Preprocessor '{preprocessor_name}' not found.")
47
+
48
+ preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]()
49
+ call_args = {**params, 'ckpt_name': model_name}
50
+
51
+ processed_pil_images = []
52
+ total_frames = len(pil_images)
53
+
54
+ for i, frame_pil in enumerate(pil_images):
55
+ progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...")
56
+
57
+ frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0)
58
+
59
+ resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])}
60
+
61
+ result_tuple = run_node_by_function_name(
62
+ preprocessor_instance,
63
+ image=frame_tensor,
64
+ **resolution_arg,
65
+ **call_args
66
+ )
67
+
68
+ processed_tensor = get_value_at_index(result_tuple, 0)
69
+ processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8)
70
+ processed_pil_images.append(Image.fromarray(processed_np))
71
+
72
+ return processed_pil_images
73
+
74
+ def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)):
75
+ from utils import app_utils
76
+ pil_images, is_video, fps = [], False, 30
77
+
78
+ progress(0, desc="Reading input file...")
79
+ if input_type == "Image":
80
+ if image_input is None: raise gr.Error("Please provide an input image.")
81
+ pil_images = [image_input]
82
+ elif input_type == "Video":
83
+ if video_input is None: raise gr.Error("Please provide an input video.")
84
+ try:
85
+ video_reader = imageio.get_reader(video_input)
86
+ meta = video_reader.get_meta_data()
87
+ fps = meta.get('fps', 30)
88
+ pil_images = [Image.fromarray(frame) for frame in video_reader]
89
+ is_video = True
90
+ video_reader.close()
91
+ except Exception as e: raise gr.Error(f"Failed to read video file: {e}")
92
+ else:
93
+ raise gr.Error("Invalid input type selected.")
94
+
95
+ if not pil_images: raise gr.Error("Could not extract any frames from the input.")
96
+
97
+ if app_utils.PREPROCESSOR_PARAMETER_MAP is None:
98
+ raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.")
99
+
100
+ params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
101
+ sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]]
102
+ dropdown_params = [p for p in params_config if isinstance(p['type'], list)]
103
+ checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"]
104
+ ordered_params_config = sliders_params + dropdown_params + checkbox_params
105
+ param_names = [p['name'] for p in ordered_params_config]
106
+ provided_params = {param_names[i]: args[i] for i in range(len(param_names))}
107
+
108
+ if preprocessor_name not in CPU_ONLY_PREPROCESSORS:
109
+ print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---")
110
+ try:
111
+ processed_pil_images = self._execute_gpu_logic(
112
+ self._gpu_logic,
113
+ duration=zero_gpu_duration,
114
+ default_duration=60,
115
+ task_name=f"Preprocessor '{preprocessor_name}'",
116
+ pil_images=pil_images,
117
+ preprocessor_name=preprocessor_name,
118
+ model_name=model_name,
119
+ params=provided_params,
120
+ progress=progress
121
+ )
122
+ except Exception as e:
123
+ import traceback; traceback.print_exc()
124
+ raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}")
125
+ else:
126
+ print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---")
127
+ try:
128
+ processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress)
129
+ except Exception as e:
130
+ import traceback; traceback.print_exc()
131
+ raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}")
132
+
133
+ if not processed_pil_images: raise gr.Error("Processing returned no frames.")
134
+
135
+ progress(0.9, desc="Finalizing output...")
136
+ if is_video:
137
+ frames_np = [np.array(img) for img in processed_pil_images]
138
+ frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0
139
+ video_path = self._encode_video_from_frames(frames_tensor, fps, progress)
140
+ return [video_path]
141
+ else:
142
+ progress(1.0, desc="Done!")
143
+ return processed_pil_images
core/pipelines/sd_image_pipeline.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ import torch
5
+ import gradio as gr
6
+ from PIL import Image, ImageChops
7
+ from typing import List, Dict, Any
8
+ from collections import defaultdict, deque
9
+ import numpy as np
10
+
11
+ from .base_pipeline import BasePipeline
12
+ from core.settings import *
13
+ from comfy_integration.nodes import *
14
+ from utils.app_utils import get_value_at_index, sanitize_prompt, get_lora_path, get_embedding_path, ensure_controlnet_model_downloaded, sanitize_filename
15
+ from core.workflow_assembler import WorkflowAssembler
16
+
17
+ class SdImagePipeline(BasePipeline):
18
+ def get_required_models(self, model_display_name: str, **kwargs) -> List[str]:
19
+ return [model_display_name]
20
+
21
+ def _topological_sort(self, workflow: Dict[str, Any]) -> List[str]:
22
+ graph = defaultdict(list)
23
+ in_degree = {node_id: 0 for node_id in workflow}
24
+
25
+ for node_id, node_info in workflow.items():
26
+ for input_value in node_info.get('inputs', {}).values():
27
+ if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
28
+ source_node_id = input_value[0]
29
+ if source_node_id in workflow:
30
+ graph[source_node_id].append(node_id)
31
+ in_degree[node_id] += 1
32
+
33
+ queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
34
+
35
+ sorted_nodes = []
36
+ while queue:
37
+ current_node_id = queue.popleft()
38
+ sorted_nodes.append(current_node_id)
39
+
40
+ for neighbor_node_id in graph[current_node_id]:
41
+ in_degree[neighbor_node_id] -= 1
42
+ if in_degree[neighbor_node_id] == 0:
43
+ queue.append(neighbor_node_id)
44
+
45
+ if len(sorted_nodes) != len(workflow):
46
+ raise RuntimeError("Workflow contains a cycle and cannot be executed.")
47
+
48
+ return sorted_nodes
49
+
50
+
51
+ def _execute_workflow(self, workflow: Dict[str, Any], initial_objects: Dict[str, Any]):
52
+ with torch.no_grad():
53
+ computed_outputs = initial_objects
54
+
55
+ try:
56
+ sorted_node_ids = self._topological_sort(workflow)
57
+ print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}")
58
+ except RuntimeError as e:
59
+ print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---")
60
+ for node_id, node_info in workflow.items():
61
+ print(f" Node {node_id} ({node_info['class_type']}):")
62
+ for input_name, input_value in node_info['inputs'].items():
63
+ if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
64
+ print(f" - {input_name} <- [{input_value[0]}, {input_value[1]}]")
65
+ raise e
66
+
67
+ for node_id in sorted_node_ids:
68
+ if node_id in computed_outputs:
69
+ continue
70
+
71
+ node_info = workflow[node_id]
72
+ class_type = node_info['class_type']
73
+
74
+ is_loader_with_filename = 'Loader' in class_type and any(key.endswith('_name') for key in node_info['inputs'])
75
+ if node_id in initial_objects and is_loader_with_filename:
76
+ continue
77
+
78
+ node_class = NODE_CLASS_MAPPINGS.get(class_type)
79
+ if node_class is None:
80
+ raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?")
81
+
82
+ node_instance = node_class()
83
+
84
+ kwargs = {}
85
+ for param_name, param_value in node_info['inputs'].items():
86
+ if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str):
87
+ source_node_id, output_index = param_value
88
+ if source_node_id not in computed_outputs:
89
+ raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.")
90
+
91
+ source_output_tuple = computed_outputs[source_node_id]
92
+ kwargs[param_name] = get_value_at_index(source_output_tuple, output_index)
93
+ else:
94
+ kwargs[param_name] = param_value
95
+
96
+ function_name = getattr(node_class, 'FUNCTION')
97
+ execution_method = getattr(node_instance, function_name)
98
+
99
+ result = execution_method(**kwargs)
100
+ computed_outputs[node_id] = result
101
+
102
+ final_node_id = None
103
+ for node_id in reversed(sorted_node_ids):
104
+ if workflow[node_id]['class_type'] == 'SaveImage':
105
+ final_node_id = node_id
106
+ break
107
+
108
+ if not final_node_id:
109
+ raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.")
110
+
111
+ save_image_inputs = workflow[final_node_id]['inputs']
112
+ image_source_node_id, image_source_index = save_image_inputs['images']
113
+
114
+ return get_value_at_index(computed_outputs[image_source_node_id], image_source_index)
115
+
116
+ def _gpu_logic(self, ui_inputs: Dict, loras_string: str, required_models_for_gpu: List[str], workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)):
117
+ model_display_name = ui_inputs['model_display_name']
118
+
119
+ progress(0.1, desc="Moving models to GPU...")
120
+ self.model_manager.move_models_to_gpu(required_models_for_gpu)
121
+
122
+ progress(0.4, desc="Executing workflow...")
123
+
124
+ loaded_model_combo = self.model_manager.loaded_models[model_display_name]
125
+
126
+ initial_objects = {}
127
+
128
+ unet_loader_id = assembler.node_map.get("unet_loader")
129
+ clip_loader_id = assembler.node_map.get("clip_loader")
130
+ vae_loader_id = assembler.node_map.get("vae_loader")
131
+
132
+ if unet_loader_id: initial_objects[unet_loader_id] = loaded_model_combo.get("unet")
133
+ if clip_loader_id: initial_objects[clip_loader_id] = loaded_model_combo.get("clip")
134
+ if vae_loader_id: initial_objects[vae_loader_id] = loaded_model_combo.get("vae")
135
+
136
+ if not all([unet_loader_id, clip_loader_id, vae_loader_id]):
137
+ raise RuntimeError("Workflow is missing one or more required loaders (unet_loader, clip_loader, vae_loader).")
138
+
139
+ decoded_images_tensor = self._execute_workflow(workflow, initial_objects=initial_objects)
140
+
141
+ output_images = []
142
+ start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1)
143
+ for i in range(decoded_images_tensor.shape[0]):
144
+ img_tensor = decoded_images_tensor[i]
145
+ pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
146
+ current_seed = start_seed + i
147
+
148
+ width_for_meta = ui_inputs.get('width', 'N/A')
149
+ height_for_meta = ui_inputs.get('height', 'N/A')
150
+
151
+ params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n"
152
+ params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}"
153
+ if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}"
154
+ if loras_string: params_string += f", {loras_string}"
155
+
156
+ pil_image.info = {'parameters': params_string.strip()}
157
+ output_images.append(pil_image)
158
+
159
+ return output_images
160
+
161
+ def run(self, ui_inputs: Dict, progress):
162
+ progress(0, desc="Preparing models...")
163
+
164
+ task_type = ui_inputs['task_type']
165
+
166
+ ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', ''))
167
+ ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', ''))
168
+
169
+ required_models = self.get_required_models(model_display_name=ui_inputs['model_display_name'])
170
+
171
+ self.model_manager.ensure_models_downloaded(required_models, progress=progress)
172
+
173
+ lora_data = ui_inputs.get('lora_data', [])
174
+ active_loras_for_gpu, active_loras_for_meta = [], []
175
+ sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
176
+
177
+ for i, (source, lora_id, scale, _) in enumerate(zip(sources, ids, scales, files)):
178
+ if scale > 0 and lora_id and lora_id.strip():
179
+ lora_filename = None
180
+ if source == "File":
181
+ lora_filename = sanitize_filename(lora_id)
182
+ elif source == "Civitai":
183
+ local_path, status = get_lora_path(source, lora_id, ui_inputs['civitai_api_key'], progress)
184
+ if local_path: lora_filename = os.path.basename(local_path)
185
+ else: raise gr.Error(f"Failed to prepare LoRA {lora_id}: {status}")
186
+
187
+ if lora_filename:
188
+ active_loras_for_gpu.append({"lora_name": lora_filename, "strength_model": scale, "strength_clip": scale})
189
+ active_loras_for_meta.append(f"{source} {lora_id}:{scale}")
190
+
191
+ progress(0.1, desc="Loading models into RAM...")
192
+ self.model_manager.load_managed_models(required_models, active_loras=active_loras_for_gpu, progress=progress)
193
+
194
+ ui_inputs['denoise'] = 1.0
195
+ if task_type == 'img2img': ui_inputs['denoise'] = ui_inputs.get('img2img_denoise', 0.7)
196
+ elif task_type == 'hires_fix': ui_inputs['denoise'] = ui_inputs.get('hires_denoise', 0.55)
197
+
198
+ temp_files_to_clean = []
199
+
200
+ if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
201
+
202
+ if task_type == 'img2img':
203
+ input_image_pil = ui_inputs.get('img2img_image')
204
+ if input_image_pil:
205
+ temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
206
+ input_image_pil.save(temp_file_path, "PNG")
207
+ ui_inputs['input_image'] = os.path.basename(temp_file_path)
208
+ temp_files_to_clean.append(temp_file_path)
209
+ ui_inputs['width'] = input_image_pil.width
210
+ ui_inputs['height'] = input_image_pil.height
211
+
212
+ elif task_type == 'inpaint':
213
+ inpaint_dict = ui_inputs.get('inpaint_image_dict')
214
+ if not inpaint_dict or not inpaint_dict.get('background') or not inpaint_dict.get('layers'):
215
+ raise gr.Error("Inpainting requires an input image and a drawn mask.")
216
+
217
+ background_img = inpaint_dict['background'].convert("RGBA")
218
+
219
+ composite_mask_pil = Image.new('L', background_img.size, 0)
220
+ for layer in inpaint_dict['layers']:
221
+ if layer:
222
+ layer_alpha = layer.split()[-1]
223
+ composite_mask_pil = ImageChops.lighter(composite_mask_pil, layer_alpha)
224
+
225
+ inverted_mask_alpha = Image.fromarray(255 - np.array(composite_mask_pil), mode='L')
226
+ r, g, b, _ = background_img.split()
227
+ composite_image_with_mask = Image.merge('RGBA', [r, g, b, inverted_mask_alpha])
228
+
229
+ temp_file_path = os.path.join(INPUT_DIR, f"temp_inpaint_composite_{random.randint(1000, 9999)}.png")
230
+ composite_image_with_mask.save(temp_file_path, "PNG")
231
+
232
+ ui_inputs['inpaint_image'] = os.path.basename(temp_file_path)
233
+ temp_files_to_clean.append(temp_file_path)
234
+ ui_inputs.pop('inpaint_mask', None)
235
+
236
+ elif task_type == 'outpaint':
237
+ input_image_pil = ui_inputs.get('outpaint_image')
238
+ if input_image_pil:
239
+ temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
240
+ input_image_pil.save(temp_file_path, "PNG")
241
+ ui_inputs['input_image'] = os.path.basename(temp_file_path)
242
+ temp_files_to_clean.append(temp_file_path)
243
+
244
+ elif task_type == 'hires_fix':
245
+ input_image_pil = ui_inputs.get('hires_image')
246
+ if input_image_pil:
247
+ temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
248
+ input_image_pil.save(temp_file_path, "PNG")
249
+ ui_inputs['input_image'] = os.path.basename(temp_file_path)
250
+ temp_files_to_clean.append(temp_file_path)
251
+
252
+ embedding_data = ui_inputs.get('embedding_data', [])
253
+ embedding_filenames = []
254
+ if embedding_data:
255
+ emb_sources, emb_ids, emb_files = embedding_data[0::3], embedding_data[1::3], embedding_data[2::3]
256
+ for i, (source, emb_id, _) in enumerate(zip(emb_sources, emb_ids, emb_files)):
257
+ if emb_id and emb_id.strip():
258
+ emb_filename = None
259
+ if source == "File":
260
+ emb_filename = sanitize_filename(emb_id)
261
+ elif source == "Civitai":
262
+ local_path, status = get_embedding_path(source, emb_id, ui_inputs['civitai_api_key'], progress)
263
+ if local_path: emb_filename = os.path.basename(local_path)
264
+ else: raise gr.Error(f"Failed to prepare Embedding {emb_id}: {status}")
265
+
266
+ if emb_filename:
267
+ embedding_filenames.append(emb_filename)
268
+
269
+ if embedding_filenames:
270
+ embedding_prompt_text = " ".join([f"embedding:{f}" for f in embedding_filenames])
271
+ if ui_inputs['positive_prompt']:
272
+ ui_inputs['positive_prompt'] = f"{ui_inputs['positive_prompt']}, {embedding_prompt_text}"
273
+ else:
274
+ ui_inputs['positive_prompt'] = embedding_prompt_text
275
+
276
+ from utils.app_utils import get_vae_path
277
+ vae_source = ui_inputs.get('vae_source')
278
+ vae_id = ui_inputs.get('vae_id')
279
+ vae_file = ui_inputs.get('vae_file')
280
+ vae_name_override = None
281
+
282
+ if vae_source and vae_source != "None":
283
+ if vae_source == "File":
284
+ vae_name_override = sanitize_filename(vae_id)
285
+ elif vae_source == "Civitai" and vae_id and vae_id.strip():
286
+ local_path, status = get_vae_path(vae_source, vae_id, ui_inputs.get('civitai_api_key'), progress)
287
+ if local_path: vae_name_override = os.path.basename(local_path)
288
+ else: raise gr.Error(f"Failed to prepare VAE {vae_id}: {status}")
289
+
290
+ if vae_name_override:
291
+ ui_inputs['vae_name'] = vae_name_override
292
+
293
+ conditioning_data = ui_inputs.get('conditioning_data', [])
294
+ active_conditioning = []
295
+ if conditioning_data:
296
+ num_units = len(conditioning_data) // 6
297
+ prompts = conditioning_data[0*num_units : 1*num_units]
298
+ widths = conditioning_data[1*num_units : 2*num_units]
299
+ heights = conditioning_data[2*num_units : 3*num_units]
300
+ xs = conditioning_data[3*num_units : 4*num_units]
301
+ ys = conditioning_data[4*num_units : 5*num_units]
302
+ strengths = conditioning_data[5*num_units : 6*num_units]
303
+
304
+ for i in range(num_units):
305
+ if prompts[i] and prompts[i].strip():
306
+ active_conditioning.append({
307
+ "prompt": prompts[i],
308
+ "width": int(widths[i]),
309
+ "height": int(heights[i]),
310
+ "x": int(xs[i]),
311
+ "y": int(ys[i]),
312
+ "strength": float(strengths[i])
313
+ })
314
+
315
+ reference_latent_data = ui_inputs.get('reference_latent_data', [])
316
+ active_reference_latents = []
317
+ if reference_latent_data:
318
+ for img_pil in reference_latent_data:
319
+ if img_pil is not None:
320
+ temp_file_path = os.path.join(INPUT_DIR, f"temp_ref_{random.randint(1000, 9999)}.png")
321
+ img_pil.save(temp_file_path, "PNG")
322
+ active_reference_latents.append(os.path.basename(temp_file_path))
323
+ temp_files_to_clean.append(temp_file_path)
324
+
325
+ loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""
326
+
327
+ progress(0.8, desc="Assembling workflow...")
328
+
329
+ if ui_inputs.get('seed') == -1:
330
+ ui_inputs['seed'] = random.randint(0, 2**32 - 1)
331
+
332
+ dynamic_values = {'task_type': ui_inputs['task_type']}
333
+
334
+ recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml")
335
+ assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values)
336
+
337
+ model_display_name = ui_inputs['model_display_name']
338
+ if model_display_name not in ALL_MODEL_MAP:
339
+ raise gr.Error(f"Model '{model_display_name}' is not configured in model_list.yaml.")
340
+
341
+ _, components, _, _ = ALL_MODEL_MAP[model_display_name]
342
+
343
+ workflow_inputs = {
344
+ "positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'],
345
+ "seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'],
346
+ "sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'],
347
+ "batch_size": ui_inputs['batch_size'],
348
+ "denoise": ui_inputs['denoise'],
349
+ "input_image": ui_inputs.get('input_image'),
350
+ "inpaint_image": ui_inputs.get('inpaint_image'),
351
+ "inpaint_mask": ui_inputs.get('inpaint_mask'),
352
+ "left": ui_inputs.get('outpaint_left'), "top": ui_inputs.get('outpaint_top'),
353
+ "right": ui_inputs.get('outpaint_right'), "bottom": ui_inputs.get('outpaint_bottom'),
354
+ "hires_upscaler": ui_inputs.get('hires_upscaler'), "hires_scale_by": ui_inputs.get('hires_scale_by'),
355
+ "unet_name": components['unet'],
356
+ "clip_name": components['clip'],
357
+ "vae_name": ui_inputs.get('vae_name', components['vae']),
358
+ "conditioning_chain": active_conditioning,
359
+ "reference_latent_chain": active_reference_latents,
360
+ }
361
+
362
+ if task_type == 'txt2img':
363
+ workflow_inputs['width'] = ui_inputs['width']
364
+ workflow_inputs['height'] = ui_inputs['height']
365
+
366
+ workflow = assembler.assemble(workflow_inputs)
367
+
368
+ progress(1.0, desc="All models ready. Requesting GPU for generation...")
369
+
370
+ try:
371
+ results = self._execute_gpu_logic(
372
+ self._gpu_logic,
373
+ duration=ui_inputs['zero_gpu_duration'],
374
+ default_duration=60,
375
+ task_name=f"ImageGen ({task_type})",
376
+ ui_inputs=ui_inputs,
377
+ loras_string=loras_string,
378
+ required_models_for_gpu=required_models,
379
+ workflow=workflow,
380
+ assembler=assembler,
381
+ progress=progress
382
+ )
383
+ finally:
384
+ for temp_file in temp_files_to_clean:
385
+ if temp_file and os.path.exists(temp_file):
386
+ os.remove(temp_file)
387
+ print(f"✅ Cleaned up temp file: {temp_file}")
388
+
389
+ return results
core/pipelines/workflow_recipes/_partials/_base_sampler.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ ksampler:
3
+ class_type: KSampler
4
+
5
+ vae_decode:
6
+ class_type: VAEDecode
7
+ save_image:
8
+ class_type: SaveImage
9
+ params: {}
10
+
11
+ connections:
12
+ - from: "ksampler:0"
13
+ to: "vae_decode:samples"
14
+ - from: "vae_decode:0"
15
+ to: "save_image:images"
16
+
17
+ ui_map:
18
+ seed: "ksampler:seed"
19
+ steps: "ksampler:steps"
20
+ cfg: "ksampler:cfg"
21
+ sampler_name: "ksampler:sampler_name"
22
+ scheduler: "ksampler:scheduler"
23
+ denoise: "ksampler:denoise"
core/pipelines/workflow_recipes/_partials/conditioning/flux2.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ unet_loader:
3
+ class_type: UNETLoader
4
+ title: "Load Diffusion Model"
5
+ params:
6
+ weight_dtype: "default"
7
+ clip_loader:
8
+ class_type: CLIPLoader
9
+ title: "Load CLIP"
10
+ params:
11
+ type: "flux2"
12
+ device: "default"
13
+ vae_loader:
14
+ class_type: VAELoader
15
+ title: "Load VAE"
16
+
17
+ pos_prompt:
18
+ class_type: CLIPTextEncode
19
+ title: "CLIP Text Encode (Positive)"
20
+ neg_prompt:
21
+ class_type: CLIPTextEncode
22
+ title: "CLIP Text Encode (Negative)"
23
+
24
+ connections:
25
+ - from: "unet_loader:0"
26
+ to: "ksampler:model"
27
+ - from: "clip_loader:0"
28
+ to: "pos_prompt:clip"
29
+ - from: "clip_loader:0"
30
+ to: "neg_prompt:clip"
31
+ - from: "vae_loader:0"
32
+ to: "vae_decode:vae"
33
+ - from: "vae_loader:0"
34
+ to: "vae_encode:vae"
35
+
36
+ - from: "pos_prompt:0"
37
+ to: "ksampler:positive"
38
+ - from: "neg_prompt:0"
39
+ to: "ksampler:negative"
40
+
41
+ dynamic_conditioning_chains:
42
+ conditioning_chain:
43
+ ksampler_node: "ksampler"
44
+ clip_source: "clip_loader:0"
45
+
46
+ dynamic_reference_latent_chains:
47
+ reference_latent_chain:
48
+ ksampler_node: "ksampler"
49
+ vae_node: "vae_loader"
50
+
51
+ ui_map:
52
+ unet_name: "unet_loader:unet_name"
53
+ clip_name: "clip_loader:clip_name"
54
+ vae_name: "vae_loader:vae_name"
55
+ positive_prompt: "pos_prompt:text"
56
+ negative_prompt: "neg_prompt:text"
core/pipelines/workflow_recipes/_partials/input/hires_fix.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ input_image_loader:
3
+ class_type: LoadImage
4
+
5
+ vae_encode:
6
+ class_type: VAEEncode
7
+
8
+ latent_upscaler:
9
+ class_type: LatentUpscaleBy
10
+
11
+ latent_source:
12
+ class_type: RepeatLatentBatch
13
+
14
+ connections:
15
+ - from: "input_image_loader:0"
16
+ to: "vae_encode:pixels"
17
+ - from: "vae_encode:0"
18
+ to: "latent_upscaler:samples"
19
+ - from: "latent_upscaler:0"
20
+ to: "latent_source:samples"
21
+
22
+ ui_map:
23
+ input_image: "input_image_loader:image"
24
+ hires_upscaler: "latent_upscaler:upscale_method"
25
+ hires_scale_by: "latent_upscaler:scale_by"
26
+ batch_size: "latent_source:amount"
core/pipelines/workflow_recipes/_partials/input/img2img.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ input_image_loader:
3
+ class_type: LoadImage
4
+
5
+ vae_encode:
6
+ class_type: VAEEncode
7
+
8
+ latent_source:
9
+ class_type: RepeatLatentBatch
10
+
11
+ connections:
12
+ - from: "input_image_loader:0"
13
+ to: "vae_encode:pixels"
14
+ - from: "vae_encode:0"
15
+ to: "latent_source:samples"
16
+
17
+ ui_map:
18
+ input_image: "input_image_loader:image"
19
+ batch_size: "latent_source:amount"
core/pipelines/workflow_recipes/_partials/input/inpaint.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ inpaint_loader:
3
+ class_type: LoadImage
4
+ title: "Load Inpaint Image+Mask"
5
+
6
+ vae_encode:
7
+ class_type: VAEEncodeForInpaint
8
+ params:
9
+ grow_mask_by: 6
10
+
11
+ latent_source:
12
+ class_type: RepeatLatentBatch
13
+
14
+ connections:
15
+ - from: "inpaint_loader:0"
16
+ to: "vae_encode:pixels"
17
+ - from: "inpaint_loader:1"
18
+ to: "vae_encode:mask"
19
+
20
+ - from: "vae_encode:0"
21
+ to: "latent_source:samples"
22
+
23
+ ui_map:
24
+ inpaint_image: "inpaint_loader:image"
25
+ batch_size: "latent_source:amount"
core/pipelines/workflow_recipes/_partials/input/outpaint.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ input_image_loader:
3
+ class_type: LoadImage
4
+
5
+ pad_image:
6
+ class_type: ImagePadForOutpaint
7
+ params:
8
+ feathering: 10
9
+
10
+ vae_encode:
11
+ class_type: VAEEncodeForInpaint
12
+ params:
13
+ grow_mask_by: 6
14
+
15
+ latent_source:
16
+ class_type: RepeatLatentBatch
17
+
18
+ connections:
19
+ - from: "input_image_loader:0"
20
+ to: "pad_image:image"
21
+
22
+ - from: "pad_image:0"
23
+ to: "vae_encode:pixels"
24
+ - from: "pad_image:1"
25
+ to: "vae_encode:mask"
26
+
27
+ - from: "vae_encode:0"
28
+ to: "latent_source:samples"
29
+
30
+ ui_map:
31
+ input_image: "input_image_loader:image"
32
+
33
+ left: "pad_image:left"
34
+ top: "pad_image:top"
35
+ right: "pad_image:right"
36
+ bottom: "pad_image:bottom"
37
+
38
+ batch_size: "latent_source:amount"
core/pipelines/workflow_recipes/_partials/input/txt2img.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ nodes:
2
+ latent_source:
3
+ class_type: EmptyFlux2LatentImage
4
+
5
+ ui_map:
6
+ width: "latent_source:width"
7
+ height: "latent_source:height"
8
+ batch_size: "latent_source:batch_size"
core/pipelines/workflow_recipes/sd_unified_recipe.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ imports:
2
+ - "_partials/_base_sampler.yaml"
3
+ - "_partials/input/{{ task_type }}.yaml"
4
+ - "_partials/conditioning/flux2.yaml"
5
+
6
+ connections:
7
+ - from: "latent_source:0"
8
+ to: "ksampler:latent_image"
core/settings.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ CHECKPOINT_DIR = "models/checkpoints"
6
+ LORA_DIR = "models/loras"
7
+ EMBEDDING_DIR = "models/embeddings"
8
+ CONTROLNET_DIR = "models/controlnet"
9
+ MODEL_PATCHES_DIR = "models/model_patches"
10
+ DIFFUSION_MODELS_DIR = "models/diffusion_models"
11
+ VAE_DIR = "models/vae"
12
+ TEXT_ENCODERS_DIR = "models/text_encoders"
13
+ INPUT_DIR = "input"
14
+ OUTPUT_DIR = "output"
15
+
16
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+ _MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_list.yaml')
18
+ _FILE_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'file_list.yaml')
19
+ _CONSTANTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'constants.yaml')
20
+ _MODEL_DEFAULTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_defaults.yaml')
21
+
22
+
23
+ def load_constants_from_yaml(filepath=_CONSTANTS_PATH):
24
+ if not os.path.exists(filepath):
25
+ print(f"Warning: Constants file not found at {filepath}. Using fallback values.")
26
+ return {}
27
+ with open(filepath, 'r', encoding='utf-8') as f:
28
+ return yaml.safe_load(f)
29
+
30
+ def load_file_download_map(filepath=_FILE_LIST_PATH):
31
+ if not os.path.exists(filepath):
32
+ raise FileNotFoundError(f"The file list (for downloads) was not found at: {filepath}")
33
+
34
+ with open(filepath, 'r', encoding='utf-8') as f:
35
+ file_list_data = yaml.safe_load(f)
36
+
37
+ download_info_map = {}
38
+ for category, files in file_list_data.get('file', {}).items():
39
+ if isinstance(files, list):
40
+ for file_info in files:
41
+ if 'filename' in file_info:
42
+ file_info['category'] = category
43
+ download_info_map[file_info['filename']] = file_info
44
+ return download_info_map
45
+
46
+
47
+ def load_models_from_yaml(model_list_filepath=_MODEL_LIST_PATH, download_map=None):
48
+ if not os.path.exists(model_list_filepath):
49
+ raise FileNotFoundError(f"The model list file was not found at: {model_list_filepath}")
50
+ if download_map is None:
51
+ raise ValueError("download_map must be provided to load_models_from_yaml")
52
+
53
+ with open(model_list_filepath, 'r', encoding='utf-8') as f:
54
+ model_data = yaml.safe_load(f)
55
+
56
+ model_maps = {
57
+ "MODEL_MAP_CHECKPOINT": OrderedDict(),
58
+ "ALL_MODEL_MAP": OrderedDict(),
59
+ }
60
+ category_map_names = {
61
+ "Checkpoint": "MODEL_MAP_CHECKPOINT",
62
+ }
63
+
64
+ for category, models in model_data.items():
65
+ if category in category_map_names:
66
+ map_name = category_map_names[category]
67
+ if not isinstance(models, list): continue
68
+ for model in models:
69
+ display_name = model['display_name']
70
+ components = model.get('components', {})
71
+
72
+ model_tuple = (
73
+ None,
74
+ components,
75
+ "SDXL",
76
+ None
77
+ )
78
+ model_maps[map_name][display_name] = model_tuple
79
+ model_maps["ALL_MODEL_MAP"][display_name] = model_tuple
80
+
81
+ return model_maps
82
+
83
+ def load_model_defaults(filepath=_MODEL_DEFAULTS_PATH):
84
+ if not os.path.exists(filepath):
85
+ print(f"Warning: Model defaults file not found at {filepath}. Using empty defaults.")
86
+ return {}
87
+ with open(filepath, 'r', encoding='utf-8') as f:
88
+ return yaml.safe_load(f)
89
+
90
+ try:
91
+ ALL_FILE_DOWNLOAD_MAP = load_file_download_map()
92
+ loaded_maps = load_models_from_yaml(download_map=ALL_FILE_DOWNLOAD_MAP)
93
+ MODEL_MAP_CHECKPOINT = loaded_maps["MODEL_MAP_CHECKPOINT"]
94
+ ALL_MODEL_MAP = loaded_maps["ALL_MODEL_MAP"]
95
+
96
+ MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}
97
+
98
+ ALL_MODEL_DEFAULTS = load_model_defaults()
99
+
100
+ except Exception as e:
101
+ print(f"FATAL: Could not load model configuration from YAML. Error: {e}")
102
+ ALL_FILE_DOWNLOAD_MAP = {}
103
+ MODEL_MAP_CHECKPOINT, ALL_MODEL_MAP = {}, {}
104
+ MODEL_TYPE_MAP = {}
105
+ ALL_MODEL_DEFAULTS = {}
106
+
107
+
108
+ try:
109
+ _constants = load_constants_from_yaml()
110
+ MAX_LORAS = _constants.get('MAX_LORAS', 5)
111
+ MAX_EMBEDDINGS = _constants.get('MAX_EMBEDDINGS', 5)
112
+ MAX_CONDITIONINGS = _constants.get('MAX_CONDITIONINGS', 10)
113
+ MAX_CONTROLNETS = _constants.get('MAX_CONTROLNETS', 5)
114
+ MAX_REFERENCE_LATENTS = _constants.get('MAX_REFERENCE_LATENTS', 10)
115
+ LORA_SOURCE_CHOICES = _constants.get('LORA_SOURCE_CHOICES', ["Civitai", "File"])
116
+ RESOLUTION_MAP = _constants.get('RESOLUTION_MAP', {})
117
+ except Exception as e:
118
+ print(f"FATAL: Could not load constants from YAML. Error: {e}")
119
+ MAX_LORAS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_CONTROLNETS = 5, 5, 10, 5
120
+ MAX_REFERENCE_LATENTS = 10
121
+ LORA_SOURCE_CHOICES = ["Civitai", "File"]
122
+ RESOLUTION_MAP = {}
123
+
124
+
125
+ DEFAULT_NEGATIVE_PROMPT = ""
core/shared_state.py ADDED
@@ -0,0 +1 @@
 
 
1
+ INVALID_MODEL_URLS = {}
core/workflow_assembler.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import importlib
4
+ from copy import deepcopy
5
+ from comfy_integration.nodes import NODE_CLASS_MAPPINGS
6
+
7
+ class WorkflowAssembler:
8
+ def __init__(self, recipe_path, dynamic_values=None):
9
+ self.base_path = os.path.dirname(recipe_path)
10
+ self.node_counter = 0
11
+ self.workflow = {}
12
+ self.node_map = {}
13
+
14
+ self._load_injector_config()
15
+
16
+ self.recipe = self._load_and_merge_recipe(os.path.basename(recipe_path), dynamic_values or {})
17
+
18
+ def _load_injector_config(self):
19
+ try:
20
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21
+ injectors_path = os.path.join(project_root, 'yaml', 'injectors.yaml')
22
+
23
+ with open(injectors_path, 'r', encoding='utf-8') as f:
24
+ injector_config = yaml.safe_load(f)
25
+
26
+ definitions = injector_config.get("injector_definitions", {})
27
+ self.injector_order = injector_config.get("injector_order", [])
28
+ self.global_injectors = {}
29
+
30
+ for chain_type, config in definitions.items():
31
+ module_path = config.get("module")
32
+ if not module_path:
33
+ print(f"Warning: Injector '{chain_type}' in injectors.yaml is missing 'module' path.")
34
+ continue
35
+ try:
36
+ module = importlib.import_module(module_path)
37
+ if hasattr(module, 'inject'):
38
+ self.global_injectors[chain_type] = module.inject
39
+ print(f"✅ Successfully registered global injector: {chain_type} from {module_path}")
40
+ else:
41
+ print(f"⚠️ Warning: Module '{module_path}' for injector '{chain_type}' does not have an 'inject' function.")
42
+ except ImportError as e:
43
+ print(f"❌ Error importing module '{module_path}' for injector '{chain_type}': {e}")
44
+
45
+ if not self.injector_order:
46
+ print("⚠️ Warning: 'injector_order' is not defined in injectors.yaml. Using definition order.")
47
+ self.injector_order = list(definitions.keys())
48
+
49
+ except FileNotFoundError:
50
+ print(f"❌ FATAL: Could not find injectors.yaml at {injectors_path}. Dynamic chains will not work.")
51
+ self.injector_order = []
52
+ self.global_injectors = {}
53
+ except Exception as e:
54
+ print(f"❌ FATAL: Could not load or parse injectors.yaml. Dynamic chains will not work. Error: {e}")
55
+ self.injector_order = []
56
+ self.global_injectors = {}
57
+
58
+ def _get_unique_id(self):
59
+ self.node_counter += 1
60
+ return str(self.node_counter)
61
+
62
+ def _get_node_template(self, class_type):
63
+ if class_type not in NODE_CLASS_MAPPINGS:
64
+ raise ValueError(f"Node class '{class_type}' not found. Ensure it's correctly imported in comfy_integration/nodes.py.")
65
+
66
+ node_class = NODE_CLASS_MAPPINGS[class_type]
67
+ input_types = node_class.INPUT_TYPES()
68
+
69
+ template = {
70
+ "inputs": {},
71
+ "class_type": class_type,
72
+ "_meta": {"title": node_class.NODE_NAME if hasattr(node_class, 'NODE_NAME') else class_type}
73
+ }
74
+
75
+ all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
76
+ for name, details in all_inputs.items():
77
+ config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
78
+ template["inputs"][name] = config.get("default")
79
+
80
+ return template
81
+
82
+ def _load_and_merge_recipe(self, recipe_filename, dynamic_values, search_context_dir=None):
83
+ search_path = search_context_dir or self.base_path
84
+ recipe_path_to_use = os.path.join(search_path, recipe_filename)
85
+
86
+ if not os.path.exists(recipe_path_to_use):
87
+ raise FileNotFoundError(f"Recipe file not found: {recipe_path_to_use}")
88
+
89
+ with open(recipe_path_to_use, 'r', encoding='utf-8') as f:
90
+ content = f.read()
91
+
92
+ for key, value in dynamic_values.items():
93
+ if value is not None:
94
+ content = content.replace(f"{{{{ {key} }}}}", str(value))
95
+
96
+ main_recipe = yaml.safe_load(content)
97
+
98
+ merged_recipe = {'nodes': {}, 'connections': [], 'ui_map': {}}
99
+ for key in self.injector_order:
100
+ if key.startswith('dynamic_'):
101
+ merged_recipe[key] = {}
102
+
103
+ parent_recipe_dir = os.path.dirname(recipe_path_to_use)
104
+ for import_path_template in main_recipe.get('imports', []):
105
+ import_path = import_path_template
106
+ for key, value in dynamic_values.items():
107
+ if value is not None:
108
+ import_path = import_path.replace(f"{{{{ {key} }}}}", str(value))
109
+
110
+ try:
111
+ imported_recipe = self._load_and_merge_recipe(import_path, dynamic_values, search_context_dir=parent_recipe_dir)
112
+ merged_recipe['nodes'].update(imported_recipe.get('nodes', {}))
113
+ merged_recipe['connections'].extend(imported_recipe.get('connections', []))
114
+ merged_recipe['ui_map'].update(imported_recipe.get('ui_map', {}))
115
+ for key in self.injector_order:
116
+ if key in imported_recipe and key.startswith('dynamic_'):
117
+ merged_recipe[key].update(imported_recipe.get(key, {}))
118
+ except FileNotFoundError:
119
+ print(f"Warning: Optional recipe partial '{import_path}' not found. Skipping.")
120
+
121
+ merged_recipe['nodes'].update(main_recipe.get('nodes', {}))
122
+ merged_recipe['connections'].extend(main_recipe.get('connections', []))
123
+ merged_recipe['ui_map'].update(main_recipe.get('ui_map', {}))
124
+ for key in self.injector_order:
125
+ if key in main_recipe and key.startswith('dynamic_'):
126
+ merged_recipe[key].update(main_recipe.get(key, {}))
127
+
128
+ return merged_recipe
129
+
130
+ def assemble(self, ui_values):
131
+ for name, details in self.recipe['nodes'].items():
132
+ class_type = details['class_type']
133
+ template = self._get_node_template(class_type)
134
+ node_data = deepcopy(template)
135
+
136
+ unique_id = self._get_unique_id()
137
+ self.node_map[name] = unique_id
138
+
139
+ if 'params' in details:
140
+ for param, value in details['params'].items():
141
+ if param in node_data['inputs']:
142
+ node_data['inputs'][param] = value
143
+
144
+ self.workflow[unique_id] = node_data
145
+
146
+ for ui_key, target in self.recipe.get('ui_map', {}).items():
147
+ if ui_key in ui_values and ui_values[ui_key] is not None:
148
+ target_list = target if isinstance(target, list) else [target]
149
+ for t in target_list:
150
+ target_name, target_param = t.split(':')
151
+ if target_name in self.node_map:
152
+ self.workflow[self.node_map[target_name]]['inputs'][target_param] = ui_values[ui_key]
153
+
154
+ for conn in self.recipe.get('connections', []):
155
+ from_name, from_output_idx = conn['from'].split(':')
156
+ to_name, to_input_name = conn['to'].split(':')
157
+
158
+ from_id = self.node_map.get(from_name)
159
+ to_id = self.node_map.get(to_name)
160
+
161
+ if from_id and to_id:
162
+ self.workflow[to_id]['inputs'][to_input_name] = [from_id, int(from_output_idx)]
163
+
164
+ print("--- [Assembler] Applying dynamic injectors ---")
165
+ recipe_chain_types = {key for key in self.recipe if key.startswith('dynamic_')}
166
+ processing_order = [key for key in self.injector_order if key in recipe_chain_types]
167
+
168
+ for chain_type in processing_order:
169
+ injector_func = self.global_injectors.get(chain_type)
170
+ if injector_func:
171
+ for chain_key, chain_def in self.recipe.get(chain_type, {}).items():
172
+ if chain_key in ui_values and ui_values[chain_key]:
173
+ print(f" -> Injecting '{chain_type}' for '{chain_key}'...")
174
+ chain_items = ui_values[chain_key]
175
+ injector_func(self, chain_def, chain_items)
176
+
177
+ print("--- [Assembler] Finished applying injectors ---")
178
+
179
+ return self.workflow
requirements.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ comfyui-frontend-package==1.37.11
2
+ comfyui-workflow-templates==0.8.24
3
+ comfyui-embedded-docs==0.4.0
4
+ torch==2.9.1
5
+ torchsde
6
+ torchvision
7
+ torchaudio
8
+ numpy>=1.25.0
9
+ einops
10
+ transformers>=4.50.3
11
+ tokenizers>=0.13.3
12
+ sentencepiece
13
+ safetensors>=0.4.2
14
+ aiohttp>=3.11.8
15
+ yarl>=1.18.0
16
+ pyyaml
17
+ Pillow
18
+ scipy
19
+ tqdm
20
+ psutil
21
+ alembic
22
+ SQLAlchemy
23
+ av>=14.2.0
24
+ comfy-kitchen>=0.2.7
25
+ requests
26
+
27
+ #non essential dependencies:
28
+ kornia>=0.7.1
29
+ spandrel
30
+ pydantic~=2.0
31
+ pydantic-settings~=2.0
32
+
33
+
34
+ addict
35
+ albumentations
36
+ filelock
37
+ ftfy
38
+ fvcore
39
+ huggingface-hub
40
+ imageio[ffmpeg]
41
+ importlib_metadata
42
+ matplotlib
43
+ mediapipe
44
+ ninja
45
+ omegaconf
46
+ opencv-python>=4.7.0.72
47
+ python-dateutil
48
+ requests
49
+ scikit-image
50
+ scikit-learn
51
+ soundfile
52
+ spaces
53
+ svglib
54
+ torchsde
55
+ trimesh[easy]
56
+ yacs
57
+ yapf
58
+ onnxruntime-gpu
scripts/__init__.py ADDED
File without changes
scripts/build_sage_attention.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import textwrap
5
+
6
+ REPO_URL = "https://github.com/thu-ml/SageAttention.git"
7
+ REPO_DIR = "SageAttention"
8
+
9
+ def run_command(command, cwd=None, env=None):
10
+ print(f"🚀 Running command: {' '.join(command)}")
11
+ result = subprocess.run(
12
+ command,
13
+ cwd=cwd,
14
+ env=env,
15
+ stdout=subprocess.PIPE,
16
+ stderr=subprocess.STDOUT,
17
+ text=True
18
+ )
19
+
20
+ if result.returncode != 0:
21
+ print(result.stdout)
22
+ raise subprocess.CalledProcessError(result.returncode, command)
23
+
24
+ def patch_setup_py(setup_py_path):
25
+ print(f"--- [SageAttention Build] Applying patches to {setup_py_path} ---")
26
+
27
+ with open(setup_py_path, 'r', encoding='utf-8') as f:
28
+ content = f.read()
29
+
30
+ original_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]'
31
+ modified_cxx_flags = 'CXX_FLAGS = ["-g", "-O3", "-std=c++17", "-DENABLE_BF16"]'
32
+
33
+ if original_cxx_flags in content:
34
+ content = content.replace(original_cxx_flags, modified_cxx_flags)
35
+ print("🔧 Patch 1/1: Removed '-fopenmp' and '-lgomp' from CXX_FLAGS.")
36
+ else:
37
+ print("⚠️ Patch 1/1: CXX_FLAGS line not found as expected. It might have been changed upstream. Skipping.")
38
+
39
+ with open(setup_py_path, 'w', encoding='utf-8') as f:
40
+ f.write(content)
41
+
42
+ print("✅ Patches applied successfully.")
43
+
44
+
45
+ def install_sage_attention():
46
+ print("--- [SageAttention Build] Checking environment ---")
47
+
48
+ if os.path.isdir(REPO_DIR):
49
+ print(f"✅ Directory '{REPO_DIR}' already exists, assuming SageAttention is installed. Skipping build.")
50
+ return
51
+
52
+ print(f"⏳ Directory '{REPO_DIR}' not found. Starting a fresh installation of SageAttention.")
53
+
54
+ try:
55
+ print(f"--- [SageAttention Build] Step 1/3: Cloning repository ---")
56
+ run_command(["git", "clone", REPO_URL])
57
+ print("✅ Repository cloned successfully.")
58
+
59
+ print(f"--- [SageAttention Build] Step 2/3: Patching setup.py ---")
60
+ setup_py_path = os.path.join(REPO_DIR, "setup.py")
61
+ patch_setup_py(setup_py_path)
62
+
63
+ print(f"--- [SageAttention Build] Step 3/3: Compiling and installing ---")
64
+
65
+ build_env = os.environ.copy()
66
+ build_env.update({
67
+ "TORCH_CUDA_ARCH_LIST": "9.0",
68
+ "EXT_PARALLEL": "4",
69
+ "NVCC_APPEND_FLAGS": "--threads 8",
70
+ "MAX_JOBS": "32"
71
+ })
72
+ print("🔧 Setting build environment variables:")
73
+ print(f" - TORCH_CUDA_ARCH_LIST='{build_env['TORCH_CUDA_ARCH_LIST']}'")
74
+ print(f" - EXT_PARALLEL={build_env['EXT_PARALLEL']}")
75
+ print(f" - NVCC_APPEND_FLAGS='{build_env['NVCC_APPEND_FLAGS']}'")
76
+ print(f" - MAX_JOBS={build_env['MAX_JOBS']}")
77
+
78
+ install_command = [sys.executable, "setup.py", "install"]
79
+
80
+ run_command(install_command, cwd=REPO_DIR, env=build_env)
81
+
82
+ print("🎉 SageAttention compiled and installed successfully! ---")
83
+
84
+ except FileNotFoundError:
85
+ print("❌ ERROR: 'git' command not found. Please ensure Git is installed in your environment.")
86
+ sys.exit(1)
87
+ except subprocess.CalledProcessError as e:
88
+ print(f"❌ Command failed with return code: {e.returncode}")
89
+ print(f"❌ Command: {' '.join(e.cmd)}")
90
+ print("❌ SageAttention installation failed. Please check the logs above for details.")
91
+ sys.exit(1)
92
+ except Exception as e:
93
+ print(f"❌ An unknown error occurred: {e}")
94
+ sys.exit(1)
95
+
96
+ if __name__ == "__main__":
97
+ if os.path.isdir(REPO_DIR):
98
+ print(f"Note: To force a rebuild, please delete the '{REPO_DIR}' directory first.")
99
+ install_sage_attention()
ui/__init__.py ADDED
File without changes
ui/events.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yaml
3
+ import os
4
+ import shutil
5
+ from functools import lru_cache
6
+ from core.settings import *
7
+ from utils.app_utils import *
8
+ from core.generation_logic import *
9
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
10
+
11
+ from core.pipelines.controlnet_preprocessor import CPU_ONLY_PREPROCESSORS
12
+ from utils.app_utils import PREPROCESSOR_MODEL_MAP, PREPROCESSOR_PARAMETER_MAP, save_uploaded_file_with_hash
13
+ from ui.shared.ui_components import RESOLUTION_MAP, MAX_CONTROLNETS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_LORAS, MAX_REFERENCE_LATENTS
14
+
15
+
16
+ def on_model_change(model_display_name):
17
+ """
18
+ Callback function to update UI elements when the base model changes.
19
+ It loads default values for steps and cfg from model_defaults.yaml.
20
+ """
21
+ defaults = ALL_MODEL_DEFAULTS.get('Default', {}).copy()
22
+
23
+ model_found = False
24
+ for category, models_in_category in ALL_MODEL_DEFAULTS.items():
25
+ if category == 'Default' or not isinstance(models_in_category, dict):
26
+ continue
27
+
28
+ if model_display_name in models_in_category:
29
+ if '_defaults' in models_in_category:
30
+ defaults.update(models_in_category['_defaults'])
31
+ defaults.update(models_in_category[model_display_name])
32
+ model_found = True
33
+ break
34
+
35
+ if not model_found:
36
+ print(f"No specific defaults found for '{model_display_name}'. Using category or global defaults.")
37
+
38
+ steps_update = gr.update(value=defaults.get('steps'))
39
+ cfg_update = gr.update(value=defaults.get('cfg'))
40
+
41
+ return steps_update, cfg_update
42
+
43
+ def attach_event_handlers(ui_components, demo):
44
+ def update_cn_input_visibility(choice):
45
+ return {
46
+ ui_components["cn_image_input"]: gr.update(visible=choice == "Image"),
47
+ ui_components["cn_video_input"]: gr.update(visible=choice == "Video")
48
+ }
49
+ ui_components["cn_input_type"].change(
50
+ fn=update_cn_input_visibility,
51
+ inputs=[ui_components["cn_input_type"]],
52
+ outputs=[ui_components["cn_image_input"], ui_components["cn_video_input"]]
53
+ )
54
+
55
+ def update_preprocessor_models_dropdown(preprocessor_name):
56
+ models = PREPROCESSOR_MODEL_MAP.get(preprocessor_name)
57
+ if models:
58
+ model_filenames = [m[1] for m in models]
59
+ return gr.update(choices=model_filenames, value=model_filenames[0], visible=True)
60
+ else:
61
+ return gr.update(choices=[], value=None, visible=False)
62
+
63
+ def update_preprocessor_settings_ui(preprocessor_name):
64
+ from ui.layout import MAX_DYNAMIC_CONTROLS
65
+ params = PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
66
+
67
+ slider_updates, dropdown_updates, checkbox_updates = [], [], []
68
+
69
+ s_idx, d_idx, c_idx = 0, 0, 0
70
+
71
+ for param in params:
72
+ if s_idx + d_idx + c_idx >= MAX_DYNAMIC_CONTROLS: break
73
+
74
+ name = param["name"]
75
+ ptype = param["type"]
76
+ config = param["config"]
77
+ label = name.replace('_', ' ').title()
78
+
79
+ if ptype == "INT" or ptype == "FLOAT":
80
+ if s_idx < MAX_DYNAMIC_CONTROLS:
81
+ slider_updates.append(gr.update(
82
+ label=label,
83
+ minimum=config.get('min', 0),
84
+ maximum=config.get('max', 255),
85
+ step=config.get('step', 0.1 if ptype == "FLOAT" else 1),
86
+ value=config.get('default', 0),
87
+ visible=True
88
+ ))
89
+ s_idx += 1
90
+ elif isinstance(ptype, list):
91
+ if d_idx < MAX_DYNAMIC_CONTROLS:
92
+ dropdown_updates.append(gr.update(
93
+ label=label,
94
+ choices=ptype,
95
+ value=config.get('default', ptype[0] if ptype else None),
96
+ visible=True
97
+ ))
98
+ d_idx += 1
99
+ elif ptype == "BOOLEAN":
100
+ if c_idx < MAX_DYNAMIC_CONTROLS:
101
+ checkbox_updates.append(gr.update(
102
+ label=label,
103
+ value=config.get('default', False),
104
+ visible=True
105
+ ))
106
+ c_idx += 1
107
+
108
+ for _ in range(s_idx, MAX_DYNAMIC_CONTROLS): slider_updates.append(gr.update(visible=False))
109
+ for _ in range(d_idx, MAX_DYNAMIC_CONTROLS): dropdown_updates.append(gr.update(visible=False))
110
+ for _ in range(c_idx, MAX_DYNAMIC_CONTROLS): checkbox_updates.append(gr.update(visible=False))
111
+
112
+ return slider_updates + dropdown_updates + checkbox_updates
113
+
114
+ def update_run_button_for_cpu(preprocessor_name):
115
+ if preprocessor_name in CPU_ONLY_PREPROCESSORS:
116
+ return gr.update(value="Run Preprocessor CPU Only", variant="primary"), gr.update(visible=False)
117
+ else:
118
+ return gr.update(value="Run Preprocessor", variant="primary"), gr.update(visible=True)
119
+
120
+ ui_components["preprocessor_cn"].change(
121
+ fn=update_preprocessor_models_dropdown,
122
+ inputs=[ui_components["preprocessor_cn"]],
123
+ outputs=[ui_components["preprocessor_model_cn"]]
124
+ ).then(
125
+ fn=update_preprocessor_settings_ui,
126
+ inputs=[ui_components["preprocessor_cn"]],
127
+ outputs=ui_components["cn_sliders"] + ui_components["cn_dropdowns"] + ui_components["cn_checkboxes"]
128
+ ).then(
129
+ fn=update_run_button_for_cpu,
130
+ inputs=[ui_components["preprocessor_cn"]],
131
+ outputs=[ui_components["run_cn"], ui_components["zero_gpu_cn"]]
132
+ )
133
+
134
+ all_dynamic_inputs = (
135
+ ui_components["cn_sliders"] +
136
+ ui_components["cn_dropdowns"] +
137
+ ui_components["cn_checkboxes"]
138
+ )
139
+
140
+ ui_components["run_cn"].click(
141
+ fn=run_cn_preprocessor_entry,
142
+ inputs=[
143
+ ui_components["cn_input_type"],
144
+ ui_components["cn_image_input"],
145
+ ui_components["cn_video_input"],
146
+ ui_components["preprocessor_cn"],
147
+ ui_components["preprocessor_model_cn"],
148
+ ui_components["zero_gpu_cn"],
149
+ ] + all_dynamic_inputs,
150
+ outputs=[ui_components["output_gallery_cn"]]
151
+ )
152
+
153
+ def create_lora_event_handlers(prefix):
154
+ lora_rows = ui_components[f'lora_rows_{prefix}']
155
+ lora_ids = ui_components[f'lora_ids_{prefix}']
156
+ lora_scales = ui_components[f'lora_scales_{prefix}']
157
+ lora_uploads = ui_components[f'lora_uploads_{prefix}']
158
+ count_state = ui_components[f'lora_count_state_{prefix}']
159
+ add_button = ui_components[f'add_lora_button_{prefix}']
160
+ del_button = ui_components[f'delete_lora_button_{prefix}']
161
+
162
+ def add_lora_row(c):
163
+ updates = {}
164
+ if c < MAX_LORAS:
165
+ c += 1
166
+ updates[lora_rows[c - 1]] = gr.update(visible=True)
167
+
168
+ updates[count_state] = c
169
+ updates[add_button] = gr.update(visible=c < MAX_LORAS)
170
+ updates[del_button] = gr.update(visible=c > 1)
171
+ return updates
172
+
173
+ def del_lora_row(c):
174
+ updates = {}
175
+ if c > 1:
176
+ updates[lora_rows[c - 1]] = gr.update(visible=False)
177
+ updates[lora_ids[c - 1]] = ""
178
+ updates[lora_scales[c - 1]] = 0.0
179
+ updates[lora_uploads[c - 1]] = None
180
+ c -= 1
181
+
182
+ updates[count_state] = c
183
+ updates[add_button] = gr.update(visible=True)
184
+ updates[del_button] = gr.update(visible=c > 1)
185
+ return updates
186
+
187
+ add_outputs = [count_state, add_button, del_button] + lora_rows
188
+ del_outputs = [count_state, add_button, del_button] + lora_rows + lora_ids + lora_scales + lora_uploads
189
+
190
+ add_button.click(add_lora_row, [count_state], add_outputs, show_progress=False)
191
+ del_button.click(del_lora_row, [count_state], del_outputs, show_progress=False)
192
+
193
+ def create_embedding_event_handlers(prefix):
194
+ rows = ui_components[f'embedding_rows_{prefix}']
195
+ ids = ui_components[f'embeddings_ids_{prefix}']
196
+ files = ui_components[f'embeddings_files_{prefix}']
197
+ count_state = ui_components[f'embedding_count_state_{prefix}']
198
+ add_button = ui_components[f'add_embedding_button_{prefix}']
199
+ del_button = ui_components[f'delete_embedding_button_{prefix}']
200
+
201
+ def add_row(c):
202
+ c += 1
203
+ return {
204
+ count_state: c,
205
+ rows[c - 1]: gr.update(visible=True),
206
+ add_button: gr.update(visible=c < MAX_EMBEDDINGS),
207
+ del_button: gr.update(visible=True)
208
+ }
209
+
210
+ def del_row(c):
211
+ c -= 1
212
+ return {
213
+ count_state: c,
214
+ rows[c]: gr.update(visible=False),
215
+ ids[c]: "",
216
+ files[c]: None,
217
+ add_button: gr.update(visible=True),
218
+ del_button: gr.update(visible=c > 0)
219
+ }
220
+
221
+ add_outputs = [count_state, add_button, del_button] + rows
222
+ del_outputs = [count_state, add_button, del_button] + rows + ids + files
223
+ add_button.click(fn=add_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
224
+ del_button.click(fn=del_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
225
+
226
+ def create_conditioning_event_handlers(prefix):
227
+ rows = ui_components[f'conditioning_rows_{prefix}']
228
+ prompts = ui_components[f'conditioning_prompts_{prefix}']
229
+ count_state = ui_components[f'conditioning_count_state_{prefix}']
230
+ add_button = ui_components[f'add_conditioning_button_{prefix}']
231
+ del_button = ui_components[f'delete_conditioning_button_{prefix}']
232
+
233
+ def add_row(c):
234
+ c += 1
235
+ return {
236
+ count_state: c,
237
+ rows[c - 1]: gr.update(visible=True),
238
+ add_button: gr.update(visible=c < MAX_CONDITIONINGS),
239
+ del_button: gr.update(visible=True),
240
+ }
241
+
242
+ def del_row(c):
243
+ c -= 1
244
+ return {
245
+ count_state: c,
246
+ rows[c]: gr.update(visible=False),
247
+ prompts[c]: "",
248
+ add_button: gr.update(visible=True),
249
+ del_button: gr.update(visible=c > 0),
250
+ }
251
+
252
+ add_outputs = [count_state, add_button, del_button] + rows
253
+ del_outputs = [count_state, add_button, del_button] + rows + prompts
254
+ add_button.click(fn=add_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
255
+ del_button.click(fn=del_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
256
+
257
+ def create_reference_latent_event_handlers(prefix):
258
+ rows = ui_components[f'reference_latent_rows_{prefix}']
259
+ images = ui_components[f'reference_latent_images_{prefix}']
260
+ count_state = ui_components[f'reference_latent_count_state_{prefix}']
261
+ add_button = ui_components[f'add_reference_latent_button_{prefix}']
262
+ del_button = ui_components[f'delete_reference_latent_button_{prefix}']
263
+
264
+ def add_row(c):
265
+ c += 1
266
+ return {
267
+ count_state: c,
268
+ rows[c - 1]: gr.update(visible=True),
269
+ add_button: gr.update(visible=c < MAX_REFERENCE_LATENTS),
270
+ del_button: gr.update(visible=True),
271
+ }
272
+
273
+ def del_row(c):
274
+ c -= 1
275
+ return {
276
+ count_state: c,
277
+ rows[c]: gr.update(visible=False),
278
+ images[c]: None,
279
+ add_button: gr.update(visible=True),
280
+ del_button: gr.update(visible=c > 0),
281
+ }
282
+
283
+ add_outputs = [count_state, add_button, del_button] + rows
284
+ del_outputs = [count_state, add_button, del_button] + rows + images
285
+ add_button.click(fn=add_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
286
+ del_button.click(fn=del_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
287
+
288
+ def on_vae_upload(file_obj):
289
+ if not file_obj:
290
+ return gr.update(), gr.update(), None
291
+
292
+ hashed_filename = save_uploaded_file_with_hash(file_obj, VAE_DIR)
293
+ return hashed_filename, "File", file_obj
294
+
295
+ def on_lora_upload(file_obj):
296
+ if not file_obj:
297
+ return gr.update(), gr.update()
298
+
299
+ hashed_filename = save_uploaded_file_with_hash(file_obj, LORA_DIR)
300
+ return hashed_filename, "File"
301
+
302
+ def on_embedding_upload(file_obj):
303
+ if not file_obj:
304
+ return gr.update(), gr.update(), None
305
+
306
+ hashed_filename = save_uploaded_file_with_hash(file_obj, EMBEDDING_DIR)
307
+ return hashed_filename, "File", file_obj
308
+
309
+
310
+ def create_run_event(prefix: str, task_type: str):
311
+ run_inputs_map = {
312
+ 'model_display_name': ui_components[f'base_model_{prefix}'],
313
+ 'positive_prompt': ui_components[f'prompt_{prefix}'],
314
+ 'negative_prompt': ui_components[f'neg_prompt_{prefix}'],
315
+ 'seed': ui_components[f'seed_{prefix}'],
316
+ 'batch_size': ui_components[f'batch_size_{prefix}'],
317
+ 'guidance_scale': ui_components[f'cfg_{prefix}'],
318
+ 'num_inference_steps': ui_components[f'steps_{prefix}'],
319
+ 'sampler': ui_components[f'sampler_{prefix}'],
320
+ 'scheduler': ui_components[f'scheduler_{prefix}'],
321
+ 'zero_gpu_duration': ui_components[f'zero_gpu_{prefix}'],
322
+ 'civitai_api_key': ui_components.get(f'civitai_api_key_{prefix}'),
323
+ 'clip_skip': ui_components[f'clip_skip_{prefix}'],
324
+ 'task_type': gr.State(task_type)
325
+ }
326
+
327
+ if task_type not in ['img2img', 'inpaint']:
328
+ run_inputs_map.update({'width': ui_components[f'width_{prefix}'], 'height': ui_components[f'height_{prefix}']})
329
+
330
+ task_specific_map = {
331
+ 'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'},
332
+ 'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}'},
333
+ 'outpaint': {'outpaint_image': f'input_image_{prefix}', 'outpaint_left': f'outpaint_left_{prefix}', 'outpaint_top': f'outpaint_top_{prefix}', 'outpaint_right': f'outpaint_right_{prefix}', 'outpaint_bottom': f'outpaint_bottom_{prefix}'},
334
+ 'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'}
335
+ }
336
+ if task_type in task_specific_map:
337
+ for key, comp_name in task_specific_map[task_type].items():
338
+ run_inputs_map[key] = ui_components[comp_name]
339
+
340
+ lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', [])
341
+ embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', [])
342
+ conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', [])
343
+ reference_latent_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', [])
344
+
345
+ run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}')
346
+ run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}')
347
+ run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}')
348
+
349
+ input_keys = list(run_inputs_map.keys())
350
+ input_list_flat = [v for v in run_inputs_map.values() if v is not None]
351
+ input_list_flat += lora_data_components + embedding_data_components + conditioning_data_components + reference_latent_components
352
+
353
+ def create_ui_inputs_dict(*args):
354
+ valid_keys = [k for k in input_keys if run_inputs_map[k] is not None]
355
+ ui_dict = dict(zip(valid_keys, args[:len(valid_keys)]))
356
+ arg_idx = len(valid_keys)
357
+
358
+ ui_dict['lora_data'] = list(args[arg_idx : arg_idx + len(lora_data_components)])
359
+ arg_idx += len(lora_data_components)
360
+ ui_dict['embedding_data'] = list(args[arg_idx : arg_idx + len(embedding_data_components)])
361
+ arg_idx += len(embedding_data_components)
362
+ ui_dict['conditioning_data'] = list(args[arg_idx : arg_idx + len(conditioning_data_components)])
363
+ arg_idx += len(conditioning_data_components)
364
+ ui_dict['reference_latent_data'] = list(args[arg_idx : arg_idx + len(reference_latent_components)])
365
+
366
+
367
+ return ui_dict
368
+
369
+ ui_components[f'run_{prefix}'].click(
370
+ fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress),
371
+ inputs=input_list_flat,
372
+ outputs=[ui_components[f'result_{prefix}']]
373
+ )
374
+
375
+
376
+ for prefix, task_type in [
377
+ ("txt2img", "txt2img"), ("img2img", "img2img"), ("inpaint", "inpaint"),
378
+ ("outpaint", "outpaint"), ("hires_fix", "hires_fix"),
379
+ ]:
380
+ model_dropdown = ui_components.get(f'base_model_{prefix}')
381
+ steps_slider = ui_components.get(f'steps_{prefix}')
382
+ cfg_slider = ui_components.get(f'cfg_{prefix}')
383
+ if all([model_dropdown, steps_slider, cfg_slider]):
384
+ model_dropdown.change(
385
+ fn=on_model_change,
386
+ inputs=[model_dropdown],
387
+ outputs=[steps_slider, cfg_slider],
388
+ show_progress=False
389
+ )
390
+
391
+ if f'add_lora_button_{prefix}' in ui_components:
392
+ create_lora_event_handlers(prefix)
393
+ lora_uploads = ui_components[f'lora_uploads_{prefix}']
394
+ lora_ids = ui_components[f'lora_ids_{prefix}']
395
+ lora_sources = ui_components[f'lora_sources_{prefix}']
396
+ for i in range(MAX_LORAS):
397
+ lora_uploads[i].upload(
398
+ fn=on_lora_upload,
399
+ inputs=[lora_uploads[i]],
400
+ outputs=[lora_ids[i], lora_sources[i]],
401
+ show_progress=False
402
+ )
403
+
404
+ if f'add_embedding_button_{prefix}' in ui_components:
405
+ create_embedding_event_handlers(prefix)
406
+ if f'embeddings_uploads_{prefix}' in ui_components:
407
+ emb_uploads = ui_components[f'embeddings_uploads_{prefix}']
408
+ emb_ids = ui_components[f'embeddings_ids_{prefix}']
409
+ emb_sources = ui_components[f'embeddings_sources_{prefix}']
410
+ emb_files = ui_components[f'embeddings_files_{prefix}']
411
+ for i in range(MAX_EMBEDDINGS):
412
+ emb_uploads[i].upload(
413
+ fn=on_embedding_upload,
414
+ inputs=[emb_uploads[i]],
415
+ outputs=[emb_ids[i], emb_sources[i], emb_files[i]],
416
+ show_progress=False
417
+ )
418
+ if f'add_conditioning_button_{prefix}' in ui_components: create_conditioning_event_handlers(prefix)
419
+ if f'add_reference_latent_button_{prefix}' in ui_components: create_reference_latent_event_handlers(prefix)
420
+ if f'vae_source_{prefix}' in ui_components:
421
+ upload_button = ui_components.get(f'vae_upload_button_{prefix}')
422
+ if upload_button:
423
+ upload_button.upload(
424
+ fn=on_vae_upload,
425
+ inputs=[upload_button],
426
+ outputs=[
427
+ ui_components[f'vae_id_{prefix}'],
428
+ ui_components[f'vae_source_{prefix}'],
429
+ ui_components[f'vae_file_{prefix}']
430
+ ]
431
+ )
432
+
433
+ create_run_event(prefix, task_type)
434
+
435
+ def on_aspect_ratio_change(ratio_key, model_display_name):
436
+ model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl').lower()
437
+ res_map = RESOLUTION_MAP.get(model_type, RESOLUTION_MAP.get("sdxl", {}))
438
+ w, h = res_map.get(ratio_key, (1024, 1024))
439
+ return w, h
440
+
441
+ for prefix in ["txt2img", "img2img", "inpaint", "outpaint", "hires_fix"]:
442
+ if f'aspect_ratio_{prefix}' in ui_components:
443
+ aspect_ratio_dropdown = ui_components[f'aspect_ratio_{prefix}']
444
+ width_component = ui_components[f'width_{prefix}']
445
+ height_component = ui_components[f'height_{prefix}']
446
+ model_dropdown = ui_components[f'base_model_{prefix}']
447
+ aspect_ratio_dropdown.change(fn=on_aspect_ratio_change, inputs=[aspect_ratio_dropdown, model_dropdown], outputs=[width_component, height_component], show_progress=False)
448
+
449
+ if 'view_mode_inpaint' in ui_components:
450
+ def toggle_inpaint_fullscreen_view(view_mode):
451
+ is_fullscreen = (view_mode == "Fullscreen View")
452
+ other_elements_visible = not is_fullscreen
453
+ editor_height = 800 if is_fullscreen else 272
454
+ return {
455
+ ui_components['model_and_run_row_inpaint']: gr.update(visible=other_elements_visible),
456
+ ui_components['prompts_column_inpaint']: gr.update(visible=other_elements_visible),
457
+ ui_components['params_and_gallery_row_inpaint']: gr.update(visible=other_elements_visible),
458
+ ui_components['accordion_wrapper_inpaint']: gr.update(visible=other_elements_visible),
459
+ ui_components['input_image_dict_inpaint']: gr.update(height=editor_height),
460
+ }
461
+
462
+ output_components = [
463
+ ui_components['model_and_run_row_inpaint'], ui_components['prompts_column_inpaint'],
464
+ ui_components['params_and_gallery_row_inpaint'], ui_components['accordion_wrapper_inpaint'],
465
+ ui_components['input_image_dict_inpaint']
466
+ ]
467
+ ui_components['view_mode_inpaint'].change(fn=toggle_inpaint_fullscreen_view, inputs=[ui_components['view_mode_inpaint']], outputs=output_components, show_progress=False)
468
+
469
+ def run_on_load():
470
+ all_updates = {}
471
+
472
+ default_preprocessor = "Canny Edge"
473
+ model_update = update_preprocessor_models_dropdown(default_preprocessor)
474
+ all_updates[ui_components["preprocessor_model_cn"]] = model_update
475
+
476
+ settings_outputs = update_preprocessor_settings_ui(default_preprocessor)
477
+ dynamic_outputs = ui_components["cn_sliders"] + ui_components["cn_dropdowns"] + ui_components["cn_checkboxes"]
478
+ for i, comp in enumerate(dynamic_outputs):
479
+ all_updates[comp] = settings_outputs[i]
480
+
481
+ run_button_update, zero_gpu_update = update_run_button_for_cpu(default_preprocessor)
482
+ all_updates[ui_components["run_cn"]] = run_button_update
483
+ all_updates[ui_components["zero_gpu_cn"]] = zero_gpu_update
484
+
485
+ return all_updates
486
+
487
+ all_load_outputs = [
488
+ ui_components["preprocessor_model_cn"],
489
+ *ui_components["cn_sliders"],
490
+ *ui_components["cn_dropdowns"],
491
+ *ui_components["cn_checkboxes"],
492
+ ui_components["run_cn"],
493
+ ui_components["zero_gpu_cn"]
494
+ ]
495
+
496
+ if all_load_outputs:
497
+ demo.load(
498
+ fn=run_on_load,
499
+ outputs=all_load_outputs
500
+ )
ui/layout.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from core.settings import *
4
+
5
+ from .shared import txt2img_ui, img2img_ui, inpaint_ui, outpaint_ui, hires_fix_ui
6
+
7
+ MAX_DYNAMIC_CONTROLS = 10
8
+
9
+ def get_preprocessor_choices():
10
+ from nodes import NODE_DISPLAY_NAME_MAPPINGS
11
+
12
+ preprocessor_names = [
13
+ display_name for class_name, display_name in NODE_DISPLAY_NAME_MAPPINGS.items()
14
+ if "Preprocessor" in class_name or "Segmentor" in class_name or
15
+ "Estimator" in class_name or "Detector" in class_name
16
+ ]
17
+ return sorted(list(set(preprocessor_names)))
18
+
19
+
20
+ def build_ui(event_handler_function):
21
+ ui_components = {}
22
+
23
+ with gr.Blocks() as demo:
24
+ gr.Markdown("# ImageGen - FLUX.2")
25
+ gr.Markdown(
26
+ "This demo is a streamlined version of the [Comfy web UI](https://github.com/RioShiina47/comfy-webui)'s ImgGen functionality. "
27
+ "Other versions are also available: "
28
+ "[Z-Image](https://huggingface.co/spaces/RioShiina/ImageGen-Z-Image), "
29
+ "[Qwen-Image](https://huggingface.co/spaces/RioShiina/ImageGen-Qwen-Image), "
30
+ "[NewBie-Image](https://huggingface.co/spaces/RioShiina/ImageGen-NewBie-Image), "
31
+ "[Illstrious](https://huggingface.co/spaces/RioShiina/ImageGen-Illstrious), "
32
+ "[NoobAI](https://huggingface.co/spaces/RioShiina/ImageGen-NoobAI), "
33
+ "[Pony](https://huggingface.co/spaces/RioShiina/ImageGen-Pony1), "
34
+ "[SDXL](https://huggingface.co/spaces/RioShiina/ImageGen-SDXL), "
35
+ "[SD1.5](https://huggingface.co/spaces/RioShiina/ImageGen-SD15)"
36
+ )
37
+ with gr.Tabs(elem_id="tabs_container") as tabs:
38
+ with gr.TabItem("FLUX.2", id=0):
39
+ with gr.Tabs(elem_id="image_gen_tabs") as image_gen_tabs:
40
+ with gr.TabItem("Txt2Img", id=0):
41
+ ui_components.update(txt2img_ui.create_ui())
42
+
43
+ with gr.TabItem("Img2Img", id=1):
44
+ ui_components.update(img2img_ui.create_ui())
45
+
46
+ with gr.TabItem("Inpaint", id=2):
47
+ ui_components.update(inpaint_ui.create_ui())
48
+
49
+ with gr.TabItem("Outpaint", id=3):
50
+ ui_components.update(outpaint_ui.create_ui())
51
+
52
+ with gr.TabItem("Hires. Fix", id=4):
53
+ ui_components.update(hires_fix_ui.create_ui())
54
+
55
+ ui_components['image_gen_tabs'] = image_gen_tabs
56
+
57
+ with gr.TabItem("Controlnet Preprocessors", id=1):
58
+ gr.Markdown("## ControlNet Auxiliary Preprocessors")
59
+ gr.Markdown("Powered by [Fannovel16/comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux).")
60
+ gr.Markdown("Upload an image or video to process it with a ControlNet preprocessor.")
61
+ with gr.Row():
62
+ with gr.Column(scale=1):
63
+ cn_input_type = gr.Radio(["Image", "Video"], label="Input Type", value="Image")
64
+ cn_image_input = gr.Image(type="pil", label="Input Image", visible=True, height=384)
65
+ cn_video_input = gr.Video(label="Input Video", visible=False)
66
+ preprocessor_cn = gr.Dropdown(label="Preprocessor", choices=get_preprocessor_choices(), value="Canny Edge")
67
+ preprocessor_model_cn = gr.Dropdown(label="Preprocessor Model", choices=[], value=None, visible=False)
68
+ with gr.Column() as preprocessor_settings_ui:
69
+ cn_sliders, cn_dropdowns, cn_checkboxes = [], [], []
70
+ for i in range(MAX_DYNAMIC_CONTROLS):
71
+ cn_sliders.append(gr.Slider(visible=False, label=f"dyn_slider_{i}"))
72
+ cn_dropdowns.append(gr.Dropdown(visible=False, label=f"dyn_dropdown_{i}"))
73
+ cn_checkboxes.append(gr.Checkbox(visible=False, label=f"dyn_checkbox_{i}"))
74
+ run_cn = gr.Button("Run Preprocessor", variant="primary")
75
+ with gr.Column(scale=1):
76
+ output_gallery_cn = gr.Gallery(label="Output", show_label=False, object_fit="contain", height=512)
77
+ zero_gpu_cn = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional")
78
+ ui_components.update({
79
+ "cn_input_type": cn_input_type, "cn_image_input": cn_image_input, "cn_video_input": cn_video_input,
80
+ "preprocessor_cn": preprocessor_cn, "preprocessor_model_cn": preprocessor_model_cn, "run_cn": run_cn,
81
+ "zero_gpu_cn": zero_gpu_cn, "output_gallery_cn": output_gallery_cn,
82
+ "preprocessor_settings_ui": preprocessor_settings_ui, "cn_sliders": cn_sliders,
83
+ "cn_dropdowns": cn_dropdowns, "cn_checkboxes": cn_checkboxes
84
+ })
85
+
86
+ ui_components["tabs"] = tabs
87
+
88
+ gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by RioShiina with ❤️<br><a href='https://github.com/RioShiina47' target='_blank'>GitHub</a> | <a href='https://huggingface.co/RioShiina' target='_blank'>Hugging Face</a> | <a href='https://civitai.com/user/RioShiina' target='_blank'>Civitai</a></div>")
89
+
90
+ event_handler_function(ui_components, demo)
91
+
92
+ return demo
ui/shared/hires_fix_ui.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.settings import MODEL_MAP_CHECKPOINT
3
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
4
+ from .ui_components import (
5
+ create_lora_settings_ui,
6
+ create_embedding_ui,
7
+ create_conditioning_ui, create_vae_override_ui, create_api_key_ui,
8
+ create_reference_latent_ui
9
+ )
10
+
11
+ def create_ui():
12
+ prefix = "hires_fix"
13
+ components = {}
14
+
15
+ with gr.Column():
16
+ with gr.Row():
17
+ components[f'base_model_{prefix}'] = gr.Dropdown(
18
+ label="Base Model",
19
+ choices=list(MODEL_MAP_CHECKPOINT.keys()),
20
+ value=list(MODEL_MAP_CHECKPOINT.keys())[0],
21
+ scale=3
22
+ )
23
+ with gr.Column(scale=1):
24
+ components[f'run_{prefix}'] = gr.Button("Run Hires. Fix", variant="primary")
25
+
26
+ with gr.Row():
27
+ with gr.Column(scale=1):
28
+ components[f'input_image_{prefix}'] = gr.Image(type="pil", label="Input Image", height=255)
29
+ with gr.Column(scale=2):
30
+ components[f'prompt_{prefix}'] = gr.Text(label="Prompt", lines=3, placeholder="Describe the final image...")
31
+ components[f'neg_prompt_{prefix}'] = gr.Text(label="Negative prompt", lines=3, value="")
32
+
33
+ with gr.Row():
34
+ with gr.Column(scale=1):
35
+ with gr.Row():
36
+ components[f'hires_upscaler_{prefix}'] = gr.Dropdown(
37
+ label="Upscaler",
38
+ choices=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"],
39
+ value="nearest-exact"
40
+ )
41
+ components[f'hires_scale_by_{prefix}'] = gr.Slider(
42
+ label="Upscale by", minimum=1.0, maximum=4.0, step=0.1, value=1.5
43
+ )
44
+
45
+ with gr.Row():
46
+ components[f'denoise_{prefix}'] = gr.Slider(label="Denoise Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.55)
47
+
48
+ with gr.Row():
49
+ components[f'sampler_{prefix}'] = gr.Dropdown(label="Sampler", choices=SAMPLER_CHOICES, value="euler")
50
+ components[f'scheduler_{prefix}'] = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value="simple")
51
+ with gr.Row():
52
+ components[f'steps_{prefix}'] = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
53
+ components[f'cfg_{prefix}'] = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=1.0)
54
+ with gr.Row():
55
+ components[f'seed_{prefix}'] = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
56
+ components[f'batch_size_{prefix}'] = gr.Slider(label="Batch Size", minimum=1, maximum=16, step=1, value=1)
57
+ with gr.Row():
58
+ components[f'zero_gpu_{prefix}'] = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional: Set how long to reserve the GPU.")
59
+
60
+ components[f'clip_skip_{prefix}'] = gr.State(value=1)
61
+ components[f'width_{prefix}'] = gr.State(value=512)
62
+ components[f'height_{prefix}'] = gr.State(value=512)
63
+
64
+ with gr.Column(scale=1):
65
+ components[f'result_{prefix}'] = gr.Gallery(label="Result", show_label=False, columns=1, object_fit="contain", height=610)
66
+
67
+ components.update(create_api_key_ui(prefix))
68
+ components.update(create_lora_settings_ui(prefix))
69
+ # components.update(create_diffsynth_controlnet_ui(prefix))
70
+ # components.update(create_controlnet_ui(prefix))
71
+ # components.update(create_embedding_ui(prefix))
72
+ components.update(create_reference_latent_ui(prefix))
73
+ components.update(create_conditioning_ui(prefix))
74
+ # components.update(create_vae_override_ui(prefix))
75
+
76
+ return components
ui/shared/img2img_ui.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.settings import MODEL_MAP_CHECKPOINT
3
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
4
+ from .ui_components import (
5
+ create_lora_settings_ui,
6
+ create_embedding_ui,
7
+ create_conditioning_ui, create_vae_override_ui, create_api_key_ui,
8
+ create_reference_latent_ui
9
+ )
10
+
11
+ def create_ui():
12
+ prefix = "img2img"
13
+ components = {}
14
+
15
+ with gr.Column():
16
+ with gr.Row():
17
+ components[f'base_model_{prefix}'] = gr.Dropdown(label="Base Model", choices=list(MODEL_MAP_CHECKPOINT.keys()), value=list(MODEL_MAP_CHECKPOINT.keys())[0], scale=3)
18
+ with gr.Column(scale=1):
19
+ components[f'run_{prefix}'] = gr.Button("Run", variant="primary")
20
+
21
+ with gr.Row():
22
+ with gr.Column(scale=1):
23
+ components[f'input_image_{prefix}'] = gr.Image(type="pil", label="Input Image", height=255)
24
+
25
+ with gr.Column(scale=2):
26
+ components[f'prompt_{prefix}'] = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
27
+ components[f'neg_prompt_{prefix}'] = gr.Text(label="Negative prompt", lines=3, value="")
28
+
29
+ with gr.Row():
30
+ with gr.Column(scale=1):
31
+ components[f'denoise_{prefix}'] = gr.Slider(label="Denoise Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
32
+
33
+ with gr.Row():
34
+ components[f'sampler_{prefix}'] = gr.Dropdown(label="Sampler", choices=SAMPLER_CHOICES, value="euler")
35
+ components[f'scheduler_{prefix}'] = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value="simple")
36
+ with gr.Row():
37
+ components[f'steps_{prefix}'] = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
38
+ components[f'cfg_{prefix}'] = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=1.0)
39
+ with gr.Row():
40
+ components[f'seed_{prefix}'] = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
41
+ components[f'batch_size_{prefix}'] = gr.Slider(label="Batch Size", minimum=1, maximum=16, step=1, value=1)
42
+ with gr.Row():
43
+ components[f'zero_gpu_{prefix}'] = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional: Set how long to reserve the GPU. Longer jobs may need more time.")
44
+
45
+ components[f'clip_skip_{prefix}'] = gr.State(value=1)
46
+
47
+ with gr.Column(scale=1):
48
+ components[f'result_{prefix}'] = gr.Gallery(label="Result", show_label=False, columns=1, object_fit="contain", height=505)
49
+
50
+ components.update(create_api_key_ui(prefix))
51
+ components.update(create_lora_settings_ui(prefix))
52
+ # components.update(create_diffsynth_controlnet_ui(prefix))
53
+ # components.update(create_controlnet_ui(prefix))
54
+ # components.update(create_embedding_ui(prefix))
55
+ components.update(create_reference_latent_ui(prefix))
56
+ components.update(create_conditioning_ui(prefix))
57
+ # components.update(create_vae_override_ui(prefix))
58
+
59
+ return components
ui/shared/inpaint_ui.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.settings import MODEL_MAP_CHECKPOINT
3
+ from .ui_components import (
4
+ create_base_parameter_ui, create_lora_settings_ui,
5
+ create_embedding_ui,
6
+ create_conditioning_ui, create_vae_override_ui, create_api_key_ui,
7
+ create_reference_latent_ui
8
+ )
9
+
10
+ def create_ui():
11
+ prefix = "inpaint"
12
+ components = {}
13
+
14
+ with gr.Column():
15
+ with gr.Row() as model_and_run_row:
16
+ components[f'base_model_{prefix}'] = gr.Dropdown(
17
+ label="Base Model",
18
+ choices=list(MODEL_MAP_CHECKPOINT.keys()),
19
+ value=list(MODEL_MAP_CHECKPOINT.keys())[0],
20
+ scale=3
21
+ )
22
+ with gr.Column(scale=1):
23
+ components[f'run_{prefix}'] = gr.Button("Run Inpaint", variant="primary")
24
+
25
+ components[f'model_and_run_row_{prefix}'] = model_and_run_row
26
+
27
+ with gr.Row() as main_content_row:
28
+ with gr.Column(scale=1) as editor_column:
29
+ components[f'view_mode_{prefix}'] = gr.Radio(
30
+ ["Normal View", "Fullscreen View"],
31
+ label="Editor View",
32
+ value="Normal View",
33
+ interactive=True
34
+ )
35
+ components[f'input_image_dict_{prefix}'] = gr.ImageEditor(
36
+ type="pil",
37
+ label="Image & Mask",
38
+ height=272
39
+ )
40
+ components[f'editor_column_{prefix}'] = editor_column
41
+
42
+ with gr.Column(scale=2) as prompts_column:
43
+ components[f'prompt_{prefix}'] = gr.Text(label="Prompt", lines=6, placeholder="Describe what to fill in the mask...")
44
+ components[f'neg_prompt_{prefix}'] = gr.Text(label="Negative prompt", lines=6, value="")
45
+ components[f'prompts_column_{prefix}'] = prompts_column
46
+
47
+ with gr.Row() as params_and_gallery_row:
48
+ with gr.Column(scale=1):
49
+ param_defaults = {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}
50
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
51
+ with gr.Row():
52
+ components[f'sampler_{prefix}'] = gr.Dropdown(label="Sampler", choices=SAMPLER_CHOICES, value="euler")
53
+ components[f'scheduler_{prefix}'] = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value="simple")
54
+ with gr.Row():
55
+ components[f'steps_{prefix}'] = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
56
+ components[f'cfg_{prefix}'] = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=1.0)
57
+ with gr.Row():
58
+ components[f'seed_{prefix}'] = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
59
+ components[f'batch_size_{prefix}'] = gr.Slider(label="Batch Size", minimum=1, maximum=16, step=1, value=1)
60
+ with gr.Row():
61
+ components[f'zero_gpu_{prefix}'] = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional: Set how long to reserve the GPU.")
62
+
63
+ components[f'clip_skip_{prefix}'] = gr.State(value=1)
64
+ components[f'width_{prefix}'] = gr.State(value=512)
65
+ components[f'height_{prefix}'] = gr.State(value=512)
66
+
67
+ with gr.Column(scale=1):
68
+ components[f'result_{prefix}'] = gr.Gallery(label="Result", show_label=False, columns=1, object_fit="contain", height=414)
69
+
70
+ components[f'params_and_gallery_row_{prefix}'] = params_and_gallery_row
71
+
72
+ with gr.Column() as accordion_wrapper:
73
+ components.update(create_api_key_ui(prefix))
74
+ components.update(create_lora_settings_ui(prefix))
75
+ # components.update(create_diffsynth_controlnet_ui(prefix))
76
+ # components.update(create_controlnet_ui(prefix))
77
+ # components.update(create_embedding_ui(prefix))
78
+ components.update(create_reference_latent_ui(prefix))
79
+ components.update(create_conditioning_ui(prefix))
80
+ # components.update(create_vae_override_ui(prefix))
81
+ components[f'accordion_wrapper_{prefix}'] = accordion_wrapper
82
+
83
+ return components
ui/shared/outpaint_ui.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.settings import MODEL_MAP_CHECKPOINT
3
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
4
+ from .ui_components import (
5
+ create_lora_settings_ui,
6
+ create_embedding_ui,
7
+ create_conditioning_ui, create_vae_override_ui, create_api_key_ui,
8
+ create_reference_latent_ui
9
+ )
10
+
11
+ def create_ui():
12
+ prefix = "outpaint"
13
+ components = {}
14
+
15
+ with gr.Column():
16
+ with gr.Row():
17
+ components[f'base_model_{prefix}'] = gr.Dropdown(
18
+ label="Base Model",
19
+ choices=list(MODEL_MAP_CHECKPOINT.keys()),
20
+ value=list(MODEL_MAP_CHECKPOINT.keys())[0],
21
+ scale=3
22
+ )
23
+ with gr.Column(scale=1):
24
+ components[f'run_{prefix}'] = gr.Button("Run Outpaint", variant="primary")
25
+
26
+ with gr.Row():
27
+ with gr.Column(scale=1):
28
+ components[f'input_image_{prefix}'] = gr.Image(type="pil", label="Input Image", height=255)
29
+ with gr.Column(scale=2):
30
+ components[f'prompt_{prefix}'] = gr.Text(label="Prompt", lines=3, placeholder="Describe the content for the expanded areas...")
31
+ components[f'neg_prompt_{prefix}'] = gr.Text(label="Negative prompt", lines=3, value="")
32
+
33
+ with gr.Row():
34
+ with gr.Column(scale=1):
35
+ with gr.Row():
36
+ components[f'outpaint_left_{prefix}'] = gr.Slider(label="Pad Left", minimum=0, maximum=512, step=64, value=0)
37
+ components[f'outpaint_right_{prefix}'] = gr.Slider(label="Pad Right", minimum=0, maximum=512, step=64, value=256)
38
+ with gr.Row():
39
+ components[f'outpaint_top_{prefix}'] = gr.Slider(label="Pad Top", minimum=0, maximum=512, step=64, value=0)
40
+ components[f'outpaint_bottom_{prefix}'] = gr.Slider(label="Pad Bottom", minimum=0, maximum=512, step=64, value=0)
41
+
42
+ with gr.Row():
43
+ components[f'sampler_{prefix}'] = gr.Dropdown(label="Sampler", choices=SAMPLER_CHOICES, value="euler")
44
+ components[f'scheduler_{prefix}'] = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value="simple")
45
+ with gr.Row():
46
+ components[f'steps_{prefix}'] = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
47
+ components[f'cfg_{prefix}'] = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=1.0)
48
+ with gr.Row():
49
+ components[f'seed_{prefix}'] = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
50
+ components[f'batch_size_{prefix}'] = gr.Slider(label="Batch Size", minimum=1, maximum=16, step=1, value=1)
51
+ with gr.Row():
52
+ components[f'zero_gpu_{prefix}'] = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional: Set how long to reserve the GPU.")
53
+
54
+ components[f'clip_skip_{prefix}'] = gr.State(value=1)
55
+ components[f'width_{prefix}'] = gr.State(value=512)
56
+ components[f'height_{prefix}'] = gr.State(value=512)
57
+
58
+ with gr.Column(scale=1):
59
+ components[f'result_{prefix}'] = gr.Gallery(label="Result", show_label=False, columns=1, object_fit="contain", height=595)
60
+
61
+ components.update(create_api_key_ui(prefix))
62
+ components.update(create_lora_settings_ui(prefix))
63
+ # components.update(create_diffsynth_controlnet_ui(prefix))
64
+ # components.update(create_controlnet_ui(prefix))
65
+ # components.update(create_embedding_ui(prefix))
66
+ components.update(create_reference_latent_ui(prefix))
67
+ components.update(create_conditioning_ui(prefix))
68
+ # components.update(create_vae_override_ui(prefix))
69
+
70
+ return components
ui/shared/txt2img_ui.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.settings import MODEL_MAP_CHECKPOINT
3
+ from .ui_components import (
4
+ create_base_parameter_ui, create_lora_settings_ui,
5
+ create_embedding_ui,
6
+ create_conditioning_ui, create_vae_override_ui, create_api_key_ui,
7
+ create_reference_latent_ui
8
+ )
9
+
10
+ def create_ui():
11
+ """Creates the UI components for the Txt2Img tab."""
12
+ prefix = "txt2img"
13
+ components = {}
14
+
15
+ with gr.Column():
16
+ with gr.Row():
17
+ components[f'base_model_{prefix}'] = gr.Dropdown(label="Base Model", choices=list(MODEL_MAP_CHECKPOINT.keys()), value=list(MODEL_MAP_CHECKPOINT.keys())[0], scale=3)
18
+ with gr.Column(scale=1):
19
+ components[f'run_{prefix}'] = gr.Button("Run", variant="primary")
20
+
21
+ components[f'prompt_{prefix}'] = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
22
+ components[f'neg_prompt_{prefix}'] = gr.Text(label="Negative prompt", lines=3, value="")
23
+
24
+ with gr.Row():
25
+ with gr.Column(scale=1):
26
+ param_defaults = {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}
27
+ components.update(create_base_parameter_ui(prefix, param_defaults))
28
+ with gr.Column(scale=1):
29
+ components[f'result_{prefix}'] = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height=627)
30
+
31
+ components.update(create_api_key_ui(prefix))
32
+ components.update(create_lora_settings_ui(prefix))
33
+ # components.update(create_diffsynth_controlnet_ui(prefix))
34
+ # components.update(create_controlnet_ui(prefix))
35
+ # components.update(create_embedding_ui(prefix))
36
+ components.update(create_reference_latent_ui(prefix))
37
+ components.update(create_conditioning_ui(prefix))
38
+ # components.update(create_vae_override_ui(prefix))
39
+
40
+ return components
ui/shared/ui_components.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
3
+ from core.settings import (
4
+ MAX_LORAS, LORA_SOURCE_CHOICES, MAX_EMBEDDINGS, MAX_CONDITIONINGS,
5
+ MAX_CONTROLNETS, RESOLUTION_MAP, MAX_REFERENCE_LATENTS
6
+ )
7
+ import yaml
8
+ import os
9
+ from functools import lru_cache
10
+
11
+ def create_base_parameter_ui(prefix, defaults=None):
12
+ if defaults is None:
13
+ defaults = {}
14
+
15
+ components = {}
16
+
17
+ with gr.Row():
18
+ components[f'aspect_ratio_{prefix}'] = gr.Dropdown(
19
+ label="Aspect Ratio",
20
+ choices=list(RESOLUTION_MAP['sdxl'].keys()),
21
+ value="1:1 (Square)",
22
+ interactive=True
23
+ )
24
+ with gr.Row():
25
+ components[f'width_{prefix}'] = gr.Number(label="Width", value=defaults.get('w', 1024), interactive=True)
26
+ components[f'height_{prefix}'] = gr.Number(label="Height", value=defaults.get('h', 1024), interactive=True)
27
+ with gr.Row():
28
+ components[f'sampler_{prefix}'] = gr.Dropdown(label="Sampler", choices=SAMPLER_CHOICES, value="euler")
29
+ components[f'scheduler_{prefix}'] = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value="simple")
30
+ with gr.Row():
31
+ components[f'steps_{prefix}'] = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
32
+ components[f'cfg_{prefix}'] = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=1.0)
33
+ with gr.Row():
34
+ components[f'seed_{prefix}'] = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
35
+ components[f'batch_size_{prefix}'] = gr.Slider(label="Batch Size", minimum=1, maximum=16, step=1, value=1)
36
+ with gr.Row():
37
+ components[f'zero_gpu_{prefix}'] = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60, Max: 120", info="Optional: Set how long to reserve the GPU. Longer jobs may need more time.")
38
+
39
+ components[f'clip_skip_{prefix}'] = gr.State(value=1)
40
+
41
+ return components
42
+
43
+
44
+ def create_api_key_ui(prefix: str):
45
+ components = {}
46
+ with gr.Accordion("API Key Settings", open=False) as api_key_accordion:
47
+ components[f'api_key_accordion_{prefix}'] = api_key_accordion
48
+ gr.Markdown("💡 **Tip:** Enter API key (optional). An API key is required for resources that need a login to download. The key will be used for all Civitai downloads on this tab. You can also manually upload the corresponding files to avoid API Key leakage caused by potential vulnerabilities.")
49
+ with gr.Row():
50
+ components[f'civitai_api_key_{prefix}'] = gr.Textbox(
51
+ label="Civitai API Key",
52
+ type="password",
53
+ placeholder="Enter your Civitai API key here (optional)"
54
+ )
55
+ return components
56
+
57
+
58
+ def create_lora_settings_ui(prefix: str):
59
+ components = {}
60
+
61
+ lora_rows, lora_sources, lora_ids, lora_scales, lora_uploads = [], [], [], [], []
62
+
63
+ with gr.Accordion("LoRA Settings", open=False) as lora_accordion:
64
+ components[f'lora_accordion_{prefix}'] = lora_accordion
65
+ gr.Markdown("💡 **Tip:** When downloading from Civitai, please use the **Version ID**, not the Model ID. You can find the Version ID in the URL (e.g., `civitai.com/models/123?modelVersionId=456`) or under the model's download button.")
66
+ components[f'lora_count_state_{prefix}'] = gr.State(1)
67
+
68
+ for i in range(MAX_LORAS):
69
+ with gr.Row(visible=i==0) as row:
70
+ source = gr.Dropdown(label=f"LoRA Source {i+1}", choices=LORA_SOURCE_CHOICES, value=LORA_SOURCE_CHOICES[0], scale=1)
71
+ lora_id = gr.Textbox(label=f"Civitai Version ID / File", placeholder="Civitai Version ID or Filename", scale=2, type="text")
72
+ scale = gr.Slider(label=f"Scale", minimum=-2.0, maximum=2.0, step=0.05, value=0.8, scale=1)
73
+ upload = gr.UploadButton(label="Upload", file_types=[".safetensors"], scale=1)
74
+
75
+ lora_rows.append(row)
76
+ lora_sources.append(source)
77
+ lora_ids.append(lora_id)
78
+ lora_scales.append(scale)
79
+ lora_uploads.append(upload)
80
+
81
+ with gr.Row():
82
+ components[f'add_lora_button_{prefix}'] = gr.Button("Add LoRA", variant="secondary")
83
+ components[f'delete_lora_button_{prefix}'] = gr.Button("Remove LoRA", variant="secondary", visible=False)
84
+
85
+ components[f'lora_rows_{prefix}'] = lora_rows
86
+ components[f'lora_sources_{prefix}'] = lora_sources
87
+ components[f'lora_ids_{prefix}'] = lora_ids
88
+ components[f'lora_scales_{prefix}'] = lora_scales
89
+ components[f'lora_uploads_{prefix}'] = lora_uploads
90
+
91
+ all_lora_components_flat = []
92
+ for i in range(MAX_LORAS):
93
+ all_lora_components_flat.extend([lora_sources[i], lora_ids[i], lora_scales[i], lora_uploads[i]])
94
+ components[f'all_lora_components_flat_{prefix}'] = all_lora_components_flat
95
+
96
+ return components
97
+
98
+ def create_embedding_ui(prefix: str):
99
+ components = {}
100
+ key = lambda name: f"{name}_{prefix}"
101
+
102
+ with gr.Accordion("Embedding Settings", open=False, visible=True) as accordion:
103
+ components[key('embedding_accordion')] = accordion
104
+ gr.Markdown("💡 **Tip:** Embeddings are automatically added to your prompt using `embedding:filename` syntax. When downloading from Civitai, please use the **Version ID**, not the Model ID. You can find the Version ID in the URL (e.g., `civitai.com/models/123?modelVersionId=456`) or under the model's download button. For instance, using the Version ID `456` from the example above would automatically append `embedding:civitai_456` to your positive prompt.")
105
+
106
+ embedding_rows, sources, ids, files, upload_buttons = [], [], [], [], []
107
+ components.update({
108
+ key('embedding_rows'): embedding_rows,
109
+ key('embeddings_sources'): sources,
110
+ key('embeddings_ids'): ids,
111
+ key('embeddings_files'): files,
112
+ key('embeddings_uploads'): upload_buttons
113
+ })
114
+
115
+ for i in range(MAX_EMBEDDINGS):
116
+ with gr.Row(visible=(i < 1)) as row:
117
+ sources.append(gr.Dropdown(label=f"Embedding Source {i+1}", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1, interactive=True))
118
+ ids.append(gr.Textbox(label="Civitai Version ID / File", placeholder="Civitai Version ID or Filename", scale=3, interactive=True, type="text"))
119
+ upload_btn = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1)
120
+ files.append(gr.State(None))
121
+ upload_buttons.append(upload_btn)
122
+ embedding_rows.append(row)
123
+
124
+ with gr.Row():
125
+ components[key('add_embedding_button')] = gr.Button("✚ Add Embedding")
126
+ components[key('delete_embedding_button')] = gr.Button("➖ Delete Embedding", visible=False)
127
+ components[key('embedding_count_state')] = gr.State(1)
128
+
129
+ all_embedding_components_flat = []
130
+ for i in range(MAX_EMBEDDINGS):
131
+ all_embedding_components_flat.extend([sources[i], ids[i], files[i]])
132
+ components[key('all_embedding_components_flat')] = all_embedding_components_flat
133
+
134
+ return components
135
+
136
+ def create_conditioning_ui(prefix: str):
137
+ components = {}
138
+ key = lambda name: f"{name}_{prefix}"
139
+
140
+ with gr.Accordion("Conditioning Settings", open=False) as accordion:
141
+ components[key('conditioning_accordion')] = accordion
142
+ gr.Markdown("💡 **Tip:** Define rectangular areas and assign specific prompts to them. Coordinates (X, Y) start from the top-left corner.")
143
+
144
+ cond_rows, prompts, widths, heights, xs, ys, strengths = [], [], [], [], [], [], []
145
+ components.update({
146
+ key('conditioning_rows'): cond_rows,
147
+ key('conditioning_prompts'): prompts,
148
+ key('conditioning_widths'): widths,
149
+ key('conditioning_heights'): heights,
150
+ key('conditioning_xs'): xs,
151
+ key('conditioning_ys'): ys,
152
+ key('conditioning_strengths'): strengths
153
+ })
154
+
155
+ for i in range(MAX_CONDITIONINGS):
156
+ with gr.Column(visible=(i < 1)) as row_wrapper:
157
+ prompts.append(gr.Textbox(label=f"Area Prompt {i+1}", lines=2, interactive=True))
158
+ with gr.Row():
159
+ xs.append(gr.Number(label="X", value=0, interactive=True, step=8, scale=1))
160
+ ys.append(gr.Number(label="Y", value=0, interactive=True, step=8, scale=1))
161
+ widths.append(gr.Number(label="Width", value=512, interactive=True, step=8, scale=1))
162
+ heights.append(gr.Number(label="Height", value=512, interactive=True, step=8, scale=1))
163
+ strengths.append(gr.Slider(label="Strength", minimum=0.1, maximum=2.0, step=0.05, value=1.0, interactive=True, scale=2))
164
+ cond_rows.append(row_wrapper)
165
+
166
+ with gr.Row():
167
+ components[key('add_conditioning_button')] = gr.Button("✚ Add Area")
168
+ components[key('delete_conditioning_button')] = gr.Button("➖ Delete Area", visible=False)
169
+ components[key('conditioning_count_state')] = gr.State(1)
170
+
171
+ all_cond_components_flat = prompts + widths + heights + xs + ys + strengths
172
+ components[key('all_conditioning_components_flat')] = all_cond_components_flat
173
+
174
+ return components
175
+
176
+ def create_reference_latent_ui(prefix: str):
177
+ components = {}
178
+ key = lambda name: f"{name}_{prefix}"
179
+
180
+ with gr.Accordion("Reference Edit", open=False) as accordion:
181
+ components[key('reference_latent_accordion')] = accordion
182
+ gr.Markdown("💡 **Tip:** For multimodal models (like FLUX.2), this feature enables powerful editing and combining capabilities. In txt2img mode, adding a single reference image performs an **Image Edit**, while adding multiple images performs an **Image Combine**.")
183
+
184
+ ref_rows, ref_images = [], []
185
+ components.update({
186
+ key('reference_latent_rows'): ref_rows,
187
+ key('reference_latent_images'): ref_images,
188
+ })
189
+
190
+ with gr.Row():
191
+ for i in range(MAX_REFERENCE_LATENTS):
192
+ with gr.Column(visible=(i < 1), min_width=160) as row_wrapper:
193
+ ref_images.append(gr.Image(type="pil", label=f"Reference {i+1}", sources=["upload"], height=150))
194
+ ref_rows.append(row_wrapper)
195
+
196
+ with gr.Row():
197
+ components[key('add_reference_latent_button')] = gr.Button("✚ Add Reference Image")
198
+ components[key('delete_reference_latent_button')] = gr.Button("➖ Delete Reference Image", visible=False)
199
+ components[key('reference_latent_count_state')] = gr.State(1)
200
+
201
+ components[key('all_reference_latent_components_flat')] = ref_images
202
+
203
+ return components
204
+
205
+ def create_vae_override_ui(prefix: str):
206
+ components = {}
207
+ key = lambda name: f"{name}_{prefix}"
208
+ source_choices = ["None"] + LORA_SOURCE_CHOICES
209
+
210
+ with gr.Accordion("VAE Settings (Override)", open=False) as accordion:
211
+ components[key('vae_accordion')] = accordion
212
+ gr.Markdown("💡 **Tip:** When downloading from Civitai, please use the **Version ID**, not the Model ID. You can find the Version ID in the URL (e.g., `civitai.com/models/123?modelVersionId=456`) or under the model's download button.")
213
+ with gr.Row():
214
+ components[key('vae_source')] = gr.Dropdown(
215
+ label="VAE Source",
216
+ choices=source_choices,
217
+ value="None",
218
+ scale=1,
219
+ interactive=True
220
+ )
221
+ components[key('vae_id')] = gr.Textbox(
222
+ label="Civitai Version ID / File",
223
+ placeholder="Civitai Version ID or Filename",
224
+ scale=3,
225
+ interactive=True,
226
+ type="text"
227
+ )
228
+ upload_btn = gr.UploadButton(
229
+ "Upload",
230
+ file_types=[".safetensors"],
231
+ scale=1
232
+ )
233
+ components[key('vae_upload_button')] = upload_btn
234
+ components[key('vae_file')] = gr.State(None)
235
+
236
+ return components
utils/__init__.py ADDED
File without changes
utils/app_utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import hashlib
4
+ import re
5
+ from typing import Sequence, Mapping, Any, Union, Set
6
+ from pathlib import Path
7
+ import shutil
8
+
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download, constants as hf_constants
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image, ImageChops
14
+
15
+
16
+ from core.settings import *
17
+
18
+ DISK_LIMIT_GB = 120
19
+ MODELS_ROOT_DIR = "ComfyUI/models"
20
+
21
+ PREPROCESSOR_MODEL_MAP = None
22
+ PREPROCESSOR_PARAMETER_MAP = None
23
+
24
+
25
+ def save_uploaded_file_with_hash(file_obj: gr.File, target_dir: str) -> str:
26
+ if not file_obj:
27
+ return ""
28
+
29
+ temp_path = file_obj.name
30
+
31
+ sha256 = hashlib.sha256()
32
+ with open(temp_path, 'rb') as f:
33
+ for block in iter(lambda: f.read(65536), b''):
34
+ sha256.update(block)
35
+
36
+ file_hash = sha256.hexdigest()
37
+ _, extension = os.path.splitext(temp_path)
38
+ hashed_filename = f"{file_hash}{extension.lower()}"
39
+
40
+ dest_path = os.path.join(target_dir, hashed_filename)
41
+
42
+ os.makedirs(target_dir, exist_ok=True)
43
+ if not os.path.exists(dest_path):
44
+ shutil.copy(temp_path, dest_path)
45
+ print(f"✅ Saved uploaded file as: {dest_path}")
46
+ else:
47
+ print(f"ℹ️ File already exists (deduplicated): {dest_path}")
48
+
49
+ return hashed_filename
50
+
51
+
52
+ def bytes_to_gb(byte_size: int) -> float:
53
+ if byte_size is None or byte_size == 0:
54
+ return 0.0
55
+ return round(byte_size / (1024 ** 3), 2)
56
+
57
+ def get_directory_size(path: str) -> int:
58
+ total_size = 0
59
+ if not os.path.exists(path):
60
+ return 0
61
+ try:
62
+ for dirpath, _, filenames in os.walk(path):
63
+ for f in filenames:
64
+ fp = os.path.join(dirpath, f)
65
+ if os.path.isfile(fp) and not os.path.islink(fp):
66
+ total_size += os.path.getsize(fp)
67
+ except OSError as e:
68
+ print(f"Warning: Could not access {path} to calculate size: {e}")
69
+ return total_size
70
+
71
+ def enforce_disk_limit():
72
+ disk_limit_bytes = DISK_LIMIT_GB * (1024 ** 3)
73
+ cache_dir = hf_constants.HF_HUB_CACHE
74
+
75
+ if not os.path.exists(cache_dir):
76
+ return
77
+
78
+ print(f"--- [Storage Manager] Checking disk usage in '{cache_dir}' (Limit: {DISK_LIMIT_GB} GB) ---")
79
+
80
+ try:
81
+ all_files = []
82
+ current_size_bytes = 0
83
+ for dirpath, _, filenames in os.walk(cache_dir):
84
+ for f in filenames:
85
+ if f.endswith(".incomplete") or f.endswith(".lock"):
86
+ continue
87
+ file_path = os.path.join(dirpath, f)
88
+ if os.path.isfile(file_path) and not os.path.islink(file_path):
89
+ try:
90
+ file_size = os.path.getsize(file_path)
91
+ creation_time = os.path.getctime(file_path)
92
+ all_files.append((creation_time, file_path, file_size))
93
+ current_size_bytes += file_size
94
+ except OSError:
95
+ continue
96
+
97
+ print(f"--- [Storage Manager] Current usage: {bytes_to_gb(current_size_bytes)} GB ---")
98
+
99
+ if current_size_bytes > disk_limit_bytes:
100
+ print(f"--- [Storage Manager] Usage exceeds limit. Starting cleanup... ---")
101
+ all_files.sort(key=lambda x: x[0])
102
+
103
+ while current_size_bytes > disk_limit_bytes and all_files:
104
+ oldest_file_time, oldest_file_path, oldest_file_size = all_files.pop(0)
105
+ try:
106
+ os.remove(oldest_file_path)
107
+ current_size_bytes -= oldest_file_size
108
+ print(f"--- [Storage Manager] Deleted oldest file: {os.path.basename(oldest_file_path)} ({bytes_to_gb(oldest_file_size)} GB freed) ---")
109
+ except OSError as e:
110
+ print(f"--- [Storage Manager] Error deleting file {oldest_file_path}: {e} ---")
111
+
112
+ print(f"--- [Storage Manager] Cleanup finished. New usage: {bytes_to_gb(current_size_bytes)} GB ---")
113
+ else:
114
+ print("--- [Storage Manager] Disk usage is within the limit. No action needed. ---")
115
+
116
+ except Exception as e:
117
+ print(f"--- [Storage Manager] An unexpected error occurred: {e} ---")
118
+
119
+
120
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
121
+ try:
122
+ return obj[index]
123
+ except (KeyError, IndexError):
124
+ try:
125
+ return obj["result"][index]
126
+ except (KeyError, IndexError):
127
+ return None
128
+
129
+ def sanitize_prompt(prompt: str) -> str:
130
+ if not isinstance(prompt, str):
131
+ return ""
132
+ return "".join(char for char in prompt if char.isprintable() or char in ('\n', '\t'))
133
+
134
+ def sanitize_id(input_id: str) -> str:
135
+ if not isinstance(input_id, str):
136
+ return ""
137
+ return re.sub(r'[^0-9]', '', input_id)
138
+
139
+ def sanitize_url(url: str) -> str:
140
+ if not isinstance(url, str):
141
+ raise ValueError("URL must be a string.")
142
+ url = url.strip()
143
+ if not re.match(r'^https?://[^\s/$.?#].[^\s]*$', url):
144
+ raise ValueError("Invalid URL format or scheme. Only HTTP and HTTPS are allowed.")
145
+ return url
146
+
147
+ def sanitize_filename(filename: str) -> str:
148
+ if not isinstance(filename, str):
149
+ return ""
150
+ sanitized = filename.replace('..', '')
151
+ sanitized = re.sub(r'[^\w\.\-]', '_', sanitized)
152
+ return sanitized.lstrip('/\\')
153
+
154
+
155
+ def get_civitai_file_info(version_id: str) -> dict | None:
156
+ api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
157
+ try:
158
+ response = requests.get(api_url, timeout=10)
159
+ response.raise_for_status()
160
+ data = response.json()
161
+
162
+ for file_data in data.get('files', []):
163
+ if file_data.get('type') == 'Model' and file_data['name'].endswith(('.safetensors', '.pt', '.bin')):
164
+ return file_data
165
+
166
+ if data.get('files'):
167
+ return data['files'][0]
168
+ except Exception:
169
+ return None
170
+
171
+
172
+ def download_file(url: str, save_path: str, api_key: str = None, progress=None, desc: str = "") -> str:
173
+ enforce_disk_limit()
174
+
175
+ if os.path.exists(save_path):
176
+ return f"File already exists: {os.path.basename(save_path)}"
177
+
178
+ headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {}
179
+ try:
180
+ if progress:
181
+ progress(0, desc=desc)
182
+
183
+ response = requests.get(url, stream=True, headers=headers, timeout=15)
184
+ response.raise_for_status()
185
+ total_size = int(response.headers.get('content-length', 0))
186
+
187
+ with open(save_path, "wb") as f:
188
+ downloaded = 0
189
+ for chunk in response.iter_content(chunk_size=8192):
190
+ f.write(chunk)
191
+ if progress and total_size > 0:
192
+ downloaded += len(chunk)
193
+ progress(downloaded / total_size, desc=desc)
194
+ return f"Successfully downloaded: {os.path.basename(save_path)}"
195
+ except Exception as e:
196
+ if os.path.exists(save_path):
197
+ os.remove(save_path)
198
+ return f"Download failed for {os.path.basename(save_path)}: {e}"
199
+
200
+
201
+ def get_lora_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
202
+ if not id_or_url or not id_or_url.strip():
203
+ return None, "No ID/URL provided."
204
+
205
+ try:
206
+ if source == "Civitai":
207
+ version_id = sanitize_id(id_or_url)
208
+ if not version_id:
209
+ return None, "Invalid Civitai ID provided. Must be numeric."
210
+ filename = sanitize_filename(f"civitai_{version_id}.safetensors")
211
+ local_path = os.path.join(LORA_DIR, filename)
212
+ file_info = get_civitai_file_info(version_id)
213
+ api_key_to_use = civitai_key
214
+ source_name = f"Civitai ID {version_id}"
215
+ else:
216
+ return None, "Invalid source."
217
+
218
+ except ValueError as e:
219
+ return None, f"Input validation failed: {e}"
220
+
221
+ if os.path.exists(local_path):
222
+ return local_path, "File already exists."
223
+
224
+ if not file_info or not file_info.get('downloadUrl'):
225
+ return None, f"Could not get download link for {source_name}."
226
+
227
+ status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
228
+
229
+ return (local_path, status) if "Successfully" in status else (None, status)
230
+
231
+ def get_embedding_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
232
+ if not id_or_url or not id_or_url.strip():
233
+ return None, "No ID/URL provided."
234
+
235
+ try:
236
+ file_ext = ".safetensors"
237
+
238
+ if source == "Civitai":
239
+ version_id = sanitize_id(id_or_url)
240
+ if not version_id:
241
+ return None, "Invalid Civitai ID. Must be numeric."
242
+
243
+ file_info = get_civitai_file_info(version_id)
244
+ if file_info and file_info['name'].lower().endswith(('.pt', '.bin')):
245
+ file_ext = os.path.splitext(file_info['name'])[1]
246
+
247
+ filename = sanitize_filename(f"civitai_{version_id}{file_ext}")
248
+ local_path = os.path.join(EMBEDDING_DIR, filename)
249
+ api_key_to_use = civitai_key
250
+ source_name = f"Embedding Civitai ID {version_id}"
251
+ else:
252
+ return None, "Invalid source."
253
+
254
+ except ValueError as e:
255
+ return None, f"Input validation failed: {e}"
256
+
257
+ if os.path.exists(local_path):
258
+ return local_path, "File already exists."
259
+
260
+ if not file_info or not file_info.get('downloadUrl'):
261
+ return None, f"Could not get download link for {source_name}."
262
+
263
+ status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
264
+
265
+ return (local_path, status) if "Successfully" in status else (None, status)
266
+
267
+ def get_vae_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
268
+ if not id_or_url or not id_or_url.strip():
269
+ return None, "No ID/URL provided."
270
+
271
+ try:
272
+ file_ext = ".safetensors"
273
+
274
+ if source == "Civitai":
275
+ version_id = sanitize_id(id_or_url)
276
+ if not version_id:
277
+ return None, "Invalid Civitai ID. Must be numeric."
278
+
279
+ file_info = get_civitai_file_info(version_id)
280
+ if file_info and file_info['name'].lower().endswith(('.pt', '.bin')):
281
+ file_ext = os.path.splitext(file_info['name'])[1]
282
+
283
+ filename = sanitize_filename(f"civitai_{version_id}{file_ext}")
284
+ local_path = os.path.join(VAE_DIR, filename)
285
+ api_key_to_use = civitai_key
286
+ source_name = f"VAE Civitai ID {version_id}"
287
+ else:
288
+ return None, "Invalid source."
289
+
290
+ except ValueError as e:
291
+ return None, f"Input validation failed: {e}"
292
+
293
+ if os.path.exists(local_path):
294
+ return local_path, "File already exists."
295
+
296
+ if not file_info or not file_info.get('downloadUrl'):
297
+ return None, f"Could not get download link for {source_name}."
298
+
299
+ status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
300
+
301
+ return (local_path, status) if "Successfully" in status else (None, status)
302
+
303
+
304
+ def _ensure_model_downloaded(filename: str, progress=gr.Progress()):
305
+ download_info = ALL_FILE_DOWNLOAD_MAP.get(filename)
306
+ if not download_info:
307
+ raise gr.Error(f"Model component '{filename}' not found in file_list.yaml. Cannot download.")
308
+
309
+ category_to_dir_map = {
310
+ "diffusion_models": DIFFUSION_MODELS_DIR,
311
+ "text_encoders": TEXT_ENCODERS_DIR,
312
+ "vae": VAE_DIR,
313
+ "checkpoints": CHECKPOINT_DIR,
314
+ "loras": LORA_DIR,
315
+ "controlnet": CONTROLNET_DIR,
316
+ "model_patches": MODEL_PATCHES_DIR,
317
+ "clip_vision": os.path.join(os.path.dirname(LORA_DIR), "clip_vision")
318
+ }
319
+
320
+ category = download_info.get('category')
321
+ dest_dir = category_to_dir_map.get(category)
322
+ if not dest_dir:
323
+ raise ValueError(f"Unknown model category '{category}' for file '{filename}'.")
324
+
325
+ dest_path = os.path.join(dest_dir, filename)
326
+
327
+ if os.path.lexists(dest_path):
328
+ if not os.path.exists(dest_path):
329
+ print(f"⚠️ Found and removed broken symlink: {dest_path}")
330
+ os.remove(dest_path)
331
+ else:
332
+ return filename
333
+
334
+ source = download_info.get("source")
335
+ try:
336
+ progress(0, desc=f"Downloading: {filename}")
337
+
338
+ if source == "hf":
339
+ repo_id = download_info.get("repo_id")
340
+ hf_filename = download_info.get("repository_file_path", filename)
341
+ if not repo_id:
342
+ raise ValueError(f"repo_id is missing for HF model '{filename}'")
343
+
344
+ cached_path = hf_hub_download(repo_id=repo_id, filename=hf_filename)
345
+ os.makedirs(dest_dir, exist_ok=True)
346
+ os.symlink(cached_path, dest_path)
347
+ print(f"✅ Symlinked '{cached_path}' to '{dest_path}'")
348
+
349
+ elif source == "civitai":
350
+ model_version_id = download_info.get("model_version_id")
351
+ if not model_version_id:
352
+ raise ValueError(f"model_version_id is missing for Civitai model '{filename}'")
353
+
354
+ file_info = get_civitai_file_info(model_version_id)
355
+ if not file_info or not file_info.get('downloadUrl'):
356
+ raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}")
357
+
358
+ status = download_file(
359
+ file_info['downloadUrl'], dest_path, progress=progress, desc=f"Downloading: {filename}"
360
+ )
361
+ if "Failed" in status:
362
+ raise ConnectionError(status)
363
+ else:
364
+ raise NotImplementedError(f"Download source '{source}' is not implemented for '{filename}'")
365
+
366
+ progress(1.0, desc=f"Downloaded: {filename}")
367
+
368
+ except Exception as e:
369
+ if os.path.lexists(dest_path):
370
+ try:
371
+ os.remove(dest_path)
372
+ except OSError: pass
373
+ raise gr.Error(f"Failed to download and link '{filename}': {e}")
374
+
375
+ return filename
376
+
377
+ def ensure_controlnet_model_downloaded(filename: str, progress):
378
+ if not filename or filename == "None":
379
+ return
380
+ _ensure_model_downloaded(filename, progress)
381
+
382
+
383
+ def build_preprocessor_model_map():
384
+ global PREPROCESSOR_MODEL_MAP
385
+ if PREPROCESSOR_MODEL_MAP is not None: return PREPROCESSOR_MODEL_MAP
386
+ print("--- Building ControlNet Preprocessor model map ---")
387
+ manual_map = {
388
+ "dwpose": [("yzd-v/DWPose", "yolox_l.onnx"), ("yzd-v/DWPose", "dw-ll_ucoco_384.onnx"), ("hr16/UnJIT-DWPose", "dw-ll_ucoco.onnx"), ("hr16/DWPose-TorchScript-BatchSize5", "dw-ll_ucoco_384_bs5.torchscript.pt"), ("hr16/DWPose-TorchScript-BatchSize5", "rtmpose-m_ap10k_256_bs5.torchscript.pt"), ("hr16/yolo-nas-fp16", "yolo_nas_l_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_m_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_s_fp16.onnx")],
389
+ "densepose": [("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r50_fpn_dl.torchscript"), ("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r101_fpn_dl.torchscript")]
390
+ }
391
+ temp_map = {}
392
+ from nodes import NODE_DISPLAY_NAME_MAPPINGS
393
+ wrappers_dir = Path("./custom_nodes/comfyui_controlnet_aux/node_wrappers/")
394
+ if not wrappers_dir.exists():
395
+ print("⚠️ ControlNet AUX wrappers directory not found. Cannot build model map.")
396
+ PREPROCESSOR_MODEL_MAP = {}; return PREPROCESSOR_MODEL_MAP
397
+ for wrapper_file in wrappers_dir.glob("*.py"):
398
+ if wrapper_file.name == "__init__.py": continue
399
+ with open(wrapper_file, 'r', encoding='utf-8') as f:
400
+ content = f.read()
401
+ display_name_matches = re.findall(r'NODE_DISPLAY_NAME_MAPPINGS\s*=\s*{(?:.|\n)*?["\'](.*?)["\']\s*:\s*["\'](.*?)["\']', content)
402
+ for _, display_name in display_name_matches:
403
+ if display_name not in temp_map: temp_map[display_name] = []
404
+ manual_key = wrapper_file.stem
405
+ if manual_key in manual_map: temp_map[display_name].extend(manual_map[manual_key])
406
+ matches = re.findall(r"from_pretrained\s*\(\s*(?:filename=)?\s*f?[\"']([^\"']+)[\"']", content)
407
+ for model_filename in matches:
408
+ repo_id = "lllyasviel/Annotators"
409
+ if "depth_anything" in model_filename and "v2" in model_filename: repo_id = "LiheYoung/Depth-Anything-V2"
410
+ elif "depth_anything" in model_filename: repo_id = "LiheYoung/Depth-Anything"
411
+ elif "diffusion_edge" in model_filename: repo_id = "hr16/Diffusion-Edge"
412
+ temp_map[display_name].append((repo_id, model_filename))
413
+ final_map = {name: sorted(list(set(models))) for name, models in temp_map.items() if models}
414
+ PREPROCESSOR_MODEL_MAP = final_map
415
+ print("✅ ControlNet Preprocessor model map built."); return PREPROCESSOR_MODEL_MAP
416
+
417
+ def build_preprocessor_parameter_map():
418
+ global PREPROCESSOR_PARAMETER_MAP
419
+ if PREPROCESSOR_PARAMETER_MAP is not None: return
420
+ print("--- Building ControlNet Preprocessor parameter map ---")
421
+ param_map = {}
422
+ from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
423
+ for class_name, node_class in NODE_CLASS_MAPPINGS.items():
424
+ if not hasattr(node_class, "INPUT_TYPES"): continue
425
+ if hasattr(node_class, '__module__') and 'comfyui_controlnet_aux.node_wrappers' not in node_class.__module__: continue
426
+ display_name = NODE_DISPLAY_NAME_MAPPINGS.get(class_name)
427
+ if not display_name: continue
428
+ try:
429
+ input_types = node_class.INPUT_TYPES()
430
+ all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
431
+ params = []
432
+ for name, details in all_inputs.items():
433
+ if name in ['image', 'resolution', 'pose_kps']: continue
434
+ if not isinstance(details, (list, tuple)) or not details: continue
435
+ param_type = details[0]
436
+ param_config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
437
+ param_info = {"name": name, "type": param_type, "config": param_config}
438
+ params.append(param_info)
439
+ if params: param_map[display_name] = params
440
+ except Exception as e:
441
+ print(f"⚠️ Could not parse parameters for {display_name}: {e}")
442
+ PREPROCESSOR_PARAMETER_MAP = param_map
443
+ print("✅ ControlNet Preprocessor parameter map built.")
444
+
445
+ def print_welcome_message():
446
+ author_name = "RioShiina"
447
+ project_url = "https://huggingface.co/RioShiina"
448
+ border = "=" * 72
449
+
450
+ message = (
451
+ f"\n{border}\n\n"
452
+ f" Thank you for using this project!\n\n"
453
+ f" **Author:** {author_name}\n"
454
+ f" **Find more from the author:** {project_url}\n\n"
455
+ f" This project is open-source under the GNU General Public License v3.0 (GPL-3.0).\n"
456
+ f" As it's built upon GPL-3.0 components (like ComfyUI), any modifications you\n"
457
+ f" distribute must also be open-sourced under the same license.\n\n"
458
+ f" Your respect for the principles of free software is greatly appreciated!\n\n"
459
+ f"{border}\n"
460
+ )
461
+
462
+ print(message)
yaml/constants.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MAX_LORAS: 5
2
+ MAX_CONTROLNETS: 5
3
+ MAX_EMBEDDINGS: 5
4
+ MAX_CONDITIONINGS: 10
5
+ MAX_REFERENCE_LATENTS: 10
6
+ LORA_SOURCE_CHOICES: ["Civitai", "File"]
7
+
8
+ RESOLUTION_MAP:
9
+ sdxl:
10
+ "1:1 (Square)": [1024, 1024]
11
+ "16:9 (Landscape)": [1344, 768]
12
+ "9:16 (Portrait)": [768, 1344]
13
+ "4:3 (Classic)": [1152, 896]
14
+ "3:4 (Classic Portrait)": [896, 1152]
15
+ "3:2 (Photography)": [1216, 832]
16
+ "2:3 (Photography Portrait)": [832, 1216]
yaml/file_list.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ file:
2
+ diffusion_models:
3
+ # FLUX.2-klein-4B
4
+ - filename: "flux-2-klein-4b-fp8.safetensors"
5
+ source: "hf"
6
+ repo_id: "black-forest-labs/FLUX.2-klein-4b-fp8"
7
+ repository_file_path: "flux-2-klein-4b-fp8.safetensors"
8
+ # FLUX.2-klein-base-4B
9
+ - filename: "flux-2-klein-base-4b-fp8.safetensors"
10
+ source: "hf"
11
+ repo_id: "black-forest-labs/FLUX.2-klein-base-4b-fp8"
12
+ repository_file_path: "flux-2-klein-base-4b-fp8.safetensors"
13
+ # FLUX.2-klein-9B
14
+ - filename: "flux-2-klein-9b-fp8.safetensors"
15
+ source: "hf"
16
+ repo_id: "black-forest-labs/FLUX.2-klein-9b-fp8"
17
+ repository_file_path: "flux-2-klein-9b-fp8.safetensors"
18
+ # FLUX.2-klein-base-9B
19
+ - filename: "flux-2-klein-base-9b-fp8.safetensors"
20
+ source: "hf"
21
+ repo_id: "black-forest-labs/FLUX.2-klein-base-9b-fp8"
22
+ repository_file_path: "flux-2-klein-base-9b-fp8.safetensors"
23
+ # FLUX.2-dev
24
+ - filename: "flux2_dev_fp8mixed.safetensors"
25
+ source: "hf"
26
+ repo_id: "Comfy-Org/flux2-dev"
27
+ repository_file_path: "split_files/diffusion_models/flux2_dev_fp8mixed.safetensors"
28
+
29
+ text_encoders:
30
+ # FLUX.2-klein-4B & base
31
+ - filename: "qwen_3_4b.safetensors"
32
+ source: "hf"
33
+ repo_id: "Comfy-Org/vae-text-encorder-for-flux-klein-4b"
34
+ repository_file_path: "split_files/text_encoders/qwen_3_4b.safetensors"
35
+ # FLUX.2-klein-9B & base
36
+ - filename: "qwen_3_8b_fp8mixed.safetensors"
37
+ source: "hf"
38
+ repo_id: "Comfy-Org/vae-text-encorder-for-flux-klein-9b"
39
+ repository_file_path: "split_files/text_encoders/qwen_3_8b_fp8mixed.safetensors"
40
+ # FLUX.2-dev
41
+ - filename: "mistral_3_small_flux2_fp8.safetensors"
42
+ source: "hf"
43
+ repo_id: "Comfy-Org/flux2-dev"
44
+ repository_file_path: "split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors"
45
+
46
+ vae:
47
+ - filename: "flux2-vae.safetensors"
48
+ source: "hf"
49
+ repo_id: "Comfy-Org/flux2-dev"
50
+ repository_file_path: "split_files/vae/flux2-vae.safetensors"
yaml/injectors.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ injector_definitions:
2
+ dynamic_conditioning_chains:
3
+ module: "chain_injectors.conditioning_injector"
4
+ dynamic_reference_latent_chains:
5
+ module: "chain_injectors.reference_latent_injector"
6
+
7
+ injector_order:
8
+ - dynamic_reference_latent_chains
9
+ - dynamic_conditioning_chains
yaml/model_defaults.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Default:
2
+ steps: 20
3
+ cfg: 4.0
4
+ sampler_name: "euler"
5
+ scheduler: "simple"
6
+
7
+ FLUX.2:
8
+ _defaults:
9
+ steps: 20
10
+ cfg: 4.0
11
+ sampler_name: "euler"
12
+ scheduler: "simple"
13
+ "black-forest-labs/FLUX.2-klein-4B":
14
+ steps: 4
15
+ cfg: 1.0
16
+ "black-forest-labs/FLUX.2-klein-9B":
17
+ steps: 4
18
+ cfg: 1.0
yaml/model_list.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Checkpoint:
2
+ - display_name: "black-forest-labs/FLUX.2-klein-4B"
3
+ components:
4
+ unet: "flux-2-klein-4b-fp8.safetensors"
5
+ clip: "qwen_3_4b.safetensors"
6
+ vae: "flux2-vae.safetensors"
7
+ - display_name: "black-forest-labs/FLUX.2-klein-9B"
8
+ components:
9
+ unet: "flux-2-klein-9b-fp8.safetensors"
10
+ clip: "qwen_3_8b_fp8mixed.safetensors"
11
+ vae: "flux2-vae.safetensors"
12
+ - display_name: "black-forest-labs/FLUX.2-klein-base-4B"
13
+ components:
14
+ unet: "flux-2-klein-base-4b-fp8.safetensors"
15
+ clip: "qwen_3_4b.safetensors"
16
+ vae: "flux2-vae.safetensors"
17
+ - display_name: "black-forest-labs/FLUX.2-klein-base-9B"
18
+ components:
19
+ unet: "flux-2-klein-base-9b-fp8.safetensors"
20
+ clip: "qwen_3_8b_fp8mixed.safetensors"
21
+ vae: "flux2-vae.safetensors"
22
+ - display_name: "black-forest-labs/FLUX.2-dev (Need to set ZeroGPU Duration to 120)"
23
+ components:
24
+ unet: "flux2_dev_fp8mixed.safetensors"
25
+ clip: "mistral_3_small_flux2_fp8.safetensors"
26
+ vae: "flux2-vae.safetensors"
yaml/private_file_list.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ file:
2
+ diffusion_models:
3
+ # FLUX.2-klein-9B
4
+ - filename: "flux-2-klein-9b-fp8.safetensors"
5
+ source: "hf"
6
+ repo_id: "black-forest-labs/FLUX.2-klein-9b-fp8"
7
+ repository_file_path: "flux-2-klein-9b-fp8.safetensors"
8
+ # FLUX.2-klein-base-9B
9
+ - filename: "flux-2-klein-base-9b-fp8.safetensors"
10
+ source: "hf"
11
+ repo_id: "black-forest-labs/FLUX.2-klein-base-9b-fp8"
12
+ repository_file_path: "flux-2-klein-base-9b-fp8.safetensors"