| import torch |
| from ..extrapolation import extrapolate_epsilon_linear, extrapolate_epsilon_richardson, extrapolate_epsilon_h4 |
| from ...comfy_copy.res4lyf_sampling import get_res4lyf_step_with_model |
| from ..skip import should_skip_model_call, validate_epsilon_hat, decide_skip_adaptive |
| from ..log import print_step_diag |
|
|
|
|
| from ..noise import get_eps_step_official |
|
|
|
|
| def sample_step_euler(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args, |
| epsilon_history, learning_ratio, smoothing_beta, predictor_type, |
| step_index, total_steps, add_noise_ratio=0.0, add_noise_type="whitened", skip_mode="none", skip_stats=None, debug=False, protect_last_steps=4, protect_first_steps=2, anchor_interval=None, max_consecutive_skips=None, official_comfy=False, |
| adaptive_mode="none", explicit_skip_indices=None, explicit_predictor=None): |
| x = noisy_latent |
|
|
| if skip_stats is not None: |
| skip_stats["total_steps"] += 1 |
|
|
| was_skipped = False |
|
|
| |
| if explicit_skip_indices is not None and isinstance(explicit_skip_indices, set) and step_index in explicit_skip_indices: |
| es = skip_stats.get("explicit_streak", False) if skip_stats is not None else False |
| nl = skip_stats.get("needed_learns", 2) if skip_stats is not None else 2 |
| allowed_by_streak = es or (nl <= 0) |
| if allowed_by_streak and len(epsilon_history) >= 2: |
| |
| pred = (explicit_predictor or "linear") |
| if pred == "h4" and len(epsilon_history) >= 4: |
| epsilon = extrapolate_epsilon_h4(epsilon_history) |
| skip_method = "explicit-h4" |
| elif (pred in ("richardson", "h3")) and len(epsilon_history) >= 3: |
| epsilon = extrapolate_epsilon_richardson(epsilon_history) |
| skip_method = "explicit-h3" |
| else: |
| epsilon = extrapolate_epsilon_linear(epsilon_history) |
| skip_method = "explicit-h2" |
| prev_eps = epsilon_history[-1] if len(epsilon_history) >= 1 else None |
| ok, reason, hat_norm, prev_norm = validate_epsilon_hat(epsilon, prev_eps) |
| if ok: |
| if len(epsilon_history) >= 3: |
| epsilon = epsilon / max(learning_ratio, 1e-8) |
| denoised = x + epsilon |
| was_skipped = True |
| if skip_stats is not None: |
| skip_stats["skipped"] = skip_stats.get("skipped", 0) + 1 |
| skip_stats["consecutive_skips"] = skip_stats.get("consecutive_skips", 0) + 1 |
| skip_stats["explicit_streak"] = True |
| skip_stats["needed_learns"] = 0 |
| if debug: |
| dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current) |
| print(f"euler step {step_index} [SKIPPED-{skip_method}]: e_norm={hat_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}") |
| try: |
| x_rms = float(torch.sqrt(torch.mean((denoised)**2)).item()) |
| except Exception: |
| x_rms = None |
| print_step_diag( |
| sampler="euler", |
| step_index=step_index, |
| sigma_current=sigma_current, |
| sigma_next=sigma_next, |
| target_sigma=sigma_next, |
| sigma_up=None, |
| alpha_ratio=None, |
| h=None, |
| c2=None, |
| b1=None, |
| b2=None, |
| eps_norm=hat_norm, |
| eps_prev_norm=float(torch.norm(prev_eps).item()) if prev_eps is not None else None, |
| x_rms=x_rms, |
| flags=f"SKIPPED-{skip_method}", |
| ) |
| else: |
| if debug: |
| print(f"euler step {step_index}: skip rejected by validate_epsilon_hat (reason={reason})") |
| else: |
| if debug: |
| reason = "need_two_learns_before_skip" if not (es or nl <= 0) else "insufficient_history" |
| print(f"euler step {step_index}: explicit skip gated ({reason})") |
|
|
| |
| if (not was_skipped) and skip_mode == "adaptive": |
| should_skip, epsilon, meta = decide_skip_adaptive( |
| epsilon_history=epsilon_history, |
| step_index=step_index, |
| total_steps=total_steps, |
| protect_last_steps=protect_last_steps, |
| protect_first_steps=protect_first_steps, |
| skip_stats=skip_stats, |
| x_current=x, |
| sigma_current=sigma_current, |
| sigma_next=sigma_next, |
| sampler_kind="euler", |
| anchor_interval=anchor_interval, |
| max_consecutive_skips=max_consecutive_skips, |
| ) |
| skip_method = "adaptive" |
| else: |
| should_skip, skip_method = (False, None) if was_skipped else should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history, protect_last_steps, protect_first_steps) |
| epsilon = None |
|
|
| if (not was_skipped) and should_skip and skip_method is not None: |
| if epsilon is None: |
| if skip_method == "richardson": |
| epsilon = extrapolate_epsilon_richardson(epsilon_history) |
| elif skip_method == "h4": |
| epsilon = extrapolate_epsilon_h4(epsilon_history) |
| else: |
| epsilon = extrapolate_epsilon_linear(epsilon_history) |
| prev_eps = epsilon_history[-1] if len(epsilon_history) >= 1 else None |
| ok, reason, hat_norm, prev_norm = validate_epsilon_hat(epsilon, prev_eps) |
| if not ok: |
| should_skip = False |
| if debug: |
| print(f"euler step {step_index}: skip cancelled (ε̂ invalid: {reason}) hat_norm={hat_norm:.2e}, prev_norm={(prev_norm if prev_norm is not None else float('nan')):.2e}") |
| else: |
| |
| if len(epsilon_history) >= 3 and adaptive_mode in ("learning", "learn+grad_est"): |
| epsilon = epsilon / max(learning_ratio, 1e-8) |
| denoised = x + epsilon |
| was_skipped = True |
| if skip_stats is not None: |
| skip_stats["skipped"] += 1 |
| skip_stats["consecutive_skips"] = skip_stats.get("consecutive_skips", 0) + 1 |
| if debug: |
| dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current) |
| if skip_mode == "adaptive": |
| rel = (meta.get("relative_error") if isinstance(meta, dict) else None) |
| print(f"euler step {step_index} [SKIPPED-adaptive]: err_rel={(rel if rel is not None else float('nan')):.4f}, L={learning_ratio:.4f}, dt={dt_val:.4f}") |
| else: |
| print(f"euler step {step_index} [SKIPPED-{skip_method}]: e_norm={hat_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}") |
| try: |
| x_rms = float(torch.sqrt(torch.mean((denoised)**2)).item()) |
| except Exception: |
| x_rms = None |
| print_step_diag( |
| sampler="euler", |
| step_index=step_index, |
| sigma_current=sigma_current, |
| sigma_next=sigma_next, |
| target_sigma=sigma_next, |
| sigma_up=None, |
| alpha_ratio=None, |
| h=None, |
| c2=None, |
| b1=None, |
| b2=None, |
| eps_norm=hat_norm, |
| eps_prev_norm=float(torch.norm(prev_eps).item()) if prev_eps is not None else None, |
| x_rms=x_rms, |
| flags=f"SKIPPED-{skip_method}", |
| ) |
|
|
| if not was_skipped and not should_skip: |
| denoised = model(x, sigma_current * s_in, **extra_args) |
| if skip_stats is not None: |
| skip_stats["model_calls"] += 1 |
| skip_stats["consecutive_skips"] = 0 |
| skip_stats["last_anchor_step"] = step_index |
| |
| es = skip_stats.get("explicit_streak", False) |
| nl = skip_stats.get("needed_learns", 2) |
| if es: |
| skip_stats["explicit_streak"] = False |
| skip_stats["needed_learns"] = 1 |
| else: |
| skip_stats["needed_learns"] = max(0, int(nl) - 1) |
|
|
| |
| d = (x - denoised) / sigma_current |
| |
| d_prev = None |
| if skip_stats is not None and isinstance(skip_stats, dict): |
| d_prev = skip_stats.get("d_prev") |
| dbar = 0.0 |
| if was_skipped and adaptive_mode in ("grad_est", "learn+grad_est") and d_prev is not None: |
| dbar = (2.0 - 1.0) * (d - d_prev) |
| |
| try: |
| ratio = float(torch.norm(dbar) / (torch.norm(d) + 1e-8)) |
| except Exception: |
| ratio = 0.0 |
| if ratio > 0.25: |
| dbar = dbar * (0.25 / ratio) |
|
|
| |
| if add_noise_ratio > 0.0 and float(sigma_next) > 0.0 and not was_skipped: |
| if official_comfy: |
| sigma_up, sigma_down = get_eps_step_official(sigma_current, sigma_next, eta=add_noise_ratio) |
| dt = sigma_down - sigma_current |
| x = x + d * dt |
| noise = torch.randn_like(x) |
| if add_noise_type == "whitened": |
| noise = (noise - noise.mean()) / (noise.std() + 1e-12) |
| x = x + noise * sigma_up |
| alpha_ratio = None |
| else: |
| sigma_up, _sigma_for_calc, sigma_down, alpha_ratio = get_res4lyf_step_with_model( |
| model, sigma_current, sigma_next, add_noise_ratio, "hard" |
| ) |
| dt = sigma_down - sigma_current |
| x = x + d * dt |
| |
| if add_noise_type == "whitened": |
| noise = torch.randn_like(x) |
| std = noise.std() |
| noise = (noise - noise.mean()) / (std + 1e-12) |
| else: |
| noise = torch.randn_like(x) |
| x = alpha_ratio * x + noise * sigma_up |
| else: |
| dt = sigma_next - sigma_current |
| x = x + (d + (dbar if was_skipped else 0.0)) * dt |
| sigma_up = None |
| alpha_ratio = None |
| sigma_down = None |
|
|
| if not was_skipped: |
| epsilon = denoised - noisy_latent |
| epsilon_history.append(epsilon) |
| if len(epsilon_history) >= 3: |
| if predictor_type == "h4": |
| epsilon_hat = extrapolate_epsilon_h4(epsilon_history) |
| elif predictor_type == "richardson": |
| epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) |
| else: |
| epsilon_hat = extrapolate_epsilon_linear(epsilon_history) |
| if epsilon_hat is not None: |
| learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon) + 1e-8)).item() |
| learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs |
| if learning_ratio < 0.5: |
| learning_ratio = 0.5 |
| elif learning_ratio > 2.0: |
| learning_ratio = 2.0 |
| |
| |
| if skip_stats is not None and not was_skipped: |
| try: |
| skip_stats["d_prev"] = d.detach() |
| except Exception: |
| skip_stats["d_prev"] = d |
|
|
| if debug: |
| d_norm = torch.norm(d).item() |
| |
| try: |
| e_norm = float(torch.norm(epsilon).item()) if 'epsilon' in locals() and isinstance(epsilon, torch.Tensor) else None |
| except Exception: |
| e_norm = None |
| if not was_skipped: |
| if len(epsilon_history) >= 3: |
| print(f"euler step {step_index}: e_norm={(e_norm if e_norm is not None else float('nan')):.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") |
| else: |
| print(f"euler step {step_index}: e_norm={(e_norm if e_norm is not None else float('nan')):.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}") |
| else: |
| |
| try: |
| dbar_norm = float(torch.norm(dbar).item()) if isinstance(dbar, torch.Tensor) else float(dbar) |
| except Exception: |
| dbar_norm = 0.0 |
| print(f"euler step {step_index} [SKIP-APPLY]: d_norm={d_norm:.2f}, dbar_norm={dbar_norm:.2f}, mode={adaptive_mode}") |
| try: |
| x_rms = float(torch.sqrt(torch.mean(x**2)).item()) |
| except Exception: |
| x_rms = None |
| |
| try: |
| target_sigma_print = sigma_down if ('sigma_down' in locals() and sigma_down is not None) else sigma_next |
| h_val = -torch.log(target_sigma_print / sigma_current) |
| except Exception: |
| h_val = None |
| target_sigma_print = sigma_next |
| print_step_diag( |
| sampler="euler", |
| step_index=step_index, |
| sigma_current=sigma_current, |
| sigma_next=sigma_next, |
| target_sigma=target_sigma_print, |
| sigma_up=sigma_up, |
| alpha_ratio=alpha_ratio, |
| h=h_val, |
| c2=None, |
| b1=None, |
| b2=None, |
| eps_norm=e_norm, |
| eps_prev_norm=float(torch.norm(epsilon_history[-2]).item()) if len(epsilon_history) >= 2 else None, |
| x_rms=x_rms, |
| flags="", |
| ) |
|
|
| return x, learning_ratio |
|
|