"""K-sampler utilities for diffusion models.""" import collections import logging import numpy as np import scipy import torch from src.sample import sampling_util def calculate_start_end_timesteps(model: torch.nn.Module, conds: list) -> None: """Calculate start/end timesteps for conditions.""" s = model.model_sampling for t in range(len(conds)): x = conds[t] ts, te = x.get("start_percent"), x.get("end_percent") if ts is not None or te is not None: n = x.copy() if ts is not None: n["timestep_start"] = s.percent_to_sigma(ts) if te is not None: n["timestep_end"] = s.percent_to_sigma(te) conds[t] = n def pre_run_control(model: torch.nn.Module, conds: list) -> None: """Pre-run control for conditions.""" s = model.model_sampling for x in conds: if "control" in x: x["control"].pre_run(model, lambda a: s.percent_to_sigma(a)) def apply_empty_x_to_equal_area(conds: list, uncond: list, name: str, uncond_fill_func: callable) -> None: """Apply empty x to equal area.""" cond_cnets, cond_other = [], [] uncond_cnets, uncond_other = [], [] for t, x in enumerate(conds): if "area" not in x: (cond_cnets if name in x and x[name] else cond_other).append((x[name], None) if name in x and x[name] else (x, t)) for t, x in enumerate(uncond): if "area" not in x: (uncond_cnets if name in x and x[name] else uncond_other).append((x[name], None) if name in x and x[name] else (x, t)) if uncond_cnets: return for i, _ in enumerate(cond_cnets): temp = uncond_other[i % len(uncond_other)] n = temp[0].copy() n[name] = uncond_fill_func([c[0] for c in cond_cnets if c[1] is None], i) if temp[1] is not None: uncond[temp[1]] = n else: uncond.append(n) CondObj = collections.namedtuple("cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches", "batch_indices"]) def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj: """Get area and multiplier from conditions.""" x_shape, device = x_in.shape, x_in.device area = (x_shape[2], x_shape[3], 0, 0) batch_indices = conds.get("batch_index") if isinstance(batch_indices, int): batch_indices = [batch_indices] area_h, area_w = max(0, min(int(area[0]), x_shape[2])), max(0, min(int(area[1]), x_shape[3])) area = (area_h, area_w, 0, 0) if batch_indices is None: input_x = x_in[:, :, :area_h, :area_w] else: try: mapped = [(int(b) if b >= 0 else x_shape[0] + int(b)) for b in batch_indices] valid = [b for b in mapped if 0 <= b < x_shape[0]] if not valid: batch_indices = None input_x = x_in[:, :, :area_h, :area_w] else: input_x = x_in[torch.tensor(valid, dtype=torch.long, device=device), :, :area_h, :area_w] except Exception: batch_indices = None input_x = x_in[:, :, :area_h, :area_w] mult = torch.ones_like(input_x) batch_size = x_shape[0] if batch_indices is None else len(batch_indices) # Handle mock objects in tests if not isinstance(batch_size, int): try: temp = int(batch_size) if isinstance(temp, int): batch_size = temp else: batch_size = 1 except Exception: batch_size = 1 if not isinstance(device, (torch.device, str)): from src.Device import Device device = Device.get_torch_device() conditioning = {c: conds["model_conds"][c].process_cond(batch_size=batch_size, device=device, area=area) for c in conds["model_conds"]} return CondObj(input_x, mult, conditioning, area, conds.get("control"), None, batch_indices) def normal_scheduler(model_sampling, steps: int, sgm: bool = False, floor: bool = False) -> torch.FloatTensor: """Create normal noise scheduler.""" s = model_sampling timesteps = torch.linspace(s.timestep(s.sigma_max), s.timestep(s.sigma_min), steps, device=s.sigmas.device) return torch.cat([s.sigma(timesteps), s.sigmas.new_zeros([1])]).cpu().float() def simple_scheduler(model_sampling, steps: int) -> torch.FloatTensor: """Create simple noise scheduler.""" s = model_sampling if steps <= 0: return torch.FloatTensor([0.0]) indices = (torch.arange(steps, device=s.sigmas.device) * len(s.sigmas) / steps).long() sigs = s.sigmas.flip(0)[indices] return torch.cat([sigs, sigs.new_zeros([1])]).cpu().float() def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6) -> torch.FloatTensor: """Create beta distribution noise scheduler.""" total = len(model_sampling.sigmas) - 1 ts = scipy.stats.beta.ppf(1 - np.linspace(0, 1, steps, endpoint=False), alpha, beta) ts_indices = np.rint(ts * total).astype(np.int32) unique_ts, indices = np.unique(ts_indices, return_index=True) ordered = unique_ts[np.argsort(indices)] sigs = model_sampling.sigmas[torch.from_numpy(ordered).to(model_sampling.sigmas.device, torch.long)] return torch.cat([sigs, sigs.new_zeros([1])]).cpu().float() def _compute_flux2_mu(image_seq_len: int, num_steps: int) -> float: """Compute empirical mu for Flux2 scheduler (matches ComfyUI exactly). This resolution-dependent mu calculation is critical for Flux2 quality. """ a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 if image_seq_len > 4300: return a2 * image_seq_len + b2 m_200 = a2 * image_seq_len + b2 m_10 = a1 * image_seq_len + b1 a = (m_200 - m_10) / 190.0 b = m_200 - 200.0 * a return a * num_steps + b def _flux2_time_shift(t: torch.Tensor, mu: float, sigma: float = 1.0) -> torch.Tensor: """Generalized time SNR shift for Flux2 (matches ComfyUI exactly).""" import math return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def flux2_scheduler(steps: int, width: int, height: int) -> torch.FloatTensor: """Create Flux2 noise scheduler (matches ComfyUI Flux2Scheduler exactly). This scheduler dynamically computes mu based on image resolution and steps, which is critical for Flux2 image quality. Args: steps: Number of sampling steps width: Image width in pixels height: Image height in pixels Returns: Sigmas tensor of shape (steps + 1,) ending with 0 """ # Calculate sequence length (number of 16x16 patches) seq_len = round((width * height) / (16 * 16)) # Compute resolution/steps-dependent mu mu = _compute_flux2_mu(seq_len, steps) # Create timesteps from 1 to 0 (inclusive) timesteps = torch.linspace(1, 0, steps + 1) # Apply time shift - avoid division by zero at t=0 sigmas = torch.zeros_like(timesteps) mask = timesteps > 0 sigmas[mask] = _flux2_time_shift(timesteps[mask], mu) sigmas[~mask] = 0.0 # t=0 maps to sigma=0 return sigmas.cpu().float() def calculate_sigmas(model_sampling, scheduler_name: str, steps: int, width: int = None, height: int = None, is_flux2: bool = False) -> torch.Tensor: """Calculate sigmas for scheduler. For Flux2 models, use the resolution-aware Flux2Scheduler when width/height are provided. This matches ComfyUI's behavior and is critical for image quality. """ # Robust Flux2 detection if flag not set if not is_flux2 and model_sampling: cls_name = model_sampling.__class__.__name__ if "ModelSamplingFlux2" in cls_name: is_flux2 = True # Handle mock objects in tests if not isinstance(steps, int): try: steps = int(steps) except Exception: steps = 20 # For Flux2 with resolution info, use the dedicated Flux2 scheduler (matches ComfyUI) if is_flux2 and width is not None and height is not None: return flux2_scheduler(steps, width, height) if scheduler_name == "karras": return sampling_util.get_sigmas_karras(steps, float(model_sampling.sigma_min), float(model_sampling.sigma_max)) elif scheduler_name == "normal": return normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": return simple_scheduler(model_sampling, steps) elif scheduler_name == "beta": return beta_scheduler(model_sampling, steps) elif scheduler_name in ["ays", "ays_sd15", "ays_sdxl"]: from src.sample import ays_scheduler as ays model_type = {"ays_sdxl": "SDXL", "ays_sd15": "SD15"}.get(scheduler_name) if not model_type: try: # Robust detection based on class name or config flags cls_name = model_sampling.__class__.__name__.lower() if "flux" in cls_name: model_type = "FLUX" else: config = getattr(model_sampling, 'model_config', None) if config and getattr(config, 'is_flux', False): model_type = "FLUX" elif config and getattr(config, 'uses_dual_clip', False): model_type = "SDXL" else: # Fallback to context_dim check unet_config = getattr(config, 'unet_config', {}) model_type = "SDXL" if unet_config.get('context_dim', 0) == 2048 else "SD15" except: model_type = "SD15" return ays.ays_scheduler(model_sampling, steps, model_type) logging.error(f"Invalid scheduler: {scheduler_name}") return None def prepare_noise(latent_image: torch.Tensor, seed: int, noise_inds: list = None, seeds_per_sample: list | None = None) -> torch.Tensor: """Prepare noise for latent image. NOTE: Noise is generated on CPU for reproducibility across devices (matching ComfyUI behavior). Using a GPU generator produces different random values than CPU even with the same seed. """ target_device = latent_image.device if seeds_per_sample is not None: sps = np.array(seeds_per_sample) if sps.shape[0] != latent_image.size(0): raise ValueError("seeds_per_sample length must match latent batch size") unique_seeds, inverse = np.unique(sps, return_inverse=True) noises = [] for us in unique_seeds: g = torch.Generator(device="cpu") g.manual_seed(int(us)) # Generate on CPU for reproducibility, then move to target device noises.append(torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=g, device="cpu").to(target_device)) return torch.cat([noises[i] for i in inverse], axis=0) generator = torch.Generator(device="cpu") generator.manual_seed(seed) if noise_inds is None: # Generate on CPU for reproducibility (matches ComfyUI), then move to target device return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu").to(target_device) unique_inds, inverse = np.unique(noise_inds, return_inverse=True) noises = [] for i in range(unique_inds[-1] + 1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu").to(target_device) if i in unique_inds: noises.append(noise) return torch.cat([noises[i] for i in inverse], axis=0)