Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import time
import torch
import threading
from ..comfy_copy.res4lyf_sampling import get_res4lyf_step_with_model
import comfy.sample
from .log import print_step_timing
from .skip import _parse_hs_mode, parse_skip_indices_config
from .extrapolation import SigmaAwareHistory, set_sigma_target, set_current_x
from .samplers.euler import sample_step_euler
from .samplers.res2m import sample_step_res_2m
from .samplers.res2s import sample_step_res_2s
from .samplers.ddim import sample_step_ddim
from .samplers.dpmpp_2m import sample_step_dpmpp_2m
from .samplers.dpmpp_2s import sample_step_dpmpp_2s
from .samplers.lms import sample_step_lms
from .samplers.res_multistep import sample_step_res_multistep
from .samplers.res_multistep_official import sample_step_res_multistep_official
from .samplers.res_multistep_ancestral import sample_step_res_multistep_ancestral
from .samplers.heun import sample_step_heun
from .samplers.gradient_estimation import sample_step_gradient_estimation
# Thread-local storage for metadata (to pass out of ksampler_function without breaking ComfyUI)
_thread_local = threading.local()
def detect_model_type(model_patcher, verbose=False):
"""
Enhanced model type detection that identifies specific models like Qwen, SDXL, Flux, etc.
Args:
model_patcher: ComfyUI ModelPatcher object
verbose: Print debug information during detection
Returns:
str: Detected model type identifier (e.g., "qwen-image", "sdxl-base", "flux-dev")
"""
try:
# Get the model class name as fallback
class_name = "unknown"
if hasattr(model_patcher, 'model') and hasattr(model_patcher.model, '__class__'):
class_name = model_patcher.model.__class__.__name__.lower()
if verbose:
print(f"[FSampler] Model class name: {class_name}")
# Check model_patcher attributes for checkpoint info
checkpoint_info = []
for attr in ['checkpoint_path', 'model_name', 'filename', 'name']:
if hasattr(model_patcher, attr):
val = getattr(model_patcher, attr)
if val:
checkpoint_info.append(str(val).lower())
if verbose:
print(f"[FSampler] Found {attr}: {val}")
# Check checkpoint info for Qwen patterns
checkpoint_str = ' '.join(checkpoint_info)
if 'qwen' in checkpoint_str:
if verbose:
print(f"[FSampler] Found 'qwen' in checkpoint info: {checkpoint_str}")
if 'edit' in checkpoint_str:
if '2509' in checkpoint_str:
return "qwen-image-edit-2509"
return "qwen-image-edit"
return "qwen-image"
# Check for checkpoint filename (most reliable for Qwen detection)
checkpoint_name = ""
if hasattr(model_patcher, 'model') and hasattr(model_patcher.model, 'model_config'):
config = model_patcher.model.model_config
if verbose:
print(f"[FSampler] Model has config, checking unet_config...")
if hasattr(config, 'unet_config') and isinstance(config.unet_config, dict):
# Check config for model-specific markers
unet_config = config.unet_config
if verbose:
print(f"[FSampler] unet_config keys: {list(unet_config.keys())[:10]}")
# Qwen detection via config
unet_str = str(unet_config).lower()
if 'qwen' in unet_str:
if verbose:
print(f"[FSampler] Found 'qwen' in unet_config")
if 'edit' in unet_str:
if '2509' in unet_str:
return "qwen-image-edit-2509"
return "qwen-image-edit"
return "qwen-image"
# Check model state dict keys for Qwen-specific patterns
if hasattr(model_patcher, 'model') and hasattr(model_patcher.model, 'state_dict'):
try:
state_dict_keys = list(model_patcher.model.state_dict().keys())
keys_str = ' '.join(state_dict_keys[:20]).lower() # Check first 20 keys
if verbose:
print(f"[FSampler] First state dict keys: {state_dict_keys[:5]}")
if 'qwen' in keys_str:
if verbose:
print(f"[FSampler] Found 'qwen' in state_dict keys")
if 'edit' in keys_str:
return "qwen-image-edit"
return "qwen-image"
except:
pass
# Check for common model architecture patterns
# SDXL detection
if 'sdxl' in class_name:
if 'refiner' in class_name:
return "sdxl-refiner"
return "sdxl-base"
# Flux detection
if 'flux' in class_name:
if 'schnell' in class_name:
return "flux-schnell"
elif 'pro' in class_name:
return "flux-pro"
return "flux-dev"
# SD3 detection
if 'sd3' in class_name or 'stable_diffusion_3' in class_name:
if 'turbo' in class_name:
return "sd3-large-turbo"
elif 'large' in class_name:
return "sd3-large"
return "sd3-medium"
# Cascade detection
if 'cascade' in class_name:
if 'stage_c' in class_name or 'stagec' in class_name:
return "cascade-stage-c"
elif 'stage_b' in class_name or 'stageb' in class_name:
return "cascade-stage-b"
return "cascade-stage-a"
# Hunyuan detection
if 'hunyuan' in class_name:
if 'video' in class_name:
return "hunyuan-video"
return "hunyuan-dit"
# Kolors detection
if 'kolors' in class_name:
return "kolors"
# PixArt detection
if 'pixart' in class_name:
if 'sigma' in class_name:
return "pixart-sigma"
return "pixart-alpha"
# Playground detection
if 'playground' in class_name:
if 'v2.5' in class_name or 'v25' in class_name:
return "playground-v25"
return "playground-v2"
# AuraFlow detection
if 'auraflow' in class_name:
return "auraflow"
# CogView detection
if 'cogview' in class_name:
return "cogview3"
# Cosmos detection
if 'cosmos' in class_name:
return "cosmos-1"
# Lumina detection
if 'lumina' in class_name:
return "lumina-next"
# SD 1.x/2.x detection (fallback)
if 'sd' in class_name or 'stablediffusion' in class_name:
if '2.1' in class_name or 'v2_1' in class_name:
return "sd21"
elif '2.0' in class_name or 'v2_0' in class_name or 'v2' in class_name:
return "sd20"
return "sd15"
# Return class name as fallback
return class_name if class_name != "unknown" else "unknown"
except Exception as e:
print(f"[FSampler] Model detection error: {e}")
return "unknown"
def create_fsampler_ksampler(sampler="euler", adaptive_mode="none",
smoothing_beta=0.9, skip_mode="none", add_noise_ratio=0.0,
add_noise_type="whitened", scheduler=None, start_at_step=None,
end_at_step=None, denoise=None, debug=False,
protect_last_steps=4, protect_first_steps=2,
anchor_interval=None, max_consecutive_skips=None,
use_no_grad=True, official_comfy=False, skip_indices: str = "",
seed=None, timestamp_start=None, sigma_aware=False,
extrapolate_denoised=False):
"""Create a KSAMPLER with FSampler's skip-aware sampling logic.
Returns a comfy.samplers.KSAMPLER that can be used with comfy.sample.sample_custom()
or guider.sample() for the modular SamplerCustomAdvanced workflow.
"""
def ksampler_function(model, x, sigmas, extra_args=None, callback=None, disable=None):
# Allow runtime control of autograd to match official behavior
import contextlib, torch as _torch
ctx = _torch.no_grad() if use_no_grad else contextlib.nullcontext()
with ctx:
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
error_history = [] # RES2M denoised history (REAL + SKIPPED)
epsilon_history = SigmaAwareHistory() if (sigma_aware or extrapolate_denoised) else [] # REAL-only epsilon history for extrapolation + learning
sigma_previous = None
smoothed_error_ratio = 1.0 # For RES2M Phase-1 adaptive weights
learning_ratio = 1.0 # Universal learning stabilizer
# Initialize metadata tracking (when debug=True)
per_step_data = [] if debug else None
l_history = [] if debug else None
# Normalize predictor selection using history from skip_mode (hN/sK or legacy aliases)
_skip_mode_l = str(skip_mode).lower() if isinstance(skip_mode, str) else "none"
hs = _parse_hs_mode(_skip_mode_l)
if hs is not None:
history_order, _ = hs
if history_order >= 4:
predictor_type = "h4"
elif history_order == 3:
predictor_type = "richardson"
else:
predictor_type = "linear"
else:
if _skip_mode_l in ("h4", ): # 4-point predictor
predictor_type = "h4"
elif _skip_mode_l in ("h3", "richardson"):
predictor_type = "richardson"
else:
predictor_type = "linear"
skip_stats = {
"total_steps": 0,
"model_calls": 0,
"skipped": 0,
# Adaptive skip controller state
"consecutive_skips": 0,
"last_anchor_step": -1,
# Explicit indices gating state
"explicit_streak": False,
"needed_learns": 2,
}
if debug:
print(f"\n{'='*60}")
print(f"FSampler Settings:")
print(f" sampler: {sampler}")
if scheduler is not None:
print(f" scheduler: {scheduler}")
print(f" adaptive_mode: {adaptive_mode}")
print(f" smoothing_beta: {smoothing_beta}")
print(f" skip_mode: {skip_mode}")
print(f" add_noise_ratio: {add_noise_ratio}")
print(f" noise_type: {add_noise_type}")
print(f" official_comfy: {official_comfy}")
print(f" sigma_aware: {sigma_aware}")
print(f" extrapolate_denoised: {extrapolate_denoised}")
if skip_mode != "none":
print(f" protect_first_steps: {protect_first_steps}")
print(f" protect_last_steps: {protect_last_steps}")
if denoise is not None:
print(f" denoise: {denoise}")
if start_at_step is not None:
print(f" start_at_step: {start_at_step}")
if end_at_step is not None:
print(f" end_at_step: {end_at_step}")
print(f" steps: {len(sigmas)-1}")
try:
import torch as _torch
s = sigmas
if isinstance(s, _torch.Tensor):
s0 = float(s[0]); sL = float(s[-1])
f3 = [float(v) for v in s[:3]]
l3 = [float(v) for v in s[-3:]]
else:
s0 = float(s[0]); sL = float(s[-1])
f3 = [float(v) for v in s[:3]]
l3 = [float(v) for v in s[-3:]]
print(f" sigma_range: [{s0:.4f}, {sL:.4f}] len={len(s)}")
print(f" sigmas head: {[round(v,4) for v in f3]} tail: {[round(v,4) for v in l3]}")
except Exception:
pass
print(f"{'='*60}\n")
# Parse explicit skip indices early (indices bounded once total_steps is known)
explicit_predictor, explicit_indices = parse_skip_indices_config(skip_indices or "")
explicit_mode = len(explicit_indices) > 0
res2m_prev_was_skipped = False
res2m_noise_cooldown = 0
res2m_prev_sigma_down = None
resms_prev_sigma_down = None
_t_start = time.time()
total_steps = len(sigmas) - 1
# Determine effective modes and bound explicit indices
if explicit_mode:
# Bound and filter per-total-steps; also ensure we never include 0/1
explicit_indices = {i for i in explicit_indices if 0 <= i < total_steps and i >= 2}
if len(explicit_indices) == 0:
explicit_mode = False
effective_skip_mode = ("none" if explicit_mode else skip_mode)
effective_adaptive_mode = ("none" if explicit_mode else adaptive_mode)
if explicit_mode:
# Override predictor used for learning/hints
predictor_type = explicit_predictor
if debug:
print(f"Explicit Skip Mode: ON")
print(f" explicit_predictor: {explicit_predictor}")
print(f" explicit_indices: {sorted(list(explicit_indices))}")
print(f" disabled: skip_mode/adaptive/anchor/max_consecutive/protect_*")
for step_index in range(total_steps):
sigma_current = sigmas[step_index]
sigma_next = sigmas[step_index + 1]
print_step_timing(sampler, step_index, _t_start, total_steps)
# Track total steps centrally for all samplers
skip_stats["total_steps"] += 1
# Capture step start time and model_calls before step (for metadata)
step_start_time = time.time() if debug else None
model_calls_before = skip_stats.get("model_calls", 0) if debug else None
# Set sigma context for sigma-aware extrapolation
if sigma_aware:
epsilon_history.set_pending_sigma(float(sigma_current))
set_sigma_target(float(sigma_current))
# Set denoised context for denoised-mode extrapolation
if extrapolate_denoised:
epsilon_history.set_pending_x(x)
set_current_x(x)
if sampler == "res_2m":
x, smoothed_error_ratio, learning_ratio, res2m_prev_was_skipped, res2m_noise_cooldown, res2m_prev_sigma_down = sample_step_res_2m(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
prev_was_skipped=res2m_prev_was_skipped,
step_index=step_index,
total_steps=total_steps,
adaptive_mode=effective_adaptive_mode,
smoothing_beta=smoothing_beta,
smoothed_error_ratio=smoothed_error_ratio,
learning_ratio=learning_ratio,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
noise_cooldown=res2m_noise_cooldown,
old_sigma_down=res2m_prev_sigma_down,
)
elif sampler == "res_2s":
x, learning_ratio = sample_step_res_2s(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
debug=debug,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "euler":
x, learning_ratio = sample_step_euler(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "ddim":
x, learning_ratio = sample_step_ddim(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "dpmpp_2m":
x, learning_ratio = sample_step_dpmpp_2m(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "dpmpp_2s":
x, learning_ratio = sample_step_dpmpp_2s(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "lms":
x, learning_ratio = sample_step_lms(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "heun":
x, learning_ratio = sample_step_heun(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "gradient_estimation":
x, learning_ratio = sample_step_gradient_estimation(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
s_in=s_in,
extra_args=extra_args,
epsilon_history=epsilon_history,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
step_index=step_index,
total_steps=total_steps,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
protect_last_steps=protect_last_steps,
debug=debug,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
elif sampler == "res_multistep":
# Choose official vs res4lyf implementation
if official_comfy:
try:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_official(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
except TypeError:
# Back-compat with older function signature
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_official(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
)
else:
try:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
except TypeError:
# Back-compat with older function signature
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
)
elif sampler == "res_multistep_ancestral":
# Dedicated ancestral variant; always call the wrapper and continue
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_ancestral(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
official_comfy=official_comfy,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
continue
if official_comfy:
try:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_official(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
except TypeError:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_official(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
)
# else (removed duplicate)
try:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
except TypeError:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
)
except TypeError:
# Back-compat with older function signature
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep_official(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
)
else:
try:
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
explicit_skip_indices=(explicit_indices if explicit_mode else None),
explicit_predictor=(explicit_predictor if explicit_mode else None),
)
except TypeError:
# Back-compat with older function signature
x, learning_ratio, resms_prev_sigma_down = sample_step_res_multistep(
model=model,
noisy_latent=x,
sigma_current=sigma_current,
sigma_next=sigma_next,
sigma_previous=sigma_previous,
old_sigma_down=resms_prev_sigma_down,
s_in=s_in,
extra_args=extra_args,
error_history=error_history,
epsilon_history=epsilon_history,
step_index=step_index,
total_steps=total_steps,
learning_ratio=learning_ratio,
smoothing_beta=smoothing_beta,
predictor_type=predictor_type,
add_noise_ratio=add_noise_ratio,
add_noise_type=add_noise_type,
skip_mode=effective_skip_mode,
skip_stats=skip_stats,
debug=debug,
protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps,
anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips,
adaptive_mode=effective_adaptive_mode,
)
# Track per-step metadata (when debug=True)
if debug and per_step_data is not None:
step_end_time = time.time()
l_history.append(float(learning_ratio))
# Determine if this step was skipped by checking if model_calls increased
model_calls_after = skip_stats.get("model_calls", 0)
was_skipped = 1 if model_calls_after == model_calls_before else 0
per_step_data.append({
"step_index": step_index,
"sigma_current": float(sigma_current),
"sigma_next": float(sigma_next),
"was_skipped": was_skipped,
"learning_ratio": float(learning_ratio),
"step_time_seconds": step_end_time - step_start_time if step_start_time else 0.0,
})
sigma_previous = sigma_current
if callback is not None:
callback({'x': x, 'i': step_index, 'sigma': sigma_current, 'sigma_next': sigma_next, 'denoised': x})
if debug and (effective_skip_mode != "none" or explicit_mode):
total = skip_stats["total_steps"]
called = skip_stats["model_calls"]
skipped = skip_stats["skipped"]
if total > 0:
skip_percent = (skipped / total) * 100
print(f"\n{'='*60}")
print(f"Skip Statistics:")
print(f" Total steps: {total}")
print(f" Model calls: {called}")
print(f" Skipped: {skipped}")
print(f" Reduction: {skip_percent:.1f}%")
print(f"{'='*60}\n")
# Build metadata dict (when debug=True) and store in thread-local
if debug and per_step_data is not None:
_t_end = time.time()
# Detect model type from model_patcher using enhanced detection
model_type = detect_model_type(model, verbose=True)
print(f"[FSampler Debug] Detected model type: {model_type}")
metadata = {
"seed": seed,
"timestamp_start": timestamp_start if timestamp_start is not None else _t_start,
"timestamp_end": _t_end,
"model_type": model_type,
"sampler": sampler,
"scheduler": scheduler if scheduler else "unknown",
"skip_mode": skip_mode,
"adaptive_mode": adaptive_mode,
"smoothing_beta": smoothing_beta,
"total_steps": skip_stats["total_steps"],
"model_calls": skip_stats["model_calls"],
"skipped": skip_stats["skipped"],
"reduction_percent": (skip_stats["skipped"] / skip_stats["total_steps"] * 100) if skip_stats["total_steps"] > 0 else 0.0,
"total_time_seconds": _t_end - _t_start,
"protect_first_steps": protect_first_steps,
"protect_last_steps": protect_last_steps,
"anchor_interval": anchor_interval if adaptive_mode != "none" else None,
"max_consecutive_skips": max_consecutive_skips if adaptive_mode != "none" else None,
"l_final": float(learning_ratio),
"l_mean": sum(l_history) / len(l_history) if l_history else 1.0,
"l_min": min(l_history) if l_history else 1.0,
"l_max": max(l_history) if l_history else 1.0,
"per_step_data": per_step_data,
}
# Store in thread-local storage
_thread_local.last_run_metadata = metadata
else:
_thread_local.last_run_metadata = {}
# Return just x (not tuple) to keep ComfyUI's sampler code happy
return x
from comfy.samplers import KSAMPLER
return KSAMPLER(ksampler_function)
def sample_fsampler(model_patcher, noise, sigmas, positive_conditioning, negative_conditioning,
cfg_scale, latent_image, sampler="euler", adaptive_mode="none",
smoothing_beta=0.9, skip_mode="none", add_noise_ratio=0.0, add_noise_type="whitened",
scheduler=None, start_at_step=None, end_at_step=None, denoise=None,
debug=False, callback=None, protect_last_steps=4, protect_first_steps=2,
anchor_interval=None, max_consecutive_skips=None, use_no_grad=True,
official_comfy=False, skip_indices: str = "", seed=None, timestamp_start=None,
sigma_aware=False, extrapolate_denoised=False):
"""Orchestrates sampling with pluggable samplers and shared skip/learning/timing.
Uses create_fsampler_ksampler() to build a skip-aware KSAMPLER, then runs it
through comfy.sample.sample_custom() with CFG guiding.
"""
wrapped_sampler = create_fsampler_ksampler(
sampler=sampler, adaptive_mode=adaptive_mode, smoothing_beta=smoothing_beta,
skip_mode=skip_mode, add_noise_ratio=add_noise_ratio, add_noise_type=add_noise_type,
scheduler=scheduler, start_at_step=start_at_step, end_at_step=end_at_step,
denoise=denoise, debug=debug, protect_last_steps=protect_last_steps,
protect_first_steps=protect_first_steps, anchor_interval=anchor_interval,
max_consecutive_skips=max_consecutive_skips, use_no_grad=use_no_grad,
official_comfy=official_comfy, skip_indices=skip_indices,
seed=seed, timestamp_start=timestamp_start, sigma_aware=sigma_aware,
extrapolate_denoised=extrapolate_denoised,
)
samples = comfy.sample.sample_custom(
model=model_patcher,
noise=noise,
cfg=cfg_scale,
sampler=wrapped_sampler,
sigmas=sigmas,
positive=positive_conditioning,
negative=negative_conditioning,
latent_image=latent_image,
noise_mask=None,
callback=callback,
disable_pbar=False,
seed=None
)
# Retrieve metadata from thread-local storage
metadata = getattr(_thread_local, 'last_run_metadata', {})
return samples, metadata