Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import shutil | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageChops | |
| from typing import List, Dict, Any | |
| from collections import defaultdict, deque | |
| import numpy as np | |
| from .base_pipeline import BasePipeline | |
| from core.settings import * | |
| from comfy_integration.nodes import * | |
| from utils.app_utils import get_value_at_index, sanitize_prompt, get_lora_path, get_embedding_path, ensure_controlnet_model_downloaded, sanitize_filename | |
| from core.workflow_assembler import WorkflowAssembler | |
| class SdImagePipeline(BasePipeline): | |
| def get_required_models(self, model_display_name: str, **kwargs) -> List[str]: | |
| return [model_display_name] | |
| def _topological_sort(self, workflow: Dict[str, Any]) -> List[str]: | |
| graph = defaultdict(list) | |
| in_degree = {node_id: 0 for node_id in workflow} | |
| for node_id, node_info in workflow.items(): | |
| for input_value in node_info.get('inputs', {}).values(): | |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): | |
| source_node_id = input_value[0] | |
| if source_node_id in workflow: | |
| graph[source_node_id].append(node_id) | |
| in_degree[node_id] += 1 | |
| queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) | |
| sorted_nodes = [] | |
| while queue: | |
| current_node_id = queue.popleft() | |
| sorted_nodes.append(current_node_id) | |
| for neighbor_node_id in graph[current_node_id]: | |
| in_degree[neighbor_node_id] -= 1 | |
| if in_degree[neighbor_node_id] == 0: | |
| queue.append(neighbor_node_id) | |
| if len(sorted_nodes) != len(workflow): | |
| raise RuntimeError("Workflow contains a cycle and cannot be executed.") | |
| return sorted_nodes | |
| def _execute_workflow(self, workflow: Dict[str, Any], initial_objects: Dict[str, Any]): | |
| with torch.no_grad(): | |
| computed_outputs = initial_objects | |
| try: | |
| sorted_node_ids = self._topological_sort(workflow) | |
| print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}") | |
| except RuntimeError as e: | |
| print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---") | |
| for node_id, node_info in workflow.items(): | |
| print(f" Node {node_id} ({node_info['class_type']}):") | |
| for input_name, input_value in node_info['inputs'].items(): | |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): | |
| print(f" - {input_name} <- [{input_value[0]}, {input_value[1]}]") | |
| raise e | |
| for node_id in sorted_node_ids: | |
| if node_id in computed_outputs: | |
| continue | |
| node_info = workflow[node_id] | |
| class_type = node_info['class_type'] | |
| is_loader_with_filename = 'Loader' in class_type and any(key.endswith('_name') for key in node_info['inputs']) | |
| if node_id in initial_objects and is_loader_with_filename: | |
| continue | |
| node_class = NODE_CLASS_MAPPINGS.get(class_type) | |
| if node_class is None: | |
| raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?") | |
| node_instance = node_class() | |
| kwargs = {} | |
| for param_name, param_value in node_info['inputs'].items(): | |
| if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str): | |
| source_node_id, output_index = param_value | |
| if source_node_id not in computed_outputs: | |
| raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.") | |
| source_output_tuple = computed_outputs[source_node_id] | |
| kwargs[param_name] = get_value_at_index(source_output_tuple, output_index) | |
| else: | |
| kwargs[param_name] = param_value | |
| function_name = getattr(node_class, 'FUNCTION') | |
| execution_method = getattr(node_instance, function_name) | |
| result = execution_method(**kwargs) | |
| computed_outputs[node_id] = result | |
| final_node_id = None | |
| for node_id in reversed(sorted_node_ids): | |
| if workflow[node_id]['class_type'] == 'SaveImage': | |
| final_node_id = node_id | |
| break | |
| if not final_node_id: | |
| raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.") | |
| save_image_inputs = workflow[final_node_id]['inputs'] | |
| image_source_node_id, image_source_index = save_image_inputs['images'] | |
| return get_value_at_index(computed_outputs[image_source_node_id], image_source_index) | |
| def _gpu_logic(self, ui_inputs: Dict, loras_string: str, workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)): | |
| model_display_name = ui_inputs['model_display_name'] | |
| progress(0.4, desc="Executing workflow...") | |
| initial_objects = {} | |
| decoded_images_tensor = self._execute_workflow(workflow, initial_objects=initial_objects) | |
| output_images = [] | |
| start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1) | |
| for i in range(decoded_images_tensor.shape[0]): | |
| img_tensor = decoded_images_tensor[i] | |
| pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) | |
| current_seed = start_seed + i | |
| width_for_meta = ui_inputs.get('width', 'N/A') | |
| height_for_meta = ui_inputs.get('height', 'N/A') | |
| params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n" | |
| params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}" | |
| if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}" | |
| if loras_string: params_string += f", {loras_string}" | |
| pil_image.info = {'parameters': params_string.strip()} | |
| output_images.append(pil_image) | |
| return output_images | |
| def run(self, ui_inputs: Dict, progress): | |
| progress(0, desc="Preparing models...") | |
| task_type = ui_inputs['task_type'] | |
| ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', '')) | |
| ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', '')) | |
| required_models = self.get_required_models(model_display_name=ui_inputs['model_display_name']) | |
| self.model_manager.ensure_models_downloaded(required_models, progress=progress) | |
| lora_data = ui_inputs.get('lora_data', []) | |
| active_loras_for_gpu, active_loras_for_meta = [], [] | |
| if lora_data: | |
| sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] | |
| for i, (source, lora_id, scale, _) in enumerate(zip(sources, ids, scales, files)): | |
| if scale > 0 and lora_id and lora_id.strip(): | |
| lora_filename = None | |
| if source == "File": | |
| lora_filename = sanitize_filename(lora_id) | |
| elif source == "Civitai": | |
| local_path, status = get_lora_path(source, lora_id, ui_inputs['civitai_api_key'], progress) | |
| if local_path: lora_filename = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare LoRA {lora_id}: {status}") | |
| if lora_filename: | |
| active_loras_for_gpu.append({"lora_name": lora_filename, "strength_model": scale, "strength_clip": scale}) | |
| active_loras_for_meta.append(f"{source} {lora_id}:{scale}") | |
| ui_inputs['denoise'] = 1.0 | |
| if task_type == 'img2img': ui_inputs['denoise'] = ui_inputs.get('img2img_denoise', 0.7) | |
| elif task_type == 'hires_fix': ui_inputs['denoise'] = ui_inputs.get('hires_denoise', 0.55) | |
| temp_files_to_clean = [] | |
| if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR) | |
| if task_type == 'img2img': | |
| input_image_pil = ui_inputs.get('img2img_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| ui_inputs['width'] = input_image_pil.width | |
| ui_inputs['height'] = input_image_pil.height | |
| elif task_type == 'inpaint': | |
| inpaint_dict = ui_inputs.get('inpaint_image_dict') | |
| if not inpaint_dict or not inpaint_dict.get('background') or not inpaint_dict.get('layers'): | |
| raise gr.Error("Inpainting requires an input image and a drawn mask.") | |
| background_img = inpaint_dict['background'].convert("RGBA") | |
| composite_mask_pil = Image.new('L', background_img.size, 0) | |
| for layer in inpaint_dict['layers']: | |
| if layer: | |
| layer_alpha = layer.split()[-1] | |
| composite_mask_pil = ImageChops.lighter(composite_mask_pil, layer_alpha) | |
| inverted_mask_alpha = Image.fromarray(255 - np.array(composite_mask_pil), mode='L') | |
| r, g, b, _ = background_img.split() | |
| composite_image_with_mask = Image.merge('RGBA', [r, g, b, inverted_mask_alpha]) | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_inpaint_composite_{random.randint(1000, 9999)}.png") | |
| composite_image_with_mask.save(temp_file_path, "PNG") | |
| ui_inputs['inpaint_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| ui_inputs.pop('inpaint_mask', None) | |
| elif task_type == 'outpaint': | |
| input_image_pil = ui_inputs.get('outpaint_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| elif task_type == 'hires_fix': | |
| input_image_pil = ui_inputs.get('hires_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| embedding_data = ui_inputs.get('embedding_data', []) | |
| embedding_filenames = [] | |
| if embedding_data: | |
| emb_sources, emb_ids, emb_files = embedding_data[0::3], embedding_data[1::3], embedding_data[2::3] | |
| for i, (source, emb_id, _) in enumerate(zip(emb_sources, emb_ids, emb_files)): | |
| if emb_id and emb_id.strip(): | |
| emb_filename = None | |
| if source == "File": | |
| emb_filename = sanitize_filename(emb_id) | |
| elif source == "Civitai": | |
| local_path, status = get_embedding_path(source, emb_id, ui_inputs['civitai_api_key'], progress) | |
| if local_path: emb_filename = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare Embedding {emb_id}: {status}") | |
| if emb_filename: | |
| embedding_filenames.append(emb_filename) | |
| if embedding_filenames: | |
| embedding_prompt_text = " ".join([f"embedding:{f}" for f in embedding_filenames]) | |
| if ui_inputs['positive_prompt']: | |
| ui_inputs['positive_prompt'] = f"{ui_inputs['positive_prompt']}, {embedding_prompt_text}" | |
| else: | |
| ui_inputs['positive_prompt'] = embedding_prompt_text | |
| from utils.app_utils import get_vae_path | |
| vae_source = ui_inputs.get('vae_source') | |
| vae_id = ui_inputs.get('vae_id') | |
| vae_file = ui_inputs.get('vae_file') | |
| vae_name_override = None | |
| if vae_source and vae_source != "None": | |
| if vae_source == "File": | |
| vae_name_override = sanitize_filename(vae_id) | |
| elif vae_source == "Civitai" and vae_id and vae_id.strip(): | |
| local_path, status = get_vae_path(vae_source, vae_id, ui_inputs.get('civitai_api_key'), progress) | |
| if local_path: vae_name_override = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare VAE {vae_id}: {status}") | |
| if vae_name_override: | |
| ui_inputs['vae_name'] = vae_name_override | |
| conditioning_data = ui_inputs.get('conditioning_data', []) | |
| active_conditioning = [] | |
| if conditioning_data: | |
| num_units = len(conditioning_data) // 6 | |
| prompts = conditioning_data[0*num_units : 1*num_units] | |
| widths = conditioning_data[1*num_units : 2*num_units] | |
| heights = conditioning_data[2*num_units : 3*num_units] | |
| xs = conditioning_data[3*num_units : 4*num_units] | |
| ys = conditioning_data[4*num_units : 5*num_units] | |
| strengths = conditioning_data[5*num_units : 6*num_units] | |
| for i in range(num_units): | |
| if prompts[i] and prompts[i].strip(): | |
| active_conditioning.append({ | |
| "prompt": prompts[i], | |
| "width": int(widths[i]), | |
| "height": int(heights[i]), | |
| "x": int(xs[i]), | |
| "y": int(ys[i]), | |
| "strength": float(strengths[i]) | |
| }) | |
| reference_latent_data = ui_inputs.get('reference_latent_data', []) | |
| active_reference_latents = [] | |
| if reference_latent_data: | |
| for img_pil in reference_latent_data: | |
| if img_pil is not None: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_ref_{random.randint(1000, 9999)}.png") | |
| img_pil.save(temp_file_path, "PNG") | |
| active_reference_latents.append(os.path.basename(temp_file_path)) | |
| temp_files_to_clean.append(temp_file_path) | |
| loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" | |
| progress(0.8, desc="Assembling workflow...") | |
| if ui_inputs.get('seed') == -1: | |
| ui_inputs['seed'] = random.randint(0, 2**32 - 1) | |
| dynamic_values = {'task_type': ui_inputs['task_type']} | |
| recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml") | |
| assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values) | |
| model_display_name = ui_inputs['model_display_name'] | |
| if model_display_name not in ALL_MODEL_MAP: | |
| raise gr.Error(f"Model '{model_display_name}' is not configured in model_list.yaml.") | |
| _, components, _, _ = ALL_MODEL_MAP[model_display_name] | |
| workflow_inputs = { | |
| "positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'], | |
| "seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'], | |
| "sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'], | |
| "batch_size": ui_inputs['batch_size'], | |
| "denoise": ui_inputs['denoise'], | |
| "input_image": ui_inputs.get('input_image'), | |
| "inpaint_image": ui_inputs.get('inpaint_image'), | |
| "inpaint_mask": ui_inputs.get('inpaint_mask'), | |
| "left": ui_inputs.get('outpaint_left'), "top": ui_inputs.get('outpaint_top'), | |
| "right": ui_inputs.get('outpaint_right'), "bottom": ui_inputs.get('outpaint_bottom'), | |
| "hires_upscaler": ui_inputs.get('hires_upscaler'), "hires_scale_by": ui_inputs.get('hires_scale_by'), | |
| "unet_name": components['unet'], | |
| "clip_name": components['clip'], | |
| "vae_name": ui_inputs.get('vae_name', components['vae']), | |
| "lora_chain": active_loras_for_gpu, | |
| "conditioning_chain": active_conditioning, | |
| "reference_latent_chain": active_reference_latents, | |
| } | |
| if task_type == 'txt2img': | |
| workflow_inputs['width'] = ui_inputs['width'] | |
| workflow_inputs['height'] = ui_inputs['height'] | |
| workflow = assembler.assemble(workflow_inputs) | |
| progress(1.0, desc="All models ready. Requesting GPU for generation...") | |
| try: | |
| results = self._execute_gpu_logic( | |
| self._gpu_logic, | |
| duration=ui_inputs['zero_gpu_duration'], | |
| default_duration=60, | |
| task_name=f"ImageGen ({task_type})", | |
| ui_inputs=ui_inputs, | |
| loras_string=loras_string, | |
| workflow=workflow, | |
| assembler=assembler, | |
| progress=progress | |
| ) | |
| import json | |
| import glob | |
| from PIL import PngImagePlugin | |
| prompt_json = json.dumps(workflow) | |
| out_dir = os.path.abspath(OUTPUT_DIR) | |
| os.makedirs(out_dir, exist_ok=True) | |
| try: | |
| existing_files = glob.glob(os.path.join(out_dir, "gen_*.png")) | |
| existing_files.sort(key=os.path.getmtime) | |
| while len(existing_files) > 50: | |
| os.remove(existing_files.pop(0)) | |
| except Exception as e: | |
| print(f"Warning: Failed to cleanup output dir: {e}") | |
| final_results = [] | |
| for img in results: | |
| if not isinstance(img, Image.Image): | |
| final_results.append(img) | |
| continue | |
| metadata = PngImagePlugin.PngInfo() | |
| params_string = img.info.get("parameters", "") | |
| if params_string: | |
| metadata.add_text("parameters", params_string) | |
| metadata.add_text("prompt", prompt_json) | |
| filename = f"gen_{random.randint(1000000, 9999999)}.png" | |
| filepath = os.path.join(out_dir, filename) | |
| img.save(filepath, "PNG", pnginfo=metadata) | |
| final_results.append(filepath) | |
| results = final_results | |
| finally: | |
| for temp_file in temp_files_to_clean: | |
| if temp_file and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| print(f"✅ Cleaned up temp file: {temp_file}") | |
| return results |