| import math |
| import torch |
| from tqdm.auto import trange |
| from lib_es.compat import get_ancestral_step, to_d |
| from lib_es.utils import default_noise_sampler |
|
|
| import lib_es.const as consts |
| from lib_es.utils import sampler_metadata |
|
|
|
|
| @sampler_metadata( |
| "Adaptive Progressive", |
| {"scheduler": "sgm_uniform", "uses_ensd": True}, |
| ) |
| @torch.no_grad() |
| def sample_adaptive_progressive( |
| model, |
| x, |
| sigmas, |
| extra_args=None, |
| callback=None, |
| disable=None, |
| s_churn=0.0, |
| s_tmin=0.0, |
| s_tmax=float("inf"), |
| s_noise=1.0, |
| noise_sampler=None, |
| ): |
| """ |
| Adaptive progressive sampler that automatically adjusts to different step counts. |
| Combines Euler ancestral, DPM++ 2M, and detail enhancement with phase-based transitions. |
| |
| This sampler is optimized for both high and very low step counts (4+), |
| dynamically adjusting phase durations based on total step count. |
| |
| Args: |
| model: The denoising model |
| x: Input noise tensor |
| sigmas: Noise schedule |
| extra_args: Additional arguments for the model |
| callback: Optional callback function |
| disable: Whether to disable the progress bar |
| s_churn: Amount of stochasticity |
| s_tmin: Minimum sigma for stochasticity |
| s_tmax: Maximum sigma for stochasticity |
| eta: Ancestral noise parameter |
| s_noise: Noise scale |
| noise_sampler: Custom noise sampler function |
| detail_strength: Strength of detail enhancement phase |
| |
| Returns: |
| Denoised tensor |
| """ |
| extra_args = {} if extra_args is None else extra_args |
| noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| s_in = x.new_ones([x.shape[0]]) |
| steps = len(sigmas) - 1 |
|
|
| euler_a_end = getattr(model.p, consts.AP_EULER_A_END, 0.35) |
| dpm_2m_end = getattr(model.p, consts.AP_DPM_2M_END, 0.75) |
| ancestral_eta = getattr(model.p, consts.AP_ANCESTRAL_ETA, 0.4) |
| detail_strength = getattr(model.p, consts.AP_DETAIL_STRENGTH, 1.5) |
|
|
| |
| prev_d = None |
| prev_denoised = None |
|
|
| euler_end, dpm_end = calc_phase_bounds(steps, euler_a_end, dpm_2m_end) |
|
|
| for i in trange(steps, disable=disable): |
| progress = i / steps |
|
|
| |
| if progress < euler_end: |
| |
| w_euler = 1.0 |
| w_multi = 0.0 |
| w_detail = 0.0 |
| elif progress < dpm_end: |
| |
| phase_progress = (progress - euler_end) / (dpm_end - euler_end) |
| w_euler = max(0.0, 1.0 - phase_progress * 2.5) |
| w_multi = 1.0 - w_euler |
| w_detail = 0.0 |
| else: |
| |
| phase_progress = (progress - dpm_end) / (1.0 - dpm_end) |
| w_euler = 0.0 |
| w_multi = max(0.0, 1.0 - phase_progress * 1.5) |
| w_detail = 1.0 - w_multi |
|
|
| |
| if s_churn > 0 and s_tmin <= sigmas[i] <= s_tmax and progress < 0.4: |
| |
| gamma = min(s_churn / steps, 2**0.5 - 1) * (1.0 - progress / 0.4) |
| sigma_hat = sigmas[i] * (gamma + 1) |
| eps = torch.randn_like(x) * s_noise |
| x = x + eps * (sigma_hat**2 - sigmas[i] ** 2).sqrt() |
| else: |
| sigma_hat = sigmas[i] |
|
|
| |
| denoised = model(x, sigma_hat * s_in, **extra_args) |
|
|
| if callback is not None: |
| callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}) |
|
|
| |
| |
| step_eta = ancestral_eta if progress < 0.5 else ancestral_eta * (1.0 - min(1.0, (progress - 0.5) * 2.0)) |
| sigma_down, sigma_up = get_ancestral_step(sigma_hat, sigmas[i + 1], eta=step_eta) |
|
|
| |
| d = to_d(x, sigma_hat, denoised) |
| dt = sigma_down - sigma_hat |
|
|
| |
| if sigmas[i + 1] == 0: |
| x = denoised |
| break |
|
|
| |
| if prev_d is None: |
| |
| direction = d |
| else: |
| |
| direction = torch.zeros_like(d) |
|
|
| |
| if w_euler > 0: |
| direction += w_euler * d |
|
|
| |
| if w_multi > 0: |
| |
| if sigma_hat > 2.0: |
| |
| c1, c2 = 0.7, 0.3 |
| else: |
| |
| c1, c2 = 0.6, 0.4 |
|
|
| multi_direction = c1 * d + c2 * prev_d |
| direction += w_multi * multi_direction |
|
|
| |
| if w_detail > 0 and prev_denoised is not None: |
| |
| if sigma_hat < 1.0: |
| |
| detail_vector = denoised - prev_denoised |
|
|
| |
| detail_scale = detail_strength * min(1.0, 0.2 / (sigma_hat + 0.2)) |
|
|
| |
| detail_direction = d + detail_vector * detail_scale / dt |
| direction += w_detail * detail_direction |
| else: |
| |
| direction += w_detail * d |
|
|
| |
| direction = torch.clamp(direction, -1e2, 1e2) |
|
|
| |
| x = x + direction * dt |
|
|
| |
| if sigma_up > 0: |
| |
| noise_scale = s_noise |
| if progress > 0.3: |
| |
| noise_scale *= math.exp(-4.0 * (progress - 0.3)) |
|
|
| |
| x = x + noise_sampler(sigma_hat, sigmas[i + 1]) * sigma_up * noise_scale |
|
|
| |
| prev_d = d |
| prev_denoised = denoised |
|
|
| return x |
|
|
|
|
| def calc_phase_bounds(steps: int, custom_euler_end: float = 0.25, custom_dpm_end: float = 0.7) -> tuple[float, float]: |
| """ |
| Calculate phase boundaries for the adaptive progressive sampler. |
| |
| Args: |
| steps: Total number of steps |
| custom_euler_end: End point for Euler phase (0.0-1.0) |
| custom_dpm_end: End point for DPM++ phase (0.0-1.0) |
| |
| Returns: |
| Tuple of phase boundaries (Euler end, DPM++ end) |
| """ |
| |
| euler_end = max(0.0, min(1.0, custom_euler_end)) |
| dpm_end = max(0.0, min(1.0, custom_dpm_end)) |
|
|
| |
| if euler_end >= dpm_end: |
| euler_end = max(0.0, dpm_end - 0.2) |
|
|
| |
| if steps < 10: |
| |
| euler_end = min(euler_end, 0.15 + (steps - 4) * 0.01) |
| dpm_end = min(dpm_end, 0.5 + (steps - 4) * 0.02) |
| elif steps < 20: |
| |
| euler_end = min(euler_end, 0.2) |
| dpm_end = min(dpm_end, 0.65) |
| elif steps > 50: |
| |
| euler_end = min(0.3, euler_end + (steps - 50) * 0.0005) |
| |
| dpm_end = min(0.8, dpm_end + (steps - 50) * 0.0005) |
|
|
| |
| if dpm_end - euler_end < 0.1: |
| dpm_end = min(1.0, euler_end + 0.1) |
|
|
| return euler_end, dpm_end |
|
|