LightDiffusion-Next / src /sample /BaseSampler.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Simplified base sampler infrastructure for LightDiffusion-Next."""
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from tqdm.auto import trange
from src.Device import Device
from src.AutoEncoders import taesd
from src.sample import sampling_util
from src.user import app_instance
from src.Utilities import util
@dataclass
class MultiscaleConfig:
enabled: bool = True
factor: float = 0.5
fullres_start: int = 3
fullres_end: int = 8
intermittent_fullres: bool = False
class MultiscaleManager:
"""Handles resolution switching during sampling."""
def __init__(self, shape: tuple, n_steps: int, config: MultiscaleConfig):
self.orig_h, self.orig_w = shape[2], shape[3]
# Handle mock objects in tests
if not isinstance(self.orig_h, int):
try:
self.orig_h = int(self.orig_h)
except Exception:
self.orig_h = 512
if not isinstance(self.orig_w, int):
try:
self.orig_w = int(self.orig_w)
except Exception:
self.orig_w = 512
self.n_steps = n_steps
self.config = config
# Calculate scaled dimensions (multiples of 8)
# CRITICAL: Disable multi-scale for Flux (16 or 32 channels)
is_flux = shape[1] in (16, 32)
self.active = config.enabled and 0.1 <= config.factor <= 1.0 and config.fullres_start >= 0 and config.fullres_end >= 0 and not is_flux
if self.active:
self.scale_h = int(max(8, ((self.orig_h * config.factor) // 8) * 8))
self.scale_w = int(max(8, ((self.orig_w * config.factor) // 8) * 8))
self.active = self.scale_h != self.orig_h or self.scale_w != self.orig_w
else:
self.scale_h, self.scale_w = self.orig_h, self.orig_w
if self.active:
print(f"Multi-scale: {self.orig_h}x{self.orig_w} -> {self.scale_h}x{self.scale_w}")
elif config.enabled and is_flux:
print("Multi-scale disabled: not compatible with Flux architecture")
self._schedule = [self._should_fullres(i) for i in range(n_steps)]
def _should_fullres(self, step: int) -> bool:
if not self.active:
return True
if step < self.config.fullres_start or step >= self.n_steps - self.config.fullres_end:
return True
if self.config.intermittent_fullres:
low_start = self.config.fullres_start
if low_start <= step < self.n_steps - self.config.fullres_end:
return (step - low_start) % 2 == 0
return False
def use_fullres(self, step: int) -> bool:
return self._schedule[step] if step < len(self._schedule) else True
def _coerce_to_4d(self, t: torch.Tensor) -> torch.Tensor:
"""Coerce inputs into a 4D tensor (N, C, H, W) for robust multiscale ops.
This handles non-tensor inputs or tensors with unexpected dims that some
tests can produce (e.g., 0-dim, 1-dim, or MagicMock-like objects). The
goal is to fail gracefully in tests rather than raise hard errors.
"""
# If not a tensor, try to convert; if that fails, return zeros of expected shape
if not isinstance(t, torch.Tensor):
try:
t = torch.as_tensor(t)
except Exception:
return torch.zeros((1, 4, self.scale_h, self.scale_w))
# If tensor has fewer than 4 dims, try to expand to (N, C, H, W)
if t.ndim < 4:
try:
if t.ndim == 3:
t = t.unsqueeze(0)
elif t.ndim == 2:
t = t.unsqueeze(0).unsqueeze(0)
elif t.ndim == 1:
t = t.view(1, 1, 1, -1)
else:
# 0-dim or unexpected - fall back to zeros of expected shape
return torch.zeros((1, 4, self.scale_h, self.scale_w), dtype=t.dtype, device=getattr(t, 'device', None))
except Exception:
return torch.zeros((1, 4, self.scale_h, self.scale_w), dtype=t.dtype, device=getattr(t, 'device', None))
return t
def downscale(self, t: torch.Tensor) -> torch.Tensor:
if not self.active:
return t
t = self._coerce_to_4d(t)
if t.shape[-2:] == (self.scale_h, self.scale_w):
return t
return torch.nn.functional.interpolate(t, (self.scale_h, self.scale_w), mode="bilinear", align_corners=False)
def upscale(self, t: torch.Tensor) -> torch.Tensor:
if not self.active:
return t
t = self._coerce_to_4d(t)
if t.shape[-2:] == (self.orig_h, self.orig_w):
return t
return torch.nn.functional.interpolate(t, (self.orig_h, self.orig_w), mode="bilinear", align_corners=False)
class SamplerCallback:
"""Handles progress, interruption, and preview.
Optimized for minimal per-step overhead:
- App reference cached once at init
- Fast path when app is None or in pipeline mode
- Preview checks minimized
"""
__slots__ = ('n_steps', 'pipeline', '_preview_lock', '_preview_thread',
'_app', '_has_app', '_preview_enabled', '_preview_interval')
def __init__(self, n_steps: int, pipeline: bool = False):
self.n_steps = n_steps
self.pipeline = pipeline
self._preview_lock = threading.Lock()
self._preview_thread = None
# Cache app reference once (avoid getattr chain every step)
self._app = getattr(app_instance, "app", None)
self._has_app = self._app is not None
# Pre-compute preview settings
if self._has_app and not pipeline:
try:
self._preview_enabled = self._app.previewer_var.get()
except Exception:
self._preview_enabled = False
# Adaptive interval: at least 5 previews, max every 5 steps
self._preview_interval = min(5, max(1, n_steps // 5))
else:
self._preview_enabled = False
self._preview_interval = n_steps + 1 # Never trigger
def check_interrupt(self) -> bool:
"""Fast interrupt check with cached app reference."""
if not self._has_app:
return False
return getattr(self._app, "interrupt_flag", False)
def update_progress(self, step: int):
"""Update progress bar (skipped in pipeline mode)."""
if self.pipeline or not self._has_app:
return
try:
self._app.progress.set(step / self.n_steps)
except Exception:
pass
def preview(self, x: torch.Tensor, step: int):
"""Generate preview if enabled and at appropriate interval."""
if not self._preview_enabled:
return
# Check if this is a significant step
is_significant = (step % self._preview_interval == 0) or (step == self.n_steps - 1)
if not is_significant:
return
# Only start a new preview thread if the previous one is finished
if not self._preview_lock.acquire(blocking=False):
return
try:
if self._preview_thread is not None and self._preview_thread.is_alive():
self._preview_lock.release()
return
def run_preview():
try:
# If channels == 16, it's Flux1. Flux2 uses 128 channels.
is_flux = (x.shape[1] == 16)
taesd.taesd_preview(x.clone(), flux=is_flux, step=step, total_steps=self.n_steps)
finally:
self._preview_lock.release()
self._preview_thread = threading.Thread(target=run_preview)
self._preview_thread.start()
except Exception:
if self._preview_lock.locked():
self._preview_lock.release()
def set_model_options_post_cfg_function(opts: dict, fn: Callable, disable_cfg1_optimization: bool = False) -> dict:
opts = opts.copy()
opts["sampler_post_cfg_function"] = opts.get("sampler_post_cfg_function", []) + [fn]
if disable_cfg1_optimization:
opts["disable_cfg1_optimization"] = True
# Note: We don't force disable_cfg1_optimization=True anymore -
# when CFG=1.0 we want to skip the unconditional pass for speed
return opts
@dataclass
class CFGState:
old_denoised: Optional[torch.Tensor] = None
old_uncond: Optional[torch.Tensor] = None
def capture(self, args: dict) -> torch.Tensor:
self.old_uncond = args.get("uncond_denoised")
return args["denoised"]
def update(self, denoised: torch.Tensor, uncond: torch.Tensor):
self.old_denoised = denoised
self.old_uncond = uncond
class BaseSampler(ABC):
"""Abstract base for all samplers."""
def __init__(self, enable_multiscale: bool = True, multiscale_factor: float = 0.5,
multiscale_fullres_start: int = 3, multiscale_fullres_end: int = 8,
multiscale_intermittent_fullres: bool = False, cfg_scale: float = 7.5,
cfg_min: float = 1.0, cfg_x0_scale: float = 1.0, pipeline: bool = False,
use_momentum: bool = False):
self.ms_config = MultiscaleConfig(enable_multiscale, multiscale_factor,
multiscale_fullres_start, multiscale_fullres_end,
multiscale_intermittent_fullres)
self.cfg_scale = cfg_scale
self.cfg_min = cfg_min
self.cfg_x0_scale = cfg_x0_scale
self.pipeline = pipeline
self.use_momentum = use_momentum
def get_cfg(self, step: int, n_steps: int) -> float:
return self.cfg_scale + (self.cfg_min - self.cfg_scale) * (step / max(1, n_steps - 1))
def apply_cfg(self, denoised: torch.Tensor, uncond: torch.Tensor, cfg: float,
state: CFGState, h_ratio: Optional[float] = None) -> torch.Tensor:
"""Apply CFG++ momentum if enabled and we have history, otherwise just return denoised.
Note: The model (CFGGuider) already applies CFG, so we only apply
momentum correction for CFG++ here, NOT additional CFG scaling.
"""
if not self.use_momentum or state.old_denoised is None or h_ratio is None:
# No momentum or no history, just use the already-CFG'd denoised
return denoised
# Apply CFG++ momentum correction only (not CFG scale - that's already applied)
h1 = 1 + h_ratio
momentum = h1 * denoised - h_ratio * state.old_denoised
return momentum
@torch.inference_mode()
def sample(self, model: Any, x: torch.Tensor, sigmas: torch.Tensor,
extra_args: Optional[dict] = None, callback: Optional[Callable] = None,
disable: Optional[bool] = None, **kwargs) -> torch.Tensor:
"""Sample with inference_mode for optimal performance."""
extra_args = extra_args or {}
n_steps = len(sigmas) - 1
if n_steps <= 0:
return x
device = x.device
# Handle mock objects in tests
if not isinstance(device, (torch.device, str)):
device = Device.get_torch_device()
ms = MultiscaleManager(x.shape, n_steps, self.ms_config)
cb = SamplerCallback(n_steps, self.pipeline)
s_in = torch.ones((x.shape[0],), device=device)
# Setup CFG++ state tracking (for momentum only, not CFG scaling)
# Use disable_cfg1_optimization=False to allow skipping uncond pass when CFG=1.0
state = CFGState()
extra_args = extra_args.copy()
extra_args["model_options"] = set_model_options_post_cfg_function(
extra_args.get("model_options", {}), state.capture, disable_cfg1_optimization=False)
return self._loop(model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, **kwargs)
@abstractmethod
def _loop(self, model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, **kwargs) -> torch.Tensor:
pass
class EulerSampler(BaseSampler):
def _loop(self, model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, s_churn=0.0, s_tmin=0.0,
s_tmax=float("inf"), s_noise=1.0, **kwargs):
gamma_max = min(s_churn / n_steps, 2**0.5 - 1) if s_churn > 0 else 0
ms_active = ms.active
for i in trange(n_steps, disable=disable):
if cb.check_interrupt():
return x
cb.update_progress(i)
sigma_hat = sigmas[i]
if gamma_max > 0 and s_tmin <= sigmas[i] <= s_tmax:
sigma_hat = sigmas[i] * (1 + gamma_max)
x = x + torch.randn_like(x) * s_noise * (sigma_hat**2 - sigmas[i]**2)**0.5
if not ms_active or ms.use_fullres(i):
denoised = model(x, sigma_hat * s_in, **extra_args)
else:
denoised = ms.upscale(model(ms.downscale(x), sigma_hat * torch.ones((ms.downscale(x).shape[0],), device=device), **extra_args))
# CFG is already applied by CFGGuider, just apply momentum if available
cfg_denoised = self.apply_cfg(denoised, None, 0, state)
state.update(denoised, None)
x = x + util.to_d(x, sigma_hat, cfg_denoised) * (sigmas[i + 1] - sigma_hat)
if callback:
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised, "total_steps": n_steps})
cb.preview(x, i)
return x
class EulerAncestralSampler(BaseSampler):
def _loop(self, model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, eta=1.0, s_noise=1.0,
noise_sampler=None, **kwargs):
noise_sampler = noise_sampler or sampling_util.default_noise_sampler(x)
ms_active = ms.active
for i in trange(n_steps, disable=disable):
if cb.check_interrupt():
return x
cb.update_progress(i)
if not ms_active or ms.use_fullres(i):
denoised = model(x, sigmas[i] * s_in, **extra_args)
else:
denoised = ms.upscale(model(ms.downscale(x), sigmas[i] * torch.ones((ms.downscale(x).shape[0],), device=device), **extra_args))
# CFG is already applied by CFGGuider, just apply momentum if available
cfg_denoised = self.apply_cfg(denoised, None, 0, state)
state.update(denoised, None)
sigma_down, sigma_up = sampling_util.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
x = x + util.to_d(x, sigmas[i], cfg_denoised) * (sigma_down - sigmas[i])
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback:
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised, "total_steps": n_steps})
cb.preview(x, i)
return x
class DPMPP2MSampler(BaseSampler):
def _loop(self, model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, **kwargs):
t_steps = -torch.log(sigmas)
sigma_steps = torch.exp(-t_steps)
ratios = sigma_steps[1:] / sigma_steps[:-1]
h_steps = t_steps[1:] - t_steps[:-1]
ms_active = ms.active
for i in trange(n_steps, disable=disable):
if cb.check_interrupt():
return x
cb.update_progress(i)
if not ms_active or ms.use_fullres(i):
denoised = model(x, sigmas[i] * s_in, **extra_args)
else:
denoised = ms.upscale(model(ms.downscale(x), sigmas[i] * torch.ones((ms.downscale(x).shape[0],), device=device), **extra_args))
# CFG is already applied by CFGGuider, just apply momentum if available
h_ratio = h_steps[i - 1] / (2 * h_steps[i]) if i > 0 and state.old_denoised is not None else None
cfg_denoised = self.apply_cfg(denoised, None, 0, state, h_ratio)
state.update(denoised, None)
x = ratios[i] * x - torch.expm1(-h_steps[i]) * cfg_denoised
if callback:
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised, "total_steps": n_steps})
cb.preview(x, i)
return x
class DPMPPSDESampler(BaseSampler):
def _loop(self, model, x, sigmas, extra_args, callback, disable,
n_steps, device, ms, cb, s_in, state, eta=1.0, s_noise=1.0,
noise_sampler=None, r=0.5, seed=None, **kwargs):
sigma_fn = lambda t: (-t).exp()
t_fn = lambda s: -s.log()
ms_active = ms.active
if noise_sampler is None:
sigmas_cpu = sigmas.cpu()
noise_sampler = sampling_util.BrownianTreeNoiseSampler(
x, sigmas_cpu[sigmas_cpu > 0].min(), sigmas_cpu.max(), seed=seed, cpu=True)
for i in trange(n_steps, disable=disable):
if cb.check_interrupt():
return x
cb.update_progress(i)
if not ms_active or ms.use_fullres(i):
denoised = model(x, sigmas[i] * s_in, **extra_args)
else:
denoised = ms.upscale(model(ms.downscale(x), sigmas[i] * torch.ones((ms.downscale(x).shape[0],), device=device), **extra_args))
# CFG is already applied by CFGGuider
if sigmas[i + 1] == 0:
cfg_denoised = self.apply_cfg(denoised, None, 0, state)
x = x + util.to_d(x, sigmas[i], cfg_denoised) * (sigmas[i + 1] - sigmas[i])
else:
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
s = t + (t_next - t) * r
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
h_ratio = (t - s_) / (2 * (t - t_next)) if state.old_denoised is not None else None
cfg_denoised = self.apply_cfg(denoised, None, 0, state, h_ratio)
noise1 = noise_sampler(sigma_fn(t), sigma_fn(s)).to(device) * s_noise * su
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * cfg_denoised + noise1
if not ms_active or ms.use_fullres(i):
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
else:
denoised_2 = ms.upscale(model(ms.downscale(x_2), sigma_fn(s) * torch.ones((ms.downscale(x_2).shape[0],), device=device), **extra_args))
cfg_denoised_2 = self.apply_cfg(denoised_2, None, 0, state, h_ratio)
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
noise_final = noise_sampler(sigma_fn(t), sigma_fn(t_next)).to(device) * s_noise * su
x = ((sigma_fn(t_next_) / sigma_fn(t)) * x
- (t - t_next_).expm1() * ((1 - 1/(2*r)) * cfg_denoised + (1/(2*r)) * cfg_denoised_2)
+ noise_final)
state.update(denoised, None)
if callback:
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised, "total_steps": n_steps})
cb.preview(x, i)
return x
# Registry
SAMPLERS = {
"euler": EulerSampler,
"euler_ancestral": EulerAncestralSampler,
"dpmpp_2m": DPMPP2MSampler,
"dpmpp_2m_cfgpp": DPMPP2MSampler,
"dpmpp_sde": DPMPPSDESampler,
"dpmpp_sde_cfgpp": DPMPPSDESampler,
}
def get_sampler(name: str, **kwargs) -> BaseSampler:
if name not in SAMPLERS:
raise ValueError(f"Unknown sampler: {name}. Available: {list(SAMPLERS.keys())}")
# Enable momentum only for _cfgpp samplers
use_momentum = "_cfgpp" in name
return SAMPLERS[name](use_momentum=use_momentum, **kwargs)