| import os |
| import random |
| import shutil |
| import torch |
| import gradio as gr |
| from PIL import Image |
| from typing import List, Dict, Any |
|
|
| from .base_pipeline import BasePipeline |
| from core.settings import * |
| from utils.app_utils import sanitize_prompt |
| from core.workflow_assembler import WorkflowAssembler |
| from .workflow_executor import WorkflowExecutor |
| from .pipeline_input_processor import process_pipeline_inputs |
|
|
| class SdImagePipeline(BasePipeline): |
| def get_required_models(self, model_display_name: str, **kwargs) -> List[str]: |
| model_info = ALL_MODEL_MAP.get(model_display_name) |
| if not model_info: |
| return [model_display_name] |
|
|
| path_or_components = model_info[1] |
| if isinstance(path_or_components, dict): |
| return [v for v in path_or_components.values() if v and v != "pixel_space"] |
| else: |
| return [model_display_name] |
|
|
| def _gpu_logic(self, ui_inputs: Dict, loras_string: str, workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)): |
| model_display_name = ui_inputs['model_display_name'] |
| |
| progress(0.4, desc="Executing workflow...") |
| |
| initial_objects = {} |
|
|
| decoded_images_tensor = WorkflowExecutor.execute_workflow(workflow, initial_objects=initial_objects) |
| |
| output_images = [] |
| start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1) |
| for i in range(decoded_images_tensor.shape[0]): |
| img_tensor = decoded_images_tensor[i] |
| pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) |
| current_seed = start_seed + i |
| |
| width_for_meta = ui_inputs.get('width', 'N/A') |
| height_for_meta = ui_inputs.get('height', 'N/A') |
|
|
| params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n" |
| params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}" |
| if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}" |
| if ui_inputs.get('clip_skip') and ui_inputs['clip_skip'] != 1: params_string += f", Clip skip: {abs(ui_inputs['clip_skip'])}" |
| if loras_string: params_string += f", {loras_string}" |
| |
| pil_image.info = {'parameters': params_string.strip()} |
| output_images.append(pil_image) |
| |
| return output_images |
|
|
| def run(self, ui_inputs: Dict, progress): |
| progress(0, desc="Preparing models...") |
| |
| task_type = ui_inputs['task_type'] |
| model_display_name = ui_inputs['model_display_name'] |
| model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl') |
| |
| architectures_dict = ARCHITECTURES_CONFIG.get('architectures', {}) |
| workflow_model_type = architectures_dict.get(model_type, {}).get("model_type", model_type.lower().replace(" ", "").replace(".", "")) |
| |
| ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', '')) |
| ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', '')) |
| |
| if 'clip_skip' in ui_inputs and ui_inputs['clip_skip'] is not None: |
| ui_inputs['clip_skip'] = -int(ui_inputs['clip_skip']) |
| else: |
| ui_inputs['clip_skip'] = -1 |
|
|
| required_models = self.get_required_models(model_display_name=model_display_name) |
|
|
| is_pid_enabled = (ui_inputs.get('pid_settings', 'OFF') == 'ON' and task_type == 'txt2img') |
| if is_pid_enabled: |
| import yaml |
| pid_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'yaml', 'pid.yaml') |
| pid_unet_name = "pid_flux1_1024_to_4096_4step_mxfp8.safetensors" |
| try: |
| with open(pid_config_path, 'r', encoding='utf-8') as f: |
| pid_config = yaml.safe_load(f) or {} |
| pid_items = pid_config.get("PiD", []) |
| for item in pid_items: |
| archs = item.get("architectures", []) |
| if workflow_model_type in archs: |
| pid_unet_name = item.get("filepath") |
| break |
| except Exception as e: |
| print(f"Error loading PiD config for download: {e}") |
|
|
| if pid_unet_name not in required_models: |
| required_models.append(pid_unet_name) |
| if "gemma_2_2b_it_elm_fp8_scaled.safetensors" not in required_models: |
| required_models.append("gemma_2_2b_it_elm_fp8_scaled.safetensors") |
|
|
| self.model_manager.ensure_models_downloaded(required_models, progress=progress) |
| |
| temp_files_to_clean = [] |
| try: |
| processed = process_pipeline_inputs(ui_inputs, progress, workflow_model_type) |
| temp_files_to_clean.extend(processed["temp_files_to_clean"]) |
| |
| active_loras_for_gpu = processed["active_loras_for_gpu"] |
| active_loras_for_meta = processed["active_loras_for_meta"] |
| active_controlnets = processed["active_controlnets"] |
| active_anima_controlnets = processed["active_anima_controlnets"] |
| active_diffsynth_controlnets = processed["active_diffsynth_controlnets"] |
| active_ipadapters = processed["active_ipadapters"] |
| active_flux1_ipadapters = processed["active_flux1_ipadapters"] |
| active_sd3_ipadapters = processed["active_sd3_ipadapters"] |
| active_styles = processed["active_styles"] |
| active_reference_latents = processed["active_reference_latents"] |
| active_hidream_o1_reference = processed["active_hidream_o1_reference"] |
| active_conditioning = processed["active_conditioning"] |
|
|
| loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" |
| |
| progress(0.8, desc="Assembling workflow...") |
| |
| if ui_inputs.get('seed') == -1: |
| ui_inputs['seed'] = random.randint(0, 2**32 - 1) |
| |
| model_info = ALL_MODEL_MAP[model_display_name] |
| path_or_components = model_info[1] |
| latent_type = model_info[3] if len(model_info) > 3 and model_info[3] else 'latent' |
| latent_generator_template = "EmptyLatentImage" |
| if latent_type == 'sd3_latent': |
| latent_generator_template = "EmptySD3LatentImage" |
| elif latent_type == 'chroma_radiance_latent': |
| latent_generator_template = "EmptyChromaRadianceLatentImage" |
| elif latent_type == 'hunyuan_latent': |
| latent_generator_template = "EmptyHunyuanImageLatent" |
|
|
| dynamic_values = { |
| 'task_type': ui_inputs['task_type'], |
| 'model_type': workflow_model_type, |
| 'latent_type': latent_type, |
| 'latent_generator_template': latent_generator_template |
| } |
| |
| recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml") |
| assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values) |
| |
| hidream_o1_smoothing_data = [] |
| if workflow_model_type == 'hidream-o1' and model_display_name == "HiDream-O1-Image": |
| hidream_o1_smoothing_data.append({}) |
|
|
| workflow_inputs = { |
| **ui_inputs, |
| "positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'], |
| "seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'], |
| "sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'], |
| "batch_size": ui_inputs['batch_size'], |
| "clip_skip": ui_inputs['clip_skip'], |
| "denoise": ui_inputs['denoise'], |
| "vae_name": ui_inputs.get('vae_name'), |
| "guidance": ui_inputs.get('guidance', 3.5), |
| "lora_chain": active_loras_for_gpu, |
| "controlnet_chain": active_controlnets if not active_anima_controlnets else [], |
| "anima_controlnet_lllite_chain": active_anima_controlnets, |
| "diffsynth_controlnet_chain": active_diffsynth_controlnets, |
| "ipadapter_chain": active_ipadapters, |
| "flux1_ipadapter_chain": active_flux1_ipadapters, |
| "sd3_ipadapter_chain": active_sd3_ipadapters, |
| "style_chain": active_styles, |
| "conditioning_chain": active_conditioning, |
| "reference_latent_chain": active_reference_latents, |
| "hidream_o1_reference_chain": active_hidream_o1_reference, |
| "vae_chain": [ui_inputs.get('vae_name')] if ui_inputs.get('vae_name') else [], |
| "hidream_o1_smoothing_chain": hidream_o1_smoothing_data, |
| "pid_chain": [ui_inputs.get('pid_settings', 'OFF')] if is_pid_enabled else [], |
| } |
|
|
| if isinstance(path_or_components, dict): |
| workflow_inputs.update({ |
| 'unet_name': path_or_components.get('unet'), |
| 'vae_name': ui_inputs.get('vae_name') or path_or_components.get('vae'), |
| 'clip_name': path_or_components.get('clip'), |
| 'clip1_name': path_or_components.get('clip1'), |
| 'clip2_name': path_or_components.get('clip2'), |
| 'clip3_name': path_or_components.get('clip3'), |
| 'clip4_name': path_or_components.get('clip4'), |
| 'lora_name': path_or_components.get('lora'), |
| }) |
| else: |
| workflow_inputs['model_name'] = path_or_components |
| |
| if task_type == 'txt2img': |
| workflow_inputs['width'] = ui_inputs['width'] |
| workflow_inputs['height'] = ui_inputs['height'] |
| |
| workflow = assembler.assemble(workflow_inputs) |
|
|
| progress(1.0, desc="All models ready. Requesting GPU for generation...") |
|
|
| results = self._execute_gpu_logic( |
| self._gpu_logic, |
| duration=ui_inputs['zero_gpu_duration'], |
| default_duration=60, |
| task_name=f"ImageGen ({task_type})", |
| ui_inputs=ui_inputs, |
| loras_string=loras_string, |
| workflow=workflow, |
| assembler=assembler, |
| progress=progress |
| ) |
| |
| import json |
| import glob |
| from PIL import PngImagePlugin |
| |
| prompt_json = json.dumps(workflow) |
| |
| out_dir = os.path.abspath(OUTPUT_DIR) |
| os.makedirs(out_dir, exist_ok=True) |
| |
| try: |
| existing_files = glob.glob(os.path.join(out_dir, "gen_*.png")) |
| existing_files.sort(key=os.path.getmtime) |
| while len(existing_files) > 50: |
| os.remove(existing_files.pop(0)) |
| except Exception as e: |
| print(f"Warning: Failed to cleanup output dir: {e}") |
|
|
| final_results = [] |
| for img in results: |
| if not isinstance(img, Image.Image): |
| final_results.append(img) |
| continue |
| |
| metadata = PngImagePlugin.PngInfo() |
| params_string = img.info.get("parameters", "") |
| if params_string: |
| metadata.add_text("parameters", params_string) |
| metadata.add_text("prompt", prompt_json) |
| |
| filename = f"gen_{random.randint(1000000, 9999999)}.png" |
| filepath = os.path.join(out_dir, filename) |
| img.save(filepath, "PNG", pnginfo=metadata) |
| final_results.append(filepath) |
| |
| results = final_results |
|
|
| finally: |
| for temp_file in temp_files_to_clean: |
| if temp_file and os.path.exists(temp_file): |
| os.remove(temp_file) |
| print(f"✅ Cleaned up temp file: {temp_file}") |
|
|
| return results |