ImageGen / core /pipelines /sd_image_pipeline.py
RioShiina's picture
Upload folder using huggingface_hub
dc0cea5 verified
import os
import random
import shutil
import torch
import gradio as gr
from PIL import Image
from typing import List, Dict, Any
from .base_pipeline import BasePipeline
from core.settings import *
from utils.app_utils import sanitize_prompt
from core.workflow_assembler import WorkflowAssembler
from .workflow_executor import WorkflowExecutor
from .pipeline_input_processor import process_pipeline_inputs
class SdImagePipeline(BasePipeline):
def get_required_models(self, model_display_name: str, **kwargs) -> List[str]:
model_info = ALL_MODEL_MAP.get(model_display_name)
if not model_info:
return [model_display_name]
path_or_components = model_info[1]
if isinstance(path_or_components, dict):
return [v for v in path_or_components.values() if v and v != "pixel_space"]
else:
return [model_display_name]
def _gpu_logic(self, ui_inputs: Dict, loras_string: str, workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)):
model_display_name = ui_inputs['model_display_name']
progress(0.4, desc="Executing workflow...")
initial_objects = {}
decoded_images_tensor = WorkflowExecutor.execute_workflow(workflow, initial_objects=initial_objects)
output_images = []
start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1)
for i in range(decoded_images_tensor.shape[0]):
img_tensor = decoded_images_tensor[i]
pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
current_seed = start_seed + i
width_for_meta = ui_inputs.get('width', 'N/A')
height_for_meta = ui_inputs.get('height', 'N/A')
params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n"
params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}"
if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}"
if ui_inputs.get('clip_skip') and ui_inputs['clip_skip'] != 1: params_string += f", Clip skip: {abs(ui_inputs['clip_skip'])}"
if loras_string: params_string += f", {loras_string}"
pil_image.info = {'parameters': params_string.strip()}
output_images.append(pil_image)
return output_images
def run(self, ui_inputs: Dict, progress):
progress(0, desc="Preparing models...")
task_type = ui_inputs['task_type']
model_display_name = ui_inputs['model_display_name']
model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl')
architectures_dict = ARCHITECTURES_CONFIG.get('architectures', {})
workflow_model_type = architectures_dict.get(model_type, {}).get("model_type", model_type.lower().replace(" ", "").replace(".", ""))
ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', ''))
ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', ''))
if 'clip_skip' in ui_inputs and ui_inputs['clip_skip'] is not None:
ui_inputs['clip_skip'] = -int(ui_inputs['clip_skip'])
else:
ui_inputs['clip_skip'] = -1
required_models = self.get_required_models(model_display_name=model_display_name)
is_pid_enabled = (ui_inputs.get('pid_settings', 'OFF') == 'ON' and task_type == 'txt2img')
if is_pid_enabled:
import yaml
pid_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'yaml', 'pid.yaml')
pid_unet_name = "pid_flux1_1024_to_4096_4step_mxfp8.safetensors"
try:
with open(pid_config_path, 'r', encoding='utf-8') as f:
pid_config = yaml.safe_load(f) or {}
pid_items = pid_config.get("PiD", [])
for item in pid_items:
archs = item.get("architectures", [])
if workflow_model_type in archs:
pid_unet_name = item.get("filepath")
break
except Exception as e:
print(f"Error loading PiD config for download: {e}")
if pid_unet_name not in required_models:
required_models.append(pid_unet_name)
if "gemma_2_2b_it_elm_fp8_scaled.safetensors" not in required_models:
required_models.append("gemma_2_2b_it_elm_fp8_scaled.safetensors")
self.model_manager.ensure_models_downloaded(required_models, progress=progress)
temp_files_to_clean = []
try:
processed = process_pipeline_inputs(ui_inputs, progress, workflow_model_type)
temp_files_to_clean.extend(processed["temp_files_to_clean"])
active_loras_for_gpu = processed["active_loras_for_gpu"]
active_loras_for_meta = processed["active_loras_for_meta"]
active_controlnets = processed["active_controlnets"]
active_anima_controlnets = processed["active_anima_controlnets"]
active_diffsynth_controlnets = processed["active_diffsynth_controlnets"]
active_ipadapters = processed["active_ipadapters"]
active_flux1_ipadapters = processed["active_flux1_ipadapters"]
active_sd3_ipadapters = processed["active_sd3_ipadapters"]
active_styles = processed["active_styles"]
active_reference_latents = processed["active_reference_latents"]
active_hidream_o1_reference = processed["active_hidream_o1_reference"]
active_conditioning = processed["active_conditioning"]
loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""
progress(0.8, desc="Assembling workflow...")
if ui_inputs.get('seed') == -1:
ui_inputs['seed'] = random.randint(0, 2**32 - 1)
model_info = ALL_MODEL_MAP[model_display_name]
path_or_components = model_info[1]
latent_type = model_info[3] if len(model_info) > 3 and model_info[3] else 'latent'
latent_generator_template = "EmptyLatentImage"
if latent_type == 'sd3_latent':
latent_generator_template = "EmptySD3LatentImage"
elif latent_type == 'chroma_radiance_latent':
latent_generator_template = "EmptyChromaRadianceLatentImage"
elif latent_type == 'hunyuan_latent':
latent_generator_template = "EmptyHunyuanImageLatent"
dynamic_values = {
'task_type': ui_inputs['task_type'],
'model_type': workflow_model_type,
'latent_type': latent_type,
'latent_generator_template': latent_generator_template
}
recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml")
assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values)
hidream_o1_smoothing_data = []
if workflow_model_type == 'hidream-o1' and model_display_name == "HiDream-O1-Image":
hidream_o1_smoothing_data.append({})
workflow_inputs = {
**ui_inputs,
"positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'],
"seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'],
"sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'],
"batch_size": ui_inputs['batch_size'],
"clip_skip": ui_inputs['clip_skip'],
"denoise": ui_inputs['denoise'],
"vae_name": ui_inputs.get('vae_name'),
"guidance": ui_inputs.get('guidance', 3.5),
"lora_chain": active_loras_for_gpu,
"controlnet_chain": active_controlnets if not active_anima_controlnets else [],
"anima_controlnet_lllite_chain": active_anima_controlnets,
"diffsynth_controlnet_chain": active_diffsynth_controlnets,
"ipadapter_chain": active_ipadapters,
"flux1_ipadapter_chain": active_flux1_ipadapters,
"sd3_ipadapter_chain": active_sd3_ipadapters,
"style_chain": active_styles,
"conditioning_chain": active_conditioning,
"reference_latent_chain": active_reference_latents,
"hidream_o1_reference_chain": active_hidream_o1_reference,
"vae_chain": [ui_inputs.get('vae_name')] if ui_inputs.get('vae_name') else [],
"hidream_o1_smoothing_chain": hidream_o1_smoothing_data,
"pid_chain": [ui_inputs.get('pid_settings', 'OFF')] if is_pid_enabled else [],
}
if isinstance(path_or_components, dict):
workflow_inputs.update({
'unet_name': path_or_components.get('unet'),
'vae_name': ui_inputs.get('vae_name') or path_or_components.get('vae'),
'clip_name': path_or_components.get('clip'),
'clip1_name': path_or_components.get('clip1'),
'clip2_name': path_or_components.get('clip2'),
'clip3_name': path_or_components.get('clip3'),
'clip4_name': path_or_components.get('clip4'),
'lora_name': path_or_components.get('lora'),
})
else:
workflow_inputs['model_name'] = path_or_components
if task_type == 'txt2img':
workflow_inputs['width'] = ui_inputs['width']
workflow_inputs['height'] = ui_inputs['height']
workflow = assembler.assemble(workflow_inputs)
progress(1.0, desc="All models ready. Requesting GPU for generation...")
results = self._execute_gpu_logic(
self._gpu_logic,
duration=ui_inputs['zero_gpu_duration'],
default_duration=60,
task_name=f"ImageGen ({task_type})",
ui_inputs=ui_inputs,
loras_string=loras_string,
workflow=workflow,
assembler=assembler,
progress=progress
)
import json
import glob
from PIL import PngImagePlugin
prompt_json = json.dumps(workflow)
out_dir = os.path.abspath(OUTPUT_DIR)
os.makedirs(out_dir, exist_ok=True)
try:
existing_files = glob.glob(os.path.join(out_dir, "gen_*.png"))
existing_files.sort(key=os.path.getmtime)
while len(existing_files) > 50:
os.remove(existing_files.pop(0))
except Exception as e:
print(f"Warning: Failed to cleanup output dir: {e}")
final_results = []
for img in results:
if not isinstance(img, Image.Image):
final_results.append(img)
continue
metadata = PngImagePlugin.PngInfo()
params_string = img.info.get("parameters", "")
if params_string:
metadata.add_text("parameters", params_string)
metadata.add_text("prompt", prompt_json)
filename = f"gen_{random.randint(1000000, 9999999)}.png"
filepath = os.path.join(out_dir, filename)
img.save(filepath, "PNG", pnginfo=metadata)
final_results.append(filepath)
results = final_results
finally:
for temp_file in temp_files_to_clean:
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
print(f"✅ Cleaned up temp file: {temp_file}")
return results