Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import shutil | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from typing import List, Dict, Any | |
| from .base_pipeline import BasePipeline | |
| from core.settings import * | |
| from utils.app_utils import sanitize_prompt | |
| from core.workflow_assembler import WorkflowAssembler | |
| from .workflow_executor import WorkflowExecutor | |
| from .pipeline_input_processor import process_pipeline_inputs | |
| class SdImagePipeline(BasePipeline): | |
| def get_required_models(self, model_display_name: str, **kwargs) -> List[str]: | |
| model_info = ALL_MODEL_MAP.get(model_display_name) | |
| if not model_info: | |
| return [model_display_name] | |
| path_or_components = model_info[1] | |
| if isinstance(path_or_components, dict): | |
| return [v for v in path_or_components.values() if v and v != "pixel_space"] | |
| else: | |
| return [model_display_name] | |
| def _gpu_logic(self, ui_inputs: Dict, loras_string: str, workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)): | |
| model_display_name = ui_inputs['model_display_name'] | |
| progress(0.4, desc="Executing workflow...") | |
| initial_objects = {} | |
| decoded_images_tensor = WorkflowExecutor.execute_workflow(workflow, initial_objects=initial_objects) | |
| output_images = [] | |
| start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1) | |
| for i in range(decoded_images_tensor.shape[0]): | |
| img_tensor = decoded_images_tensor[i] | |
| pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) | |
| current_seed = start_seed + i | |
| width_for_meta = ui_inputs.get('width', 'N/A') | |
| height_for_meta = ui_inputs.get('height', 'N/A') | |
| params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n" | |
| params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}" | |
| if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}" | |
| if ui_inputs.get('clip_skip') and ui_inputs['clip_skip'] != 1: params_string += f", Clip skip: {abs(ui_inputs['clip_skip'])}" | |
| if loras_string: params_string += f", {loras_string}" | |
| pil_image.info = {'parameters': params_string.strip()} | |
| output_images.append(pil_image) | |
| return output_images | |
| def run(self, ui_inputs: Dict, progress): | |
| progress(0, desc="Preparing models...") | |
| task_type = ui_inputs['task_type'] | |
| model_display_name = ui_inputs['model_display_name'] | |
| model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl') | |
| architectures_dict = ARCHITECTURES_CONFIG.get('architectures', {}) | |
| workflow_model_type = architectures_dict.get(model_type, {}).get("model_type", model_type.lower().replace(" ", "").replace(".", "")) | |
| ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', '')) | |
| ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', '')) | |
| if 'clip_skip' in ui_inputs and ui_inputs['clip_skip'] is not None: | |
| ui_inputs['clip_skip'] = -int(ui_inputs['clip_skip']) | |
| else: | |
| ui_inputs['clip_skip'] = -1 | |
| required_models = self.get_required_models(model_display_name=model_display_name) | |
| is_pid_enabled = (ui_inputs.get('pid_settings', 'OFF') == 'ON' and task_type == 'txt2img') | |
| if is_pid_enabled: | |
| import yaml | |
| pid_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'yaml', 'pid.yaml') | |
| pid_unet_name = "pid_flux1_1024_to_4096_4step_mxfp8.safetensors" | |
| try: | |
| with open(pid_config_path, 'r', encoding='utf-8') as f: | |
| pid_config = yaml.safe_load(f) or {} | |
| pid_items = pid_config.get("PiD", []) | |
| for item in pid_items: | |
| archs = item.get("architectures", []) | |
| if workflow_model_type in archs: | |
| pid_unet_name = item.get("filepath") | |
| break | |
| except Exception as e: | |
| print(f"Error loading PiD config for download: {e}") | |
| if pid_unet_name not in required_models: | |
| required_models.append(pid_unet_name) | |
| if "gemma_2_2b_it_elm_fp8_scaled.safetensors" not in required_models: | |
| required_models.append("gemma_2_2b_it_elm_fp8_scaled.safetensors") | |
| self.model_manager.ensure_models_downloaded(required_models, progress=progress) | |
| temp_files_to_clean = [] | |
| try: | |
| processed = process_pipeline_inputs(ui_inputs, progress, workflow_model_type) | |
| temp_files_to_clean.extend(processed["temp_files_to_clean"]) | |
| active_loras_for_gpu = processed["active_loras_for_gpu"] | |
| active_loras_for_meta = processed["active_loras_for_meta"] | |
| active_controlnets = processed["active_controlnets"] | |
| active_anima_controlnets = processed["active_anima_controlnets"] | |
| active_diffsynth_controlnets = processed["active_diffsynth_controlnets"] | |
| active_ipadapters = processed["active_ipadapters"] | |
| active_flux1_ipadapters = processed["active_flux1_ipadapters"] | |
| active_sd3_ipadapters = processed["active_sd3_ipadapters"] | |
| active_styles = processed["active_styles"] | |
| active_reference_latents = processed["active_reference_latents"] | |
| active_hidream_o1_reference = processed["active_hidream_o1_reference"] | |
| active_conditioning = processed["active_conditioning"] | |
| loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" | |
| progress(0.8, desc="Assembling workflow...") | |
| if ui_inputs.get('seed') == -1: | |
| ui_inputs['seed'] = random.randint(0, 2**32 - 1) | |
| model_info = ALL_MODEL_MAP[model_display_name] | |
| path_or_components = model_info[1] | |
| latent_type = model_info[3] if len(model_info) > 3 and model_info[3] else 'latent' | |
| latent_generator_template = "EmptyLatentImage" | |
| if latent_type == 'sd3_latent': | |
| latent_generator_template = "EmptySD3LatentImage" | |
| elif latent_type == 'chroma_radiance_latent': | |
| latent_generator_template = "EmptyChromaRadianceLatentImage" | |
| elif latent_type == 'hunyuan_latent': | |
| latent_generator_template = "EmptyHunyuanImageLatent" | |
| dynamic_values = { | |
| 'task_type': ui_inputs['task_type'], | |
| 'model_type': workflow_model_type, | |
| 'latent_type': latent_type, | |
| 'latent_generator_template': latent_generator_template | |
| } | |
| recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml") | |
| assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values) | |
| hidream_o1_smoothing_data = [] | |
| if workflow_model_type == 'hidream-o1' and model_display_name == "HiDream-O1-Image": | |
| hidream_o1_smoothing_data.append({}) | |
| workflow_inputs = { | |
| **ui_inputs, | |
| "positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'], | |
| "seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'], | |
| "sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'], | |
| "batch_size": ui_inputs['batch_size'], | |
| "clip_skip": ui_inputs['clip_skip'], | |
| "denoise": ui_inputs['denoise'], | |
| "vae_name": ui_inputs.get('vae_name'), | |
| "guidance": ui_inputs.get('guidance', 3.5), | |
| "lora_chain": active_loras_for_gpu, | |
| "controlnet_chain": active_controlnets if not active_anima_controlnets else [], | |
| "anima_controlnet_lllite_chain": active_anima_controlnets, | |
| "diffsynth_controlnet_chain": active_diffsynth_controlnets, | |
| "ipadapter_chain": active_ipadapters, | |
| "flux1_ipadapter_chain": active_flux1_ipadapters, | |
| "sd3_ipadapter_chain": active_sd3_ipadapters, | |
| "style_chain": active_styles, | |
| "conditioning_chain": active_conditioning, | |
| "reference_latent_chain": active_reference_latents, | |
| "hidream_o1_reference_chain": active_hidream_o1_reference, | |
| "vae_chain": [ui_inputs.get('vae_name')] if ui_inputs.get('vae_name') else [], | |
| "hidream_o1_smoothing_chain": hidream_o1_smoothing_data, | |
| "pid_chain": [ui_inputs.get('pid_settings', 'OFF')] if is_pid_enabled else [], | |
| "scheduler_width": ui_inputs.get('width', 1024), | |
| "scheduler_height": ui_inputs.get('height', 1024), | |
| } | |
| if isinstance(path_or_components, dict): | |
| workflow_inputs.update({ | |
| 'unet_name': path_or_components.get('unet'), | |
| 'unet_uncond_name': path_or_components.get('unet_uncond'), | |
| 'vae_name': ui_inputs.get('vae_name') or path_or_components.get('vae'), | |
| 'clip_name': path_or_components.get('clip'), | |
| 'clip1_name': path_or_components.get('clip1'), | |
| 'clip2_name': path_or_components.get('clip2'), | |
| 'clip3_name': path_or_components.get('clip3'), | |
| 'clip4_name': path_or_components.get('clip4'), | |
| 'lora_name': path_or_components.get('lora'), | |
| }) | |
| else: | |
| workflow_inputs['model_name'] = path_or_components | |
| if task_type == 'txt2img': | |
| workflow_inputs['width'] = ui_inputs['width'] | |
| workflow_inputs['height'] = ui_inputs['height'] | |
| workflow = assembler.assemble(workflow_inputs) | |
| progress(1.0, desc="All models ready. Requesting GPU for generation...") | |
| results = self._execute_gpu_logic( | |
| self._gpu_logic, | |
| duration=ui_inputs['zero_gpu_duration'], | |
| default_duration=60, | |
| task_name=f"ImageGen ({task_type})", | |
| ui_inputs=ui_inputs, | |
| loras_string=loras_string, | |
| workflow=workflow, | |
| assembler=assembler, | |
| progress=progress | |
| ) | |
| import json | |
| import glob | |
| from PIL import PngImagePlugin | |
| prompt_json = json.dumps(workflow) | |
| out_dir = os.path.abspath(OUTPUT_DIR) | |
| os.makedirs(out_dir, exist_ok=True) | |
| try: | |
| existing_files = glob.glob(os.path.join(out_dir, "gen_*.png")) | |
| existing_files.sort(key=os.path.getmtime) | |
| while len(existing_files) > 50: | |
| os.remove(existing_files.pop(0)) | |
| except Exception as e: | |
| print(f"Warning: Failed to cleanup output dir: {e}") | |
| final_results = [] | |
| for img in results: | |
| if not isinstance(img, Image.Image): | |
| final_results.append(img) | |
| continue | |
| metadata = PngImagePlugin.PngInfo() | |
| params_string = img.info.get("parameters", "") | |
| if params_string: | |
| metadata.add_text("parameters", params_string) | |
| metadata.add_text("prompt", prompt_json) | |
| filename = f"gen_{random.randint(1000000, 9999999)}.png" | |
| filepath = os.path.join(out_dir, filename) | |
| img.save(filepath, "PNG", pnginfo=metadata) | |
| final_results.append(filepath) | |
| results = final_results | |
| finally: | |
| for temp_file in temp_files_to_clean: | |
| if temp_file and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| print(f"✅ Cleaned up temp file: {temp_file}") | |
| return results |