LightDiffusion-Next / src /Core /Pipeline.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Core Pipeline orchestrator for LightDiffusion-Next.
This module provides the main Pipeline class - a clean, linear orchestrator
that coordinates model loading, generation, and post-processing.
The Pipeline is designed to be:
- Simple: <100 lines of core logic
- Modular: Delegates to Models and Processors
- Extensible: Easy to add new processing steps
Architecture:
[Context] -> [Load Model] -> [Encode] -> [Generate] -> [Decode] -> [Processors] -> [Result]
"""
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
import torch
from src.Core.Context import Context
from src.Core.Models import create_model
from src.Core.AbstractModel import AbstractModel
from src.Processors import HiresFix, Adetailer, AutoHDRProcessor
logger = logging.getLogger(__name__)
@dataclass
class PipelineResult:
"""Result of a pipeline run."""
images: list[torch.Tensor] = field(default_factory=list)
latents: Optional[torch.Tensor] = None
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary for legacy compatibility."""
return {
"images": self.images,
"latents": self.latents,
**self.metadata,
}
class Pipeline:
"""Main generation pipeline orchestrator.
This class coordinates the entire generation flow in a clean,
linear manner. Each step is isolated and the Context flows through.
Usage:
ctx = Context(prompt="a cat", width=512, height=512)
pipeline = Pipeline()
result = pipeline.run(ctx)
"""
def __init__(
self,
model_factory: Callable[[str], AbstractModel] = None,
default_lora: Optional[tuple[str, float, float]] = ("add_detail.safetensors", 0.7, 0.7),
):
"""Initialize the pipeline.
Args:
model_factory: Function to create models (default: create_model)
default_lora: Default LoRA to apply (name, model_str, clip_str) or None
"""
self.model_factory = model_factory or create_model
self.default_lora = default_lora
self._model: Optional[AbstractModel] = None
def _apply_runtime_preferences(self, ctx: Context, model: AbstractModel) -> None:
"""Apply request-scoped runtime preferences that should track reused models."""
model.set_vae_autotune(ctx.generation.vae_autotune)
def run(self, ctx: Context) -> Context:
"""Run the full generation pipeline.
Args:
ctx: Configured Context with all parameters
Returns:
Context with generated images in current_image
"""
self._check_interrupt()
# 1. Load base model
model = self._load_model(ctx)
self._apply_runtime_preferences(ctx, model)
# 2. Apply optimizations to base model
mo = getattr(model, 'model', None)
mo_opts = getattr(mo, 'model_options', {}) if mo is not None else {}
if not mo_opts.get("model_function_wrapper"):
self._apply_optimizations(ctx, model)
# 3. Encode prompts for base model
positive, negative = self._encode_prompts(ctx, model)
ctx.positive_cond = positive
ctx.negative_cond = negative
# 4. Handle refiner preparation if enabled (SDXL only)
refiner_model = None
ref_positive, ref_negative = None, None
is_sdxl = getattr(model.capabilities, "uses_dual_clip", False)
use_refiner = bool(
is_sdxl and
ctx.generation.refiner_model_path and
ctx.generation.refiner_switch_step is not None and
0 < ctx.generation.refiner_switch_step < ctx.sampling.steps
)
if use_refiner:
print(f"Refiner enabled: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})")
# We don't load it yet to save VRAM, but we need to know if we should unload base later
# 5. Generate for each seed
from src.FileManaging import ImageSaver
saver = ImageSaver.SaveImage()
for i, seed in enumerate(ctx.seeds[:ctx.generation.number]):
self._check_interrupt()
ctx.seed = seed
# Stage 1: Base model generation
if use_refiner:
steps_for_base = ctx.generation.refiner_switch_step
print(f"Stage 1: Running Base model ({steps_for_base}/{ctx.sampling.steps} steps)...")
latents = model.generate(
ctx, positive, negative,
last_step=ctx.generation.refiner_switch_step,
callback=ctx.callback
)
else:
latents = model.generate(ctx, positive, negative, callback=ctx.callback)
ctx.current_latents = latents["samples"]
# Stage 2: Refiner model generation
if use_refiner:
self._check_interrupt()
# Load refiner model (this will unload base model if necessary)
refiner_model = self._load_refiner_model(ctx)
self._apply_optimizations(ctx, refiner_model)
# Encode prompts for refiner (it has different CLIP)
ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model)
# Disable multi-scale for refiner pass (always)
orig_ms = ctx.sampling.enable_multiscale
ctx.sampling.enable_multiscale = False
steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step
print(f"Stage 2: Running Refiner model ({steps_for_refiner}/{ctx.sampling.steps} steps)...")
latents = refiner_model.generate(
ctx, ref_positive, ref_negative,
latent_image=latents,
start_step=ctx.generation.refiner_switch_step,
disable_noise=True,
callback=ctx.callback
)
ctx.current_latents = latents["samples"]
ctx.sampling.enable_multiscale = orig_ms
# If we have more seeds, we'll need to reload base model in the next iteration
# _load_model handles this automatically
# Decode latents to image
ctx.current_image = model.decode(ctx.current_latents)
# 6. Post-processing
# Apply HiresFix if enabled. Prefer running hires pass with the base model
# and base prompts for consistency; using a refiner for the hires pass can
# introduce artifacts because its UNet/CLIP can differ from the base model.
current_model = model
# Prefer base prompts for hires pass (refiner prompts tend to mismatch)
hf_pos = positive
hf_neg = negative
if HiresFix.is_enabled(ctx):
self._check_interrupt()
logger.info(f"HiresFix: using base model for hires pass (use_refiner={use_refiner})")
# If a refiner was used earlier we may have unloaded the base model to free VRAM.
# Ensure the base model is reloaded and optimized before running the hires pass so
# downstream code (sampler / CFGGuider) can access model.model_options etc.
if use_refiner and (not model.is_loaded or getattr(model, "model", None) is None):
logger.info("HiresFix: reloading base model for hires pass (was unloaded by refiner)")
model = self._load_model(ctx)
# Re-apply optimizations (LoRA / StableFast / FP8 / DeepCache) to the reloaded model
self._apply_optimizations(ctx, model)
# Re-encode prompts for the reloaded base model to ensure conditioning matches
try:
hf_pos, hf_neg = self._encode_prompts(ctx, model)
except Exception:
# Fallback to previously-encoded conditioning if re-encoding fails
hf_pos, hf_neg = hf_pos, hf_neg
current_model = model
# HiresFix might still need base model prompts if it was trained on them
latents = HiresFix.apply(latents, ctx, current_model, hf_pos, hf_neg, callback=ctx.callback)
ctx.current_latents = latents["samples"]
if AutoHDRProcessor.is_enabled(ctx):
self._check_interrupt()
ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx)
# Apply Adetailer if enabled (handles its own saving)
if Adetailer.is_enabled(ctx):
self._check_interrupt()
if use_refiner:
# Reload base model for ADetailer - the refiner's UNet/CLIP
# is not suited for text-guided crop enhancement
ad_model = self._load_model(ctx)
ad_pos, ad_neg = self._encode_prompts(ctx, ad_model)
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, ad_model,
positive=ad_pos, negative=ad_neg,
callback=ctx.callback
)
else:
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, current_model,
positive=hf_pos, negative=hf_neg,
callback=ctx.callback
)
else:
# Save the image synchronously so the server can reliably find it
prefix = "LD-HF" if ctx.features.hires_fix else "LD"
filename_prefix = f"{ctx.features.request_filename_prefix}_{prefix}" if ctx.features.request_filename_prefix else prefix
images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image]
saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata(), store_bytes_prefix=ctx.features.request_filename_prefix)
ctx.save_seed()
return ctx
def run_img2img(self, ctx: Context) -> Context:
"""Run image-to-image generation pipeline.
Supports two modes:
1. Upscale mode: When target dimensions are larger than input (uses USDU)
2. Diffusion mode: True img2img with denoising strength (uses simple_img2img)
Args:
ctx: Context with img2img_image set
Returns:
Context with generated images
"""
from src.Processors import Img2Img
from src.FileManaging import ImageSaver
from PIL import Image
import numpy as np
import torch
self._check_interrupt()
model = self._load_model(ctx)
self._apply_optimizations(ctx, model)
positive, negative = self._encode_prompts(ctx, model)
saver = ImageSaver.SaveImage()
# Load input image to determine mode
img_path = ctx.features.img2img_image
if not img_path:
raise ValueError("No input image provided for img2img")
img = Image.open(img_path)
input_w, input_h = img.size
target_w, target_h = ctx.generation.width, ctx.generation.height
# Convert image to tensor [B, H, W, C]
img_array = np.array(img.convert("RGB"))
img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0
if img_tensor.dim() == 3:
img_tensor = img_tensor.unsqueeze(0)
# Determine mode: upscale if target is larger, otherwise diffusion
use_upscale = (target_w > input_w * 1.1) or (target_h > input_h * 1.1)
denoise = ctx.features.img2img_denoise
# Inject SDXL size conditioning if required
if getattr(model.capabilities, 'requires_size_conditioning', False):
for cond_list in [positive, negative]:
for cond_item in cond_list:
if len(cond_item) > 1 and isinstance(cond_item[1], dict):
cond_item[1].update({
"width": target_w,
"height": target_h,
"crop_w": 0,
"crop_h": 0,
"target_width": target_w,
"target_height": target_h,
})
logger.info(f"Img2Img: input={input_w}x{input_h}, target={target_w}x{target_h}, denoise={denoise:.2f}, mode={'upscale' if use_upscale else 'diffusion'}")
for seed in ctx.seeds[:ctx.generation.number]:
self._check_interrupt()
ctx.seed = seed
if use_upscale:
# Use USDU upscaler (existing behavior)
# Higher LoRA strength for img2img upscaling
if self.default_lora and getattr(model.capabilities, 'supports_lora', True):
try:
model.apply_lora(self.default_lora[0], 2.0, 2.0)
except Exception as e:
logger.warning(f"LoRA failed: {e}")
result = Img2Img.apply(ctx, model, positive, negative, image_tensor=img_tensor, denoise=denoise, callback=ctx.callback)
ctx.current_image = result
else:
# True diffusion-based img2img with denoising strength
# Resize input image to target dimensions if different
if input_w != target_w or input_h != target_h:
resized_img = img.resize((target_w, target_h), Image.Resampling.LANCZOS)
img_array = np.array(resized_img.convert("RGB"))
img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0
if img_tensor.dim() == 3:
img_tensor = img_tensor.unsqueeze(0)
# Check if refiner is enabled BEFORE running base model (SDXL only)
is_sdxl = getattr(model.capabilities, "uses_dual_clip", False)
use_refiner = bool(
is_sdxl and
ctx.generation.refiner_model_path and
ctx.generation.refiner_switch_step is not None and
0 < ctx.generation.refiner_switch_step < ctx.sampling.steps
)
refiner_model = None
ref_negative = None
base_last_step = ctx.generation.refiner_switch_step if use_refiner else None
if use_refiner:
print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...")
# Run simple_img2img for true diffusion-based generation
latents = Img2Img.simple_img2img(
ctx, model, positive, negative,
image_tensor=img_tensor,
denoise=denoise,
last_step=base_last_step,
callback=ctx.callback,
)
ctx.current_latents = latents["samples"]
# Apply refiner if enabled
if use_refiner:
self._check_interrupt()
# Load refiner model
refiner_model = self._load_refiner_model(ctx)
self._apply_optimizations(ctx, refiner_model)
# Encode prompts for refiner (it has different CLIP)
ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model)
# Disable multi-scale for refiner pass
orig_ms = ctx.sampling.enable_multiscale
ctx.sampling.enable_multiscale = False
steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step
print(f"Img2Img Refiner: Running {steps_for_refiner}/{ctx.sampling.steps} steps...")
refiner_latents = refiner_model.generate(
ctx, ref_positive, ref_negative,
latent_image=latents,
start_step=ctx.generation.refiner_switch_step,
disable_noise=True,
callback=ctx.callback
)
ctx.current_latents = refiner_latents["samples"]
ctx.sampling.enable_multiscale = orig_ms
# Decode using refiner's VAE
image = refiner_model.decode(ctx.current_latents)
else:
# Decode to image using base model
image = model.decode(ctx.current_latents)
ctx.current_image = image
# Apply Adetailer if enabled
from src.Processors import Adetailer
if Adetailer.is_enabled(ctx):
self._check_interrupt()
if not use_upscale and use_refiner:
# Reload base model for ADetailer - the refiner's UNet/CLIP
# is not suited for text-guided crop enhancement
ad_model = self._load_model(ctx)
ad_pos, ad_neg = self._encode_prompts(ctx, ad_model)
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, ad_model,
positive=ad_pos, negative=ad_neg,
callback=ctx.callback
)
else:
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, model,
positive=positive, negative=negative,
callback=ctx.callback
)
# Apply AutoHDR if enabled
if AutoHDRProcessor.is_enabled(ctx):
ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx)
# Save the image with metadata including denoise value
filename_prefix = "LD-I2I"
if ctx.features.request_filename_prefix:
filename_prefix = f"{ctx.features.request_filename_prefix}_{filename_prefix}"
images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image]
saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata({
"img2img": "True",
"img2img_denoise": str(denoise),
"img2img_mode": "upscale" if use_upscale else "diffusion",
}), store_bytes_prefix=ctx.features.request_filename_prefix)
ctx.save_seed()
return ctx
def run_controlnet(self, ctx: Context) -> Context:
"""Run ControlNet-style generation using Canny edges + img2img.
This uses edge detection to preserve structure while allowing
color and content changes via high-denoise img2img.
Args:
ctx: Context with controlnet_model, img2img_image set
Returns:
Context with generated images
"""
from src.Processors import ControlNet as CNProcessor
from src.FileManaging import ImageSaver
from PIL import Image
import numpy as np
self._check_interrupt()
# Validate inputs
if not ctx.features.img2img_image:
raise ValueError("No input image provided for ControlNet")
model = self._load_model(ctx)
self._apply_optimizations(ctx, model)
# Load and preprocess input image
img_path = ctx.features.img2img_image
img = Image.open(img_path)
img = img.resize((ctx.generation.width, ctx.generation.height), Image.Resampling.LANCZOS)
# Convert to tensor [B, H, W, C]
img_array = np.array(img.convert("RGB"))
img_tensor = torch.from_numpy(img_array).float().cpu() / 255.0
if img_tensor.dim() == 3:
img_tensor = img_tensor.unsqueeze(0)
# Apply preprocessor (Canny edge detection by default)
control_image = CNProcessor.ControlNetProcessor.preprocess_image(
img_tensor,
preprocessor=ctx.features.controlnet_type,
)
strength = ctx.features.controlnet_strength
logger.info(f"ControlNet-style: {ctx.features.controlnet_type} edges, strength={strength}")
# Encode prompts
positive, negative = self._encode_prompts(ctx, model)
saver = ImageSaver.SaveImage()
is_flux2 = getattr(model.capabilities, "is_flux2", False)
# Check if refiner is enabled (SDXL only)
is_sdxl = getattr(model.capabilities, "uses_dual_clip", False)
use_refiner = bool(
is_sdxl and
ctx.generation.refiner_model_path and
ctx.generation.refiner_switch_step is not None and
0 < ctx.generation.refiner_switch_step < ctx.sampling.steps
)
refiner_model = None
ref_negative = None
if use_refiner:
print(f"Refiner enabled for ControlNet: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})")
for seed in ctx.seeds[:ctx.generation.number]:
self._check_interrupt()
ctx.seed = seed
# Use the Canny+img2img approach, passing original image for blending
# When refiner is enabled, stop base model at refiner switch step
base_last_step = ctx.generation.refiner_switch_step if use_refiner else None
if use_refiner:
print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...")
latents, ctx = CNProcessor.apply_controlnet_to_img2img(
ctx, model, positive, negative,
control_image=control_image,
strength=strength,
original_image=img_tensor,
last_step=base_last_step,
callback=ctx.callback,
)
ctx.current_latents = latents["samples"]
# Apply refiner if enabled
if use_refiner:
self._check_interrupt()
# Load refiner model
refiner_model = self._load_refiner_model(ctx)
self._apply_optimizations(ctx, refiner_model)
# Encode prompts for refiner (it has different CLIP)
ref_positive, ref_negative = self._encode_prompts(ctx, refiner_model)
# Disable multi-scale for refiner pass
orig_ms = ctx.sampling.enable_multiscale
ctx.sampling.enable_multiscale = False
steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step
print(f"ControlNet Refiner: Running {steps_for_refiner}/{ctx.sampling.steps} steps...")
refiner_latents = refiner_model.generate(
ctx, ref_positive, ref_negative,
latent_image=latents,
start_step=ctx.generation.refiner_switch_step,
disable_noise=True,
callback=ctx.callback
)
ctx.current_latents = refiner_latents["samples"]
ctx.sampling.enable_multiscale = orig_ms
# Decode using refiner's VAE
image = refiner_model.decode(ctx.current_latents)
else:
# Decode to image using base model
image = model.decode(ctx.current_latents)
ctx.current_image = image
# Apply Adetailer if enabled
from src.Processors import Adetailer
if Adetailer.is_enabled(ctx):
self._check_interrupt()
if use_refiner:
# Reload base model for ADetailer - the refiner's UNet/CLIP
# is not suited for text-guided crop enhancement
ad_model = self._load_model(ctx)
ad_pos, ad_neg = self._encode_prompts(ctx, ad_model)
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, ad_model,
positive=ad_pos, negative=ad_neg,
callback=ctx.callback
)
else:
ctx.current_image, _ = Adetailer.apply(
ctx.current_image, ctx, model,
positive=positive, negative=negative,
callback=ctx.callback
)
# Apply AutoHDR if enabled
if AutoHDRProcessor.is_enabled(ctx):
ctx.current_image = AutoHDRProcessor.apply(ctx.current_image, ctx)
# Save with metadata
filename_prefix = "LD-CN"
if ctx.features.request_filename_prefix:
filename_prefix = f"{ctx.features.request_filename_prefix}_{filename_prefix}"
images = ctx.current_image if isinstance(ctx.current_image, list) else [ctx.current_image]
saver.save_images(images, filename_prefix=filename_prefix, prompt=str(ctx.prompt), extra_pnginfo=ctx.build_metadata({
"controlnet_style": "True",
"controlnet_strength": str(strength),
"controlnet_type": ctx.features.controlnet_type,
}), store_bytes_prefix=ctx.features.request_filename_prefix)
ctx.save_seed()
return ctx
def run_batched(self, ctx: Context, per_sample_info: list = None) -> dict:
"""Run batched multi-prompt generation.
Args:
ctx: Context with list of prompts
per_sample_info: Per-sample overrides
Returns:
Dictionary mapping request_ids to results
"""
import uuid
from src.FileManaging import ImageSaver
from src.Utilities import Latent
from src.sample import sampling
from src.hidiffusion import msw_msa_attention
from src.Processors import Img2Img
self._check_interrupt()
prompts = list(ctx.prompt)
total_batch = len(prompts)
per_sample_info = per_sample_info or [{} for _ in range(total_batch)]
# Setup negatives
if isinstance(ctx.negative_prompt, (list, tuple)):
negatives = list(ctx.negative_prompt)
else:
negatives = [ctx.negative_prompt] * total_batch
model = self._load_model(ctx)
self._apply_optimizations(ctx, model)
# Encode all prompts
positive, negative = model.encode_prompt(prompts, negatives)
# Add batch routing so positive and negative conditioning stay aligned.
for cond_list in (positive, negative):
if isinstance(cond_list, list):
for i, entry in enumerate(cond_list):
if len(entry) > 1 and isinstance(entry[1], dict):
entry[1]["batch_index"] = [i]
# Determine latent channels (SD1.5/SDXL=4, SD3/Flux1=16, Flux2=32)
latent_channels = 4
try:
lf = model.get_model_object("latent_format")
if lf and hasattr(lf, "latent_channels"):
latent_channels = lf.latent_channels
except Exception:
pass
# Architecture flags for sampler
is_flux = getattr(model.capabilities, "is_flux", False) or (latent_channels == 16)
is_flux2 = getattr(model.capabilities, "is_flux2", False) or (latent_channels == 32)
# Generate all latents with correct channel count
latent_gen = Latent.EmptyLatentImage()
latent = latent_gen.generate(ctx.width, ctx.height, total_batch, channels=latent_channels)[0]
latent["seeds"] = ctx.seeds[:total_batch]
# Apply HiDiffusion (multiscale) if enabled
# CRITICAL: HiDiffusion MSW-MSA is for UNet (SD1.5/SDXL) only.
# DiT models like Flux will suffer from tiling artifacts if patched.
is_flux_or_flux2 = is_flux or is_flux2
if ctx.sampling.enable_multiscale and not is_flux_or_flux2:
try:
# Clone model before patching to avoid persistent state across batches
base_inner = getattr(model, 'model', model)
patch_model = base_inner.clone() if hasattr(base_inner, 'clone') else base_inner
hidiff = msw_msa_attention.ApplyMSWMSAAttentionSimple()
opt_model = hidiff.go(model_type="auto", model=patch_model)[0]
if not hasattr(opt_model, "get_model_object") and hasattr(model, "get_model_object"):
opt_model.get_model_object = model.get_model_object
if not hasattr(opt_model, "load_device") and hasattr(model, "load_device"):
opt_model.load_device = model.load_device
except Exception as e:
logger.warning(f"Failed to apply HiDiffusion: {e}")
opt_model = model
else:
if ctx.sampling.enable_multiscale and is_flux_or_flux2:
logger.info("HiDiffusion disabled: not compatible with Flux architecture")
opt_model = model
# Determine if refiner is enabled (SDXL only)
is_sdxl = getattr(model.capabilities, "uses_dual_clip", False)
use_refiner = bool(
is_sdxl and
ctx.generation.refiner_model_path and
ctx.generation.refiner_switch_step is not None and
0 < ctx.generation.refiner_switch_step < ctx.sampling.steps
)
ksampler = sampling.KSampler()
# Distilled Flux2 Klein safety defaults
# These models are extremely sensitive to CFG > 1.2 and work best with specific samplers
if is_flux2:
if ctx.sampling.cfg > 1.2:
logger.info(f"Flux2 Klein detected: capping CFG from {ctx.sampling.cfg} to 1.0 for distilled quality")
ctx.sampling.cfg = 1.0
if ctx.sampling.sampler not in ["euler", "euler_ancestral", "dpmpp_2m", "dpmpp_sde", "uni_pc"]:
logger.info(f"Flux2 Klein detected: switching sampler to 'euler' for compatibility")
ctx.sampling.sampler = "euler"
batched_img2img_tensor = None
batched_img2img_denoise = ctx.features.img2img_denoise
if ctx.features.img2img and ctx.features.img2img_image:
from PIL import Image
import numpy as np
input_image = Image.open(ctx.features.img2img_image).convert("RGB")
target_size = (ctx.generation.width, ctx.generation.height)
if input_image.size != target_size:
input_image = input_image.resize(target_size, Image.Resampling.LANCZOS)
input_array = np.array(input_image)
batched_img2img_tensor = torch.from_numpy(input_array).float().cpu() / 255.0
batched_img2img_tensor = batched_img2img_tensor.unsqueeze(0).repeat(total_batch, 1, 1, 1)
if getattr(model.capabilities, "requires_size_conditioning", False):
for cond_list in (positive, negative):
for cond_item in cond_list:
if len(cond_item) > 1 and isinstance(cond_item[1], dict):
cond_item[1].update({
"width": ctx.generation.width,
"height": ctx.generation.height,
"crop_w": 0,
"crop_h": 0,
"target_width": ctx.generation.width,
"target_height": ctx.generation.height,
})
if use_refiner:
print(f"Batched Refiner enabled: {os.path.basename(ctx.generation.refiner_model_path)} (Switch at step {ctx.generation.refiner_switch_step})")
# Stage 1: Base model generation
print(f"Stage 1: Running Base model ({ctx.generation.refiner_switch_step}/{ctx.sampling.steps} steps)...")
if batched_img2img_tensor is not None:
batch_latents = (
Img2Img.simple_img2img(
ctx,
model,
positive,
negative,
image_tensor=batched_img2img_tensor,
denoise=batched_img2img_denoise,
last_step=ctx.generation.refiner_switch_step,
callback=ctx.callback,
),
)
else:
batch_latents = ksampler.sample(
seed=None,
steps=ctx.sampling.steps,
cfg=ctx.sampling.cfg,
sampler_name=ctx.sampling.sampler,
scheduler=ctx.sampling.scheduler,
denoise=1.0,
pipeline=True,
model=opt_model,
positive=positive,
negative=negative,
latent_image=latent,
last_step=ctx.generation.refiner_switch_step,
enable_multiscale=ctx.sampling.enable_multiscale,
multiscale_factor=ctx.sampling.multiscale_factor,
multiscale_fullres_start=ctx.sampling.multiscale_fullres_start,
multiscale_fullres_end=ctx.sampling.multiscale_fullres_end,
cfg_free_enabled=ctx.sampling.cfg_free_enabled,
cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
flux=is_flux,
flux2=is_flux2,
callback=ctx.callback,
)
self._check_interrupt()
# Stage 2: Refiner model generation
# Explicitly clear Stage 1 objects to free VRAM for refiner
import gc
if 'opt_model' in locals(): del opt_model
if 'positive' in locals(): del positive
if 'negative' in locals(): del negative
# CRITICAL: The local variable 'model' still holds the Base model.
# We must unload it and delete the reference so refcount hits 0.
if 'model' in locals() and model is not None:
model.unload()
del model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
refiner_model = self._load_refiner_model(ctx)
# Skip optimizations if already applied (check model_function_wrapper)
mo = getattr(refiner_model, 'model', None)
mo_opts = getattr(mo, 'model_options', {}) if mo is not None else {}
if not mo_opts.get("model_function_wrapper"):
self._apply_optimizations(ctx, refiner_model)
# Encode prompts for refiner
ref_positive, ref_negative = refiner_model.encode_prompt(prompts, negatives)
# Re-apply batch routing to refiner conditioning if needed
if isinstance(ref_positive, list):
for i, entry in enumerate(ref_positive):
if len(entry) > 1 and isinstance(entry[1], dict):
entry[1]["batch_index"] = [i]
# Apply resolution conditioning for SDXL refiner if required
if getattr(refiner_model.capabilities, 'requires_size_conditioning', False):
for cond_list in [ref_positive, ref_negative]:
for cond_item in cond_list:
if len(cond_item) > 1 and isinstance(cond_item[1], dict):
cond_item[1].update({
"width": ctx.generation.width,
"height": ctx.generation.height,
"crop_w": 0,
"crop_h": 0,
"target_width": ctx.generation.width,
"target_height": ctx.generation.height,
})
# HiDiffusion optimization for refiner: NEVER use multi-scale for refiner pass
opt_refy = getattr(refiner_model, 'model', refiner_model)
# Disable multi-scale for refiner pass
orig_ms = ctx.sampling.enable_multiscale
ctx.sampling.enable_multiscale = False
steps_for_refiner = ctx.sampling.steps - ctx.generation.refiner_switch_step
print(f"Stage 2: Running Refiner model ({steps_for_refiner}/{ctx.sampling.steps} steps)...")
batch_latents = ksampler.sample(
seed=None,
steps=ctx.sampling.steps,
cfg=ctx.sampling.cfg,
sampler_name=ctx.sampling.sampler,
scheduler=ctx.sampling.scheduler,
denoise=1.0,
pipeline=True,
model=opt_refy,
positive=ref_positive,
negative=ref_negative,
latent_image=batch_latents[0],
start_step=ctx.generation.refiner_switch_step,
disable_noise=True,
callback=ctx.callback,
cfg_free_enabled=ctx.sampling.cfg_free_enabled,
cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
)
ctx.sampling.enable_multiscale = orig_ms
# Use refiner for decoding
model = refiner_model
else:
# Normal single-stage generation
if batched_img2img_tensor is not None:
batch_latents = (
Img2Img.simple_img2img(
ctx,
model,
positive,
negative,
image_tensor=batched_img2img_tensor,
denoise=batched_img2img_denoise,
callback=ctx.callback,
),
)
else:
batch_latents = ksampler.sample(
seed=None,
steps=ctx.sampling.steps,
cfg=ctx.sampling.cfg,
sampler_name=ctx.sampling.sampler,
scheduler=ctx.sampling.scheduler,
denoise=1.0,
pipeline=True,
model=opt_model,
positive=positive,
negative=negative,
latent_image=latent,
enable_multiscale=ctx.sampling.enable_multiscale,
multiscale_factor=ctx.sampling.multiscale_factor,
multiscale_fullres_start=ctx.sampling.multiscale_fullres_start,
multiscale_fullres_end=ctx.sampling.multiscale_fullres_end,
cfg_free_enabled=ctx.sampling.cfg_free_enabled,
cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
flux=is_flux,
flux2=is_flux2,
callback=ctx.callback,
)
# Hires/Adetailer prompts - use refiner prompts if refiner was used
if use_refiner:
hf_pos = ref_positive
hf_neg = ref_negative
else:
hf_pos = positive
hf_neg = negative
# Decode all
images = model.decode(batch_latents[0]["samples"])
if AutoHDRProcessor.is_enabled(ctx):
images = AutoHDRProcessor.apply(images, ctx)
# If refiner was used, reload base model for ADetailer.
# The refiner's UNet/CLIP is optimized for short refinement passes,
# not for the text-guided crop enhancement that ADetailer performs.
ad_model = model
ad_pos = hf_pos
ad_neg = hf_neg
if use_refiner:
needs_adetailer = any(
(per_sample_info[j] if j < len(per_sample_info) else {}).get("adetailer", False)
for j in range(total_batch)
)
if needs_adetailer:
ad_model = self._load_model(ctx)
self._apply_optimizations(ctx, ad_model)
ad_pos, ad_neg = ad_model.encode_prompt(prompts, negatives)
if isinstance(ad_pos, list):
for idx, entry in enumerate(ad_pos):
if len(entry) > 1 and isinstance(entry[1], dict):
entry[1]["batch_index"] = [idx]
# Process individually
saver = ImageSaver.SaveImage()
results = {}
for i in range(total_batch):
self._check_interrupt()
info = per_sample_info[i] if i < len(per_sample_info) else {}
req_id = info.get("request_id", uuid.uuid4().hex[:8])
prefix = info.get("filename_prefix", f"LD-REQ-{req_id}")
final = images[i]
# Per-sample HiresFix
if info.get("hires_fix", False):
try:
single_latent = {"samples": batch_latents[0]["samples"][i:i+1]}
single_ctx = ctx.clone()
single_ctx.seed = ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed
# Default to the currently-loaded model (may be refiner)
hires_model = model
hires_pos = [hf_pos[i]] if isinstance(hf_pos, list) else hf_pos
hires_neg = [hf_neg[i]] if isinstance(hf_neg, list) else hf_neg
# If a refiner was used, prefer reloading the base model for the hires pass.
# Attempt to reload + optimize the base model and re-encode the single-sample
# prompts; fall back to existing behavior on any failure.
if use_refiner:
try:
base_model = self._load_model(ctx)
self._apply_optimizations(ctx, base_model)
# Re-encode only the single sample for the reloaded base model
single_pos, single_neg = base_model.encode_prompt([prompts[i]], [negatives[i]])
if isinstance(single_pos, list):
single_pos = single_pos[0]
single_neg = single_neg[0]
hires_model = base_model
hires_pos = [single_pos] if isinstance(hf_pos, list) else single_pos
hires_neg = [single_neg] if isinstance(hf_neg, list) else single_neg
except Exception:
# If reload/encode fails, continue with the previously-loaded model
hires_model = model
hires_pos = [hf_pos[i]] if isinstance(hf_pos, list) else hf_pos
hires_neg = [hf_neg[i]] if isinstance(hf_neg, list) else hf_neg
hires = HiresFix.apply(
single_latent, single_ctx, hires_model,
hires_pos,
hires_neg,
callback=ctx.callback,
)
final = hires_model.decode(hires["samples"])[0]
if AutoHDRProcessor.is_enabled(ctx):
final = AutoHDRProcessor.apply(final, ctx)
except Exception as e:
logger.warning(f"Batch hires_fix failed: {e}")
# Per-sample Adetailer
if info.get("adetailer", False):
try:
single_ctx = ctx.clone()
single_ctx.seed = ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed
final, saved = Adetailer.apply(
final, single_ctx, ad_model,
positive=[ad_pos[i]] if isinstance(ad_pos, list) else ad_pos,
negative=[ad_neg[i]] if isinstance(ad_neg, list) else ad_neg,
callback=ctx.callback
)
for s in saved:
results.setdefault(req_id, []).extend(
s.get("ui", {}).get("images", [s])
)
except Exception as e:
logger.warning(f"Batch adetailer failed: {e}")
# Save
meta = ctx.build_metadata({
"seed": str(ctx.seeds[i] if i < len(ctx.seeds) else ctx.seed),
"prompt": prompts[i],
})
saved = saver.save_images([final], prefix, prompts[i], meta, store_bytes_prefix=prefix)
results.setdefault(req_id, []).extend(
saved.get("ui", {}).get("images", [saved])
)
return {"batched_results": results}
def _clear_model_patches(self, model: AbstractModel) -> None:
"""Clear all patches from the model to ensure a clean state."""
if model and hasattr(model, "model") and model.model:
# Clear transformer patches (HiDiffusion, etc.)
if hasattr(model.model, "model_options"):
to = model.model.model_options.get("transformer_options", {})
if "patches" in to:
logger.debug(f"Clearing {len(to['patches'])} patches from model")
to["patches"] = {}
# Clear Token Merging
if hasattr(model.model, "remove_tome"):
model.model.remove_tome()
def _load_model(self, ctx: Context) -> AbstractModel:
"""Load the model for this context.
Uses ModelFactory for auto-detection when model_path is empty or
set to the special __FLUX2_KLEIN__ marker.
Optimized to reuse existing loaded model if it matches the request.
"""
path = ctx.model_path
# 1. Determine target model type for reuse check
from src.Core.Models.ModelFactory import detect_model_type
target_type = "Flux2Klein" if path == "__FLUX2_KLEIN__" else detect_model_type(path)
# 2. Check if current model can be reused
if self._model is not None and self._model.is_loaded:
current_type = self._model.__class__.__name__.replace("Model", "")
# Match if paths are identical OR if both are Flux2 (auto-detected/marker)
paths_match = (self._model.model_path == path)
types_match = (current_type == target_type)
if paths_match or (not path and types_match) or (path == "__FLUX2_KLEIN__" and target_type == "Flux2Klein" and types_match):
logger.info(f"Reusing currently loaded {current_type} model")
self._clear_model_patches(self._model)
return self._model
# 3. Different model requested: UNLOAD OLD ONE FIRST to free VRAM
logger.info(f"Unloading {current_type} model to load {target_type}")
self._model.unload()
self._model = None
# Clear prompt cache since the CLIP model is changing
try:
from src.Utilities.prompt_cache import clear_prompt_cache
clear_prompt_cache()
except Exception:
pass
# Force cleanup to prevent memory pressure/stuttering during transition
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# 4. Create and load new model instance
if path == "__FLUX2_KLEIN__":
# Explicitly request Flux2 Klein
model = self.model_factory(model_path=None, model_type="Flux2Klein")
elif not path:
# Auto-detect model type (may detect Flux2 components)
model = self.model_factory(model_path=None)
else:
# Specific checkpoint path provided
model = self.model_factory(model_path=path)
model.load()
self._model = model
return model
def _load_refiner_model(self, ctx: Context) -> AbstractModel:
"""Load the refiner model for this context.
Optimized to reuse existing loaded model if it matches the refiner path.
"""
path = ctx.generation.refiner_model_path
if not path:
raise ValueError("refiner_model_path is required for refiner pass")
# 1. Determine target model type
from src.Core.Models.ModelFactory import detect_model_type
target_type = detect_model_type(path)
# 2. Check if current model can be reused
if self._model is not None and self._model.is_loaded:
if self._model.model_path == path:
logger.info(f"Reusing currently loaded model as refiner")
self._clear_model_patches(self._model)
return self._model
# 3. Different model requested: UNLOAD OLD ONE FIRST to free VRAM
logger.info(f"Unloading current model to load refiner {target_type}")
self._model.unload()
# self._model = None # Don't set to None yet, we'll replace it
# Force cleanup
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# 4. Create and load new model instance
model = self.model_factory(model_path=path)
model.load()
self._model = model
return model
def _apply_optimizations(self, ctx: Context, model: AbstractModel) -> None:
"""Apply all configured optimizations to the model."""
self._apply_runtime_preferences(ctx, model)
# LoRA - only if model supports it and matches default LoRA type
# Default LoRA (add_detail) is SD1.5 (context_dim 768)
is_sd15 = False
try:
is_sd15 = model.get_model_object("context_dim") == 768
except Exception:
pass
if self.default_lora and getattr(model.capabilities, 'supports_lora', True):
# Only apply default detailing LoRA to SD1.5 models
if not is_sd15 and self.default_lora[0] == "add_detail.safetensors":
logger.debug(f"Skipping default SD1.5 LoRA for non-SD1.5 model")
else:
try:
model.apply_lora(*self.default_lora)
except Exception as e:
logger.warning(f"LoRA failed: {e}")
# StableFast and torch.compile are mutually exclusive
if ctx.generation.stable_fast:
model.apply_stable_fast(enable_cuda_graph=True)
elif ctx.generation.torch_compile:
model.apply_torch_compile()
# FP8 quantization (hardware-gated, applies independently)
if ctx.generation.fp8_inference or ctx.generation.weight_quantization == "fp8":
model.apply_fp8()
elif ctx.generation.weight_quantization == "nvfp4":
model.apply_nvfp4()
# Token Merging (ToMe)
if ctx.sampling.tome_enabled and getattr(model.capabilities, 'supports_tome', True):
try:
if hasattr(model.model, 'apply_tome'):
model.model.apply_tome(
ratio=ctx.sampling.tome_ratio,
max_downsample=ctx.sampling.tome_max_downsample,
)
except Exception as e:
logger.warning(f"ToMe application failed: {e}")
# DeepCache
if ctx.sampling.deepcache_enabled:
model.apply_deepcache(
ctx.sampling.deepcache_interval,
ctx.sampling.deepcache_depth,
ctx.sampling.deepcache_start_step,
ctx.sampling.deepcache_end_step,
)
def _encode_prompts(self, ctx: Context, model: AbstractModel) -> tuple[Any, Any]:
"""Encode prompts to conditioning tensors."""
return model.encode_prompt(ctx.prompt, ctx.negative_prompt)
def _check_interrupt(self) -> None:
"""Check for user interrupt."""
from src.user import app_instance
app = getattr(app_instance, "app", None)
if app and getattr(app, "interrupt_flag", False):
raise InterruptedError("Generation interrupted")
# Singleton default pipeline
_default_pipeline: Optional[Pipeline] = None
def get_default_pipeline() -> Pipeline:
"""Get the default pipeline instance."""
global _default_pipeline
if _default_pipeline is None:
_default_pipeline = Pipeline()
return _default_pipeline
def reset_default_pipeline() -> None:
"""Release the singleton pipeline and any loaded model it still owns."""
global _default_pipeline
if _default_pipeline is not None:
try:
if _default_pipeline._model is not None and _default_pipeline._model.is_loaded:
_default_pipeline._model.unload()
except Exception:
pass
_default_pipeline = None