"""Core Pipeline orchestrator for LightDiffusion-Next. This module provides the main Pipeline class - a clean, linear orchestrator that coordinates model loading, generation, and post-processing. The Pipeline is designed to be: - Simple: <100 lines of core logic - Modular: Delegates to Models and Processors - Extensible: Easy to add new processing steps Architecture: [Context] -> [Load Model] -> [Encode] -> [Generate] -> [Decode] -> [Processors] -> [Result] """ import logging import os from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union import torch from src.Core.Context import Context from src.Core.Models import create_model from src.Core.AbstractModel import AbstractModel from src.Processors import HiresFix, Adetailer, AutoHDRProcessor logger = logging.getLogger(__name__) @dataclass class PipelineResult: """Result of a pipeline run.""" images: list[torch.Tensor] = field(default_factory=list) latents: Optional[torch.Tensor] = None metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: """Convert to dictionary for legacy compatibility.""" return { "images": self.images, "latents": self.latents, **self.metadata, } class Pipeline: """Main generation pipeline orchestrator. This class coordinates the entire generation flow in a clean, linear manner. Each step is isolated and the Context flows through. Usage: ctx = Context(prompt="a cat", width=512, height=512) pipeline = Pipeline() result = pipeline.run(ctx) """ def __init__( self, model_factory: Callable[[str], AbstractModel] = None, default_lora: Optional[tuple[str, float, float]] = ("add_detail.safetensors", 0.7, 0.7), ): """Initialize the pipeline. Args: model_factory: Function to create models (default: create_model) default_lora: Default LoRA to apply (name, model_str, clip_str) or None """ self.model_factory = model_factory or create_model self.default_lora = default_lora self._model: Optional[AbstractModel] = None def _apply_runtime_preferences(self, ctx: Context, model: AbstractModel) -> None: """Apply request-scoped runtime preferences that should track reused models.""" model.set_vae_autotune(ctx.generation.vae_autotune) def run(self, ctx: Context) -> Context: """Run the full generation pipeline. Args: ctx: Configured Context with all parameters Returns: Context with generated images in current_image """ self._check_interrupt() # 1. Load base model model = self._load_model(ctx) self._apply_runtime_preferences(ctx, model) # 2. Apply optimizations to base model mo = getattr(model, 'model', None) mo_opts = getattr(mo, 'model_options', {}) if mo is not None else {} if not mo_opts.get("model_function_wrapper"): self._apply_optimizations(ctx, model) # 3. Encode prompts for base model positive, negative = self._encode_prompts(ctx, model) ctx.positive_cond = positive ctx.negative_cond = negative # 4. Handle refiner preparation if enabled (SDXL only) refiner_model = None ref_positive, ref_negative = None, None is_sdxl = getattr(model.capabilities, "uses_dual_clip", False) use_refiner = bool( is_sdxl and ctx.generation.refiner_model_path and ctx.generation.refiner_switch_step is not None and 0 < ctx.generation.refiner_switch_step < ctx.sampling.steps ) if use_refiner: print(f"Refiner enabled: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})") # We don't load it yet to save VRAM, but we need to know if we should unload base later # 5. Generate for each seed from src.FileManaging import ImageSaver saver = ImageSaver.SaveImage() for i, seed in enumerate(ctx.seeds[:ctx.generation.number]): self._check_interrupt() ctx.seed = seed # Stage 1: Base model generation if use_refiner: steps_for_base = ctx.generation.refiner_switch_step print(f"Stage 1: Running Base model ({steps_for_base}/{ctx.sampling.steps} steps)...") latents = model.generate( ctx, positive, negative, last_step=ctx.generation.refiner_switch_step, callback=ctx.callback ) else: latents = model.generate(ctx, positive, negative, callback=ctx.callback) ctx.current_latents = latents["samples"] # Stage 2: Refiner model generation if use_refiner: self._check_interrupt() # Load refiner model (this will unload base model if necessary) refiner_model = self._load_refiner_model(ctx) self._apply_optimizations(ctx, refiner_model) # Encode prompts for refiner (it has different CLIP) ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model) # Disable multi-scale for refiner pass (always) orig_ms = ctx.sampling.enable_multiscale ctx.sampling.enable_multiscale = False steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step print(f"Stage 2: Running Refiner model ({steps_for_refiner}/{ctx.sampling.steps} steps)...") latents = refiner_model.generate( ctx, ref_positive, ref_negative, latent_image=latents, start_step=ctx.generation.refiner_switch_step, disable_noise=True, callback=ctx.callback ) ctx.current_latents = latents["samples"] ctx.sampling.enable_multiscale = orig_ms # If we have more seeds, we'll need to reload base model in the next iteration # _load_model handles this automatically # Decode latents to image ctx.current_image = model.decode(ctx.current_latents) # 6. Post-processing # Apply HiresFix if enabled. Prefer running hires pass with the base model # and base prompts for consistency; using a refiner for the hires pass can # introduce artifacts because its UNet/CLIP can differ from the base model. current_model = model # Prefer base prompts for hires pass (refiner prompts tend to mismatch) hf_pos = positive hf_neg = negative if HiresFix.is_enabled(ctx): self._check_interrupt() logger.info(f"HiresFix: using base model for hires pass (use_refiner={use_refiner})") # If a refiner was used earlier we may have unloaded the base model to free VRAM. # Ensure the base model is reloaded and optimized before running the hires pass so # downstream code (sampler / CFGGuider) can access model.model_options etc. if use_refiner and (not model.is_loaded or getattr(model, "model", None) is None): logger.info("HiresFix: reloading base model for hires pass (was unloaded by refiner)") model = self._load_model(ctx) # Re-apply optimizations (LoRA / StableFast / FP8 / DeepCache) to the reloaded model self._apply_optimizations(ctx, model) # Re-encode prompts for the reloaded base model to ensure conditioning matches try: hf_pos, hf_neg = self._encode_prompts(ctx, model) except Exception: # Fallback to previously-encoded conditioning if re-encoding fails hf_pos, hf_neg = hf_pos, hf_neg current_model = model # HiresFix might still need base model prompts if it was trained on them latents = HiresFix.apply(latents, ctx, current_model, hf_pos, hf_neg, callback=ctx.callback) ctx.current_latents = latents["samples"] if AutoHDRProcessor.is_enabled(ctx): self._check_interrupt() ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx) # Apply Adetailer if enabled (handles its own saving) if Adetailer.is_enabled(ctx): self._check_interrupt() if use_refiner: # Reload base model for ADetailer - the refiner's UNet/CLIP # is not suited for text-guided crop enhancement ad_model = self._load_model(ctx) ad_pos, ad_neg = self._encode_prompts(ctx, ad_model) ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, ad_model, positive=ad_pos, negative=ad_neg, callback=ctx.callback ) else: ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, current_model, positive=hf_pos, negative=hf_neg, callback=ctx.callback ) else: # Save the image synchronously so the server can reliably find it prefix = "LD-HF" if ctx.features.hires_fix else "LD" filename_prefix = f"{ctx.features.request_filename_prefix}_{prefix}" if ctx.features.request_filename_prefix else prefix images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image] saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata(), store_bytes_prefix=ctx.features.request_filename_prefix) ctx.save_seed() return ctx def run_img2img(self, ctx: Context) -> Context: """Run image-to-image generation pipeline. Supports two modes: 1. Upscale mode: When target dimensions are larger than input (uses USDU) 2. Diffusion mode: True img2img with denoising strength (uses simple_img2img) Args: ctx: Context with img2img_image set Returns: Context with generated images """ from src.Processors import Img2Img from src.FileManaging import ImageSaver from PIL import Image import numpy as np import torch self._check_interrupt() model = self._load_model(ctx) self._apply_optimizations(ctx, model) positive, negative = self._encode_prompts(ctx, model) saver = ImageSaver.SaveImage() # Load input image to determine mode img_path = ctx.features.img2img_image if not img_path: raise ValueError("No input image provided for img2img") img = Image.open(img_path) input_w, input_h = img.size target_w, target_h = ctx.generation.width, ctx.generation.height # Convert image to tensor [B, H, W, C] img_array = np.array(img.convert("RGB")) img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0 if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) # Determine mode: upscale if target is larger, otherwise diffusion use_upscale = (target_w > input_w * 1.1) or (target_h > input_h * 1.1) denoise = ctx.features.img2img_denoise # Inject SDXL size conditioning if required if getattr(model.capabilities, 'requires_size_conditioning', False): for cond_list in [positive, negative]: for cond_item in cond_list: if len(cond_item) > 1 and isinstance(cond_item[1], dict): cond_item[1].update({ "width": target_w, "height": target_h, "crop_w": 0, "crop_h": 0, "target_width": target_w, "target_height": target_h, }) logger.info(f"Img2Img: input={input_w}x{input_h}, target={target_w}x{target_h}, denoise={denoise:.2f}, mode={'upscale' if use_upscale else 'diffusion'}") for seed in ctx.seeds[:ctx.generation.number]: self._check_interrupt() ctx.seed = seed if use_upscale: # Use USDU upscaler (existing behavior) # Higher LoRA strength for img2img upscaling if self.default_lora and getattr(model.capabilities, 'supports_lora', True): try: model.apply_lora(self.default_lora[0], 2.0, 2.0) except Exception as e: logger.warning(f"LoRA failed: {e}") result = Img2Img.apply(ctx, model, positive, negative, image_tensor=img_tensor, denoise=denoise, callback=ctx.callback) ctx.current_image = result else: # True diffusion-based img2img with denoising strength # Resize input image to target dimensions if different if input_w != target_w or input_h != target_h: resized_img = img.resize((target_w, target_h), Image.Resampling.LANCZOS) img_array = np.array(resized_img.convert("RGB")) img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0 if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) # Check if refiner is enabled BEFORE running base model (SDXL only) is_sdxl = getattr(model.capabilities, "uses_dual_clip", False) use_refiner = bool( is_sdxl and ctx.generation.refiner_model_path and ctx.generation.refiner_switch_step is not None and 0 < ctx.generation.refiner_switch_step < ctx.sampling.steps ) refiner_model = None ref_negative = None base_last_step = ctx.generation.refiner_switch_step if use_refiner else None if use_refiner: print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...") # Run simple_img2img for true diffusion-based generation latents = Img2Img.simple_img2img( ctx, model, positive, negative, image_tensor=img_tensor, denoise=denoise, last_step=base_last_step, callback=ctx.callback, ) ctx.current_latents = latents["samples"] # Apply refiner if enabled if use_refiner: self._check_interrupt() # Load refiner model refiner_model = self._load_refiner_model(ctx) self._apply_optimizations(ctx, refiner_model) # Encode prompts for refiner (it has different CLIP) ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model) # Disable multi-scale for refiner pass orig_ms = ctx.sampling.enable_multiscale ctx.sampling.enable_multiscale = False steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step print(f"Img2Img Refiner: Running {steps_for_refiner}/{ctx.sampling.steps} steps...") refiner_latents = refiner_model.generate( ctx, ref_positive, ref_negative, latent_image=latents, start_step=ctx.generation.refiner_switch_step, disable_noise=True, callback=ctx.callback ) ctx.current_latents = refiner_latents["samples"] ctx.sampling.enable_multiscale = orig_ms # Decode using refiner's VAE image = refiner_model.decode(ctx.current_latents) else: # Decode to image using base model image = model.decode(ctx.current_latents) ctx.current_image = image # Apply Adetailer if enabled from src.Processors import Adetailer if Adetailer.is_enabled(ctx): self._check_interrupt() if not use_upscale and use_refiner: # Reload base model for ADetailer - the refiner's UNet/CLIP # is not suited for text-guided crop enhancement ad_model = self._load_model(ctx) ad_pos, ad_neg = self._encode_prompts(ctx, ad_model) ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, ad_model, positive=ad_pos, negative=ad_neg, callback=ctx.callback ) else: ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, model, positive=positive, negative=negative, callback=ctx.callback ) # Apply AutoHDR if enabled if AutoHDRProcessor.is_enabled(ctx): ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx) # Save the image with metadata including denoise value filename_prefix = "LD-I2I" if ctx.features.request_filename_prefix: filename_prefix = f"{ctx.features.request_filename_prefix}_{filename_prefix}" images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image] saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata({ "img2img": "True", "img2img_denoise": str(denoise), "img2img_mode": "upscale" if use_upscale else "diffusion", }), store_bytes_prefix=ctx.features.request_filename_prefix) ctx.save_seed() return ctx def run_controlnet(self, ctx: Context) -> Context: """Run ControlNet-style generation using Canny edges + img2img. This uses edge detection to preserve structure while allowing color and content changes via high-denoise img2img. Args: ctx: Context with controlnet_model, img2img_image set Returns: Context with generated images """ from src.Processors import ControlNet as CNProcessor from src.FileManaging import ImageSaver from PIL import Image import numpy as np self._check_interrupt() # Validate inputs if not ctx.features.img2img_image: raise ValueError("No input image provided for ControlNet") model = self._load_model(ctx) self._apply_optimizations(ctx, model) # Load and preprocess input image img_path = ctx.features.img2img_image img = Image.open(img_path) img = img.resize((ctx.generation.width, ctx.generation.height), Image.Resampling.LANCZOS) # Convert to tensor [B, H, W, C] img_array = np.array(img.convert("RGB")) img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0 if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) # Apply preprocessor (Canny edge detection by default) control_image = CNProcessor.ControlNetProcessor.preprocess_image( img_tensor, preprocessor=ctx.features.controlnet_type, ) strength = ctx.features.controlnet_strength logger.info(f"ControlNet-style: {ctx.features.controlnet_type} edges, strength={strength}") # Encode prompts positive, negative = self._encode_prompts(ctx, model) saver = ImageSaver.SaveImage() is_flux2 = getattr(model.capabilities, "is_flux2", False) # Check if refiner is enabled (SDXL only) is_sdxl = getattr(model.capabilities, "uses_dual_clip", False) use_refiner = bool( is_sdxl and ctx.generation.refiner_model_path and ctx.generation.refiner_switch_step is not None and 0 < ctx.generation.refiner_switch_step < ctx.sampling.steps ) refiner_model = None ref_negative = None if use_refiner: print(f"Refiner enabled for ControlNet: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})") for seed in ctx.seeds[:ctx.generation.number]: self._check_interrupt() ctx.seed = seed # Use the Canny+img2img approach, passing original image for blending # When refiner is enabled, stop base model at refiner switch step base_last_step = ctx.generation.refiner_switch_step if use_refiner else None if use_refiner: print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...") latents, ctx = CNProcessor.apply_controlnet_to_img2img( ctx, model, positive, negative, control_image=control_image, strength=strength, original_image=img_tensor, last_step=base_last_step, callback=ctx.callback, ) ctx.current_latents = latents["samples"] # Apply refiner if enabled if use_refiner: self._check_interrupt() # Load refiner model refiner_model = self._load_refiner_model(ctx) self._apply_optimizations(ctx, refiner_model) # Encode prompts for refiner (it has different CLIP) ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model) # Disable multi-scale for refiner pass orig_ms = ctx.sampling.enable_multiscale ctx.sampling.enable_multiscale = False steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step print(f"ControlNet Refiner: Running {steps_for_refiner}/{ctx.sampling.steps} steps...") refiner_latents = refiner_model.generate( ctx, ref_positive, ref_negative, latent_image=latents, start_step=ctx.generation.refiner_switch_step, disable_noise=True, callback=ctx.callback ) ctx.current_latents = refiner_latents["samples"] ctx.sampling.enable_multiscale = orig_ms # Decode using refiner's VAE image = refiner_model.decode(ctx.current_latents) else: # Decode to image using base model image = model.decode(ctx.current_latents) ctx.current_image = image # Apply Adetailer if enabled from src.Processors import Adetailer if Adetailer.is_enabled(ctx): self._check_interrupt() if use_refiner: # Reload base model for ADetailer - the refiner's UNet/CLIP # is not suited for text-guided crop enhancement ad_model = self._load_model(ctx) ad_pos, ad_neg = self._encode_prompts(ctx, ad_model) ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, ad_model, positive=ad_pos, negative=ad_neg, callback=ctx.callback ) else: ctx.current_image, _ = Adetailer.apply( ctx.current_image, ctx, model, positive=positive, negative=negative, callback=ctx.callback ) # Apply AutoHDR if enabled if AutoHDRProcessor.is_enabled(ctx): ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx) # Save with metadata filename_prefix = "LD-CN" if ctx.features.request_filename_prefix: filename_prefix = f"{ctx.features.request_filename_prefix}_{filename_prefix}" images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image] saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata({ "controlnet_style": "True", "controlnet_strength": str(strength), "controlnet_type": ctx.features.controlnet_type, }), store_bytes_prefix=ctx.features.request_filename_prefix) ctx.save_seed() return ctx def run_batched(self, ctx: Context, per_sample_info: list = None) -> dict: """Run batched multi-prompt generation. Args: ctx: Context with list of prompts per_sample_info: Per-sample overrides Returns: Dictionary mapping request_ids to results """ import uuid from src.FileManaging import ImageSaver from src.Utilities import Latent from src.sample import sampling from src.hidiffusion import msw_msa_attention from src.Processors import Img2Img self._check_interrupt() prompts = list(ctx.prompt) total_batch = len(prompts) per_sample_info = per_sample_info or [{} for _ in range(total_batch)] # Setup negatives if isinstance(ctx.negative_prompt, (list, tuple)): negatives = list(ctx.negative_prompt) else: negatives = [ctx.negative_prompt] * total_batch model = self._load_model(ctx) self._apply_optimizations(ctx, model) # Encode all prompts positive, negative = model.encode_prompt(prompts, negatives) # Add batch routing so positive and negative conditioning stay aligned. for cond_list in (positive, negative): if isinstance(cond_list, list): for i, entry in enumerate(cond_list): if len(entry) > 1 and isinstance(entry[1], dict): entry[1]["batch_index"] = [i] # Determine latent channels (SD1.5/SDXL=4, SD3/Flux1=16, Flux2=32) latent_channels = 4 try: lf = model.get_model_object("latent_format") if lf and hasattr(lf, "latent_channels"): latent_channels = lf.latent_channels except Exception: pass # Architecture flags for sampler is_flux = getattr(model.capabilities, "is_flux", False) or (latent_channels == 16) is_flux2 = getattr(model.capabilities, "is_flux2", False) or (latent_channels == 32) # Generate all latents with correct channel count latent_gen = Latent.EmptyLatentImage() latent = latent_gen.generate(ctx.width, ctx.height, total_batch, channels=latent_channels)[0] latent["seeds"] = ctx.seeds[:total_batch] # Apply HiDiffusion (multiscale) if enabled # CRITICAL: HiDiffusion MSW-MSA is for UNet (SD1.5/SDXL) only. # DiT models like Flux will suffer from tiling artifacts if patched. is_flux_or_flux2 = is_flux or is_flux2 if ctx.sampling.enable_multiscale and not is_flux_or_flux2: try: # Clone model before patching to avoid persistent state across batches base_inner = getattr(model, 'model', model) patch_model = base_inner.clone() if hasattr(base_inner, 'clone') else base_inner hidiff = msw_msa_attention.ApplyMSWMSAAttentionSimple() opt_model = hidiff.go(model_type="auto", model=patch_model)[0] if not hasattr(opt_model, "get_model_object") and hasattr(model, "get_model_object"): opt_model.get_model_object = model.get_model_object if not hasattr(opt_model, "load_device") and hasattr(model, "load_device"): opt_model.load_device = model.load_device except Exception as e: logger.warning(f"Failed to apply HiDiffusion: {e}") opt_model = model else: if ctx.sampling.enable_multiscale and is_flux_or_flux2: logger.info("HiDiffusion disabled: not compatible with Flux architecture") opt_model = model # Determine if refiner is enabled (SDXL only) is_sdxl = getattr(model.capabilities, "uses_dual_clip", False) use_refiner = bool( is_sdxl and ctx.generation.refiner_model_path and ctx.generation.refiner_switch_step is not None and 0 < ctx.generation.refiner_switch_step < ctx.sampling.steps ) ksampler = sampling.KSampler() # Distilled Flux2 Klein safety defaults # These models are extremely sensitive to CFG > 1.2 and work best with specific samplers if is_flux2: if ctx.sampling.cfg > 1.2: logger.info(f"Flux2 Klein detected: capping CFG from {ctx.sampling.cfg} to 1.0 for distilled quality") ctx.sampling.cfg = 1.0 if ctx.sampling.sampler not in ["euler", "euler_ancestral", "dpmpp_2m", "dpmpp_sde", "uni_pc"]: logger.info(f"Flux2 Klein detected: switching sampler to 'euler' for compatibility") ctx.sampling.sampler = "euler" batched_img2img_tensor = None batched_img2img_denoise = ctx.features.img2img_denoise if ctx.features.img2img and ctx.features.img2img_image: from PIL import Image import numpy as np input_image = Image.open(ctx.features.img2img_image).convert("RGB") target_size = (ctx.generation.width, ctx.generation.height) if input_image.size != target_size: input_image = input_image.resize(target_size, Image.Resampling.LANCZOS) input_array = np.array(input_image) batched_img2img_tensor = torch.from_numpy(input_array).float().cpu() / 255.0 batched_img2img_tensor = batched_img2img_tensor.unsqueeze(0).repeat(total_batch, 1, 1, 1) if getattr(model.capabilities, "requires_size_conditioning", False): for cond_list in (positive, negative): for cond_item in cond_list: if len(cond_item) > 1 and isinstance(cond_item[1], dict): cond_item[1].update({ "width": ctx.generation.width, "height": ctx.generation.height, "crop_w": 0, "crop_h": 0, "target_width": ctx.generation.width, "target_height": ctx.generation.height, }) if use_refiner: print(f"Batched Refiner enabled: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})") # Stage 1: Base model generation print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...") if batched_img2img_tensor is not None: batch_latents = ( Img2Img.simple_img2img( ctx, model, positive, negative, image_tensor=batched_img2img_tensor, denoise=batched_img2img_denoise, last_step=ctx.generation.refiner_switch_step, callback=ctx.callback, ), ) else: batch_latents = ksampler.sample( seed=None, steps=ctx.sampling.steps, cfg=ctx.sampling.cfg, sampler_name=ctx.sampling.sampler, scheduler=ctx.sampling.scheduler, denoise=1.0, pipeline=True, model=opt_model, positive=positive, negative=negative, latent_image=latent, last_step=ctx.generation.refiner_switch_step, enable_multiscale=ctx.sampling.enable_multiscale, multiscale_factor=ctx.sampling.multiscale_factor, multiscale_fullres_start=ctx.sampling.multiscale_fullres_start, multiscale_fullres_end=ctx.sampling.multiscale_fullres_end, cfg_free_enabled=ctx.sampling.cfg_free_enabled, cfg_free_start_percent=ctx.sampling.cfg_free_start_percent, flux=is_flux, flux2=is_flux2, callback=ctx.callback, ) self._check_interrupt() # Stage 2: Refiner model generation # Explicitly clear Stage 1 objects to free VRAM for refiner import gc if 'opt_model' in locals(): del opt_model if 'positive' in locals(): del positive if 'negative' in locals(): del negative # CRITICAL: The local variable 'model' still holds the Base model. # We must unload it and delete the reference so refcount hits 0. if 'model' in locals() and model is not None: model.unload() del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() refiner_model = self._load_refiner_model(ctx) # Skip optimizations if already applied (check model_function_wrapper) mo = getattr(refiner_model, 'model', None) mo_opts = getattr(mo, 'model_options', {}) if mo is not None else {} if not mo_opts.get("model_function_wrapper"): self._apply_optimizations(ctx, refiner_model) # Encode prompts for refiner ref_positive, ref_negative = refiner_model.encode_prompt(prompts, negatives) # Re-apply batch routing to refiner conditioning if needed if isinstance(ref_positive, list): for i, entry in enumerate(ref_positive): if len(entry) > 1 and isinstance(entry[1], dict): entry[1]["batch_index"] = [i] # Apply resolution conditioning for SDXL refiner if required if getattr(refiner_model.capabilities, 'requires_size_conditioning', False): for cond_list in [ref_positive, ref_negative]: for cond_item in cond_list: if len(cond_item) > 1 and isinstance(cond_item[1], dict): cond_item[1].update({ "width": ctx.generation.width, "height": ctx.generation.height, "crop_w": 0, "crop_h": 0, "target_width": ctx.generation.width, "target_height": ctx.generation.height, }) # HiDiffusion optimization for refiner: NEVER use multi-scale for refiner pass opt_refy = getattr(refiner_model, 'model', refiner_model) # Disable multi-scale for refiner pass orig_ms = ctx.sampling.enable_multiscale ctx.sampling.enable_multiscale = False steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step print(f"Stage 2: Running Refiner model ({steps_for_refiner}/{ctx.sampling.steps} steps)...") batch_latents = ksampler.sample( seed=None, steps=ctx.sampling.steps, cfg=ctx.sampling.cfg, sampler_name=ctx.sampling.sampler, scheduler=ctx.sampling.scheduler, denoise=1.0, pipeline=True, model=opt_refy, positive=ref_positive, negative=ref_negative, latent_image=batch_latents[0], start_step=ctx.generation.refiner_switch_step, disable_noise=True, callback=ctx.callback, cfg_free_enabled=ctx.sampling.cfg_free_enabled, cfg_free_start_percent=ctx.sampling.cfg_free_start_percent, ) ctx.sampling.enable_multiscale = orig_ms # Use refiner for decoding model = refiner_model else: # Normal single-stage generation if batched_img2img_tensor is not None: batch_latents = ( Img2Img.simple_img2img( ctx, model, positive, negative, image_tensor=batched_img2img_tensor, denoise=batched_img2img_denoise, callback=ctx.callback, ), ) else: batch_latents = ksampler.sample( seed=None, steps=ctx.sampling.steps, cfg=ctx.sampling.cfg, sampler_name=ctx.sampling.sampler, scheduler=ctx.sampling.scheduler, denoise=1.0, pipeline=True, model=opt_model, positive=positive, negative=negative, latent_image=latent, enable_multiscale=ctx.sampling.enable_multiscale, multiscale_factor=ctx.sampling.multiscale_factor, multiscale_fullres_start=ctx.sampling.multiscale_fullres_start, multiscale_fullres_end=ctx.sampling.multiscale_fullres_end, cfg_free_enabled=ctx.sampling.cfg_free_enabled, cfg_free_start_percent=ctx.sampling.cfg_free_start_percent, flux=is_flux, flux2=is_flux2, callback=ctx.callback, ) # Hires/Adetailer prompts - use refiner prompts if refiner was used if use_refiner: hf_pos = ref_positive hf_neg = ref_negative else: hf_pos = positive hf_neg = negative # Decode all images = model.decode(batch_latents[0]["samples"]) if AutoHDRProcessor.is_enabled(ctx): images = AutoHDRProcessor.apply(images, ctx) # If refiner was used, reload base model for ADetailer. # The refiner's UNet/CLIP is optimized for short refinement passes, # not for the text-guided crop enhancement that ADetailer performs. ad_model = model ad_pos = hf_pos ad_neg = hf_neg if use_refiner: needs_adetailer = any( (per_sample_info[j] if j < len(per_sample_info) else {}).get("adetailer", False) for j in range(total_batch) ) if needs_adetailer: ad_model = self._load_model(ctx) self._apply_optimizations(ctx, ad_model) ad_pos, ad_neg = ad_model.encode_prompt(prompts, negatives) if isinstance(ad_pos, list): for idx, entry in enumerate(ad_pos): if len(entry) > 1 and isinstance(entry[1], dict): entry[1]["batch_index"] = [idx] # Process individually saver = ImageSaver.SaveImage() results = {} for i in range(total_batch): self._check_interrupt() info = per_sample_info[i] if i < len(per_sample_info) else {} req_id = info.get("request_id", uuid.uuid4().hex[:8]) prefix = info.get("filename_prefix", f"LD-REQ-{req_id}") final = images[i] # Per-sample HiresFix if info.get("hires_fix", False): try: single_latent = {"samples": batch_latents[0]["samples"][i:i+1]} single_ctx = ctx.clone() single_ctx.seed = ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed # Default to the currently-loaded model (may be refiner) hires_model = model hires_pos = [hf_pos[i]] if isinstance(hf_pos, list) else hf_pos hires_neg = [hf_neg[i]] if isinstance(hf_neg, list) else hf_neg # If a refiner was used, prefer reloading the base model for the hires pass. # Attempt to reload + optimize the base model and re-encode the single-sample # prompts; fall back to existing behavior on any failure. if use_refiner: try: base_model = self._load_model(ctx) self._apply_optimizations(ctx, base_model) # Re-encode only the single sample for the reloaded base model single_pos, single_neg = base_model.encode_prompt([prompts[i]], [negatives[i]]) if isinstance(single_pos, list): single_pos = single_pos[0] single_neg = single_neg[0] hires_model = base_model hires_pos = [single_pos] if isinstance(hf_pos, list) else single_pos hires_neg = [single_neg] if isinstance(hf_neg, list) else single_neg except Exception: # If reload/encode fails, continue with the previously-loaded model hires_model = model hires_pos = [hf_pos[i]] if isinstance(hf_pos, list) else hf_pos hires_neg = [hf_neg[i]] if isinstance(hf_neg, list) else hf_neg hires = HiresFix.apply( single_latent, single_ctx, hires_model, hires_pos, hires_neg, callback=ctx.callback, ) final = hires_model.decode(hires["samples"])[0] if AutoHDRProcessor.is_enabled(ctx): final = AutoHDRProcessor.apply(final, ctx) except Exception as e: logger.warning(f"Batch hires_fix failed: {e}") # Per-sample Adetailer if info.get("adetailer", False): try: single_ctx = ctx.clone() single_ctx.seed = ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed final, saved = Adetailer.apply( final, single_ctx, ad_model, positive=[ad_pos[i]] if isinstance(ad_pos, list) else ad_pos, negative=[ad_neg[i]] if isinstance(ad_neg, list) else ad_neg, callback=ctx.callback ) for s in saved: results.setdefault(req_id, []).extend( s.get("ui", {}).get("images", [s]) ) except Exception as e: logger.warning(f"Batch adetailer failed: {e}") # Save meta = ctx.build_metadata({ "seed": str(ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed), "prompt": prompts[i], }) saved = saver.save_images([final], prefix, prompts[i], meta, store_bytes_prefix=prefix) results.setdefault(req_id, []).extend( saved.get("ui", {}).get("images", [saved]) ) return {"batched_results": results} def _clear_model_patches(self, model: AbstractModel) -> None: """Clear all patches from the model to ensure a clean state.""" if model and hasattr(model, "model") and model.model: # Clear transformer patches (HiDiffusion, etc.) if hasattr(model.model, "model_options"): to = model.model.model_options.get("transformer_options", {}) if "patches" in to: logger.debug(f"Clearing {len(to['patches'])} patches from model") to["patches"] = {} # Clear Token Merging if hasattr(model.model, "remove_tome"): model.model.remove_tome() def _load_model(self, ctx: Context) -> AbstractModel: """Load the model for this context. Uses ModelFactory for auto-detection when model_path is empty or set to the special __FLUX2_KLEIN__ marker. Optimized to reuse existing loaded model if it matches the request. """ path = ctx.model_path # 1. Determine target model type for reuse check from src.Core.Models.ModelFactory import detect_model_type target_type = "Flux2Klein" if path == "__FLUX2_KLEIN__" else detect_model_type(path) # 2. Check if current model can be reused if self._model is not None and self._model.is_loaded: current_type = self._model.__class__.__name__.replace("Model", "") # Match if paths are identical OR if both are Flux2 (auto-detected/marker) paths_match = (self._model.model_path == path) types_match = (current_type == target_type) if paths_match or (not path and types_match) or (path == "__FLUX2_KLEIN__" and target_type == "Flux2Klein" and types_match): logger.info(f"Reusing currently loaded {current_type} model") self._clear_model_patches(self._model) return self._model # 3. Different model requested: UNLOAD OLD ONE FIRST to free VRAM logger.info(f"Unloading {current_type} model to load {target_type}") self._model.unload() self._model = None # Clear prompt cache since the CLIP model is changing try: from src.Utilities.prompt_cache import clear_prompt_cache clear_prompt_cache() except Exception: pass # Force cleanup to prevent memory pressure/stuttering during transition import gc gc.collect() import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # 4. Create and load new model instance if path == "__FLUX2_KLEIN__": # Explicitly request Flux2 Klein model = self.model_factory(model_path=None, model_type="Flux2Klein") elif not path: # Auto-detect model type (may detect Flux2 components) model = self.model_factory(model_path=None) else: # Specific checkpoint path provided model = self.model_factory(model_path=path) model.load() self._model = model return model def _load_refiner_model(self, ctx: Context) -> AbstractModel: """Load the refiner model for this context. Optimized to reuse existing loaded model if it matches the refiner path. """ path = ctx.generation.refiner_model_path if not path: raise ValueError("refiner_model_path is required for refiner pass") # 1. Determine target model type from src.Core.Models.ModelFactory import detect_model_type target_type = detect_model_type(path) # 2. Check if current model can be reused if self._model is not None and self._model.is_loaded: if self._model.model_path == path: logger.info(f"Reusing currently loaded model as refiner") self._clear_model_patches(self._model) return self._model # 3. Different model requested: UNLOAD OLD ONE FIRST to free VRAM logger.info(f"Unloading current model to load refiner {target_type}") self._model.unload() # self._model = None # Don't set to None yet, we'll replace it # Force cleanup import gc gc.collect() import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # 4. Create and load new model instance model = self.model_factory(model_path=path) model.load() self._model = model return model def _apply_optimizations(self, ctx: Context, model: AbstractModel) -> None: """Apply all configured optimizations to the model.""" self._apply_runtime_preferences(ctx, model) # LoRA - only if model supports it and matches default LoRA type # Default LoRA (add_detail) is SD1.5 (context_dim 768) is_sd15 = False try: is_sd15 = model.get_model_object("context_dim") == 768 except Exception: pass if self.default_lora and getattr(model.capabilities, 'supports_lora', True): # Only apply default detailing LoRA to SD1.5 models if not is_sd15 and self.default_lora[0] == "add_detail.safetensors": logger.debug(f"Skipping default SD1.5 LoRA for non-SD1.5 model") else: try: model.apply_lora(*self.default_lora) except Exception as e: logger.warning(f"LoRA failed: {e}") # StableFast and torch.compile are mutually exclusive if ctx.generation.stable_fast: model.apply_stable_fast(enable_cuda_graph=True) elif ctx.generation.torch_compile: model.apply_torch_compile() # FP8 quantization (hardware-gated, applies independently) if ctx.generation.fp8_inference or ctx.generation.weight_quantization == "fp8": model.apply_fp8() elif ctx.generation.weight_quantization == "nvfp4": model.apply_nvfp4() # Token Merging (ToMe) if ctx.sampling.tome_enabled and getattr(model.capabilities, 'supports_tome', True): try: if hasattr(model.model, 'apply_tome'): model.model.apply_tome( ratio=ctx.sampling.tome_ratio, max_downsample=ctx.sampling.tome_max_downsample, ) except Exception as e: logger.warning(f"ToMe application failed: {e}") # DeepCache if ctx.sampling.deepcache_enabled: model.apply_deepcache( ctx.sampling.deepcache_interval, ctx.sampling.deepcache_depth, ctx.sampling.deepcache_start_step, ctx.sampling.deepcache_end_step, ) def _encode_prompts(self, ctx: Context, model: AbstractModel) -> tuple[Any, Any]: """Encode prompts to conditioning tensors.""" return model.encode_prompt(ctx.prompt, ctx.negative_prompt) def _check_interrupt(self) -> None: """Check for user interrupt.""" from src.user import app_instance app = getattr(app_instance, "app", None) if app and getattr(app, "interrupt_flag", False): raise InterruptedError("Generation interrupted") # Singleton default pipeline _default_pipeline: Optional[Pipeline] = None def get_default_pipeline() -> Pipeline: """Get the default pipeline instance.""" global _default_pipeline if _default_pipeline is None: _default_pipeline = Pipeline() return _default_pipeline def reset_default_pipeline() -> None: """Release the singleton pipeline and any loaded model it still owns.""" global _default_pipeline if _default_pipeline is not None: try: if _default_pipeline._model is not None and _default_pipeline._model.is_loaded: _default_pipeline._model.unload() except Exception: pass _default_pipeline = None