from .engine import sample_fsampler def sample_step_euler(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", skip_stats=None, debug=False): """Standard Euler step using Karras ODE derivative formulation. Implements the standard k-diffusion Euler method: - Converts denoised to ODE derivative: d = (x - denoised) / sigma - Takes Euler step: x = x + d * dt, where dt = sigma_next - sigma_current Supports model call skipping via epsilon extrapolation. """ x = noisy_latent # Update skip statistics if skip_stats is not None: skip_stats["total_steps"] += 1 # Check if we should skip the model call should_skip, skip_method = should_skip_model_call( 1.0, # error_ratio - Euler doesn't track this, use neutral value step_index, total_steps, skip_mode, epsilon_history ) # Get denoised: either from model call or extrapolation was_skipped = False if should_skip and skip_method is not None: # SKIP: Use epsilon extrapolation if skip_method == "linear": epsilon = extrapolate_epsilon_linear(epsilon_history) elif skip_method == "richardson": epsilon = extrapolate_epsilon_richardson(epsilon_history) else: epsilon = None # Safety check: if extrapolation failed, fall back to model call if epsilon is None or torch.isnan(epsilon).any(): should_skip = False if debug: print(f"euler step {step_index}: extrapolation failed, falling back to model call") if should_skip and epsilon is not None: # Successful skip - reconstruct denoised from extrapolated epsilon # Apply universal learning stabilizer if we have enough REAL history (>=3) if len(epsilon_history) >= 3: epsilon = epsilon / max(learning_ratio, 1e-8) denoised = x + epsilon was_skipped = True if skip_stats is not None: skip_stats["skipped"] += 1 if debug: e_norm = torch.norm(epsilon).item() dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current) print(f"euler step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}") if not should_skip: # CALL MODEL: Normal path denoised = model(x, sigma_current * s_in, **extra_args) if skip_stats is not None: skip_stats["model_calls"] += 1 # Karras ODE derivative: d = (x - denoised) / sigma # This is the standard k-diffusion formulation d = (x - denoised) / sigma_current # Euler step in sigma space: x = x + d * dt dt = sigma_next - sigma_current x = x + d * dt # Store REAL epsilon for extrapolation/learning (append full history for this run) if not was_skipped: epsilon = denoised - noisy_latent epsilon_history.append(epsilon) # Universal learning update only when enough REAL history exists (>=3) if len(epsilon_history) >= 3: # Compute predictor-matched epsilon_hat from REAL history if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon) + 1e-8)).item() # EMA update with smoothing_beta and clamp learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs # clamps (hidden constants) if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug and not was_skipped: d_norm = torch.norm(d).item() e_norm = torch.norm(epsilon).item() if len(epsilon_history) >= 3: print(f"euler step {step_index}: e_norm={e_norm:.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") else: print(f"euler step {step_index}: e_norm={e_norm:.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}") return x, learning_ratio def sample_step_ddim(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", debug=False): """DDIM deterministic step (eta=0) with optional skipping. Formula: x_next = x0 + (sigma_next / sigma_current) * (x - x0), where x0 = denoised. On skips, use extrapolated epsilon_hat to form x0_hat = x + epsilon_hat_scaled. """ x = noisy_latent # Decide skip should_skip, skip_method = should_skip_model_call( 1.0, step_index, total_steps, skip_mode, epsilon_history ) was_skipped = False if should_skip and skip_method is not None: # Predictor from REAL history if skip_method == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is None or torch.isnan(epsilon_hat).any(): should_skip = False else: if len(epsilon_history) >= 3: epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8) x0_hat = x + epsilon_hat scale = (sigma_next / sigma_current) x = x0_hat + scale * (x - x0_hat) was_skipped = True if debug: e_norm = torch.norm(epsilon_hat).item() print(f"ddim step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}") if not should_skip: # REAL call denoised = model(x, sigma_current * s_in, **extra_args) # Update: x_next = denoised + (sigma_next / sigma_current) * (x - denoised) scale = (sigma_next / sigma_current) x = denoised + scale * (x - denoised) # Learning update (append REAL epsilon and update L if ≥3 REAL eps) epsilon_real = denoised - noisy_latent epsilon_history.append(epsilon_real) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"ddim step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") return x, learning_ratio def sample_step_dpmpp_2m(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", skip_stats=None, debug=False): """DPM++ 2M (second-order multistep) with learning + skip. Update: x_next = x + dt * [ (3/2)·d_n − (1/2)·d_{n−1} ], with d = (x − denoised)/sigma. First step falls back to Euler. On skip, use epsilon_hat (scaled by 1/L) to form d_n. """ x = noisy_latent # Count step if skip_stats is not None: skip_stats["total_steps"] += 1 # Skip decision should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history) d_prev = None if sigma_previous is not None and len(epsilon_history) >= 1: eps_prev = epsilon_history[-1] d_prev = -(eps_prev) / sigma_previous if should_skip and skip_method is not None: if skip_method == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is None or torch.isnan(epsilon_hat).any(): should_skip = False else: if len(epsilon_history) >= 3: epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8) d_curr = -(epsilon_hat) / sigma_current dt = sigma_next - sigma_current if d_prev is not None: x = x + dt * (1.5 * d_curr - 0.5 * d_prev) else: x = x + dt * d_curr if skip_stats is not None: skip_stats["skipped"] += 1 if debug: d_norm = torch.norm(d_curr).item() print(f"dpmpp_2m step {step_index} [SKIPPED-{skip_method}]: d_norm={d_norm:.2f}, L={learning_ratio:.4f}") return x, learning_ratio # REAL call denoised = model(x, sigma_current * s_in, **extra_args) eps_curr = denoised - x d_curr = -eps_curr / sigma_current dt = sigma_next - sigma_current if d_prev is not None: x = x + dt * (1.5 * d_curr - 0.5 * d_prev) else: x = x + dt * d_curr if skip_stats is not None: skip_stats["model_calls"] += 1 # Learning update epsilon_history.append(eps_curr) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_curr) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"dpmpp_2m step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") if debug and d_prev is not None: d_norm = torch.norm(d_curr).item() print(f"dpmpp_2m step {step_index}: d_norm={d_norm:.2f}, AB2") elif debug: d_norm = torch.norm(d_curr).item() print(f"dpmpp_2m step {step_index}: d_norm={d_norm:.2f}, Euler") return x, learning_ratio def sample_step_dpmpp_2s(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", debug=False): """DPM++ 2S (two-stage ODE) with learning + skip. Two real evaluations: d1 at (x, sigma_current), predictor x_pred = x + dt*d1 d2 at (x_pred, sigma_next), corrector: x_next = x + dt*0.5*(d1 + d2) On skip, use Euler-like inter-step update with epsilon_hat. """ x = noisy_latent # Final step: avoid division by zero at sigma_next ~ 0; land on denoised sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else float(sigma_next) if abs(sigma_next_value) <= 1e-8: den = model(x, sigma_current * s_in, **extra_args) x = den # Learning update (REAL) eps_real = den - noisy_latent epsilon_history.append(eps_real) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_real) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"dpmpp_2s step {step_index} (final) [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") if debug: print(f"dpmpp_2s step {step_index} (final step): landing on denoised") return x, learning_ratio # Skip decision should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history) if should_skip and skip_method is not None: if skip_method == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is None or torch.isnan(epsilon_hat).any(): should_skip = False else: if len(epsilon_history) >= 3: epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8) d = -(epsilon_hat) / sigma_current dt = sigma_next - sigma_current x = x + dt * d if debug: e_norm = torch.norm(epsilon_hat).item() print(f"dpmpp_2s step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}") return x, learning_ratio # REAL evaluations den1 = model(x, sigma_current * s_in, **extra_args) d1 = (x - den1) / sigma_current dt = sigma_next - sigma_current x_pred = x + dt * d1 den2 = model(x_pred, sigma_next * s_in, **extra_args) d2 = (x_pred - den2) / sigma_next x = x + dt * 0.5 * (d1 + d2) # Learning update from stage-1 epsilon eps_real = den1 - noisy_latent epsilon_history.append(eps_real) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_real) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"dpmpp_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") if debug: d1n = torch.norm(d1).item(); d2n = torch.norm(d2).item() print(f"dpmpp_2s step {step_index}: d1_norm={d1n:.2f}, d2_norm={d2n:.2f}") return x, learning_ratio def _ab2_update(x, dt, d_curr, d_prev=None): if d_prev is not None: return x + dt * (1.5 * d_curr - 0.5 * d_prev) else: return x + dt * d_curr def sample_step_lms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", skip_stats=None, debug=False): """LMS (AB2 baseline) with learning + skip. d = (x - denoised)/sigma; x_next = x + dt * [ (3/2)·d_n − (1/2)·d_{n−1} ] """ x = noisy_latent if skip_stats is not None: skip_stats["total_steps"] += 1 should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history) d_prev = None if sigma_previous is not None and len(epsilon_history) >= 1: d_prev = -(epsilon_history[-1]) / sigma_previous if should_skip and skip_method is not None: epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) if skip_method == "richardson" else extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is None or torch.isnan(epsilon_hat).any(): should_skip = False else: if len(epsilon_history) >= 3: epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8) d_curr = -epsilon_hat / sigma_current dt = sigma_next - sigma_current x = _ab2_update(x, dt, d_curr, d_prev) if skip_stats is not None: skip_stats["skipped"] += 1 if debug: d_norm = torch.norm(d_curr).item() print(f"lms step {step_index} [SKIPPED-{skip_method}]: d_norm={d_norm:.2f}, L={learning_ratio:.4f}") return x, learning_ratio # REAL call den = model(x, sigma_current * s_in, **extra_args) eps = den - x d_curr = -eps / sigma_current dt = sigma_next - sigma_current x = _ab2_update(x, dt, d_curr, d_prev) if skip_stats is not None: skip_stats["model_calls"] += 1 # Learning epsilon_history.append(eps) if len(epsilon_history) >= 3: epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) if predictor_type == "richardson" else extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"lms step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") if debug: dn = torch.norm(d_curr).item() print(f"lms step {step_index}: d_norm={dn:.2f}{', AB2' if d_prev is not None else ', Euler'}") return x, learning_ratio def sample_step_plms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode="none", skip_stats=None, debug=False): """PLMS (baseline AB2 for now) with learning + skip. Note: For a full PLMS (PNDM) 4-step, we'd need sigma history for prior steps. This baseline uses AB2 until we thread sigma history; still useful and consistent with LMS. """ # For now, mirror LMS AB2 behavior return sample_step_lms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, skip_mode, skip_stats, debug) # Rebind local names to refactored implementations (ensures imports take precedence) from .samplers.euler import sample_step_euler as _euler_impl from .samplers.res2m import sample_step_res_2m as _res2m_impl from .samplers.res2s import sample_step_res_2s as _res2s_impl from .samplers.ddim import sample_step_ddim as _ddim_impl from .samplers.dpmpp_2m import sample_step_dpmpp_2m as _dpmpp2m_impl from .samplers.dpmpp_2s import sample_step_dpmpp_2s as _dpmpp2s_impl from .samplers.lms import sample_step_lms as _lms_impl sample_step_euler = _euler_impl sample_step_res_2m = _res2m_impl sample_step_res_2s = _res2s_impl sample_step_ddim = _ddim_impl sample_step_dpmpp_2m = _dpmpp2m_impl sample_step_dpmpp_2s = _dpmpp2s_impl sample_step_lms = _lms_impl def sample_step_res_2m(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args, error_history, epsilon_history, prev_was_skipped, step_index, total_steps, adaptive_mode="none", smoothing_beta=0.9, smoothed_error_ratio=1.0, learning_ratio=1.0, predictor_type="linear", skip_mode="none", skip_stats=None, debug=False): """res_2m: 2-multistep method using history from previous steps. Matches RES4LYF implementation: - Stores denoised predictions in history (not epsilon directly) - Recomputes epsilon from stored denoised each step - Uses c2 = (-h_prev / h) for multistep coefficients """ x_0 = noisy_latent # Starting point for this step # Update skip statistics if skip_stats is not None: skip_stats["total_steps"] += 1 # Check if we should skip the model call should_skip, skip_method = should_skip_model_call( smoothed_error_ratio, step_index, total_steps, skip_mode, epsilon_history ) # Get epsilon: either from model call or extrapolation was_skipped = False # Track if this step used extrapolation if should_skip and skip_method is not None: # SKIP: Use extrapolation if skip_method == "linear": epsilon_current = extrapolate_epsilon_linear(epsilon_history) elif skip_method == "richardson": epsilon_current = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_current = None # Safety check: if extrapolation failed, fall back to model call if epsilon_current is None or torch.isnan(epsilon_current).any(): should_skip = False if debug: print(f"res_2m step {step_index}: extrapolation failed, falling back to model call") if should_skip and epsilon_current is not None: # Successful skip - reconstruct denoised from extrapolated epsilon if len(epsilon_history) >= 3: epsilon_current = epsilon_current / max(learning_ratio, 1e-8) denoised = x_0 + epsilon_current was_skipped = True if skip_stats is not None: skip_stats["skipped"] += 1 if debug: e_norm = torch.norm(epsilon_current).item() print(f"res_2m step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}") if not should_skip: # CALL MODEL: Normal path denoised = model(noisy_latent, sigma_current * s_in, **extra_args) epsilon_current = denoised - x_0 if skip_stats is not None: skip_stats["model_calls"] += 1 # Step size in log space: h = -log(sigma_next / sigma_current) h = -torch.log(sigma_next / sigma_current) # Check if this is the final step (sigma_next = 0) # RES4LYF line 178: if sigma_next == 0 sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else sigma_next is_final_step = (sigma_next_value == 0) # Check if we have history and can use multistep # RES4LYF stores denoised in data_[] array, loads it as: eps_[1] = -(x_0 - data_[1]) if len(error_history) >= 1 and sigma_previous is not None and not is_final_step: # Load previous denoised from history and compute epsilon from it # RES4LYF line 215: eps_[1] = -(x_0 - data_[1]) = data_[1] - x_0 denoised_previous = error_history[-1] epsilon_previous = denoised_previous - x_0 # Multistep coefficient: RES4LYF line 808: c2 = (-h_prev / h).item() h_prev = -torch.log(sigma_current / sigma_previous) c2 = (-h_prev / h).item() # Phi function weights: RES4LYF lines 889-890 # b2 = φ(2)/c2, b1 = φ(1) - b2 phi_1 = phi_function(order=1, step_size=-h) phi_2 = phi_function(order=2, step_size=-h) b2_base = phi_2 / c2 b1_base = phi_1 - b2_base # Adaptive weight adjustment based on error ratio # IMPORTANT: Only calculate error_ratio on real model calls, not extrapolated epsilon if adaptive_mode != "none" and not was_skipped: # Calculate error ratio (only on real model calls) error_curr = torch.norm(epsilon_current).item() error_prev = torch.norm(epsilon_previous).item() error_ratio = error_curr / (error_prev + 1e-8) # Avoid division by zero if adaptive_mode == "learning": # MODE 2: EMA smoothed adjustment (learned pattern) smoothed_error_ratio_next = (smoothing_beta * smoothed_error_ratio + (1 - smoothing_beta) * error_ratio) adjustment = 1.0 / smoothed_error_ratio_next adjustment = max(0.5, min(2.0, adjustment)) # Clamp to [0.5, 2.0] else: adjustment = 1.0 smoothed_error_ratio_next = 1.0 # Apply adjustment to weights b1_adjusted = b1_base * adjustment b2_adjusted = b2_base / adjustment # Normalize to preserve sum (maintains phi_1 constraint) sum_adjusted = b1_adjusted + b2_adjusted sum_target = b1_base + b2_base # Should equal phi_1 scale = sum_target / sum_adjusted b1 = b1_adjusted * scale b2 = b2_adjusted * scale elif adaptive_mode != "none" and was_skipped: # Skipped step: preserve previous smoothed_error_ratio, use baseline weights # Don't poison the adaptive system with extrapolated epsilon b1 = b1_base b2 = b2_base adjustment = 1.0 smoothed_error_ratio_next = smoothed_error_ratio # Preserve previous value error_ratio = None # Mark as not calculated else: # No adaptation (baseline RES2M) b1 = b1_base b2 = b2_base adjustment = 1.0 smoothed_error_ratio_next = 1.0 error_ratio = None # Integration: RES4LYF line 364: x = x_0 + h * rk.b_k_sum(eps_, 0) # For 2-multistep: b = [b1, b2], eps_ = [eps_current, eps_previous] # So: b_k_sum = b1*eps_current + b2*eps_previous x = x_0 + h * (b1 * epsilon_current + b2 * epsilon_previous) if debug: eps_prev_norm = torch.norm(epsilon_previous).item() eps_curr_norm = torch.norm(epsilon_current).item() if adaptive_mode != "none": # Only print immediate EXTRAPOLATED case here; REAL case is printed after learning update if error_ratio is None: print( f"res_2m step {step_index} [learning] [EXTRAPOLATED]: " f"baseline φ-weights (adaptive error_ratio preserved); ε̂ scaled by 1/L={learning_ratio:.4f}; " f"b1={b1.item():.4f}, b2={b2.item():.4f}" ) else: print(f"res_2m step {step_index}: eps_prev_norm={eps_prev_norm:.2f}, eps_curr_norm={eps_curr_norm:.2f}, " f"c2={c2:.4f}, b1={b1.item():.4f}, b2={b2.item():.4f}") else: # First step / post-skip reanchor / final step if is_final_step: # Final step: sigma_next = 0 # Return denoised directly (Euler method for final step) # Note: Computing h = -log(0/sigma) would give infinity, causing NaN # Full DEIS final step would require porting get_deis_coeff_list() from res4lyf # For now, standard Euler works perfectly for the final step x = denoised if debug: print(f"res_2m step {step_index} (final step): using Euler") else: # Use standard Euler integration when we cannot form a valid previous step # Reason classification improves log clarity if prev_was_skipped: reason = "post-skip reanchor" elif sigma_previous is None or len(error_history) == 0: reason = "first step" else: reason = "no-history reanchor" x = x_0 + h * epsilon_current if debug: print(f"res_2m step {step_index} ({reason}): using Euler") # No adaptation on first/final steps smoothed_error_ratio_next = 1.0 # Store denoised for NEXT step (include SKIPPED to preserve multistep continuity) error_history.append(denoised) if len(error_history) > 2: error_history.pop(0) # Store REAL epsilon only for extrapolation/learning; keep full history (no cap) if not was_skipped: epsilon_history.append(epsilon_current) # Universal learning update when enough REAL history exists (>=3) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_current) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: if adaptive_mode != "none" and 'error_ratio' in locals() and error_ratio is not None: # Combined one-line print for learning mode on REAL step print( f"res_2m step {step_index} [learning] [REAL]: " f"err_ratio={error_ratio:.4f}, adjust={adjustment:.4f}, " f"b1={b1.item():.4f}({b1_base.item():.4f}), b2={b2.item():.4f}({b2_base.item():.4f})" f" | learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}" ) elif adaptive_mode != "none": # If for any reason error_ratio wasn't available, still show learning update succinctly print(f"res_2m step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") return x, smoothed_error_ratio_next, learning_ratio, was_skipped def sample_step_res_2s(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args, epsilon_history, learning_ratio, smoothing_beta, predictor_type, step_index, total_steps, debug=False, skip_mode="none"): """res_2s: 2-stage exponential integrator (baseline, no skipping). - Stage 1: Evaluate at current sigma - Stage 2: Evaluate at midpoint sigma (geometric in log-sigma) - Combine with phi-based weights - Update universal learning ratio on REAL steps (epsilon_history REAL-only) """ noisy_latent_at_step_start = noisy_latent # Inter-step skip support (baseline: Euler-like update with ε̂) should_skip, skip_method = should_skip_model_call( 1.0, # res_2s doesn't track error_ratio; adaptive uses bands but we'll pass 1.0 step_index, total_steps, skip_mode, epsilon_history ) # Note: should_skip_model_call internally checks first <2 and last 4 guards and history length. if should_skip and skip_method is not None: # Build epsilon_hat from REAL history if skip_method == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) # Fallback if missing/NaN if epsilon_hat is None or torch.isnan(epsilon_hat).any(): should_skip = False else: # Scale by learning ratio if we have ≥3 REAL eps in history if len(epsilon_history) >= 3: epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8) # Euler-like update using epsilon_hat d = -(epsilon_hat) / sigma_current dt = sigma_next - sigma_current noisy_latent = noisy_latent + d * dt if debug: e_norm = torch.norm(epsilon_hat).item() dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current) print(f"res_2s step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}") return noisy_latent, learning_ratio # Step size in log space step_size = -torch.log(sigma_next / sigma_current) # Check if this is the final step (sigma_next = 0) # When sigma_next = 0, step_size → ∞, causing numerical issues # RES4LYF switches to ralston for final step; we use Euler for simplicity sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else sigma_next is_final_step = (sigma_next_value == 0) if is_final_step: # Final step: land on denoised directly (avoid infinite step size) model_prediction = model(noisy_latent, sigma_current * s_in, **extra_args) noisy_latent = model_prediction if debug: print(f"res_2s step {step_index} (final step): using Euler") # Learning update on REAL call epsilon_real = model_prediction - noisy_latent_at_step_start epsilon_history.append(epsilon_real) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"res_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") return noisy_latent, learning_ratio midpoint_fraction = 0.5 # Evaluate at midpoint # Phi function weights for 2-stage method phi_1_value = phi_function(order=1, step_size=-step_size) phi_2_value = phi_function(order=2, step_size=-step_size) # Weights for final integration weight_stage2 = phi_2_value / midpoint_fraction weight_stage1 = phi_1_value - weight_stage2 # Weight for advancing to stage 2 phi_1_at_midpoint = phi_function(order=1, step_size=-step_size * midpoint_fraction) stage2_advance_weight = midpoint_fraction * phi_1_at_midpoint # Stage 1: Evaluate at current sigma model_prediction_stage1 = model(noisy_latent, sigma_current * s_in, **extra_args) error_stage1 = -(noisy_latent_at_step_start - model_prediction_stage1) # epsilon at current sigma # Stage 2: Evaluate at midpoint sigma sigma_midpoint = torch.exp(-(-torch.log(sigma_current) + step_size * midpoint_fraction)) noisy_latent_midpoint = noisy_latent_at_step_start + (step_size * stage2_advance_weight) * error_stage1 model_prediction_stage2 = model(noisy_latent_midpoint, sigma_midpoint * s_in, **extra_args) error_stage2 = -(noisy_latent_at_step_start - model_prediction_stage2) # epsilon at midpoint # Final integration with weighted stages noisy_latent = noisy_latent_at_step_start + step_size * ( weight_stage1 * error_stage1 + weight_stage2 * error_stage2 ) if debug: stage1_norm = torch.norm(error_stage1).item() stage2_norm = torch.norm(error_stage2).item() print(f"res_2s step {step_index}: stage1_norm={stage1_norm:.2f}, stage2_norm={stage2_norm:.2f}, " f"weight_s1={weight_stage1.item():.4f}, weight_s2={weight_stage2.item():.4f}") # Learning update on REAL call (use epsilon at current sigma: error_stage1) epsilon_real = error_stage1 epsilon_history.append(epsilon_real) if len(epsilon_history) >= 3: if predictor_type == "richardson": epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) else: epsilon_hat = extrapolate_epsilon_linear(epsilon_history) if epsilon_hat is not None: learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item() learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs if learning_ratio < 0.5: learning_ratio = 0.5 elif learning_ratio > 2.0: learning_ratio = 2.0 if debug: print(f"res_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}") return noisy_latent, learning_ratio