Upload sd_simple_kes_v3_fix? using SD-Hub
Browse files- sd_simple_kes_v3_fix?/__pycache__/get_sigmas.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/__pycache__/plot_sigma_sequence.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/__pycache__/simple_kes_v3.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/__pycache__/validate_config.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/get_sigmas.py +18 -0
- sd_simple_kes_v3_fix?/kes_config/default_config.yaml +196 -0
- sd_simple_kes_v3_fix?/kes_config/good_configs.yaml +45 -0
- sd_simple_kes_v3_fix?/kes_config/how to use these files +20 -0
- sd_simple_kes_v3_fix?/kes_config/simple_kes_requirements.txt +5 -0
- sd_simple_kes_v3_fix?/kes_config/simple_kes_scheduler.yaml +146 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/alternating soft_hard decay.yaml +42 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/anime_1.yaml +18 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/cross_style_safe.yaml +18 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/front_loaded geometric.yaml +42 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/photo_realistic_1.yaml +18 -0
- sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/progressive.yaml +42 -0
- sd_simple_kes_v3_fix?/kes_config/user_config.yaml +81 -0
- sd_simple_kes_v3_fix?/plot_sigma_sequence.py +42 -0
- sd_simple_kes_v3_fix?/requirements.txt +4 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/euler_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/exponential_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/geometric_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/harmonic_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/karras_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/logarithmic_advanced_scheduler.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/__pycache__/shared.cpython-310.pyc +0 -0
- sd_simple_kes_v3_fix?/schedulers/euler_advanced_scheduler.py +53 -0
- sd_simple_kes_v3_fix?/schedulers/exponential_advanced_scheduler.py +41 -0
- sd_simple_kes_v3_fix?/schedulers/geometric_advanced_scheduler.py +60 -0
- sd_simple_kes_v3_fix?/schedulers/harmonic_advanced_scheduler.py +45 -0
- sd_simple_kes_v3_fix?/schedulers/karras_advanced_scheduler.py +40 -0
- sd_simple_kes_v3_fix?/schedulers/logarithmic_advanced_scheduler.py +46 -0
- sd_simple_kes_v3_fix?/schedulers/shared.py +184 -0
- sd_simple_kes_v3_fix?/simple_kes_v3.py +2002 -0
- sd_simple_kes_v3_fix?/validate_config.py +88 -0
sd_simple_kes_v3_fix?/__pycache__/get_sigmas.cpython-310.pyc
ADDED
|
Binary file (964 Bytes). View file
|
|
|
sd_simple_kes_v3_fix?/__pycache__/plot_sigma_sequence.cpython-310.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
sd_simple_kes_v3_fix?/__pycache__/simple_kes_v3.cpython-310.pyc
ADDED
|
Binary file (52.7 kB). View file
|
|
|
sd_simple_kes_v3_fix?/__pycache__/validate_config.cpython-310.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
sd_simple_kes_v3_fix?/get_sigmas.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules.sd_simple_kes_v3.schedulers.karras_advanced_scheduler import get_sigmas_karras
|
| 2 |
+
from modules.sd_simple_kes_v3.schedulers.exponential_advanced_scheduler import get_sigmas_exponential
|
| 3 |
+
from modules.sd_simple_kes_v3.schedulers.geometric_advanced_scheduler import get_sigmas_geometric
|
| 4 |
+
from modules.sd_simple_kes_v3.schedulers.harmonic_advanced_scheduler import get_sigmas_harmonic
|
| 5 |
+
from modules.sd_simple_kes_v3.schedulers.logarithmic_advanced_scheduler import get_sigmas_logarithmic
|
| 6 |
+
from modules.sd_simple_kes_v3.schedulers.euler_advanced_scheduler import get_sigmas_euler, get_sigmas_euler_advanced
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
scheduler_registry = {
|
| 10 |
+
'karras': get_sigmas_karras,
|
| 11 |
+
'exponential': get_sigmas_exponential,
|
| 12 |
+
'geometric': get_sigmas_geometric,
|
| 13 |
+
'harmonic': get_sigmas_harmonic,
|
| 14 |
+
'logarithmic': get_sigmas_logarithmic,
|
| 15 |
+
'euler': get_sigmas_euler,
|
| 16 |
+
'euler_advanced': get_sigmas_euler_advanced
|
| 17 |
+
# Add more here - ensure methods are added to get_sigmas then update imports #also update simple_kes
|
| 18 |
+
}
|
sd_simple_kes_v3_fix?/kes_config/default_config.yaml
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_save_directory: "modules/sd_simple_kes_v3/image_generation_data"
|
| 2 |
+
graph_save_directory: "modules/sd_simple_kes_v3/image_generation_data"
|
| 3 |
+
graph_save_enable: false
|
| 4 |
+
load_sigma_cache: false # Master toggle to enable/disable sigma caching
|
| 5 |
+
save_sigma_cache: false # Whether to save new sigma schedules to file
|
| 6 |
+
decay_pattern: "extrapolate" #Options: zero, soft_landing, extrapolate, and fractional
|
| 7 |
+
decay_mode: "append"
|
| 8 |
+
load_prepass_sigmas: false
|
| 9 |
+
save_prepass_sigmas: false
|
| 10 |
+
|
| 11 |
+
#blend methods #Key/Value - Options: "karras", "exponential", "geometric", "harmonic", "logarithmic"
|
| 12 |
+
#blend weights #Any number between 0 and infinity. Explicit will use that number in relation to the other numbers and give it more weight. Softmax will normalize all values in relation to each other and keep values between 0 and 1.
|
| 13 |
+
#if decay_pattern is empty or missing, decay_mode, tail_steps, and decay_rate do not get used.
|
| 14 |
+
#decay rate only affects geometric patterns
|
| 15 |
+
allow_step_expansion: false # Strict A1111 compatibility mode
|
| 16 |
+
apply_tail_steps: false # Append tails from schedulers to sigma sequence
|
| 17 |
+
apply_decay_tail: false # Append decay tails from schedulers to sigma sequence
|
| 18 |
+
apply_blended_tail: false # Blend multiple tails into a single tail, then append
|
| 19 |
+
apply_progressive_decay: false # Gradually apply decay to sigma sequence step-by-step
|
| 20 |
+
|
| 21 |
+
auto_tail_smoothing: true
|
| 22 |
+
auto_tail_threshold: 0.05
|
| 23 |
+
jaggedness_threshold: 0.01
|
| 24 |
+
|
| 25 |
+
auto_stabilization_sequence:
|
| 26 |
+
- smooth_interpolation
|
| 27 |
+
- append_tail
|
| 28 |
+
- blend_tail
|
| 29 |
+
- apply_decay
|
| 30 |
+
- progressive_decay
|
| 31 |
+
|
| 32 |
+
blending_style: 'softmax' # Options: 'explicit' or 'softmax'
|
| 33 |
+
#Valid decay_patterns: 'geometric', 'harmonic', 'extrapolate','fractional', 'logarithmic', 'exponential', 'linear', and 'zero'
|
| 34 |
+
#Valid decay_modes: 'append', 'blend', 'replace'
|
| 35 |
+
#Valid decay modes compatible with A1111: all if tail_steps is not greater than 1. If any methods add steps that increase steps higher than what was requested, it is not compatible
|
| 36 |
+
#decay modes have been tested and they work. However if they increase steps beyond the requested amount, it will not work in the A1111 pipeline. If a pipeline supports increasing steps to have a smoother transition for sigma/noise reduction, then this method would function as intended - to increase steps to have a smoother transition & no jaggedness between steps.
|
| 37 |
+
|
| 38 |
+
blend_methods:
|
| 39 |
+
euler: #if euler was a scheduler -
|
| 40 |
+
weight: 0.3
|
| 41 |
+
decay_pattern: 'harmonic'
|
| 42 |
+
decay_mode: 'blend'
|
| 43 |
+
tail_steps: 1
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
euler_advanced: #if euler advanced were a scheduler -
|
| 47 |
+
weight: 0.7
|
| 48 |
+
decay_pattern: 'harmonic'
|
| 49 |
+
decay_mode: 'blend'
|
| 50 |
+
tail_steps: 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
blending_mode: "default" # Options: "auto", "default" "smooth_blend", "weights"
|
| 56 |
+
#"auto" uses smart weights if more than 2 methods, or smooth blend if exactly two methods
|
| 57 |
+
#"default" is karras + exponential for the standard Simple_KES methods which uses smooth_blend_factor
|
| 58 |
+
#"smooth_blend" enforces use even with weights included
|
| 59 |
+
# "weights" enforces weights even with only 2 methods
|
| 60 |
+
|
| 61 |
+
smooth_blend_factor_rand: false
|
| 62 |
+
smooth_blend_factor_rand_min: 6
|
| 63 |
+
smooth_blend_factor_rand_max: 11
|
| 64 |
+
smooth_blend_factor: 3
|
| 65 |
+
smooth_blend_factor_enable_randomization_type: false
|
| 66 |
+
smooth_blend_factor_randomization_type: "asymmetric"
|
| 67 |
+
smooth_blend_factor_randomization_percent: 0.2
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
skip_prepass: true # has no change to image quality - not currently functioning as intended for early stop purposes
|
| 72 |
+
device: "cuda" #cpu or cuda
|
| 73 |
+
debug: true
|
| 74 |
+
global_randomize: false
|
| 75 |
+
#
|
| 76 |
+
sigma_scale_factor: 900
|
| 77 |
+
sigma_auto_enabled: true
|
| 78 |
+
sigma_auto_mode: sigma_min # Options: sigma_min, sigma_max
|
| 79 |
+
#
|
| 80 |
+
rho_rand: false
|
| 81 |
+
rho_rand_min: 3.00 # tested recommended settings threshold
|
| 82 |
+
rho_rand_max: 8.00 # tested recommended settings threshold
|
| 83 |
+
#rho: 7.571656624637901
|
| 84 |
+
rho: 7.959565031107985
|
| 85 |
+
rho_enable_randomization_type: false
|
| 86 |
+
rho_randomization_type: "log"
|
| 87 |
+
rho_randomization_percent: 0.1
|
| 88 |
+
#
|
| 89 |
+
sigma_min_rand: false
|
| 90 |
+
sigma_min_rand_min: 0.001 # tested recommended settings
|
| 91 |
+
sigma_min_rand_max: 0.02 # tested recommended settings threshold
|
| 92 |
+
sigma_min: 0.13757067353874633
|
| 93 |
+
sigma_min_enable_randomization_type: false
|
| 94 |
+
sigma_min_randomization_type: "asymmetric"
|
| 95 |
+
sigma_min_randomization_percent: 0.2
|
| 96 |
+
#
|
| 97 |
+
sigma_max_rand: false
|
| 98 |
+
sigma_max_rand_min: 25
|
| 99 |
+
sigma_max_rand_max: 60
|
| 100 |
+
sigma_max: 47.95768510805332
|
| 101 |
+
sigma_max_enable_randomization_type: false
|
| 102 |
+
sigma_max_randomization_type: "log"
|
| 103 |
+
sigma_max_randomization_percent: 0.25
|
| 104 |
+
#
|
| 105 |
+
start_blend_rand: false
|
| 106 |
+
start_blend_rand_min: 0.04 # tested recommended settings threshold
|
| 107 |
+
start_blend_rand_max: 0.11 # tested recommended settings threshold
|
| 108 |
+
start_blend: 0.05
|
| 109 |
+
start_blend_enable_randomization_type: false
|
| 110 |
+
start_blend_randomization_type: "asymmetric"
|
| 111 |
+
start_blend_randomization_percent: 0.1
|
| 112 |
+
#
|
| 113 |
+
end_blend_rand: false
|
| 114 |
+
end_blend_rand_min: 0.4 # tested recommended settings threshold
|
| 115 |
+
end_blend_rand_max: 0.6 # tested recommended settings threshold
|
| 116 |
+
end_blend: 0.4
|
| 117 |
+
end_blend_enable_randomization_type: false
|
| 118 |
+
end_blend_randomization_type: "asymmetric"
|
| 119 |
+
end_blend_randomization_percent: 0.2
|
| 120 |
+
#
|
| 121 |
+
sharpness_rand: false
|
| 122 |
+
sharpness_rand_min: 0.75 # tested recommended settings threshold
|
| 123 |
+
sharpness_rand_max: 0.95 # tested recommended settings threshold
|
| 124 |
+
sharpness: 0.85 # Note: Visible changes in image between 2-15. Above 15 - notable differences. At 50+ - poor image quality. sharpness not applied above 0.95
|
| 125 |
+
sharpen_variance_threshold: 0.01
|
| 126 |
+
sharpen_last_n_steps: 10
|
| 127 |
+
sharpen_mode: "full" # Options: last_n, full, both
|
| 128 |
+
sharpness_enable_randomization_type: false
|
| 129 |
+
sharpness_randomization_type: "asymmetric"
|
| 130 |
+
sharpness_randomization_percent: 0.2
|
| 131 |
+
#
|
| 132 |
+
step_progress_mode: "sigmoid" # Options supported (default = "linear"), "exponential", "logarithmic", or "sigmoid". If exponential, uses "exp_power"
|
| 133 |
+
exp_power: 2
|
| 134 |
+
#
|
| 135 |
+
initial_step_size_rand: false
|
| 136 |
+
initial_step_size_rand_min: 0.7
|
| 137 |
+
initial_step_size_rand_max: 1.0
|
| 138 |
+
initial_step_size: 0.9
|
| 139 |
+
initial_step_size_enable_randomization_type: false
|
| 140 |
+
initial_step_size_randomization_type: "asymmetric" #assym, symm, log, or exp A/S/L/E
|
| 141 |
+
initial_step_size_randomization_percent: 0.2
|
| 142 |
+
#
|
| 143 |
+
final_step_size_rand: false
|
| 144 |
+
final_step_size_rand_min: 0.1
|
| 145 |
+
final_step_size_rand_max: 0.3
|
| 146 |
+
final_step_size: 0.20
|
| 147 |
+
final_step_size_enable_randomization_type: false
|
| 148 |
+
final_step_size_randomization_type: "asymmetric"
|
| 149 |
+
final_step_size_randomization_percent: 0.2
|
| 150 |
+
#
|
| 151 |
+
step_size_factor_rand: false
|
| 152 |
+
step_size_factor_rand_min: 0.65
|
| 153 |
+
step_size_factor_rand_max: 0.85
|
| 154 |
+
step_size_factor: 0.80814932869181
|
| 155 |
+
step_size_factor_enable_randomization_type: false
|
| 156 |
+
step_size_factor_randomization_type: "asymmetric"
|
| 157 |
+
step_size_factor_randomization_percent: 0.2
|
| 158 |
+
#
|
| 159 |
+
initial_noise_scale_rand: false
|
| 160 |
+
initial_noise_scale_rand_min: 1.0
|
| 161 |
+
initial_noise_scale_rand_max: 1.5
|
| 162 |
+
initial_noise_scale: 1.25
|
| 163 |
+
initial_noise_scale_enable_randomization_type: false
|
| 164 |
+
initial_noise_scale_randomization_type: "asymmetric"
|
| 165 |
+
initial_noise_scale_randomization_percent: 0.2
|
| 166 |
+
#
|
| 167 |
+
final_noise_scale_rand: false
|
| 168 |
+
final_noise_scale_rand_min: 0.6
|
| 169 |
+
final_noise_scale_rand_max: 1.0
|
| 170 |
+
final_noise_scale: 0.80
|
| 171 |
+
final_noise_scale_enable_randomization_type: false
|
| 172 |
+
final_noise_scale_randomization_type: "asymmetric"
|
| 173 |
+
final_noise_scale_randomization_percent: 0.2
|
| 174 |
+
|
| 175 |
+
#
|
| 176 |
+
noise_scale_factor_rand: false
|
| 177 |
+
noise_scale_factor_rand_min: 0.75
|
| 178 |
+
noise_scale_factor_rand_max: 0.95
|
| 179 |
+
noise_scale_factor: 0.8113992828873163
|
| 180 |
+
noise_scale_factor_enable_randomization_type: false
|
| 181 |
+
noise_scale_factor_randomization_type: "asymmetric"
|
| 182 |
+
noise_scale_factor_randomization_percent: 0.2
|
| 183 |
+
|
| 184 |
+
# Experimental settings
|
| 185 |
+
early_stopping_threshold_rand: false
|
| 186 |
+
early_stopping_threshold_rand_min: 0.001
|
| 187 |
+
early_stopping_threshold_rand_max: 0.02
|
| 188 |
+
early_stopping_threshold: 0.06
|
| 189 |
+
early_stopping_method: max # Options: mean, max, sum
|
| 190 |
+
sigma_variance_scale: 0.1 # *100 = % of current sigma, increase to reduce false early stopping, try 0.07 or 0.10
|
| 191 |
+
safety_minimum_stop_step: 10 # means won't consider until past this step, consider increasing this to increase minimum steps to process the image
|
| 192 |
+
recent_change_convergence_delta: 0.6 # this is the change between mean/max variable changes between sigmas. Keep this relatively low. This contributes directly to when we stop.
|
| 193 |
+
#min_visual_sigma: 50 # Increase from 10 to push later into the denoising sequence
|
| 194 |
+
early_stopping_threshold_enable_randomization_type: false
|
| 195 |
+
early_stopping_threshold_randomization_type: "asymmetric"
|
| 196 |
+
early_stopping_threshold_randomization_percent: 0.2
|
sd_simple_kes_v3_fix?/kes_config/good_configs.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler:
|
| 3 |
+
weight: 0.0
|
| 4 |
+
decay_pattern: 'geometric'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
euler_advanced:
|
| 9 |
+
weight: 1.895
|
| 10 |
+
decay_pattern: 'harmonic'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
geometric:
|
| 15 |
+
weight: 0.0
|
| 16 |
+
decay_pattern: 'exponential'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
| 19 |
+
|
| 20 |
+
harmonic:
|
| 21 |
+
weight: 0.5
|
| 22 |
+
decay_pattern: 'logarithmic'
|
| 23 |
+
decay_mode: 'blend'
|
| 24 |
+
tail_steps: 1
|
| 25 |
+
|
| 26 |
+
logarithmic:
|
| 27 |
+
weight: 0.5
|
| 28 |
+
decay_pattern: 'fractional'
|
| 29 |
+
decay_mode: 'blend'
|
| 30 |
+
tail_steps: 1
|
| 31 |
+
|
| 32 |
+
karras:
|
| 33 |
+
weight: 0.0
|
| 34 |
+
decay_pattern: 'exponential'
|
| 35 |
+
decay_mode: 'blend'
|
| 36 |
+
tail_steps: 1
|
| 37 |
+
|
| 38 |
+
exponential:
|
| 39 |
+
weight: 0.0
|
| 40 |
+
decay_pattern: 'logarithmic'
|
| 41 |
+
decay_mode: 'blend'
|
| 42 |
+
tail_steps: 1
|
| 43 |
+
|
| 44 |
+
blending_style: 'softmax'
|
| 45 |
+
blending_mode: 'weights'
|
sd_simple_kes_v3_fix?/kes_config/how to use these files
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
These are only suggested configurations. To give you a starting point for your configs.
|
| 3 |
+
|
| 4 |
+
Do not make changes to the default_config file unless you know what you're doing. If you do decide to tweak it, please copy it so you can revert changes if needed.
|
| 5 |
+
|
| 6 |
+
The best way is to copy whatever configuration changes from the default_config into the user_config since any values in the user_config will override the default_config when the program runs. This means you should never need to change the default_config.
|
| 7 |
+
|
| 8 |
+
I've kept my user_config pretty basic and I've kept the default_config as is. Should you want to edit your user_config with advanced settings you can find those advanced settings inside the default_config.
|
| 9 |
+
|
| 10 |
+
The files in the "suggested_scheduling_configs" folder have suggested values for different scheduler combinations **ONLY**.
|
| 11 |
+
|
| 12 |
+
Simple replace in line or completely your user_config yaml with the suggested config. Tweak as needed.
|
| 13 |
+
|
| 14 |
+
If you decide to replace in line your current file, simply set the weights to 0.0 for any current ones that you're using, or you could comment out the lines that you don't use.
|
| 15 |
+
|
| 16 |
+
If you're happy with your current user_config, I'd recommend you simply make a copy of it, and then tweak the user_config by copying the files in whole. That way is the easiest and cleanest way to ensure you don't delete things that work and you can play around with other values.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
sd_simple_kes_v3_fix?/kes_config/simple_kes_requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
k_diffusion
|
| 2 |
+
requests
|
| 3 |
+
torch
|
| 4 |
+
Pyyaml
|
| 5 |
+
watchdog
|
sd_simple_kes_v3_fix?/kes_config/simple_kes_scheduler.yaml
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scheduler:
|
| 2 |
+
|
| 3 |
+
#Optionally print to a log file for debugging. If false, debug is turned off, and no log file will be created.
|
| 4 |
+
#config options: true or false
|
| 5 |
+
debug: true
|
| 6 |
+
|
| 7 |
+
# The minimum value for the noise level (sigma) during image generation.
|
| 8 |
+
# Decreasing this value makes the image clearer but less detailed.
|
| 9 |
+
# Increasing it makes the image noisier but potentially more artistic or abstract.
|
| 10 |
+
sigma_min: 0.04 # Default: 0.01, Suggested range: 0.01 - 0.1
|
| 11 |
+
|
| 12 |
+
# The maximum value for the noise level (sigma) during image generation.
|
| 13 |
+
# Increasing this value can create more variation in the image details.
|
| 14 |
+
# Lower values keep the image more stable and less noisy.
|
| 15 |
+
sigma_max: 50 # Default: 50, Suggested range:10 - 60
|
| 16 |
+
|
| 17 |
+
# The device used for running the scheduler. If you have a GPU, set this to "cuda".
|
| 18 |
+
# Otherwise, use "cpu", but note that it will be significantly slower.
|
| 19 |
+
#device: "cuda" # Options: "cuda" (GPU) or "cpu" (processor)
|
| 20 |
+
|
| 21 |
+
# Initial blend factor between Karras and Exponential noise methods.
|
| 22 |
+
# A higher initial blend makes the image sharper at the start.
|
| 23 |
+
# A lower initial blend makes the image smoother early on.
|
| 24 |
+
start_blend: 0.15 # Default: 0.1, Suggested range: 0.05 - 0.2
|
| 25 |
+
|
| 26 |
+
# Final blend factor between Karras and Exponential noise methods.
|
| 27 |
+
# Higher values blend more noise at the end, possibly adding more detail.
|
| 28 |
+
# Lower values blend less noise for smoother, simpler images at the end.
|
| 29 |
+
end_blend: 0.485 # Default: 0.5, Suggested range: 0.4 - 0.6
|
| 30 |
+
|
| 31 |
+
# Sharpening factor applied to images during generation.
|
| 32 |
+
# Higher values increase sharpness but can add unwanted artifacts.
|
| 33 |
+
# Lower values reduce sharpness but may make the image look blurry.
|
| 34 |
+
sharpness: 0.9 # Default: 0.95, Suggested range: 0.8 - 1.0
|
| 35 |
+
|
| 36 |
+
# Early stopping threshold for stopping the image generation when changes between steps are minimal.
|
| 37 |
+
# Lower values stop early, saving time, but might produce incomplete images.
|
| 38 |
+
# Higher values take longer but may give more detailed results.
|
| 39 |
+
early_stopping_threshold: 0.05 # Default: 0.01, Suggested range: 0.01 - 0.05
|
| 40 |
+
|
| 41 |
+
# The number of steps between updates of the blend factor.
|
| 42 |
+
# Smaller values update the blend more frequently for smoother transitions.
|
| 43 |
+
# Larger values update the blend less frequently for faster processing.
|
| 44 |
+
update_interval: 10 # Default: 10, Suggested range: 5 - 15
|
| 45 |
+
|
| 46 |
+
# Initial step size, which controls how quickly the image evolves early on.
|
| 47 |
+
# Higher values make big changes at the start, possibly generating faster but less refined images.
|
| 48 |
+
# Lower values make smaller changes, giving more control over details.
|
| 49 |
+
initial_step_size: 0.5 # Default, 0.9, Suggested range: 0.5 - 1.0
|
| 50 |
+
|
| 51 |
+
# Final step size, which controls how much the image changes towards the end.
|
| 52 |
+
# Higher values keep details more flexible until the end, which may add complexity.
|
| 53 |
+
# Lower values lock the details earlier, making the image simpler.
|
| 54 |
+
final_step_size: 0.22 # Default: 0.2, Suggested range: 0.1 - 0.3
|
| 55 |
+
|
| 56 |
+
# Initial noise scaling applied to the image generation process.
|
| 57 |
+
# Higher values add more noise early on, making the initial image more random.
|
| 58 |
+
# Lower values reduce noise early on, leading to a smoother initial image.
|
| 59 |
+
initial_noise_scale: 1.1 # Default, 1.25, Suggested range: 1.0 - 1.5
|
| 60 |
+
|
| 61 |
+
# Final noise scaling applied at the end of the image generation.
|
| 62 |
+
# Higher values add noise towards the end, possibly adding fine detail.
|
| 63 |
+
# Lower values reduce noise towards the end, making the final image smoother.
|
| 64 |
+
final_noise_scale: 0.7 # Default, 0.8, Suggested range: 0.6 - 1.0
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
smooth_blend_factor: 11 #Default: 11, try 6 for more variation
|
| 68 |
+
step_size_factor: 0.75 #suggested value (0.8) to avoid oversmoothing
|
| 69 |
+
noise_scale_factor: 0.95 #suggested value (0.9) to add more variation
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Enables global randomization.
|
| 73 |
+
# If true, all parameters are randomized within specified min/max ranges.
|
| 74 |
+
# If false, individual parameters with _rand flags set to true will still be randomized.
|
| 75 |
+
randomize: false
|
| 76 |
+
|
| 77 |
+
#Sigma values typically start very small. Lowering this could allow more gradual noise reduction. Too large would overwhelm the process.
|
| 78 |
+
sigma_min_rand: false
|
| 79 |
+
sigma_min_rand_min: 0.001
|
| 80 |
+
sigma_min_rand_max: 0.05
|
| 81 |
+
|
| 82 |
+
#Sigma max controls the upper limit of the noise. A lower minimum could allow faster convergence, while a higher max gives more flexibility for noisier images.
|
| 83 |
+
sigma_max_rand: false
|
| 84 |
+
sigma_max_rand_min: 22
|
| 85 |
+
sigma_max_rand_max: 48
|
| 86 |
+
|
| 87 |
+
#Start blend controls how strongly Karras and Exponential are blended at the start. A slightly lower value introduces more variety in the blending at the beginning.
|
| 88 |
+
start_blend_rand: false
|
| 89 |
+
start_blend_rand_min: 0.05
|
| 90 |
+
start_blend_rand_max: 0.2
|
| 91 |
+
|
| 92 |
+
# End blend affects how much the blending changes towards the end. Increasing the upper limit would allow more variation.
|
| 93 |
+
end_blend_rand: false
|
| 94 |
+
end_blend_rand_min: 0.4
|
| 95 |
+
end_blend_rand_max: 0.6
|
| 96 |
+
|
| 97 |
+
# Sharpness controls detail retention. You wouldn’t want to lower it too much, as it might lose detail.
|
| 98 |
+
sharpness_rand: false
|
| 99 |
+
sharpness_rand_min: 0.85
|
| 100 |
+
sharpness_rand_max: 1.0
|
| 101 |
+
|
| 102 |
+
#A smaller early stopping threshold could lead to earlier stopping if the changes between sigma steps become too small, while the upper value would prevent early stopping until larger changes occur.
|
| 103 |
+
early_stopping_rand: false
|
| 104 |
+
early_stopping_rand_min: 0.001
|
| 105 |
+
early_stopping_rand_max: 0.02
|
| 106 |
+
|
| 107 |
+
#Update intervals affect how frequently blending factors are updated. More frequent updates allow more flexibility in blending.
|
| 108 |
+
update_interval_rand: false
|
| 109 |
+
update_interval_rand_min: 5
|
| 110 |
+
update_interval_rand_max: 10
|
| 111 |
+
|
| 112 |
+
# The initial step size defines how large the steps are at the start. A slightly smaller value introduces more gradual transitions.
|
| 113 |
+
initial_step_rand: false
|
| 114 |
+
initial_step_rand_min: 0.7
|
| 115 |
+
initial_step_rand_max: 1.0
|
| 116 |
+
|
| 117 |
+
# The final step size defines how small the steps become towards the end. A slightly larger range gives more control over the final convergence.
|
| 118 |
+
final_step_rand: false
|
| 119 |
+
final_step_rand_min: 0.1
|
| 120 |
+
final_step_rand_max: 0.3
|
| 121 |
+
|
| 122 |
+
#Initial noise scale defines how much noise to introduce initially. Larger values make the process start with more randomness, while smaller values keep it controlled.
|
| 123 |
+
initial_noise_rand: false
|
| 124 |
+
initial_noise_rand_min: 1.0
|
| 125 |
+
initial_noise_rand_max: 1.5
|
| 126 |
+
|
| 127 |
+
# Final noise scale affects how much noise is reduced at the end. A lower minimum allows more noise to persist, while a higher maximum ensures full convergence.
|
| 128 |
+
final_noise_rand: false
|
| 129 |
+
final_noise_rand_min: 0.6
|
| 130 |
+
final_noise_rand_max: 1.0
|
| 131 |
+
|
| 132 |
+
#The smooth blend factor controls how aggressively the blending is smoothed. Lower values allow more abrupt blending changes, while higher values give smoother transitions.
|
| 133 |
+
smooth_blend_factor_rand: false
|
| 134 |
+
smooth_blend_factor_rand_min: 6
|
| 135 |
+
smooth_blend_factor_rand_max: 11
|
| 136 |
+
|
| 137 |
+
#Step size factor adjusts the step size dynamically to avoid oversmoothing. A lower minimum increases variety, while a higher max provides smoother results.
|
| 138 |
+
step_size_factor_rand: false
|
| 139 |
+
step_size_factor_rand_min: 0.65
|
| 140 |
+
step_size_factor_rand_max: 0.85
|
| 141 |
+
|
| 142 |
+
# Noise scale factor controls how noise is scaled throughout the steps. A slightly lower minimum adds more variety, while keeping the maximum value near the suggested ensures more uniform results.
|
| 143 |
+
noise_scale_factor_rand: false
|
| 144 |
+
noise_scale_factor_rand_min: 0.75
|
| 145 |
+
noise_scale_factor_rand_max: 0.95
|
| 146 |
+
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/alternating soft_hard decay.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler:
|
| 3 |
+
weight: 0.5
|
| 4 |
+
decay_pattern: 'geometric'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
euler_advanced:
|
| 9 |
+
weight: 0.4
|
| 10 |
+
decay_pattern: 'harmonic'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
geometric:
|
| 15 |
+
weight: 0.7
|
| 16 |
+
decay_pattern: 'exponential'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
| 19 |
+
|
| 20 |
+
harmonic:
|
| 21 |
+
weight: 0.3
|
| 22 |
+
decay_pattern: 'linear'
|
| 23 |
+
decay_mode: 'blend'
|
| 24 |
+
tail_steps: 1
|
| 25 |
+
|
| 26 |
+
logarithmic:
|
| 27 |
+
weight: 0.6
|
| 28 |
+
decay_pattern: 'fractional'
|
| 29 |
+
decay_mode: 'blend'
|
| 30 |
+
tail_steps: 1
|
| 31 |
+
|
| 32 |
+
karras:
|
| 33 |
+
weight: 0.8
|
| 34 |
+
decay_pattern: 'exponential'
|
| 35 |
+
decay_mode: 'blend'
|
| 36 |
+
tail_steps: 1
|
| 37 |
+
|
| 38 |
+
exponential:
|
| 39 |
+
weight: 0.2
|
| 40 |
+
decay_pattern: 'logarithmic'
|
| 41 |
+
decay_mode: 'blend'
|
| 42 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/anime_1.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler_advanced:
|
| 3 |
+
weight: 0.5
|
| 4 |
+
decay_pattern: 'harmonic'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
harmonic:
|
| 9 |
+
weight: 0.7
|
| 10 |
+
decay_pattern: 'logarithmic'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
karras:
|
| 15 |
+
weight: 0.6
|
| 16 |
+
decay_pattern: 'fractional'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/cross_style_safe.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler_advanced:
|
| 3 |
+
weight: 0.4
|
| 4 |
+
decay_pattern: 'logarithmic'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
harmonic:
|
| 9 |
+
weight: 0.4
|
| 10 |
+
decay_pattern: 'linear'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
karras:
|
| 15 |
+
weight: 0.5
|
| 16 |
+
decay_pattern: 'exponential'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/front_loaded geometric.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler:
|
| 3 |
+
weight: 0.6
|
| 4 |
+
decay_pattern: 'geometric'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
euler_advanced:
|
| 9 |
+
weight: 0.7
|
| 10 |
+
decay_pattern: 'geometric'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
geometric:
|
| 15 |
+
weight: 0.8
|
| 16 |
+
decay_pattern: 'geometric'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
| 19 |
+
|
| 20 |
+
harmonic:
|
| 21 |
+
weight: 0.4
|
| 22 |
+
decay_pattern: 'harmonic'
|
| 23 |
+
decay_mode: 'blend'
|
| 24 |
+
tail_steps: 1
|
| 25 |
+
|
| 26 |
+
logarithmic:
|
| 27 |
+
weight: 0.3
|
| 28 |
+
decay_pattern: 'logarithmic'
|
| 29 |
+
decay_mode: 'blend'
|
| 30 |
+
tail_steps: 1
|
| 31 |
+
|
| 32 |
+
karras:
|
| 33 |
+
weight: 0.5
|
| 34 |
+
decay_pattern: 'fractional'
|
| 35 |
+
decay_mode: 'blend'
|
| 36 |
+
tail_steps: 1
|
| 37 |
+
|
| 38 |
+
exponential:
|
| 39 |
+
weight: 0.2
|
| 40 |
+
decay_pattern: 'linear'
|
| 41 |
+
decay_mode: 'blend'
|
| 42 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/photo_realistic_1.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler:
|
| 3 |
+
weight: 0.6
|
| 4 |
+
decay_pattern: 'geometric'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
geometric:
|
| 9 |
+
weight: 0.7
|
| 10 |
+
decay_pattern: 'exponential'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
karras:
|
| 15 |
+
weight: 0.8
|
| 16 |
+
decay_pattern: 'geometric'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/progressive.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blend_methods:
|
| 2 |
+
euler:
|
| 3 |
+
weight: 0.2
|
| 4 |
+
decay_pattern: 'harmonic'
|
| 5 |
+
decay_mode: 'blend'
|
| 6 |
+
tail_steps: 1
|
| 7 |
+
|
| 8 |
+
euler_advanced:
|
| 9 |
+
weight: 0.2
|
| 10 |
+
decay_pattern: 'logarithmic'
|
| 11 |
+
decay_mode: 'blend'
|
| 12 |
+
tail_steps: 1
|
| 13 |
+
|
| 14 |
+
geometric:
|
| 15 |
+
weight: 0.3
|
| 16 |
+
decay_pattern: 'linear'
|
| 17 |
+
decay_mode: 'blend'
|
| 18 |
+
tail_steps: 1
|
| 19 |
+
|
| 20 |
+
harmonic:
|
| 21 |
+
weight: 0.4
|
| 22 |
+
decay_pattern: 'fractional'
|
| 23 |
+
decay_mode: 'blend'
|
| 24 |
+
tail_steps: 1
|
| 25 |
+
|
| 26 |
+
logarithmic:
|
| 27 |
+
weight: 0.5
|
| 28 |
+
decay_pattern: 'geometric'
|
| 29 |
+
decay_mode: 'blend'
|
| 30 |
+
tail_steps: 1
|
| 31 |
+
|
| 32 |
+
karras:
|
| 33 |
+
weight: 0.6
|
| 34 |
+
decay_pattern: 'exponential'
|
| 35 |
+
decay_mode: 'blend'
|
| 36 |
+
tail_steps: 1
|
| 37 |
+
|
| 38 |
+
exponential:
|
| 39 |
+
weight: 0.7
|
| 40 |
+
decay_pattern: 'extrapolate'
|
| 41 |
+
decay_mode: 'blend'
|
| 42 |
+
tail_steps: 1
|
sd_simple_kes_v3_fix?/kes_config/user_config.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================================================
|
| 2 |
+
# ПОЛНЫЙ КОНФИГ (БЕЗ ЗАВИСИМОСТИ ОТ DEFAULT)
|
| 3 |
+
# ========================================================
|
| 4 |
+
|
| 5 |
+
# --- 1. ГЛАВНЫЕ НАСТРОЙКИ СМЕШИВАНИЯ ---
|
| 6 |
+
# Используем "default", чтобы избежать деления на ноль.
|
| 7 |
+
# Это смешивает Karras и Exponential.
|
| 8 |
+
blending_mode: "default"
|
| 9 |
+
|
| 10 |
+
# Управление плавностью перехода между Karras и Exponential.
|
| 11 |
+
# Значение 3 делает переход очень мягким.
|
| 12 |
+
smooth_blend_factor: 9
|
| 13 |
+
|
| 14 |
+
# --- 2. ИСПРАВЛЕНИЕ АРТЕФАКТОВ (ГЛАВНОЕ) ---
|
| 15 |
+
# "linear" убирает "ступеньки" на линиях. (Было "sigmoid").
|
| 16 |
+
step_progress_mode: "exponential"
|
| 17 |
+
|
| 18 |
+
# --- 3. НАСТРОЙКИ ШУМА И ШАГОВ (ВАШ "ХАРАКТЕР") ---
|
| 19 |
+
# Эти числа взяты из вашего старого конфига, чтобы шедулер был "нескучным".
|
| 20 |
+
|
| 21 |
+
# Размеры шагов (динамика скорости)
|
| 22 |
+
initial_step_size: 0.9
|
| 23 |
+
final_step_size: 0.20
|
| 24 |
+
step_size_factor: 0.80814932869181
|
| 25 |
+
|
| 26 |
+
# Масштаб шума (динамика детализации)
|
| 27 |
+
initial_noise_scale: 1.25
|
| 28 |
+
final_noise_scale: 0.80
|
| 29 |
+
noise_scale_factor: 0.8113992828873163
|
| 30 |
+
|
| 31 |
+
# --- 4. НАСТРОЙКИ ГРАНИЦ ШУМА (SIGMA) ---
|
| 32 |
+
# Ваши точные значения.
|
| 33 |
+
sigma_min: 0.13757067353874633
|
| 34 |
+
sigma_max: 47.95768510805332
|
| 35 |
+
rho: 7.959565031107985
|
| 36 |
+
|
| 37 |
+
# Авто-масштабирование (можно оставить включенным)
|
| 38 |
+
sigma_auto_enabled: true
|
| 39 |
+
sigma_auto_mode: "sigma_min"
|
| 40 |
+
sigma_scale_factor: 900
|
| 41 |
+
|
| 42 |
+
# --- 5. РЕЗКОСТЬ (SHARPNESS) ---
|
| 43 |
+
# 0.85 = средняя резкость.
|
| 44 |
+
sharpness: 0.85
|
| 45 |
+
|
| 46 |
+
# ВАЖНО: Применяем только к последним 10 шагам, чтобы не ломать картинку.
|
| 47 |
+
sharpen_mode: "last_n"
|
| 48 |
+
sharpen_last_n_steps: 10
|
| 49 |
+
sharpen_variance_threshold: 0.01
|
| 50 |
+
|
| 51 |
+
# --- 6. СМЕШИВАНИЕ (ТОЛЬКО ДЛЯ СПРАВКИ) ---
|
| 52 |
+
# В режиме "default" эти веса игнорируются, но параметры хвостов важны.
|
| 53 |
+
blend_methods:
|
| 54 |
+
karras:
|
| 55 |
+
weight: 1.0
|
| 56 |
+
decay_pattern: 'zero'
|
| 57 |
+
decay_mode: 'blend'
|
| 58 |
+
tail_steps: 1
|
| 59 |
+
exponential:
|
| 60 |
+
weight: 1.0
|
| 61 |
+
decay_pattern: 'zero'
|
| 62 |
+
decay_mode: 'blend'
|
| 63 |
+
tail_steps: 1
|
| 64 |
+
|
| 65 |
+
# --- 7. СИСТЕМНЫЕ НАСТРОЙКИ ---
|
| 66 |
+
debug: true
|
| 67 |
+
device: "cuda"
|
| 68 |
+
skip_prepass: true # Отключаем пре-проход для скорости
|
| 69 |
+
global_randomize: false
|
| 70 |
+
|
| 71 |
+
# Отключаем рандомизацию отдельных параметров для стабильности
|
| 72 |
+
rho_rand: false
|
| 73 |
+
sigma_min_rand: false
|
| 74 |
+
sigma_max_rand: false
|
| 75 |
+
start_blend_rand: false
|
| 76 |
+
end_blend_rand: false
|
| 77 |
+
sharpness_rand: false
|
| 78 |
+
initial_step_size_rand: false
|
| 79 |
+
final_step_size_rand: false
|
| 80 |
+
initial_noise_scale_rand: false
|
| 81 |
+
final_noise_scale_rand: false
|
sd_simple_kes_v3_fix?/plot_sigma_sequence.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def plot_sigma_sequence(sigs, stopping_index, log_filename, save_directory="modules/sd_simple_kes_v3/image_generation_data", show_plot=False):
|
| 6 |
+
"""
|
| 7 |
+
Plot the sigma sequence and mark the early stopping point.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
- sigs: The sigma tensor or numpy array (can be truncated if stopping early).
|
| 11 |
+
- stopping_index: The step index where early stopping was triggered.
|
| 12 |
+
- log_filename: The filename of the generation log (used to match the graph name).
|
| 13 |
+
- save_directory: The folder where the plot should be saved.
|
| 14 |
+
- show_plot: Set to True to display the plot interactively.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
# Extract base name to match log filename
|
| 18 |
+
base_filename = os.path.splitext(os.path.basename(log_filename))[0]
|
| 19 |
+
graph_filename = f"{base_filename}_sigma_plot.png"
|
| 20 |
+
graph_path = os.path.join(save_directory, graph_filename)
|
| 21 |
+
|
| 22 |
+
# Prepare sigma sequence for plotting
|
| 23 |
+
sigs_np = sigs.cpu().numpy() if hasattr(sigs, 'cpu') else np.array(sigs)
|
| 24 |
+
x = np.arange(len(sigs_np))
|
| 25 |
+
|
| 26 |
+
# Plotting
|
| 27 |
+
plt.figure(figsize=(10, 6))
|
| 28 |
+
plt.plot(x, sigs_np, label='Sigma Sequence', marker='o')
|
| 29 |
+
plt.axvline(x=stopping_index, color='red', linestyle='--', label=f'Stopping Point: {stopping_index}')
|
| 30 |
+
plt.xlabel('Step Index')
|
| 31 |
+
plt.ylabel('Sigma Value')
|
| 32 |
+
plt.title('Sigma Sequence with Early Stopping Point')
|
| 33 |
+
plt.legend()
|
| 34 |
+
plt.grid(True)
|
| 35 |
+
plt.tight_layout()
|
| 36 |
+
plt.savefig(graph_path)
|
| 37 |
+
|
| 38 |
+
if show_plot:
|
| 39 |
+
plt.show()
|
| 40 |
+
|
| 41 |
+
plt.close()
|
| 42 |
+
return graph_path
|
sd_simple_kes_v3_fix?/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pyyaml
|
| 3 |
+
numpy
|
| 4 |
+
matplotlib
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/euler_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/exponential_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/geometric_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/harmonic_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/karras_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/logarithmic_advanced_scheduler.cpython-310.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/__pycache__/shared.cpython-310.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
sd_simple_kes_v3_fix?/schedulers/euler_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def get_sigmas_euler(steps, sigma_min, sigma_max, device='cpu'):
|
| 6 |
+
"""
|
| 7 |
+
Sigma schedule designed to work well with Euler sampling.
|
| 8 |
+
Logarithmic spacing with a smooth transition.
|
| 9 |
+
"""
|
| 10 |
+
def _to_tensor(val, device):
|
| 11 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 12 |
+
|
| 13 |
+
# Convert sigma_min and sigma_max to tensors safely
|
| 14 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 15 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 16 |
+
|
| 17 |
+
sigmas = torch.exp(torch.linspace(math.log(sigma_max.item()), math.log(sigma_min.item()), steps, device=device))
|
| 18 |
+
|
| 19 |
+
tails = None
|
| 20 |
+
decay = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
return tails, decay, sigmas
|
| 24 |
+
|
| 25 |
+
def get_sigmas_euler_advanced(steps, sigma_min, sigma_max, device='cpu', blend_factor=0.5, decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 26 |
+
def _to_tensor(val, device):
|
| 27 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 28 |
+
ramp = torch.linspace(0, 1, steps, device=device)
|
| 29 |
+
sigmas_exp = torch.exp(torch.linspace(math.log(sigma_max), math.log(sigma_min), steps, device=device))
|
| 30 |
+
sigmas_karras = (sigma_max ** (1 / 7) + ramp * (sigma_min ** (1 / 7) - sigma_max ** (1 / 7))) ** 7
|
| 31 |
+
|
| 32 |
+
# Blend Karras and Exponential schedules for a hybrid Euler-friendly progression
|
| 33 |
+
sigmas = (1 - blend_factor) * sigmas_exp + blend_factor * sigmas_karras
|
| 34 |
+
tails = None
|
| 35 |
+
decay = None
|
| 36 |
+
|
| 37 |
+
if decay_pattern:
|
| 38 |
+
if decay_pattern in valid_decay_patterns:
|
| 39 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 40 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 41 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 42 |
+
if decay_mode:
|
| 43 |
+
if decay_mode in valid_decay_modes:
|
| 44 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 45 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 46 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 47 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 48 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 49 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 50 |
+
elif decay_mode not in valid_decay_modes:
|
| 51 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 52 |
+
|
| 53 |
+
return tails, decay, sigmas
|
sd_simple_kes_v3_fix?/schedulers/exponential_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_sigmas_exponential(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 7 |
+
"""Constructs an exponential noise schedule."""
|
| 8 |
+
|
| 9 |
+
def _to_tensor(val, device):
|
| 10 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 11 |
+
|
| 12 |
+
# Convert sigma_min and sigma_max to tensors safely
|
| 13 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 14 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 15 |
+
|
| 16 |
+
tail_steps = tail_steps or 5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Exponential progression (correct)
|
| 20 |
+
sigmas = torch.linspace(math.log(sigma_max.item()), math.log(sigma_min.item()), steps, device=device).exp()
|
| 21 |
+
|
| 22 |
+
tails = None
|
| 23 |
+
decay = None
|
| 24 |
+
|
| 25 |
+
if decay_pattern:
|
| 26 |
+
if decay_pattern in valid_decay_patterns:
|
| 27 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 28 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 29 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 30 |
+
if decay_mode:
|
| 31 |
+
if decay_mode in valid_decay_modes:
|
| 32 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 33 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 34 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 35 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 36 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 37 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 38 |
+
elif decay_mode not in valid_decay_modes:
|
| 39 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 40 |
+
|
| 41 |
+
return tails, decay, sigmas
|
sd_simple_kes_v3_fix?/schedulers/geometric_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_sigmas_geometric(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 7 |
+
def _to_tensor(val, device):
|
| 8 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 9 |
+
|
| 10 |
+
# Convert sigma_min and sigma_max to tensors safely
|
| 11 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 12 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 13 |
+
|
| 14 |
+
sigmas = [sigma_max]
|
| 15 |
+
tail_steps = tail_steps or 5
|
| 16 |
+
|
| 17 |
+
for step in range(1, steps):
|
| 18 |
+
if len(sigmas) >= 2:
|
| 19 |
+
deltas = torch.abs(torch.tensor(sigmas[:-1]) - torch.tensor(sigmas[1:]))
|
| 20 |
+
avg_delta = torch.mean(deltas).item()
|
| 21 |
+
else:
|
| 22 |
+
avg_delta = sigmas[-1].item() * 0.1
|
| 23 |
+
|
| 24 |
+
last_sigma = sigmas[-1]
|
| 25 |
+
dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85) # Clamp for stability
|
| 26 |
+
|
| 27 |
+
next_sigma = max(last_sigma * dynamic_decay_rate, sigma_min.item())
|
| 28 |
+
sigmas.append(next_sigma)
|
| 29 |
+
|
| 30 |
+
sigmas = torch.tensor(sigmas, device=device)
|
| 31 |
+
|
| 32 |
+
tails = None
|
| 33 |
+
decay = None
|
| 34 |
+
|
| 35 |
+
if decay_pattern:
|
| 36 |
+
if decay_pattern in valid_decay_patterns:
|
| 37 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 38 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 39 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 40 |
+
if decay_mode:
|
| 41 |
+
if decay_mode in valid_decay_modes:
|
| 42 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 43 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 44 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 45 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 46 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 47 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 48 |
+
elif decay_mode not in valid_decay_modes:
|
| 49 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 50 |
+
|
| 51 |
+
return tails, decay, sigmas
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
sd_simple_kes_v3_fix?/schedulers/harmonic_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_sigmas_harmonic(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 6 |
+
def _to_tensor(val, device):
|
| 7 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 8 |
+
|
| 9 |
+
# Convert sigma_min and sigma_max to tensors safely
|
| 10 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 11 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 12 |
+
|
| 13 |
+
sigmas = []
|
| 14 |
+
tail_steps = tail_steps or 5
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Harmonic decay calculation
|
| 18 |
+
for i in range(1, steps + 1):
|
| 19 |
+
next_sigma = sigma_max.item() / (i+1) # Harmonic sequence
|
| 20 |
+
next_sigma = max(next_sigma, sigma_min.item())
|
| 21 |
+
sigmas.append(next_sigma)
|
| 22 |
+
|
| 23 |
+
# Convert list to tensor
|
| 24 |
+
sigmas = torch.tensor(sigmas, device=device)
|
| 25 |
+
|
| 26 |
+
tails = None
|
| 27 |
+
decay = None
|
| 28 |
+
|
| 29 |
+
if decay_pattern:
|
| 30 |
+
if decay_pattern in valid_decay_patterns:
|
| 31 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 32 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 33 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 34 |
+
if decay_mode:
|
| 35 |
+
if decay_mode in valid_decay_modes:
|
| 36 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 37 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 38 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 39 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 40 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 41 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 42 |
+
elif decay_mode not in valid_decay_modes:
|
| 43 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 44 |
+
|
| 45 |
+
return tails, decay, sigmas
|
sd_simple_kes_v3_fix?/schedulers/karras_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import linspace, tensor
|
| 3 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device='cpu', decay_pattern=None, decay_mode = None, tail_steps=None):
|
| 7 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 8 |
+
ramp = linspace(0, 1, steps, device=device)
|
| 9 |
+
tail_steps = tail_steps or 5
|
| 10 |
+
def _to_tensor(val, device):
|
| 11 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 12 |
+
|
| 13 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 14 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 15 |
+
|
| 16 |
+
min_inv_rho = sigma_min.item() ** (1 / rho)
|
| 17 |
+
max_inv_rho = sigma_max.item() ** (1 / rho)
|
| 18 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 19 |
+
|
| 20 |
+
tails = None
|
| 21 |
+
decay = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if decay_pattern:
|
| 25 |
+
if decay_pattern in valid_decay_patterns:
|
| 26 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 27 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 28 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 29 |
+
if decay_mode:
|
| 30 |
+
if decay_mode in valid_decay_modes:
|
| 31 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 32 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 33 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 34 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 35 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 36 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 37 |
+
elif decay_mode not in valid_decay_modes:
|
| 38 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 39 |
+
|
| 40 |
+
return tails, decay, sigmas
|
sd_simple_kes_v3_fix?/schedulers/logarithmic_advanced_scheduler.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_sigmas_logarithmic(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 8 |
+
def _to_tensor(val, device):
|
| 9 |
+
return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
|
| 10 |
+
|
| 11 |
+
# Convert sigma_min and sigma_max to tensors safely
|
| 12 |
+
sigma_min = _to_tensor(sigma_min, device)
|
| 13 |
+
sigma_max = _to_tensor(sigma_max, device)
|
| 14 |
+
|
| 15 |
+
sigmas = []
|
| 16 |
+
tail_steps = tail_steps or 5
|
| 17 |
+
|
| 18 |
+
# Build the sigma list
|
| 19 |
+
for i in range(1, steps + 1):
|
| 20 |
+
next_sigma = sigma_max.item() - (math.log(i + 1) / math.log(steps + 1)) * (sigma_max.item() - sigma_min.item())
|
| 21 |
+
next_sigma = max(next_sigma, sigma_min.item())
|
| 22 |
+
sigmas.append(next_sigma)
|
| 23 |
+
|
| 24 |
+
# Convert list to tensor on the correct device
|
| 25 |
+
sigmas = torch.tensor(sigmas, device=device)
|
| 26 |
+
|
| 27 |
+
tails = None
|
| 28 |
+
decay = None
|
| 29 |
+
|
| 30 |
+
if decay_pattern:
|
| 31 |
+
if decay_pattern in valid_decay_patterns:
|
| 32 |
+
tails = apply_last_tail(sigmas, device, decay_pattern)
|
| 33 |
+
elif decay_pattern not in valid_decay_patterns:
|
| 34 |
+
print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
|
| 35 |
+
if decay_mode:
|
| 36 |
+
if decay_mode in valid_decay_modes:
|
| 37 |
+
if decay_mode == 'append': # <--- ИСПРАВЛЕНО
|
| 38 |
+
sigmas = apply_decay_tail(sigmas, device, decay_pattern)
|
| 39 |
+
elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
|
| 40 |
+
sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 41 |
+
elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
|
| 42 |
+
sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
|
| 43 |
+
elif decay_mode not in valid_decay_modes:
|
| 44 |
+
print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
|
| 45 |
+
|
| 46 |
+
return tails, decay, sigmas
|
sd_simple_kes_v3_fix?/schedulers/shared.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import linspace, tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
valid_decay_patterns = [
|
| 7 |
+
'zero', 'geometric', 'harmonic', 'logarithmic',
|
| 8 |
+
'extrapolate', 'fractional', 'exponential', 'linear'
|
| 9 |
+
]
|
| 10 |
+
valid_decay_modes = [
|
| 11 |
+
'append',
|
| 12 |
+
'blend',
|
| 13 |
+
'replace'
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
def append_zero(x):
|
| 17 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def apply_last_tail(sigmas, device, decay_pattern='zero'):
|
| 21 |
+
"""
|
| 22 |
+
Applies a single final sigma step using context-aware dynamic decay based on the sequence trend.
|
| 23 |
+
"""
|
| 24 |
+
if decay_pattern == 'zero':
|
| 25 |
+
return append_zero(sigmas)
|
| 26 |
+
|
| 27 |
+
last_sigma = sigmas[-1]
|
| 28 |
+
|
| 29 |
+
if len(sigmas) >= 2:
|
| 30 |
+
deltas = torch.abs(sigmas[:-1] - sigmas[1:])
|
| 31 |
+
avg_delta = torch.mean(deltas).item()
|
| 32 |
+
else:
|
| 33 |
+
avg_delta = last_sigma.item() * 0.1
|
| 34 |
+
|
| 35 |
+
if decay_pattern == 'geometric':
|
| 36 |
+
dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85)
|
| 37 |
+
next_sigma = max(last_sigma * dynamic_decay_rate, 1e-5)
|
| 38 |
+
|
| 39 |
+
elif decay_pattern == 'harmonic':
|
| 40 |
+
next_sigma = max(last_sigma - avg_delta, 1e-5)
|
| 41 |
+
|
| 42 |
+
elif decay_pattern == 'logarithmic':
|
| 43 |
+
next_sigma = max(last_sigma - (avg_delta / math.log(len(sigmas) + 2)), 1e-5)
|
| 44 |
+
|
| 45 |
+
elif decay_pattern == 'extrapolate':
|
| 46 |
+
if len(sigmas) >= 2:
|
| 47 |
+
last_delta = sigmas[-2] - sigmas[-1]
|
| 48 |
+
else:
|
| 49 |
+
last_delta = avg_delta
|
| 50 |
+
next_sigma = max(last_sigma - last_delta, 1e-5)
|
| 51 |
+
|
| 52 |
+
elif decay_pattern == 'fractional':
|
| 53 |
+
next_sigma = max(last_sigma * 0.1, 1e-5)
|
| 54 |
+
|
| 55 |
+
elif decay_pattern == 'exponential':
|
| 56 |
+
next_sigma = max(last_sigma * math.exp(-avg_delta), 1e-5)
|
| 57 |
+
|
| 58 |
+
elif decay_pattern == 'linear':
|
| 59 |
+
next_sigma = max(last_sigma - (avg_delta * 0.5), 1e-5)
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
|
| 63 |
+
|
| 64 |
+
tail_tensor = torch.tensor([next_sigma], device=device)
|
| 65 |
+
return torch.cat([sigmas, tail_tensor])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def apply_decay_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
|
| 69 |
+
"""
|
| 70 |
+
Applies a context-aware multi-step decay tail based on the progression of the entire sigma sequence.
|
| 71 |
+
"""
|
| 72 |
+
tail = []
|
| 73 |
+
|
| 74 |
+
if decay_pattern == 'zero':
|
| 75 |
+
return append_zero(sigmas)
|
| 76 |
+
|
| 77 |
+
if len(sigmas) >= 2:
|
| 78 |
+
deltas = torch.abs(sigmas[:-1] - sigmas[1:])
|
| 79 |
+
avg_delta = torch.mean(deltas).item()
|
| 80 |
+
else:
|
| 81 |
+
avg_delta = sigmas[-1].item() * 0.1
|
| 82 |
+
|
| 83 |
+
last_sigma = sigmas[-1]
|
| 84 |
+
|
| 85 |
+
for step in range(tail_steps):
|
| 86 |
+
if decay_pattern == 'geometric':
|
| 87 |
+
dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85)
|
| 88 |
+
next_sigma = max(last_sigma * dynamic_decay_rate, 1e-5)
|
| 89 |
+
|
| 90 |
+
elif decay_pattern == 'harmonic':
|
| 91 |
+
next_sigma = max(last_sigma - (avg_delta / (step + 1)), 1e-5)
|
| 92 |
+
|
| 93 |
+
elif decay_pattern == 'logarithmic':
|
| 94 |
+
next_sigma = max(last_sigma - (avg_delta / math.log(len(sigmas) + step + 2)), 1e-5)
|
| 95 |
+
|
| 96 |
+
elif decay_pattern == 'extrapolate':
|
| 97 |
+
if len(sigmas) >= 2:
|
| 98 |
+
last_delta = sigmas[-2] - sigmas[-1]
|
| 99 |
+
else:
|
| 100 |
+
last_delta = avg_delta
|
| 101 |
+
next_sigma = max(last_sigma - last_delta, 1e-5)
|
| 102 |
+
|
| 103 |
+
elif decay_pattern == 'fractional':
|
| 104 |
+
next_sigma = max(last_sigma * 0.1, 1e-5)
|
| 105 |
+
|
| 106 |
+
elif decay_pattern == 'exponential':
|
| 107 |
+
next_sigma = max(last_sigma * math.exp(-avg_delta * (step + 1)), 1e-5)
|
| 108 |
+
|
| 109 |
+
elif decay_pattern == 'linear':
|
| 110 |
+
next_sigma = max(last_sigma - (avg_delta * 0.5 * (step + 1)), 1e-5)
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
|
| 114 |
+
|
| 115 |
+
tail.append(next_sigma)
|
| 116 |
+
last_sigma = next_sigma
|
| 117 |
+
|
| 118 |
+
tail_tensor = torch.tensor(tail, device=device)
|
| 119 |
+
return torch.cat([sigmas, tail_tensor])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def blend_decay_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
|
| 125 |
+
"""
|
| 126 |
+
Applies in-place blending on the last N steps using decay patterns.
|
| 127 |
+
"""
|
| 128 |
+
for i in range(1, tail_steps + 1):
|
| 129 |
+
idx = -i
|
| 130 |
+
base_sigma = sigmas[idx]
|
| 131 |
+
|
| 132 |
+
if len(sigmas) >= 2:
|
| 133 |
+
deltas = torch.abs(sigmas[:-1] - sigmas[1:])
|
| 134 |
+
avg_delta = torch.mean(deltas).item()
|
| 135 |
+
else:
|
| 136 |
+
avg_delta = base_sigma.item() * 0.1
|
| 137 |
+
if decay_pattern == 'zero':
|
| 138 |
+
sigmas[idx] = 0.0
|
| 139 |
+
continue
|
| 140 |
+
if decay_pattern == 'geometric':
|
| 141 |
+
dynamic_decay_rate = max(1 - (avg_delta / (base_sigma + 1e-5)), 0.85)
|
| 142 |
+
new_sigma = max(base_sigma * dynamic_decay_rate, 1e-5)
|
| 143 |
+
|
| 144 |
+
elif decay_pattern == 'harmonic':
|
| 145 |
+
new_sigma = max(base_sigma - (avg_delta / i), 1e-5)
|
| 146 |
+
|
| 147 |
+
elif decay_pattern == 'logarithmic':
|
| 148 |
+
new_sigma = max(base_sigma - (avg_delta / math.log(len(sigmas) + 2 - i)), 1e-5)
|
| 149 |
+
|
| 150 |
+
elif decay_pattern == 'extrapolate':
|
| 151 |
+
if len(sigmas) >= 2:
|
| 152 |
+
last_delta = sigmas[-2] - sigmas[-1]
|
| 153 |
+
else:
|
| 154 |
+
last_delta = avg_delta
|
| 155 |
+
new_sigma = max(base_sigma - last_delta, 1e-5)
|
| 156 |
+
|
| 157 |
+
elif decay_pattern == 'fractional':
|
| 158 |
+
new_sigma = max(base_sigma * 0.1, 1e-5)
|
| 159 |
+
|
| 160 |
+
elif decay_pattern == 'exponential':
|
| 161 |
+
new_sigma = max(base_sigma * math.exp(-avg_delta), 1e-5)
|
| 162 |
+
|
| 163 |
+
elif decay_pattern == 'linear':
|
| 164 |
+
new_sigma = max(base_sigma - (avg_delta * 0.5), 1e-5)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
|
| 168 |
+
|
| 169 |
+
sigmas[idx] = (base_sigma + new_sigma) / 2
|
| 170 |
+
|
| 171 |
+
return sigmas
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def replace_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
|
| 175 |
+
available_steps = len(sigmas)
|
| 176 |
+
|
| 177 |
+
# Clamp tail_steps to the number of available steps
|
| 178 |
+
if tail_steps >= available_steps:
|
| 179 |
+
print(f"[Replace Tail] Requested {tail_steps} steps, but only {available_steps} available. Adjusting to {available_steps - 1} steps.")
|
| 180 |
+
tail_steps = available_steps - 1 # Ensure we leave at least one sigma
|
| 181 |
+
|
| 182 |
+
sigmas = sigmas[:-tail_steps] # Remove the last N steps
|
| 183 |
+
return apply_decay_tail(sigmas, device, decay_pattern, tail_steps)
|
| 184 |
+
|
sd_simple_kes_v3_fix?/simple_kes_v3.py
ADDED
|
@@ -0,0 +1,2002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules.sd_simple_kes_v3.get_sigmas import scheduler_registry
|
| 2 |
+
from modules.sd_simple_kes_v3.validate_config import validate_config
|
| 3 |
+
from modules.sd_simple_kes_v3.plot_sigma_sequence import plot_sigma_sequence
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import yaml
|
| 9 |
+
import random
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import warnings
|
| 12 |
+
import math
|
| 13 |
+
from typing import Optional
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import hashlib
|
| 17 |
+
import glob
|
| 18 |
+
import re
|
| 19 |
+
import inspect
|
| 20 |
+
import copy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def simple_kes_scheduler_v3(n: int, sigma_min: float, sigma_max: float, device: torch.device) -> torch.Tensor:
|
| 24 |
+
scheduler = SimpleKEScheduler(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
|
| 25 |
+
return scheduler()
|
| 26 |
+
|
| 27 |
+
class SharedLogger:
|
| 28 |
+
def __init__(self, debug=False):
|
| 29 |
+
self.debug = debug
|
| 30 |
+
self.log_buffer = []
|
| 31 |
+
self.prepass_log_buffer=[]
|
| 32 |
+
|
| 33 |
+
def log(self, message):
|
| 34 |
+
if self.debug:
|
| 35 |
+
self.log_buffer.append(message)
|
| 36 |
+
def prepass_log(self, message):
|
| 37 |
+
if self.debug:
|
| 38 |
+
self.prepass_log_buffer.append(message)
|
| 39 |
+
|
| 40 |
+
class SimpleKEScheduler:
|
| 41 |
+
"""
|
| 42 |
+
SimpleKEScheduler
|
| 43 |
+
------------------
|
| 44 |
+
A hybrid scheduler that combines Karras-style sigma sampling
|
| 45 |
+
with exponential decay and blending controls. Supports parameterized
|
| 46 |
+
customization for use in advanced diffusion pipelines.
|
| 47 |
+
|
| 48 |
+
Parameters:
|
| 49 |
+
- steps (int): Number of inference steps.
|
| 50 |
+
- device (torch.device): Target device (e.g. 'cuda').
|
| 51 |
+
- config (dict): Scheduler-specific configuration options.
|
| 52 |
+
|
| 53 |
+
Usage:
|
| 54 |
+
scheduler = SimpleKEScheduler(steps=30, device='cuda', config=config_dict)
|
| 55 |
+
sigmas = scheduler.get_sigmas()
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, n: int, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, device: torch.device = "cpu", logger=None, **kwargs)->torch.Tensor:
|
| 59 |
+
self.steps = n if n is not None else 10
|
| 60 |
+
self.original_steps = n
|
| 61 |
+
self.device = torch.device(device if isinstance(device, str) else device)
|
| 62 |
+
self.sigma_min = sigma_min
|
| 63 |
+
self.sigma_max = sigma_max
|
| 64 |
+
self.scheduler_registry = scheduler_registry
|
| 65 |
+
self.RANDOMIZATION_TYPE_ALIASES = {
|
| 66 |
+
'symmetric': 'symmetric', 'sym': 'symmetric', 's': 'symmetric',
|
| 67 |
+
'asymmetric': 'asymmetric', 'assym': 'asymmetric', 'a': 'asymmetric',
|
| 68 |
+
'logarithmic': 'logarithmic', 'log': 'logarithmic', 'l': 'logarithmic',
|
| 69 |
+
'exponential': 'exponential', 'exp': 'exponential', 'e': 'exponential'
|
| 70 |
+
}
|
| 71 |
+
self._config_schema = {
|
| 72 |
+
'min_visual_sigma': (int, 10),
|
| 73 |
+
'safety_minimum_stop_step': (int, 10),
|
| 74 |
+
'auto_tail_smoothing': (bool, False),
|
| 75 |
+
'auto_stabilization_sequence': (list, [
|
| 76 |
+
'smooth_interpolation', 'append_tail', 'blend_tail', 'apply_decay', 'progressive_decay'
|
| 77 |
+
]),
|
| 78 |
+
'sharpen_variance_threshold': (float, 0.01),
|
| 79 |
+
'sharpen_last_n_steps': (int, 10),
|
| 80 |
+
'decay_pattern': (str, 'zero'),
|
| 81 |
+
'sigma_save_subfolder': (str, 'saved_sigmas'),
|
| 82 |
+
'load_sigma_cache': (bool, False),
|
| 83 |
+
'save_sigma_cache': (bool, False),
|
| 84 |
+
'graph_save_directory': (str, 'modules/sd_simple_kes_v3/image_generation_data'),
|
| 85 |
+
'graph_save_enable': (bool, False),
|
| 86 |
+
'exp_power': (int, 2),
|
| 87 |
+
'recent_change_convergence_delta': (float, 0.02),
|
| 88 |
+
'sigma_variance_scale': (float, 0.05),
|
| 89 |
+
'allow_step_expansion': (bool, False),
|
| 90 |
+
'sharpen_mode': (str, 'full'),
|
| 91 |
+
'blend_midpoint': (float, 0.5),
|
| 92 |
+
'early_stopping_method': (str, 'mean'),
|
| 93 |
+
'save_prepass_sigmas': (bool, False),
|
| 94 |
+
'global_randomize': (bool, False),
|
| 95 |
+
'skip_prepass': (bool, False),
|
| 96 |
+
'load_prepass_sigmas': (bool, False)
|
| 97 |
+
}
|
| 98 |
+
self._overrides = kwargs.copy() #Temporarily hold overrides from kwargs
|
| 99 |
+
default_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes_v3", "kes_config", "default_config.yaml")))
|
| 100 |
+
self.default_config = self._load_config(default_config_path)
|
| 101 |
+
user_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes_v3", "kes_config", "user_config.yaml")))
|
| 102 |
+
self.user_config = self._load_config(user_config_path)
|
| 103 |
+
self.config_data = {**self.default_config, **self.user_config}
|
| 104 |
+
self.config = self.config_data.copy()
|
| 105 |
+
self.settings = self.config.copy()
|
| 106 |
+
for key, value in self.settings.items():
|
| 107 |
+
setattr(self, key, value)
|
| 108 |
+
if self.global_randomize:
|
| 109 |
+
self.apply_global_randomization()
|
| 110 |
+
self.re_randomizable_keys = [
|
| 111 |
+
"sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| 112 |
+
"early_stopping_threshold",
|
| 113 |
+
"initial_step_size", "final_step_size",
|
| 114 |
+
"initial_noise_scale", "final_noise_scale",
|
| 115 |
+
"smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| 116 |
+
]
|
| 117 |
+
for key in self.re_randomizable_keys:
|
| 118 |
+
value = self.settings.get(key)
|
| 119 |
+
if value is None:
|
| 120 |
+
raise KeyError(f"[KEScheduler] Missing required setting: {key}")
|
| 121 |
+
setattr(self, key, value)
|
| 122 |
+
self.debug = self.settings.get('debug', False)
|
| 123 |
+
#setup logging
|
| 124 |
+
logger = SharedLogger(debug=kwargs.get('debug', False))
|
| 125 |
+
self.logger=logger
|
| 126 |
+
self.log = self.logger.log
|
| 127 |
+
self.prepass_log = self.logger.prepass_log
|
| 128 |
+
self._validate_config_types()
|
| 129 |
+
validate_config(self.config, logger=self.logger)
|
| 130 |
+
# Apply overrides from kwargs if present
|
| 131 |
+
for k, v in self._overrides.items():
|
| 132 |
+
if k in self.settings:
|
| 133 |
+
self.settings[k] = v
|
| 134 |
+
setattr(self, k, v)
|
| 135 |
+
self.auto_mode_enabled = self.settings.get('auto_tail_smoothing', False)
|
| 136 |
+
self.initialize_generation_filename()
|
| 137 |
+
self.relative_converged = False
|
| 138 |
+
self.max_converged = False
|
| 139 |
+
self.delta_converged = False
|
| 140 |
+
self.early_stop_triggered = False
|
| 141 |
+
self.sigma_cache = {}
|
| 142 |
+
self.BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 143 |
+
self.cache_dir = os.path.join(self.BASE_DIR, 'cache')
|
| 144 |
+
self.sigma_save_folder = os.path.join(self.cache_dir, self.sigma_save_subfolder)
|
| 145 |
+
self.blend_method_dict = self.settings.get('blend_methods', {
|
| 146 |
+
'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
|
| 147 |
+
'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
|
| 148 |
+
})
|
| 149 |
+
self.blend_methods = list(self.blend_method_dict.keys())
|
| 150 |
+
self.blend_weights = [self.blend_method_dict[method]['weight'] for method in self.blend_methods]
|
| 151 |
+
self.loaded_sigmas = None
|
| 152 |
+
self.sigma_sequences = {}
|
| 153 |
+
|
| 154 |
+
self.schedule_type = None
|
| 155 |
+
self.suffix = None
|
| 156 |
+
self.ext = None
|
| 157 |
+
self._create_directories()
|
| 158 |
+
self._finalize_init()
|
| 159 |
+
|
| 160 |
+
def _create_directories(self):
|
| 161 |
+
|
| 162 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 163 |
+
os.makedirs(self.sigma_save_folder, exist_ok=True)
|
| 164 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 165 |
+
self.extras_log_filename = os.path.join(
|
| 166 |
+
self.settings.get('log_save_directory', 'modules/sd_simple_kes_v3/image_generation_data'),
|
| 167 |
+
f'all_extras_log_{timestamp}.txt'
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _finalize_init(self):
|
| 173 |
+
|
| 174 |
+
self.prepass_save_file = self.build_sigma_cache_filename(
|
| 175 |
+
steps=self.steps,
|
| 176 |
+
sigma_min=self.sigma_min,
|
| 177 |
+
sigma_max=self.sigma_max,
|
| 178 |
+
rho=self.rho,
|
| 179 |
+
schedule_type='karras',
|
| 180 |
+
decay_pattern=self.decay_pattern,
|
| 181 |
+
cache_dir=self.sigma_save_folder,
|
| 182 |
+
suffix='prepass',
|
| 183 |
+
ext = 'pt'
|
| 184 |
+
)
|
| 185 |
+
self.final_save_file = self.build_sigma_cache_filename(
|
| 186 |
+
steps=self.steps,
|
| 187 |
+
sigma_min=self.sigma_min,
|
| 188 |
+
sigma_max=self.sigma_max,
|
| 189 |
+
rho=self.rho,
|
| 190 |
+
schedule_type='karras',
|
| 191 |
+
decay_pattern=self.decay_pattern,
|
| 192 |
+
cache_dir=self.sigma_save_folder,
|
| 193 |
+
suffix='final',
|
| 194 |
+
ext = 'pt'
|
| 195 |
+
)
|
| 196 |
+
self.load_blend_method_sigmas()
|
| 197 |
+
def _load_config(self, config_path, **kwargs):
|
| 198 |
+
self.logger = SharedLogger(debug=kwargs.get('debug', False))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 203 |
+
user_config = yaml.safe_load(f)
|
| 204 |
+
return user_config or {} # Always return a dict, even if empty
|
| 205 |
+
except FileNotFoundError:
|
| 206 |
+
self.logger.log(f"Config file not found: {config_path}. Using empty config.")
|
| 207 |
+
return {}
|
| 208 |
+
except yaml.YAMLError as e:
|
| 209 |
+
self.logger.log(f"Error loading config file: {e}")
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _validate_config_types(self):
|
| 214 |
+
'''
|
| 215 |
+
Both corrects self.settings with an updated validated config, and also writes a corrected_user_config file with the correct types
|
| 216 |
+
'''
|
| 217 |
+
validated_settings = {}
|
| 218 |
+
corrected_lines = []
|
| 219 |
+
|
| 220 |
+
corrected_lines.append("# Corrected User Config (Invalid entries auto-corrected)\n")
|
| 221 |
+
|
| 222 |
+
for key, (expected_type, default_value) in self._config_schema.items():
|
| 223 |
+
value = self.settings.get(key, default_value)
|
| 224 |
+
if isinstance(value, expected_type):
|
| 225 |
+
validated_settings[key] = value
|
| 226 |
+
corrected_lines.append(f"{key}: {value}")
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
self.log(f"[Config Warning] Invalid type for '{key}': Expected {expected_type.__name__}, got {type(value).__name__}. Using default: {default_value}")
|
| 230 |
+
validated_settings[key] = default_value
|
| 231 |
+
corrected_lines.append(f"{key}: {default_value} # Invalid type: {type(value).__name__}, replaced with default")
|
| 232 |
+
|
| 233 |
+
# Save the corrected config with comments
|
| 234 |
+
with open('corrected_user_config.yaml', 'w', encoding='utf-8') as f:
|
| 235 |
+
f.write('\n'.join(corrected_lines))
|
| 236 |
+
|
| 237 |
+
self.settings.update(validated_settings)
|
| 238 |
+
|
| 239 |
+
for key, value in self.settings.items():
|
| 240 |
+
setattr(self, key, value)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _log_extras_to_file(self, all_extras):
|
| 244 |
+
"""
|
| 245 |
+
Logs the extras returned by each scheduler to a dedicated 'all_extras' log file.
|
| 246 |
+
|
| 247 |
+
This method iterates through the list of extras for each blend method and writes them
|
| 248 |
+
to a separate log file for easier tracking, debugging, and future analysis.
|
| 249 |
+
|
| 250 |
+
Parameters:
|
| 251 |
+
----------
|
| 252 |
+
all_extras : list
|
| 253 |
+
A list of extras returned by each scheduler, aligned with the blend_methods list.
|
| 254 |
+
Each item in the list corresponds to the extras provided by a specific scheduler.
|
| 255 |
+
|
| 256 |
+
Notes:
|
| 257 |
+
-----
|
| 258 |
+
- If extras are present, they are logged under their respective scheduler names.
|
| 259 |
+
- If extras contain complex objects, the method attempts to serialize them using JSON.
|
| 260 |
+
- Non-serializable extras are logged as raw text.
|
| 261 |
+
|
| 262 |
+
Purpose:
|
| 263 |
+
-------
|
| 264 |
+
This log file is intended for developers to track additional outputs that are not
|
| 265 |
+
directly part of the sigma, tails, or decay sequences but may be useful for diagnostics,
|
| 266 |
+
metadata, or advanced scheduler behaviors.
|
| 267 |
+
"""
|
| 268 |
+
with open(self.extras_log_filename, 'a', encoding='utf-8') as f:
|
| 269 |
+
f.write("\n=== New Scheduler Extras ===\n")
|
| 270 |
+
for method, extras in zip(self.blend_methods, all_extras):
|
| 271 |
+
if extras:
|
| 272 |
+
try:
|
| 273 |
+
f.write(f"\nScheduler: {method}\n")
|
| 274 |
+
f.write(json.dumps(extras, indent=2))
|
| 275 |
+
f.write("\n")
|
| 276 |
+
except TypeError:
|
| 277 |
+
f.write(f"\nScheduler: {method}\n")
|
| 278 |
+
f.write(f"Extras (non-serializable): {extras}\n")
|
| 279 |
+
f.write("\n============================\n")
|
| 280 |
+
def __call__(self):
|
| 281 |
+
# First pass: Run prepass to determine predicted_stop_step
|
| 282 |
+
if not self.skip_prepass:
|
| 283 |
+
self.prepass_compute_sigmas(
|
| 284 |
+
steps=self.steps,
|
| 285 |
+
sigma_min=self.sigma_min,
|
| 286 |
+
sigma_max=self.sigma_max,
|
| 287 |
+
rho=self.rho,
|
| 288 |
+
device=self.device,
|
| 289 |
+
skip_prepass=self.skip_prepass
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if self.load_prepass_sigmas:
|
| 294 |
+
self.generate_sigmas_schedule(mode='prepass')
|
| 295 |
+
|
| 296 |
+
if self.load_sigma_cache:
|
| 297 |
+
self.generate_sigmas_schedule(mode='final')
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
# Build sigma sequence directly (without prepass)
|
| 301 |
+
self.config_values()
|
| 302 |
+
self.generate_sigmas_schedule()
|
| 303 |
+
|
| 304 |
+
if self.blending_mode == 'default':
|
| 305 |
+
self.blend_sigma_sequence(
|
| 306 |
+
sigmas_karras= self.scheduler_registry.get('karras')(
|
| 307 |
+
steps=self.steps,
|
| 308 |
+
sigma_min=self.sigma_min,
|
| 309 |
+
sigma_max=self.sigma_max,
|
| 310 |
+
device=self.device,
|
| 311 |
+
decay_pattern=self.decay_pattern
|
| 312 |
+
)[2],
|
| 313 |
+
sigmas_exponential=self.scheduler_registry.get('exponential')(
|
| 314 |
+
steps=self.steps,
|
| 315 |
+
sigma_min=self.sigma_min,
|
| 316 |
+
sigma_max=self.sigma_max,
|
| 317 |
+
device=self.device,
|
| 318 |
+
decay_pattern=self.decay_pattern
|
| 319 |
+
)[2],
|
| 320 |
+
pre_pass=False,
|
| 321 |
+
blend_methods=self.blend_methods,
|
| 322 |
+
blend_weights=self.blend_weights
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
else:
|
| 326 |
+
# For multi-method blending
|
| 327 |
+
self.blend_sigma_sequence(
|
| 328 |
+
sigmas_karras=None, # Not used in non-default mode
|
| 329 |
+
sigmas_exponential=None, # Not used in non-default mode
|
| 330 |
+
pre_pass=False,
|
| 331 |
+
blend_methods=self.blend_methods,
|
| 332 |
+
blend_weights=self.blend_weights
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
sigmas = self.compute_sigmas(steps=self.steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, rho=self.rho, device=self.device)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# Safety checks
|
| 339 |
+
if torch.isnan(sigmas).any():
|
| 340 |
+
raise ValueError("[SimpleKEScheduler] NaN detected in sigmas")
|
| 341 |
+
if torch.isinf(sigmas).any():
|
| 342 |
+
raise ValueError("[SimpleKEScheduler] Inf detected in sigmas")
|
| 343 |
+
if (sigmas <= 0).all():
|
| 344 |
+
raise ValueError("[SimpleKEScheduler] All sigma values are <= 0")
|
| 345 |
+
if (sigmas > 1000).all():
|
| 346 |
+
raise ValueError("[SimpleKEScheduler] Sigma values are extremely large — might explode the model")
|
| 347 |
+
|
| 348 |
+
# Save logs to file
|
| 349 |
+
if self.debug:
|
| 350 |
+
self.save_generation_settings()
|
| 351 |
+
|
| 352 |
+
return sigmas
|
| 353 |
+
|
| 354 |
+
def _safe_sigma_loader(self, cache_key):
|
| 355 |
+
cache_folder = self.sigma_save_folder
|
| 356 |
+
|
| 357 |
+
# Check if the folder exists and has files
|
| 358 |
+
if not os.path.exists(cache_folder) or not os.listdir(cache_folder):
|
| 359 |
+
self.log(f"[Cache Check] Cache folder {cache_folder} is empty or missing. Skipping load.")
|
| 360 |
+
return None # Signal to recompute
|
| 361 |
+
|
| 362 |
+
# Check if matching file exists
|
| 363 |
+
matching_files = [f for f in os.listdir(cache_folder) if cache_key in f and f.endswith('.pt')]
|
| 364 |
+
|
| 365 |
+
if not matching_files:
|
| 366 |
+
self.log(f"[Cache Check] No matching cache file found for key: {cache_key}. Skipping load.")
|
| 367 |
+
return None # Signal to recompute
|
| 368 |
+
|
| 369 |
+
# If matching file found, load it
|
| 370 |
+
filename = os.path.join(cache_folder, matching_files[0])
|
| 371 |
+
self.log(f"[Cache Hit] Loading sigma cache from: {filename}")
|
| 372 |
+
loaded_data = torch.load(filename, map_location=self.device)
|
| 373 |
+
return loaded_data['sigma_values'].to(self.device)
|
| 374 |
+
|
| 375 |
+
def call_scheduler(self, method_name, *args, **kwargs):
|
| 376 |
+
sigma_sequence = getattr(self, f"sigmas_{method_name}")
|
| 377 |
+
if sigma_sequence is None:
|
| 378 |
+
self.log(f"No sigma sequence found for method: {method_name}")
|
| 379 |
+
return None
|
| 380 |
+
return sigma_sequence
|
| 381 |
+
|
| 382 |
+
def is_sigma_randomized(self):
|
| 383 |
+
return (
|
| 384 |
+
self.settings.get('sigma_min_rand', False) or
|
| 385 |
+
self.settings.get('sigma_max_rand', False) or
|
| 386 |
+
self.settings.get('rho_rand', False) or
|
| 387 |
+
self.settings.get('sigma_max_enable_randomization_type', False) or
|
| 388 |
+
self.settings.get('sigma_min_enable_randomization_type', False) or
|
| 389 |
+
self.settings.get('rho_enable_randomization_type', False)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def save_sigmas_as_csv(self, sigmas, filename):
|
| 394 |
+
np.savetxt(filename, sigmas.cpu().numpy(), delimiter=",")
|
| 395 |
+
|
| 396 |
+
def build_sigma_cache_filename(self, steps, sigma_min, sigma_max, rho=None, schedule_type='karras', decay_pattern='zero', cache_dir=r'modules\sd_simple_kes_v3\cache', suffix=None, ext = None or 'txt'):
|
| 397 |
+
if cache_dir is None:
|
| 398 |
+
cache_dir = r'modules\sd_simple_kes_v3\cache'
|
| 399 |
+
if schedule_type == 'karras':
|
| 400 |
+
base_filename = f'sigma_{schedule_type}_{steps}steps_rho{rho}_min{sigma_min}_max{sigma_max}_{decay_pattern}'
|
| 401 |
+
else:
|
| 402 |
+
base_filename = f'sigma_{schedule_type}_{steps}steps_min{sigma_min}_max{sigma_max}_{decay_pattern}'
|
| 403 |
+
|
| 404 |
+
# If a suffix is provided, versioning applies
|
| 405 |
+
if suffix:
|
| 406 |
+
base_filename += f'_{suffix}'
|
| 407 |
+
version = self.get_next_version_number(cache_dir, base_filename)
|
| 408 |
+
if ext:
|
| 409 |
+
version = self.get_next_version_number(cache_dir, base_filename, ext)
|
| 410 |
+
filename = f'{version:03d}_{base_filename}.{ext}'
|
| 411 |
+
else:
|
| 412 |
+
# No versioning if suffix is not provided
|
| 413 |
+
filename = f'{base_filename}.{ext}'
|
| 414 |
+
|
| 415 |
+
return os.path.join(cache_dir, filename)
|
| 416 |
+
|
| 417 |
+
def get_next_version_number(self, cache_dir, base_filename,ext=None):
|
| 418 |
+
pattern = os.path.join(cache_dir, f'*_{base_filename}')
|
| 419 |
+
if ext:
|
| 420 |
+
pattern= os.path.join(cache_dir, f'*_{base_filename}.{ext}')
|
| 421 |
+
existing_files = glob.glob(pattern)
|
| 422 |
+
|
| 423 |
+
version_numbers = []
|
| 424 |
+
for file in existing_files:
|
| 425 |
+
match = re.search(r'(\d{3})_' + re.escape(base_filename), os.path.basename(file))
|
| 426 |
+
if match:
|
| 427 |
+
version_numbers.append(int(match.group(1)))
|
| 428 |
+
|
| 429 |
+
if version_numbers:
|
| 430 |
+
return max(version_numbers) + 1
|
| 431 |
+
else:
|
| 432 |
+
return 1
|
| 433 |
+
|
| 434 |
+
def get_sigma_with_cache(self, steps, sigma_min, sigma_max, rho=7.0, device='cpu',
|
| 435 |
+
schedule_type='karras', decay_pattern=None, cache_dir=None, cache_file=None,
|
| 436 |
+
suffix=None, ext=None, mode=None, cache_key = None):
|
| 437 |
+
self.steps = steps
|
| 438 |
+
self.sigma_min = sigma_min
|
| 439 |
+
self.sigma_max = sigma_max
|
| 440 |
+
self.rho = rho
|
| 441 |
+
self.device = device
|
| 442 |
+
self.schedule_type = schedule_type
|
| 443 |
+
self.decay_pattern = decay_pattern
|
| 444 |
+
self.cache_dir = cache_dir
|
| 445 |
+
self.cache_file = cache_file
|
| 446 |
+
self.suffix = suffix
|
| 447 |
+
self.ext = ext
|
| 448 |
+
self.mode = mode
|
| 449 |
+
self.cache_key = cache_key
|
| 450 |
+
|
| 451 |
+
# Try to retrieve from in-memory cache first
|
| 452 |
+
cached_sigmas = self.get_sigma_from_cache(cache_key)
|
| 453 |
+
|
| 454 |
+
if cached_sigmas is not None:
|
| 455 |
+
return cached_sigmas
|
| 456 |
+
|
| 457 |
+
# If sigma is randomized → always generate new
|
| 458 |
+
if self.is_sigma_randomized():
|
| 459 |
+
_, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| 460 |
+
self.sigma_cache[cache_key] = sigmas
|
| 461 |
+
return sigmas
|
| 462 |
+
|
| 463 |
+
# If nothing is loaded yet → generate and cache
|
| 464 |
+
if self.loaded_sigmas is None:
|
| 465 |
+
_, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| 466 |
+
self.loaded_sigmas = sigmas
|
| 467 |
+
self.sigma_cache[cache_key] = sigmas
|
| 468 |
+
return sigmas
|
| 469 |
+
|
| 470 |
+
# Handle file cache modes
|
| 471 |
+
if mode == 'prepass':
|
| 472 |
+
self.cache_file = self.prepass_save_file
|
| 473 |
+
elif mode == 'final':
|
| 474 |
+
self.cache_file = self.final_save_file
|
| 475 |
+
else:
|
| 476 |
+
self.cache_file = self.build_sigma_cache_filename(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_dir)
|
| 477 |
+
|
| 478 |
+
# Load from file cache if enabled
|
| 479 |
+
if mode in ['prepass', 'final'] and self.load_prepass_sigmas:
|
| 480 |
+
loaded_sigmas = self.load_sigmas_with_hash_validation(
|
| 481 |
+
filename=self.cache_file,
|
| 482 |
+
steps=steps,
|
| 483 |
+
sigma_min=sigma_min,
|
| 484 |
+
sigma_max=sigma_max,
|
| 485 |
+
rho=rho,
|
| 486 |
+
device=device,
|
| 487 |
+
schedule_type=schedule_type,
|
| 488 |
+
decay_pattern=decay_pattern,
|
| 489 |
+
cache_key = cache_key
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if loaded_sigmas is not None:
|
| 493 |
+
self.loaded_sigmas = loaded_sigmas
|
| 494 |
+
self.sigma_cache[cache_key] = loaded_sigmas
|
| 495 |
+
return loaded_sigmas.to(device)
|
| 496 |
+
else:
|
| 497 |
+
self.log("[Cache Recovery] Cache load failed. Recalculating sigma schedule.")
|
| 498 |
+
'''
|
| 499 |
+
# Cache miss → recalculate
|
| 500 |
+
self.log(f"[Cache Miss] Recalculating sigma schedule for: {self.cache_file}")
|
| 501 |
+
_, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| 502 |
+
self.sigma_cache[cache_key] = sigmas
|
| 503 |
+
'''
|
| 504 |
+
return sigmas
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def load_sigmas_with_hash_validation(self, filename, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, cache_key = None, suffix=None):
|
| 508 |
+
if self.load_prepass_sigmas:
|
| 509 |
+
if cache_key:
|
| 510 |
+
try:
|
| 511 |
+
loaded_data = torch.load(filename, map_location=self.device)
|
| 512 |
+
self.loaded_sigmas = loaded_data['sigma_values'].to(self.device)
|
| 513 |
+
loaded_hash = loaded_data['sigma_hash']
|
| 514 |
+
|
| 515 |
+
expected_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data, suffix)
|
| 516 |
+
|
| 517 |
+
if loaded_hash != expected_hash:
|
| 518 |
+
self.log(f"[Sigma Validator] Hash mismatch. Expected: {expected_hash}, Found: {loaded_hash}. Recalculating.")
|
| 519 |
+
return None # Return None to signal the scheduler to recalculate
|
| 520 |
+
else:
|
| 521 |
+
self.log(f"[Sigma Validator] Hash validated successfully for file: {filename}")
|
| 522 |
+
return self.loaded_sigmas
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
self.log("[Cache Recovery] Sigma cache invalid or missing. Recalculating sigmas.")
|
| 526 |
+
_, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| 527 |
+
return sigmas
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def generate_sigma_hash(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, suffix=None):
|
| 531 |
+
data_string = f'{steps}_{sigma_min}_{sigma_max}_{rho}_{device}_{schedule_type}_{decay_pattern}_{suffix}'
|
| 532 |
+
hash_object = hashlib.sha256(data_string.encode())
|
| 533 |
+
return hash_object.hexdigest()[:12] # Use first 12 characters for compact ID
|
| 534 |
+
|
| 535 |
+
def _generate_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern=None, decay_mode=None, tail_steps=None):
|
| 536 |
+
scheduler_func = self.scheduler_registry.get(schedule_type)
|
| 537 |
+
|
| 538 |
+
if scheduler_func is None:
|
| 539 |
+
raise ValueError(f"Unknown schedule type: {schedule_type}")
|
| 540 |
+
|
| 541 |
+
tails, decay, extras, sigmas = self.call_scheduler_function(
|
| 542 |
+
scheduler_func,
|
| 543 |
+
steps=steps,
|
| 544 |
+
sigma_min=sigma_min,
|
| 545 |
+
sigma_max=sigma_max,
|
| 546 |
+
rho=rho,
|
| 547 |
+
device=device,
|
| 548 |
+
decay_pattern=decay_pattern,
|
| 549 |
+
decay_mode=decay_mode,
|
| 550 |
+
tail_steps=tail_steps
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
return tails, decay, extras, sigmas
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def initialize_generation_filename(self, folder=None, base_name="generation_log", ext="txt"):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the log filename early so it can be used throughout the process.
|
| 560 |
+
"""
|
| 561 |
+
if folder is None:
|
| 562 |
+
folder = self.settings.get('log_save_directory', 'modules/sd_simple_kes_v3/image_generation_data')
|
| 563 |
+
folder = os.path.abspath(os.path.normpath(folder))
|
| 564 |
+
|
| 565 |
+
os.makedirs(folder, exist_ok=True)
|
| 566 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 567 |
+
|
| 568 |
+
self.log_filename = os.path.join(folder, f"{base_name}_{timestamp}.{ext}")
|
| 569 |
+
|
| 570 |
+
def save_generation_settings(self):
|
| 571 |
+
"""
|
| 572 |
+
Save the generation log with configurable directory, base name, and extension.
|
| 573 |
+
|
| 574 |
+
Parameters:
|
| 575 |
+
- folder (str): Optional custom directory to save the log file.
|
| 576 |
+
- base_name (str): The base name for the file (default is 'generation_log').
|
| 577 |
+
- ext (str): The file extension to use (default is 'txt').
|
| 578 |
+
"""
|
| 579 |
+
with open(self.log_filename, "w", encoding = 'utf-8') as f:
|
| 580 |
+
for line in self.logger.log_buffer:
|
| 581 |
+
f.write(f"{line}\n")
|
| 582 |
+
for line in self.logger.prepass_log_buffer:
|
| 583 |
+
f.write(f"{line}\n")
|
| 584 |
+
self.log(f"[SimpleKEScheduler] Generation settings saved to {self.log_filename}")
|
| 585 |
+
|
| 586 |
+
self.logger.log_buffer.clear()
|
| 587 |
+
self.logger.prepass_log_buffer.clear()
|
| 588 |
+
|
| 589 |
+
def save_image_plot(self, sigs, i):
|
| 590 |
+
graph_plot = plot_sigma_sequence(
|
| 591 |
+
self.sigs[:i + 1],
|
| 592 |
+
i,
|
| 593 |
+
self.log_filename,
|
| 594 |
+
self.graph_save_directory,
|
| 595 |
+
self.graph_save_enable
|
| 596 |
+
)
|
| 597 |
+
self.log(f"Sigma sequence plot saved to {graph_plot}")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def apply_global_randomization(self):
|
| 603 |
+
"""Force randomization for all eligible settings by enabling _rand flags and re-randomizing values."""
|
| 604 |
+
# First pass: turn on all _rand flags if corresponding _rand_min/_rand_max exists
|
| 605 |
+
for key in list(self.settings.keys()):
|
| 606 |
+
if key.endswith("_rand_min") or key.endswith("_rand_max"):
|
| 607 |
+
base_key = key.rsplit("_rand_", 1)[0]
|
| 608 |
+
rand_flag_key = f"{base_key}_rand"
|
| 609 |
+
self.settings[rand_flag_key] = True
|
| 610 |
+
# Step 2: If global_randomize is active, re-randomize all eligible keys
|
| 611 |
+
if self.global_randomize:
|
| 612 |
+
if key not in self.settings:
|
| 613 |
+
raise KeyError(f"[apply_global_randomization] Missing required key: {key}")
|
| 614 |
+
|
| 615 |
+
default_val = self.settings[key]
|
| 616 |
+
randomized_val = self.get_random_or_default(key, default_val)
|
| 617 |
+
self.settings[key] = randomized_val
|
| 618 |
+
setattr(self, key, randomized_val)
|
| 619 |
+
|
| 620 |
+
def get_randomization_type(self, key_prefix):
|
| 621 |
+
"""
|
| 622 |
+
Retrieves the randomization type for a given key, with fallback to 'asymmetric' if missing.
|
| 623 |
+
"""
|
| 624 |
+
randomization_type_raw = self.settings.get(f'{key_prefix}_randomization_type', 'asymmetric')
|
| 625 |
+
randomization_type = self.RANDOMIZATION_TYPE_ALIASES.get(randomization_type_raw.lower(), 'asymmetric')
|
| 626 |
+
return randomization_type
|
| 627 |
+
|
| 628 |
+
def get_randomization_percent(self, key_prefix):
|
| 629 |
+
"""
|
| 630 |
+
Retrieves the randomization percent for a given key, with fallback to 0.2 if missing.
|
| 631 |
+
"""
|
| 632 |
+
return self.settings.get(f'{key_prefix}_randomization_percent', 0.2)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def get_random_between_min_max(self, key_prefix, default_value):
|
| 636 |
+
"""
|
| 637 |
+
Picks a random value between _rand_min and _rand_max if _rand is True.
|
| 638 |
+
Otherwise, returns the base value.
|
| 639 |
+
"""
|
| 640 |
+
randomize_flag = self.settings.get(f'{key_prefix}_rand', False)
|
| 641 |
+
|
| 642 |
+
if randomize_flag:
|
| 643 |
+
rand_min = self.settings.get(f'{key_prefix}_rand_min', default_value)
|
| 644 |
+
rand_max = self.settings.get(f'{key_prefix}_rand_max', default_value)
|
| 645 |
+
|
| 646 |
+
if rand_min == rand_max:
|
| 647 |
+
self.log(f"[Random Range] {key_prefix}: min and max are equal ({rand_min}). Using single value.")
|
| 648 |
+
return rand_min
|
| 649 |
+
|
| 650 |
+
value = random.uniform(rand_min, rand_max)
|
| 651 |
+
self.log(f"[Random Range] {key_prefix}: Picked random value {value} between {rand_min} and {rand_max}")
|
| 652 |
+
return value
|
| 653 |
+
else:
|
| 654 |
+
self.log(f"[Random Range] {key_prefix}: Randomization is OFF. Using base value {default_value}")
|
| 655 |
+
return default_value
|
| 656 |
+
|
| 657 |
+
def get_random_by_type(self, key_prefix, default_value):
|
| 658 |
+
randomization_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
|
| 659 |
+
|
| 660 |
+
if not randomization_enabled:
|
| 661 |
+
self.log(f"[Randomization Type] {key_prefix}: Randomization type is OFF. Using base value {default_value}")
|
| 662 |
+
return default_value
|
| 663 |
+
|
| 664 |
+
randomization_type = self.get_randomization_type(key_prefix)
|
| 665 |
+
randomization_percent = self.get_randomization_percent(key_prefix)
|
| 666 |
+
|
| 667 |
+
if randomization_type == 'symmetric':
|
| 668 |
+
rand_min = default_value * (1 - randomization_percent)
|
| 669 |
+
rand_max = default_value * (1 + randomization_percent)
|
| 670 |
+
self.log(f"[Symmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
|
| 671 |
+
|
| 672 |
+
elif randomization_type == 'asymmetric':
|
| 673 |
+
rand_min = default_value * (1 - randomization_percent)
|
| 674 |
+
rand_max = default_value * (1 + (randomization_percent * 2))
|
| 675 |
+
self.log(f"[Asymmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
|
| 676 |
+
|
| 677 |
+
elif randomization_type == 'logarithmic':
|
| 678 |
+
rand_min = math.log(default_value * (1 - randomization_percent))
|
| 679 |
+
rand_max = math.log(default_value * (1 + randomization_percent))
|
| 680 |
+
value = math.exp(random.uniform(rand_min, rand_max))
|
| 681 |
+
self.log(f"[Logarithmic Randomization] {key_prefix}: Log-space randomization resulted in {value}")
|
| 682 |
+
return value
|
| 683 |
+
|
| 684 |
+
elif randomization_type == 'exponential':
|
| 685 |
+
rand_min = default_value * (1 - randomization_percent)
|
| 686 |
+
rand_max = default_value * (1 + randomization_percent)
|
| 687 |
+
base_value = random.uniform(rand_min, rand_max)
|
| 688 |
+
value = math.exp(base_value)
|
| 689 |
+
self.log(f"[Exponential Randomization] {key_prefix}: Randomized exponential value {value}")
|
| 690 |
+
return value
|
| 691 |
+
|
| 692 |
+
else:
|
| 693 |
+
self.log(f"[Randomization Type] {key_prefix}: Invalid randomization type {randomization_type}. Using base value.")
|
| 694 |
+
return default_value
|
| 695 |
+
|
| 696 |
+
value = random.uniform(rand_min, rand_max)
|
| 697 |
+
|
| 698 |
+
self.log(f"[Randomization Type] {key_prefix}: Randomized value {value}")
|
| 699 |
+
return value
|
| 700 |
+
|
| 701 |
+
def get_random_or_default(self, key_prefix, default_value):
|
| 702 |
+
"""
|
| 703 |
+
Selects randomization method based on active flags:
|
| 704 |
+
- If both enabled → prioritize randomization type (or min/max if you prefer).
|
| 705 |
+
- If only one enabled → apply that one.
|
| 706 |
+
- If neither → return default value.
|
| 707 |
+
"""
|
| 708 |
+
rand_type_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
|
| 709 |
+
min_max_enabled = self.settings.get(f'{key_prefix}_rand', False)
|
| 710 |
+
|
| 711 |
+
if rand_type_enabled and min_max_enabled:
|
| 712 |
+
self.log(f"[Randomization Policy] Both min/max and randomization type enabled for {key_prefix}. System will prioritize randomization type.")
|
| 713 |
+
result_value = self.get_random_by_type(key_prefix, default_value)
|
| 714 |
+
|
| 715 |
+
elif rand_type_enabled:
|
| 716 |
+
result_value = self.get_random_by_type(key_prefix, default_value)
|
| 717 |
+
self.log(f"[Randomization] {key_prefix}: Applied randomization type. Final value: {result_value}")
|
| 718 |
+
|
| 719 |
+
elif min_max_enabled:
|
| 720 |
+
result_value = self.get_random_between_min_max(key_prefix, default_value)
|
| 721 |
+
self.log(f"[Randomization] {key_prefix}: Applied min/max randomization. Final value: {result_value}")
|
| 722 |
+
|
| 723 |
+
else:
|
| 724 |
+
result_value = default_value
|
| 725 |
+
self.log(f"[Randomization] {key_prefix}: No randomization applied. Using default value: {result_value}")
|
| 726 |
+
|
| 727 |
+
return result_value
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def resolve_blend_weights(self, blend_weights, blending_style):
|
| 731 |
+
if blending_style == 'softmax':
|
| 732 |
+
# Softmax automatically normalizes weights per step
|
| 733 |
+
blend_weights = torch.tensor(blend_weights)
|
| 734 |
+
normalized_weights = torch.softmax(blend_weights, dim=0)
|
| 735 |
+
return normalized_weights.tolist()
|
| 736 |
+
|
| 737 |
+
elif blending_style == 'explicit':
|
| 738 |
+
# Return raw weights, will manually normalize in blending step
|
| 739 |
+
return blend_weights
|
| 740 |
+
|
| 741 |
+
else:
|
| 742 |
+
raise ValueError(f"Unknown blending_style: {blending_style}")
|
| 743 |
+
|
| 744 |
+
def extract_scalar(self, value):
|
| 745 |
+
if isinstance(value, torch.Tensor):
|
| 746 |
+
if value.numel() > 1:
|
| 747 |
+
return value.mean().item() # or first element
|
| 748 |
+
else:
|
| 749 |
+
return value.item()
|
| 750 |
+
return value # Already a float
|
| 751 |
+
|
| 752 |
+
def _call_legacy_mode(self, schedule_type):
|
| 753 |
+
# Validate schedule_type
|
| 754 |
+
if schedule_type not in ['karras', 'exponential']:
|
| 755 |
+
self.log(f"[Legacy Mode] Unsupported schedule_type: {schedule_type}")
|
| 756 |
+
return
|
| 757 |
+
|
| 758 |
+
# Dynamically set target attribute
|
| 759 |
+
target_attr = f"sigmas_{schedule_type}"
|
| 760 |
+
|
| 761 |
+
scheduler_func = self.scheduler_registry.get(schedule_type)
|
| 762 |
+
|
| 763 |
+
tails, decay, extras, sigmas = self.call_scheduler_function(
|
| 764 |
+
scheduler_func,
|
| 765 |
+
steps=self.steps,
|
| 766 |
+
sigma_min=self.sigma_min,
|
| 767 |
+
sigma_max=self.sigma_max,
|
| 768 |
+
rho=self.rho,
|
| 769 |
+
device=self.device,
|
| 770 |
+
decay_pattern=self.decay_pattern
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# Assign to self.sigmas_karras or self.sigmas_exponential dynamically
|
| 774 |
+
setattr(self, target_attr, sigmas)
|
| 775 |
+
|
| 776 |
+
self.log(f"[Legacy Mode] Loaded sigma sequence for {schedule_type}. Assigned to self.{target_attr}")
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def blend_sigma_sequence(self, sigmas_karras=None, sigmas_exponential=None, pre_pass=False, blend_methods=None, blend_weights=None):
|
| 780 |
+
# Filter out schedulers with zero weight
|
| 781 |
+
active_methods = [
|
| 782 |
+
method for method, config in self.blend_method_dict.items() if config.get('weight', 1.0) > 0.0
|
| 783 |
+
]
|
| 784 |
+
# Fallback if all weights are zero
|
| 785 |
+
if not active_methods:
|
| 786 |
+
self.log("[Blend Config] All weights are zero. Falling back to default blend (karras + exponential).")
|
| 787 |
+
'''
|
| 788 |
+
# Values set in init, placed here for reference
|
| 789 |
+
self.blend_method_dict = {
|
| 790 |
+
'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
|
| 791 |
+
'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
|
| 792 |
+
}
|
| 793 |
+
'''
|
| 794 |
+
active_methods = list(self.blend_method_dict.keys())
|
| 795 |
+
|
| 796 |
+
self.blend_methods = active_methods
|
| 797 |
+
|
| 798 |
+
# Rebuild sigma lists
|
| 799 |
+
self.blend_weights = [self.blend_method_dict[m]['weight'] for m in self.blend_methods]
|
| 800 |
+
|
| 801 |
+
# Edge Case: No active schedulers
|
| 802 |
+
if len(self.blend_methods) == 0:
|
| 803 |
+
raise ValueError("[SimpleKEScheduler] No active schedulers selected. Please check your blend configuration.")
|
| 804 |
+
#Should never happen because we default to init value for default if 0.
|
| 805 |
+
|
| 806 |
+
# Edge Case: Only one scheduler (direct usage)
|
| 807 |
+
if len(self.blend_methods) == 1:
|
| 808 |
+
self.log(f"[Blend] Only one active scheduler: {self.blend_methods[0]}. Skipping blending, using it directly.")
|
| 809 |
+
self.sigs = self.sigma_sequences[self.blend_methods[0]]['sigmas']
|
| 810 |
+
#return self.sigs # Do not Exit early!! Continue function
|
| 811 |
+
|
| 812 |
+
if len(self.blend_methods) == 2:
|
| 813 |
+
self.blending_mode = 'smooth_blend'
|
| 814 |
+
elif len(self.blend_methods) > 2:
|
| 815 |
+
self.blending_mode = 'weights'
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
if not self.allow_step_expansion and self.auto_mode_enabled:
|
| 819 |
+
self.auto_mode_enabled = False
|
| 820 |
+
self.log("[Auto Mode] Step expansion disallowed. Auto mode forcibly disabled.")
|
| 821 |
+
|
| 822 |
+
self.progress = torch.linspace(0, 1, len(self.sigs)).to(self.device)
|
| 823 |
+
self.blended_sigmas = []
|
| 824 |
+
self.change_log = []
|
| 825 |
+
self.relative_converged = False
|
| 826 |
+
self.max_converged = False
|
| 827 |
+
self.delta_converged = False
|
| 828 |
+
self.early_stop_triggered = False
|
| 829 |
+
|
| 830 |
+
"""
|
| 831 |
+
Computes the blended sigma sequence using adaptive step sizes, dynamic blend factors,
|
| 832 |
+
and noise scaling across the progress of the diffusion process.
|
| 833 |
+
|
| 834 |
+
This method blends sigma values from the Karras and Exponential schedules using
|
| 835 |
+
a smooth, progress-dependent interpolation. It applies adaptive scaling based on
|
| 836 |
+
step size and noise scale factors to each sigma in the sequence.
|
| 837 |
+
|
| 838 |
+
Parameters:
|
| 839 |
+
-----------
|
| 840 |
+
sigs : torch.Tensor
|
| 841 |
+
A pre-allocated tensor where the computed sigma sequence will be stored.
|
| 842 |
+
This tensor must match the shape of the sigma schedules.
|
| 843 |
+
|
| 844 |
+
sigmas_karras : torch.Tensor
|
| 845 |
+
The sigma sequence generated using the Karras schedule.
|
| 846 |
+
|
| 847 |
+
sigmas_exponential : torch.Tensor
|
| 848 |
+
The sigma sequence generated using the Exponential schedule.
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
--------
|
| 852 |
+
sigs : torch.Tensor
|
| 853 |
+
The final blended and scaled sigma sequence.
|
| 854 |
+
|
| 855 |
+
Notes:
|
| 856 |
+
------
|
| 857 |
+
- This method is used in both the prepass and final pass of the scheduler.
|
| 858 |
+
- The progress tensor is computed linearly from 0 to 1 over the length of the sequence.
|
| 859 |
+
- The method uses class attributes for step size factors, blend factors, and noise scaling.
|
| 860 |
+
- This method modifies `sigs` in place.
|
| 861 |
+
"""
|
| 862 |
+
if self.sigmas_exponential is None:
|
| 863 |
+
self._call_legacy_mode(schedule_type='exponential')
|
| 864 |
+
|
| 865 |
+
if self.sigmas_karras is None:
|
| 866 |
+
self._call_legacy_mode(schedule_type='karras')
|
| 867 |
+
|
| 868 |
+
self.prepass_blended_sigmas = []
|
| 869 |
+
self.blended_sigma = None
|
| 870 |
+
self.blended_sigmas=[]
|
| 871 |
+
for i in range(len(self.sigs)):
|
| 872 |
+
if self.step_progress_mode == "linear":
|
| 873 |
+
progress_value = self.progress[i]
|
| 874 |
+
elif self.step_progress_mode == "exponential":
|
| 875 |
+
progress_value = self.progress[i] ** self.exp_power
|
| 876 |
+
elif self.step_progress_mode == "logarithmic":
|
| 877 |
+
progress_value = torch.log1p(self.progress[i] * (torch.exp(torch.tensor(1.0)) - 1))
|
| 878 |
+
elif self.step_progress_mode == "sigmoid":
|
| 879 |
+
progress_value = 1 / (1 + torch.exp(-12 * (self.progress[i] - 0.5)))
|
| 880 |
+
else:
|
| 881 |
+
progress_value = self.progress[i] # Fallback to linear (previous version used)
|
| 882 |
+
|
| 883 |
+
self.dynamic_blend_factor = self.start_blend * (1 - self.progress[i]) + self.end_blend * self.progress[i]
|
| 884 |
+
self.smooth_blend = torch.sigmoid((self.dynamic_blend_factor - self.blend_midpoint) * self.smooth_blend_factor)
|
| 885 |
+
self.noise_scale = self.initial_noise_scale * (1 - self.progress[i]) + self.final_noise_scale * self.progress[i] * self.noise_scale_factor
|
| 886 |
+
self.step_size = self.initial_step_size * (1 - progress_value) + self.final_step_size * progress_value * self.step_size_factor
|
| 887 |
+
if self.blending_mode == 'default':
|
| 888 |
+
# Classic default: Karras + Exponential only
|
| 889 |
+
self.blended_sigma = self.sigmas_karras[i] * (1 - self.smooth_blend) + self.sigmas_exponential[i] * self.smooth_blend
|
| 890 |
+
|
| 891 |
+
if self.blending_mode == 'smooth_blend' or (self.blending_mode == 'auto' and len(self.blend_methods) == 2):
|
| 892 |
+
# Smooth blend between exactly two methods
|
| 893 |
+
sigma_seq_a = self.sigma_sequences[self.blend_methods[0]]['sigmas']
|
| 894 |
+
sigma_seq_b = self.sigma_sequences[self.blend_methods[1]]['sigmas']
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
self.blended_sigma = sigma_seq_a[i] * (1 - self.smooth_blend) + sigma_seq_b[i] * self.smooth_blend
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
elif self.blending_mode == 'weights' or (self.blending_mode == 'auto' and len(self.blend_methods) > 2):
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
if self.blend_weights is None:
|
| 904 |
+
self.blend_weights = [1.0] * len(self.all_sigmas)
|
| 905 |
+
if self.blending_style is None:
|
| 906 |
+
self.blending_style = 'soft_max'
|
| 907 |
+
|
| 908 |
+
# Resolve weights based on blending style
|
| 909 |
+
resolved_blend_weights = self.resolve_blend_weights(self.blend_weights, self.blending_style)
|
| 910 |
+
|
| 911 |
+
weighted_sum = sum(w * self.extract_scalar(s[i]) for w, s in zip(resolved_blend_weights, self.all_sigmas))
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
total_weight = sum(resolved_blend_weights)
|
| 915 |
+
self.blended_sigma = weighted_sum / total_weight
|
| 916 |
+
|
| 917 |
+
for s in self.all_sigmas:
|
| 918 |
+
self.log(f"[DEBUG]sigma sequence shape: {s.shape}")
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
self.sigs[i] = self.blended_sigma * self.step_size * self.noise_scale
|
| 922 |
+
self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
| 923 |
+
# Safely extract scalar for both tensor and float
|
| 924 |
+
self.change_log.append(self.extract_scalar(self.change))
|
| 925 |
+
relative_sigma_progress = (self.blended_sigma - self.sigs[-1].item()) / self.blended_sigma
|
| 926 |
+
recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
|
| 927 |
+
max_change = torch.max(recent_changes).item()
|
| 928 |
+
mean_change = torch.mean(recent_changes).item()
|
| 929 |
+
#percent_of_threshold = (max_change / self.early_stopping_threshold) * 100
|
| 930 |
+
self.delta_change = abs(max_change - mean_change)
|
| 931 |
+
#self.blended_sigmas.append(self.blended_sigma.item())
|
| 932 |
+
self.blended_sigmas.append(self.extract_scalar(self.blended_sigma))
|
| 933 |
+
|
| 934 |
+
# Check 1: Relative sigma progress
|
| 935 |
+
self.relative_converged = relative_sigma_progress < 0.05
|
| 936 |
+
# Check 2: Max recent sigma change
|
| 937 |
+
self.max_converged = max_change < self.early_stopping_threshold
|
| 938 |
+
# Check 3: Max-mean difference converged
|
| 939 |
+
self.delta_converged = self.delta_change < self.recent_change_convergence_delta
|
| 940 |
+
|
| 941 |
+
if pre_pass:
|
| 942 |
+
self.prepass_blended_sigmas=self.blended_sigmas.copy()
|
| 943 |
+
self.prepass_blended_sigma = self.blended_sigma
|
| 944 |
+
if i >= 2:
|
| 945 |
+
|
| 946 |
+
sigma_rate = abs(self.prepass_blended_sigmas[i] - self.prepass_blended_sigmas[i - 1])
|
| 947 |
+
previous_sigma_rate = abs(self.prepass_blended_sigmas[i - 1] - self.prepass_blended_sigmas[i - 2])
|
| 948 |
+
if sigma_rate > previous_sigma_rate:
|
| 949 |
+
self.prepass_log(f"Sigma decline is slowing down → possible plateau at step {i+1}.")
|
| 950 |
+
|
| 951 |
+
if i == 0:
|
| 952 |
+
self.prepass_log("\n--- Starting Pre-Pass Blending ---\n")
|
| 953 |
+
step_label = "Prepass First Step"
|
| 954 |
+
elif i == len(self.sigs) - 1:
|
| 955 |
+
step_label = "Prepass Last Step"
|
| 956 |
+
else:
|
| 957 |
+
step_label = None
|
| 958 |
+
|
| 959 |
+
if step_label:
|
| 960 |
+
self.prepass_log(f"[{step_label} - Step {i}/{len(self.sigs)}] Prepass Blended Sigma: {self.prepass_blended_sigma:.6f}, Final Sigma: {self.sigs[i]:.6f}")
|
| 961 |
+
self.prepass_log(f"{step_label} Delta Converged: {self.delta_converged} delta_change: {self.delta_change:.6f}, Target Default Settings:{self.recent_change_convergence_delta}")
|
| 962 |
+
|
| 963 |
+
# Start checking for early stopping after minimum steps
|
| 964 |
+
if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
|
| 965 |
+
# Calculate variance and dynamic threshold
|
| 966 |
+
self.blended_tensor = torch.tensor(self.prepass_blended_sigmas)
|
| 967 |
+
if self.device == 'cpu':
|
| 968 |
+
self.sigma_variance = np.var(self.prepass_blended_sigmas)
|
| 969 |
+
else:
|
| 970 |
+
self.sigma_variance = torch.var(self.sigs).item()
|
| 971 |
+
|
| 972 |
+
self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale # scale factor can be tuned
|
| 973 |
+
self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
|
| 974 |
+
self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
|
| 975 |
+
self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
| 976 |
+
self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
|
| 977 |
+
self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
|
| 978 |
+
self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
# Reason for continuing (sigma still too high)
|
| 982 |
+
if self.prepass_blended_sigma > self.min_sigma_threshold:
|
| 983 |
+
self.prepass_log(f"Prepass Blended Sigma {self.prepass_blended_sigma:.6f} exceeds min sigma threshold {self.min_sigma_threshold:.6f} → Continuing.\n")
|
| 984 |
+
|
| 985 |
+
# Start Early Stopping Checks
|
| 986 |
+
if self.early_stopping_method == "mean":
|
| 987 |
+
mean_change = sum(self.change_log) / len(self.change_log)
|
| 988 |
+
if mean_change < self.early_stopping_threshold:
|
| 989 |
+
skipped_steps = len(self.sigs) - (i)
|
| 990 |
+
self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {mean_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
| 991 |
+
|
| 992 |
+
elif self.early_stopping_method == "max":
|
| 993 |
+
#max_change = max(self.change_log)
|
| 994 |
+
if max_change < self.early_stopping_threshold:
|
| 995 |
+
skipped_steps = len(self.sigs) - (i)
|
| 996 |
+
self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {max_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
| 997 |
+
|
| 998 |
+
elif self.early_stopping_method == "sum":
|
| 999 |
+
stable_steps = sum(
|
| 1000 |
+
1 for j in range(1, len(self.change_log))
|
| 1001 |
+
if abs(self.change_log[j]) < self.early_stopping_threshold * abs(self.sigs[j])
|
| 1002 |
+
)
|
| 1003 |
+
if stable_steps >= 0.8 * len(self.change_log):
|
| 1004 |
+
skipped_steps = len(self.sigs) - (i)
|
| 1005 |
+
self.prepass_log(f"Early stopping triggered by sum at step {i}. Stable steps: {stable_steps}/{len(self.change_log)}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
| 1006 |
+
|
| 1007 |
+
if self.relative_converged and self.max_converged and self.delta_converged:
|
| 1008 |
+
self.early_stop_triggered = True
|
| 1009 |
+
self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i+1} ---")
|
| 1010 |
+
self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
|
| 1011 |
+
self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
|
| 1012 |
+
self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
|
| 1013 |
+
self.prepass_log(f"Delta Change: {delta_change:.6f} (Target: {self.recent_change_convergence_delta})")
|
| 1014 |
+
self.prepass_log(f"Early stopping criteria met at step {i+1} based on all convergence checks.")
|
| 1015 |
+
self.predicted_stop_step = i
|
| 1016 |
+
#self.steps = self.predicted_stop_step
|
| 1017 |
+
self.save_image_plot(self.sigs, i)
|
| 1018 |
+
break
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# === Final Pass ===
|
| 1022 |
+
if not pre_pass:
|
| 1023 |
+
|
| 1024 |
+
if i == 0:
|
| 1025 |
+
step_label = "First Step"
|
| 1026 |
+
self.log("\n" + "=" * 10 + "\n[Start of Sigma Sequence Logging]\n" + "=" * 10)
|
| 1027 |
+
self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| 1028 |
+
f"\nStep Size: {self.step_size:.6f}"
|
| 1029 |
+
f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| 1030 |
+
f"\nNoise Scale: {self.noise_scale:.6f}"
|
| 1031 |
+
f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| 1032 |
+
f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| 1033 |
+
f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| 1034 |
+
elif i == len(self.sigs) // 2:
|
| 1035 |
+
step_label = "Middle Step"
|
| 1036 |
+
self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| 1037 |
+
f"\nStep Size: {self.step_size:.6f}"
|
| 1038 |
+
f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| 1039 |
+
f"\nNoise Scale: {self.noise_scale:.6f}"
|
| 1040 |
+
f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| 1041 |
+
f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| 1042 |
+
f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| 1043 |
+
elif i == len(self.sigs) - 1:
|
| 1044 |
+
step_label = "Last Step"
|
| 1045 |
+
self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| 1046 |
+
f"\nStep Size: {self.step_size:.6f}"
|
| 1047 |
+
f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| 1048 |
+
f"\nNoise Scale: {self.noise_scale:.6f}"
|
| 1049 |
+
f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| 1050 |
+
f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| 1051 |
+
f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| 1052 |
+
self.log("\n" + "=" * 10 + "\n[End of Sigma Sequence Logging]\n" + "=" * 10)
|
| 1053 |
+
else:
|
| 1054 |
+
step_label = None
|
| 1055 |
+
|
| 1056 |
+
if i > 0:
|
| 1057 |
+
self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
| 1058 |
+
#self.change_log.append(self.change.item())
|
| 1059 |
+
self.change_log.append(self.extract_scalar(self.change))
|
| 1060 |
+
|
| 1061 |
+
# Early Stopping Evaluation
|
| 1062 |
+
if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
|
| 1063 |
+
final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
|
| 1064 |
+
if self.blended_sigma != 0:
|
| 1065 |
+
relative_sigma_progress = (self.blended_sigma - final_target_sigma) / self.blended_sigma
|
| 1066 |
+
else:
|
| 1067 |
+
relative_sigma_progress = 0 # Assume fully converged if blended_sigma is 0
|
| 1068 |
+
# Optional: Show variance but no need to stop on it
|
| 1069 |
+
self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
|
| 1070 |
+
self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
| 1071 |
+
if self.graph_save_enable:
|
| 1072 |
+
self.save_image_plot(self.sigs, i)
|
| 1073 |
+
|
| 1074 |
+
#apply tails and decay after the loop finishes
|
| 1075 |
+
# Finished core sigma blending
|
| 1076 |
+
if not self.auto_mode_enabled:
|
| 1077 |
+
if not pre_pass: # Only extend in the final pass
|
| 1078 |
+
if self.apply_tail_steps:
|
| 1079 |
+
for i, tail in enumerate(self.all_tails):
|
| 1080 |
+
if tail is not None:
|
| 1081 |
+
self.log(f"Appending tail from method: {self.blend_methods[i]}")
|
| 1082 |
+
self.sigs = torch.cat([self.sigs, tail])
|
| 1083 |
+
|
| 1084 |
+
if self.apply_decay_tail:
|
| 1085 |
+
for i, decay in enumerate(self.all_decays):
|
| 1086 |
+
if decay is not None:
|
| 1087 |
+
self.log(f"Appending decay from method: {self.blend_methods[i]}")
|
| 1088 |
+
self.sigs = torch.cat([self.sigs, decay])
|
| 1089 |
+
|
| 1090 |
+
if self.apply_progressive_decay:
|
| 1091 |
+
progressive_decay = None
|
| 1092 |
+
total_weight = 0
|
| 1093 |
+
|
| 1094 |
+
for w, decay in zip(resolved_blend_weights, self.all_decays):
|
| 1095 |
+
if decay is not None:
|
| 1096 |
+
decay = decay[:len(self.sigs)] # Ensure matching length
|
| 1097 |
+
if progressive_decay is None:
|
| 1098 |
+
progressive_decay = w * decay
|
| 1099 |
+
else:
|
| 1100 |
+
progressive_decay += w * decay
|
| 1101 |
+
total_weight += w
|
| 1102 |
+
|
| 1103 |
+
if progressive_decay is not None and total_weight > 0:
|
| 1104 |
+
progressive_decay /= total_weight
|
| 1105 |
+
self.log("Applying progressive decay to sigma sequence.")
|
| 1106 |
+
self.sigs = self.sigs * progressive_decay
|
| 1107 |
+
|
| 1108 |
+
if self.apply_blended_tail:
|
| 1109 |
+
blended_tail = None
|
| 1110 |
+
total_weight = 0
|
| 1111 |
+
|
| 1112 |
+
for w, tail in zip(resolved_blend_weights, self.all_tails):
|
| 1113 |
+
if tail is not None:
|
| 1114 |
+
if blended_tail is None:
|
| 1115 |
+
blended_tail = w * tail
|
| 1116 |
+
else:
|
| 1117 |
+
blended_tail += w * tail
|
| 1118 |
+
total_weight += w
|
| 1119 |
+
|
| 1120 |
+
if blended_tail is not None and total_weight > 0:
|
| 1121 |
+
blended_tail /= total_weight
|
| 1122 |
+
self.log("Appending blended tail to sigma sequence.")
|
| 1123 |
+
self.sigs = torch.cat([self.sigs, blended_tail])
|
| 1124 |
+
|
| 1125 |
+
else:
|
| 1126 |
+
# Run Auto Mode stabilization sequence
|
| 1127 |
+
if len(self.sigs) > self.steps:
|
| 1128 |
+
self.auto_stabilization_sequence = []
|
| 1129 |
+
self.log(f"[Auto Mode] Sigma sequence length {len(self.sigs)} exceeds requested steps {self.steps}. Disabling auto stabilization.")
|
| 1130 |
+
self.auto_mode_enabled = False
|
| 1131 |
+
self.sigs = self.sigs[:self.steps] # Force truncate to requested step count
|
| 1132 |
+
return self.sigs
|
| 1133 |
+
self.run_auto_stabilization(self.sigs)
|
| 1134 |
+
|
| 1135 |
+
if pre_pass and self.early_stop_triggered:
|
| 1136 |
+
return self.sigs[:self.predicted_stop_step] # Return only the usable sequence
|
| 1137 |
+
else:
|
| 1138 |
+
return self.sigs
|
| 1139 |
+
def run_auto_stabilization(self):
|
| 1140 |
+
#This function works as intended, but is blocked by default if programs don't let schedulers create a sigma schedule longer than requested steps.
|
| 1141 |
+
if not self.allow_step_expansion:
|
| 1142 |
+
self.log("[Auto Mode] Step expansion is disabled by configuration. Skipping auto stabilization.")
|
| 1143 |
+
return self.sigs
|
| 1144 |
+
if self.allow_step_expansion:
|
| 1145 |
+
unstable = self.detect_sequence_instability()
|
| 1146 |
+
|
| 1147 |
+
if not unstable:
|
| 1148 |
+
self.log("[Auto Mode] Sigma sequence is already stable.")
|
| 1149 |
+
return
|
| 1150 |
+
|
| 1151 |
+
self.log("[Auto Mode] Detected instability in sigma sequence. Starting stabilization sequence.")
|
| 1152 |
+
|
| 1153 |
+
for method in self.auto_stabilization_sequence:
|
| 1154 |
+
if not unstable:
|
| 1155 |
+
self.log(f"[Auto Mode] Sequence stabilized after {method}. Stopping further corrections.")
|
| 1156 |
+
break
|
| 1157 |
+
|
| 1158 |
+
if method == 'smooth_interpolation':
|
| 1159 |
+
unstable = self.smooth_interpolation()
|
| 1160 |
+
|
| 1161 |
+
elif method == 'append_tail':
|
| 1162 |
+
unstable = self.append_tail()
|
| 1163 |
+
|
| 1164 |
+
elif method == 'blend_tail':
|
| 1165 |
+
unstable = self.blend_tail()
|
| 1166 |
+
|
| 1167 |
+
elif method == 'apply_decay':
|
| 1168 |
+
unstable = self.apply_decay()
|
| 1169 |
+
|
| 1170 |
+
elif method == 'progressive_decay':
|
| 1171 |
+
unstable = self.progressive_decay()
|
| 1172 |
+
|
| 1173 |
+
else:
|
| 1174 |
+
self.log(f"[Auto Mode] Unknown stabilization method: {method}")
|
| 1175 |
+
def detect_sequence_instability(self):
|
| 1176 |
+
delta_sigmas = self.sigs[:-1] - self.sigs[1:]
|
| 1177 |
+
second_deltas = torch.diff(delta_sigmas)
|
| 1178 |
+
|
| 1179 |
+
steep_drop_detected = torch.any(delta_sigmas > self.auto_tail_threshold)
|
| 1180 |
+
jaggedness_score = torch.var(second_deltas[-5:]) if len(second_deltas) >= 5 else 0
|
| 1181 |
+
jagged_transition_detected = jaggedness_score > self.jaggedness_threshold
|
| 1182 |
+
|
| 1183 |
+
if steep_drop_detected:
|
| 1184 |
+
self.log(f"[Auto Mode] Steep drop detected. Max drop: {torch.max(delta_sigmas).item():.6f}")
|
| 1185 |
+
if jagged_transition_detected:
|
| 1186 |
+
self.log(f"[Auto Mode] Jagged transition detected. Jaggedness score: {jaggedness_score:.6f}")
|
| 1187 |
+
|
| 1188 |
+
return steep_drop_detected or jagged_transition_detected
|
| 1189 |
+
def smooth_interpolation(self):
|
| 1190 |
+
self.log("[Auto Mode] Applying smooth interpolation to last 5 steps.")
|
| 1191 |
+
if len(self.sigs) >= 5:
|
| 1192 |
+
start = self.sigs[-6].item()
|
| 1193 |
+
end = self.sigs[-1].item()
|
| 1194 |
+
interpolated = torch.linspace(start, end, steps=6, device=self.device)[1:]
|
| 1195 |
+
self.sigs[-5:] = interpolated
|
| 1196 |
+
|
| 1197 |
+
return self.detect_sequence_instability()
|
| 1198 |
+
|
| 1199 |
+
def append_tail(self):
|
| 1200 |
+
self.log("[Auto Mode] Attempting to append available tail.")
|
| 1201 |
+
if hasattr(self, 'all_tails') and self.all_tails:
|
| 1202 |
+
for tail in self.all_tails:
|
| 1203 |
+
if tail is not None:
|
| 1204 |
+
tail = tail.to(self.device)
|
| 1205 |
+
# If tail is longer than remaining sequence, trim
|
| 1206 |
+
if tail.shape[0] > self.sigs.shape[0]:
|
| 1207 |
+
tail = tail[:len(self.sigs)]
|
| 1208 |
+
self.sigs = torch.cat([self.sigs, tail])
|
| 1209 |
+
self.log("[Auto Mode] Appended tail to sigma sequence.")
|
| 1210 |
+
break
|
| 1211 |
+
|
| 1212 |
+
return self.detect_sequence_instability()
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def blend_tail(self):
|
| 1216 |
+
if not hasattr(self, 'all_tails') or not self.all_tails:
|
| 1217 |
+
self.log("[Auto Mode] No available tails to blend.")
|
| 1218 |
+
return self.detect_sequence_instability()
|
| 1219 |
+
|
| 1220 |
+
self.log("[Auto Mode] Attempting to blend multiple tails.")
|
| 1221 |
+
blended_tail = None
|
| 1222 |
+
total_weight = 0
|
| 1223 |
+
|
| 1224 |
+
for w, tail in zip(self.blend_weights, self.all_tails):
|
| 1225 |
+
if tail is not None:
|
| 1226 |
+
tail = tail.to(self.device)
|
| 1227 |
+
|
| 1228 |
+
# Align length if needed
|
| 1229 |
+
if tail.shape[0] > self.sigs.shape[0]:
|
| 1230 |
+
tail = tail[:len(self.sigs)]
|
| 1231 |
+
|
| 1232 |
+
if blended_tail is None:
|
| 1233 |
+
blended_tail = w * tail
|
| 1234 |
+
else:
|
| 1235 |
+
blended_tail += w * tail
|
| 1236 |
+
total_weight += w
|
| 1237 |
+
|
| 1238 |
+
if blended_tail is not None and total_weight > 0:
|
| 1239 |
+
blended_tail /= total_weight
|
| 1240 |
+
self.sigs = torch.cat([self.sigs, blended_tail])
|
| 1241 |
+
self.log("[Auto Mode] Appended blended tail to sigma sequence.")
|
| 1242 |
+
|
| 1243 |
+
return self.detect_sequence_instability()
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
def apply_decay(self):
|
| 1247 |
+
self.log("[Auto Mode] Attempting to append decay tails.")
|
| 1248 |
+
if hasattr(self, 'all_decays') and self.all_decays:
|
| 1249 |
+
for decay in self.all_decays:
|
| 1250 |
+
if decay is not None:
|
| 1251 |
+
decay = decay.to(self.device)
|
| 1252 |
+
|
| 1253 |
+
# Align length if needed
|
| 1254 |
+
if decay.shape[0] > self.sigs.shape[0]:
|
| 1255 |
+
decay = decay[:len(self.sigs)]
|
| 1256 |
+
|
| 1257 |
+
self.sigs = torch.cat([self.sigs, decay])
|
| 1258 |
+
self.log("[Auto Mode] Appended decay tail to sigma sequence.")
|
| 1259 |
+
break
|
| 1260 |
+
|
| 1261 |
+
return self.detect_sequence_instability()
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
def progressive_decay(self):
|
| 1265 |
+
self.log("[Auto Mode] Applying progressive decay to sigma sequence.")
|
| 1266 |
+
progressive_decay = None
|
| 1267 |
+
total_weight = 0
|
| 1268 |
+
|
| 1269 |
+
for w, decay in zip(self.blend_weights, self.all_decays):
|
| 1270 |
+
if decay is not None:
|
| 1271 |
+
decay = decay.to(self.device)
|
| 1272 |
+
|
| 1273 |
+
# If the decay is too short, interpolate to match self.sigs length
|
| 1274 |
+
if decay.shape[0] != self.sigs.shape[0]:
|
| 1275 |
+
decay = decay.view(1, 1, -1) # Shape for interpolation
|
| 1276 |
+
decay = F.interpolate(decay, size=self.sigs.shape[0], mode='linear', align_corners=False)
|
| 1277 |
+
decay = decay.view(-1)
|
| 1278 |
+
|
| 1279 |
+
if progressive_decay is None:
|
| 1280 |
+
progressive_decay = w * decay
|
| 1281 |
+
else:
|
| 1282 |
+
progressive_decay += w * decay
|
| 1283 |
+
|
| 1284 |
+
total_weight += w
|
| 1285 |
+
|
| 1286 |
+
if progressive_decay is not None and total_weight > 0:
|
| 1287 |
+
progressive_decay /= total_weight
|
| 1288 |
+
self.sigs = self.sigs * progressive_decay
|
| 1289 |
+
self.log("[Auto Mode] Applied progressive decay to sigma sequence.")
|
| 1290 |
+
|
| 1291 |
+
return self.detect_sequence_instability()
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
def load_blend_method_sigmas(self, mode=None):
|
| 1295 |
+
"""Loads all sigma sequences for the blend_methods list based on current settings and mode."""
|
| 1296 |
+
self.all_sigmas = []
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
for method in self.blend_methods:
|
| 1300 |
+
self.method_config = self.blend_method_dict[method]
|
| 1301 |
+
self.method_config[method] = {
|
| 1302 |
+
'decay_pattern': self.method_config.get('decay_pattern', 'zero'),
|
| 1303 |
+
'decay_mode': self.method_config.get('decay_mode', 'blend'),
|
| 1304 |
+
'tail_steps': self.method_config.get('tail_steps', 1)
|
| 1305 |
+
}
|
| 1306 |
+
self.current_config = self.method_config[method]
|
| 1307 |
+
|
| 1308 |
+
|
| 1309 |
+
sigma_func = self.scheduler_registry[method]
|
| 1310 |
+
tails, decay, extras, sigmas = self.call_scheduler_function(
|
| 1311 |
+
self.scheduler_registry.get(method),
|
| 1312 |
+
steps=self.steps,
|
| 1313 |
+
sigma_min=self.sigma_min,
|
| 1314 |
+
sigma_max=self.sigma_max,
|
| 1315 |
+
rho=self.rho, # Only passed if the scheduler accepts it
|
| 1316 |
+
device=self.device,
|
| 1317 |
+
decay_pattern=self.current_config['decay_pattern'], # Method-specific
|
| 1318 |
+
decay_mode=self.current_config['decay_mode'], # Method-specific
|
| 1319 |
+
tail_steps=self.current_config['tail_steps'] # Method-specific
|
| 1320 |
+
)
|
| 1321 |
+
self.sigma_sequences[method] = {
|
| 1322 |
+
'sigmas': sigmas,
|
| 1323 |
+
'tails': tails,
|
| 1324 |
+
'decay': decay,
|
| 1325 |
+
'extras': extras
|
| 1326 |
+
}
|
| 1327 |
+
setattr(self, f"sigmas_{method}", sigmas)
|
| 1328 |
+
|
| 1329 |
+
self.all_sigmas = [self.sigma_sequences[method]['sigmas'] for method in self.blend_methods]
|
| 1330 |
+
self.all_tails = [self.sigma_sequences[method]['tails'] for method in self.blend_methods]
|
| 1331 |
+
self.all_decays = [self.sigma_sequences[method]['decay'] for method in self.blend_methods]
|
| 1332 |
+
self.all_extras = [self.sigma_sequences[method].get('extras', []) for method in self.blend_methods]
|
| 1333 |
+
self._log_extras_to_file(self.all_extras)
|
| 1334 |
+
self.all_sigmas.append(sigmas)
|
| 1335 |
+
|
| 1336 |
+
# Optionally log which schedules were loaded
|
| 1337 |
+
self.log(f"Loaded sigma schedules for blend methods: {self.blend_methods} using mode: {mode}")
|
| 1338 |
+
def validate_and_align_sigmas(self):
|
| 1339 |
+
"""
|
| 1340 |
+
Ensures all sigma sequences in self.all_sigmas are valid and have the same length.
|
| 1341 |
+
Pads shorter sequences with their last sigma.
|
| 1342 |
+
"""
|
| 1343 |
+
if not self.all_sigmas or len(self.all_sigmas) == 0:
|
| 1344 |
+
raise ValueError("No sigma sequences were loaded for blending.")
|
| 1345 |
+
|
| 1346 |
+
target_length = max(len(s) for s in self.all_sigmas)
|
| 1347 |
+
|
| 1348 |
+
for idx, sigmas in enumerate(self.all_sigmas):
|
| 1349 |
+
if sigmas is None or len(sigmas) == 0:
|
| 1350 |
+
raise ValueError(f"Sigma sequence at index {idx} is invalid or empty: {sigmas}")
|
| 1351 |
+
|
| 1352 |
+
if len(sigmas) < target_length:
|
| 1353 |
+
padding = torch.full((target_length - len(sigmas),), sigmas[-1]).to(sigmas.device)
|
| 1354 |
+
self.all_sigmas[idx] = torch.cat([sigmas, padding])
|
| 1355 |
+
|
| 1356 |
+
self.log(f"Validated and aligned all sigma sequences to length {target_length}.")
|
| 1357 |
+
|
| 1358 |
+
def generate_sigmas_schedule(self, mode=None):
|
| 1359 |
+
"""
|
| 1360 |
+
Generates the sigma schedules required for the hybrid blending process.
|
| 1361 |
+
|
| 1362 |
+
The Karras and Exponential sigma sequences are created to provide two distinct
|
| 1363 |
+
noise scaling strategies:
|
| 1364 |
+
- The Karras sequence offers a more aggressive noise decay, commonly used in
|
| 1365 |
+
modern schedulers for improved image quality and denoising stability.
|
| 1366 |
+
- The Exponential sequence provides a traditional log-space noise schedule.
|
| 1367 |
+
|
| 1368 |
+
These two sequences are dynamically blended in later steps using progress-dependent
|
| 1369 |
+
weights to produce a custom sigma path that combines the advantages of both approaches.
|
| 1370 |
+
|
| 1371 |
+
This blending process is critical to the scheduler's ability to:
|
| 1372 |
+
- Adapt noise scaling across steps.
|
| 1373 |
+
- Control the sharpness and smoothness of transitions.
|
| 1374 |
+
- Support early stopping based on sigma convergence patterns.
|
| 1375 |
+
|
| 1376 |
+
These sigma sequences must be regenerated in both the prepass (for early stopping detection)
|
| 1377 |
+
and the final pass (for polished sigma application), ensuring both passes are synchronized
|
| 1378 |
+
with the current step count and randomization settings.
|
| 1379 |
+
"""
|
| 1380 |
+
#self.cache_key= self.generate_sigma_hash(self.steps, self.sigma_min, self.sigma_max, self.rho, self.device, self.schedule_type, self.decay_pattern, self.suffix or None)
|
| 1381 |
+
# ✅ Clean Mode Selection
|
| 1382 |
+
if mode == 'prepass':
|
| 1383 |
+
if self.load_prepass_sigmas:
|
| 1384 |
+
self.cache_file = self.prepass_save_file
|
| 1385 |
+
self.mode = 'prepass'
|
| 1386 |
+
|
| 1387 |
+
elif mode == 'final':
|
| 1388 |
+
if self.load_sigma_cache:
|
| 1389 |
+
self.cache_file = self.final_save_file
|
| 1390 |
+
self.mode = 'final'
|
| 1391 |
+
|
| 1392 |
+
else:
|
| 1393 |
+
self.mode = None
|
| 1394 |
+
self.cache_file = None # Optional, for safety
|
| 1395 |
+
|
| 1396 |
+
'''
|
| 1397 |
+
if self.cache_file:
|
| 1398 |
+
sigmas = self.get_sigma_with_cache(
|
| 1399 |
+
steps=self.steps,
|
| 1400 |
+
sigma_min=self.sigma_min,
|
| 1401 |
+
sigma_max=self.sigma_max,
|
| 1402 |
+
rho=self.rho,
|
| 1403 |
+
device=self.device,
|
| 1404 |
+
decay_pattern=self.decay_pattern,
|
| 1405 |
+
cache_file=self.cache_file,
|
| 1406 |
+
mode=self.mode
|
| 1407 |
+
#cache_key = self.cache_key
|
| 1408 |
+
)
|
| 1409 |
+
return sigmas
|
| 1410 |
+
'''
|
| 1411 |
+
|
| 1412 |
+
#else:
|
| 1413 |
+
#logic for multiple schedulers
|
| 1414 |
+
#self.load_blend_method_sigmas(mode=self.mode)
|
| 1415 |
+
self.load_blend_method_sigmas(mode=self.mode)
|
| 1416 |
+
self.blend_pairs = []
|
| 1417 |
+
self.active_methods = [method for method in self.blend_methods if self.blend_method_dict[method].get('weight', 1.0) > 0]
|
| 1418 |
+
|
| 1419 |
+
if self.blending_mode == 'default':
|
| 1420 |
+
self._call_legacy_mode(schedule_type='exponential')
|
| 1421 |
+
self._call_legacy_mode(schedule_type='karras')
|
| 1422 |
+
|
| 1423 |
+
self.blend_pairs = []
|
| 1424 |
+
self.blend_pairs.append({
|
| 1425 |
+
'method_label': 'method_a',
|
| 1426 |
+
'method': 'karras',
|
| 1427 |
+
'sigmas': self.sigmas_karras
|
| 1428 |
+
})
|
| 1429 |
+
self.blend_pairs.append({
|
| 1430 |
+
'method_label': 'method_b',
|
| 1431 |
+
'method': 'exponential',
|
| 1432 |
+
'sigmas': self.sigmas_exponential
|
| 1433 |
+
})
|
| 1434 |
+
|
| 1435 |
+
# ✅ Optional: Pad if needed (if you know some might be misaligned)
|
| 1436 |
+
max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
|
| 1437 |
+
|
| 1438 |
+
for pair in self.blend_pairs:
|
| 1439 |
+
if len(pair['sigmas']) < max_length:
|
| 1440 |
+
padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
|
| 1441 |
+
pair['sigmas'] = torch.cat([pair['sigmas'], padding])
|
| 1442 |
+
|
| 1443 |
+
self.log(f"All sigma sequences aligned to length: {max_length}")
|
| 1444 |
+
|
| 1445 |
+
# ✅ For legacy compatibility
|
| 1446 |
+
sigmas_a = self.blend_pairs[0]['sigmas']
|
| 1447 |
+
sigmas_b = self.blend_pairs[1]['sigmas']
|
| 1448 |
+
label_a = self.blend_pairs[0]['method']
|
| 1449 |
+
label_b = self.blend_pairs[1]['method']
|
| 1450 |
+
if sigmas_a is None:
|
| 1451 |
+
raise ValueError(f"Sigmas {label_a} failed to generate or assign correctly.")
|
| 1452 |
+
if sigmas_b is None:
|
| 1453 |
+
raise ValueError(f"Sigmas {label_b} failed to generate or assign correctly.")
|
| 1454 |
+
else:
|
| 1455 |
+
if len(self.active_methods) == 1:
|
| 1456 |
+
# Only one method, assign both to the same method
|
| 1457 |
+
self.blend_pairs = []
|
| 1458 |
+
method = self.active_methods[0]
|
| 1459 |
+
self.blend_pairs.append({
|
| 1460 |
+
'method_label': 'method_a',
|
| 1461 |
+
'method': method,
|
| 1462 |
+
'sigmas': self.sigma_sequences[method]['sigmas']
|
| 1463 |
+
})
|
| 1464 |
+
|
| 1465 |
+
elif len(self.active_methods) >= 2:
|
| 1466 |
+
# Build blend_pairs dynamically for all active methods
|
| 1467 |
+
self.blend_pairs = []
|
| 1468 |
+
for idx, method in enumerate(self.active_methods):
|
| 1469 |
+
self.blend_pairs.append({
|
| 1470 |
+
'method_label': f'method_{chr(97 + idx)}', # method_a, method_b, etc.
|
| 1471 |
+
'method': method,
|
| 1472 |
+
'sigmas': self.sigma_sequences[method]['sigmas']
|
| 1473 |
+
})
|
| 1474 |
+
|
| 1475 |
+
# Validation checks (loop version)
|
| 1476 |
+
for pair in self.blend_pairs:
|
| 1477 |
+
if pair['sigmas'] is None:
|
| 1478 |
+
raise ValueError(f"Sigmas {pair['method']} failed to generate or assign correctly.")
|
| 1479 |
+
|
| 1480 |
+
|
| 1481 |
+
# Skip length matching if only 1 method is enabled
|
| 1482 |
+
if len(self.blend_pairs) > 1:
|
| 1483 |
+
target_length = min(len(pair['sigmas']) for pair in self.blend_pairs)
|
| 1484 |
+
|
| 1485 |
+
# Trim all to target length
|
| 1486 |
+
for pair in self.blend_pairs:
|
| 1487 |
+
pair['sigmas'] = pair['sigmas'][:target_length]
|
| 1488 |
+
|
| 1489 |
+
# Find the max length (in case any sequence is longer after trimming)
|
| 1490 |
+
max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
|
| 1491 |
+
|
| 1492 |
+
for pair in self.blend_pairs:
|
| 1493 |
+
if len(pair['sigmas']) < max_length:
|
| 1494 |
+
padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
|
| 1495 |
+
pair['sigmas'] = torch.cat([pair['sigmas'], padding])
|
| 1496 |
+
|
| 1497 |
+
self.log(f"All sigma sequences aligned to length: {max_length}")
|
| 1498 |
+
|
| 1499 |
+
self.sigs = torch.zeros(target_length, device=self.blend_pairs[0]['sigmas'].device)
|
| 1500 |
+
else:
|
| 1501 |
+
# Only one method, set self.sigs directly
|
| 1502 |
+
self.sigs = self.blend_pairs[0]['sigmas'].clone()
|
| 1503 |
+
|
| 1504 |
+
|
| 1505 |
+
'''
|
| 1506 |
+
# Now it's safe to compute sigs
|
| 1507 |
+
start = math.log(self.sigma_max)
|
| 1508 |
+
end = math.log(self.sigma_min)
|
| 1509 |
+
#self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
|
| 1510 |
+
if self.sigs is None or self.force_rebuild_sigs:
|
| 1511 |
+
self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
|
| 1515 |
+
# Ensure sigs contain valid values before using them
|
| 1516 |
+
if torch.any(self.sigs > 0):
|
| 1517 |
+
self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
|
| 1518 |
+
else:
|
| 1519 |
+
# If sigs are all invalid, set a safe fallback
|
| 1520 |
+
self.sigma_min, self.sigma_max = self.min_threshold, self.min_threshold
|
| 1521 |
+
self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
| 1522 |
+
|
| 1523 |
+
return {
|
| 1524 |
+
'karras': self.sigmas_karras,
|
| 1525 |
+
'exponential': self.sigmas_exponential,
|
| 1526 |
+
'blend_methods': self.blend_methods,
|
| 1527 |
+
'all_sigmas': self.all_sigmas,
|
| 1528 |
+
'sigs': self.sigs
|
| 1529 |
+
}
|
| 1530 |
+
|
| 1531 |
+
#sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
|
| 1532 |
+
#if len(set(sigma_lengths)) > 1: # There are mismatched lengths
|
| 1533 |
+
#self.validate_and_align_sigmas()
|
| 1534 |
+
#self.sigs = torch.zeros(self.steps, device=self.device)
|
| 1535 |
+
sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
|
| 1536 |
+
if len(set(sigma_lengths)) > 1:
|
| 1537 |
+
self.log("[Sigma Alignment] Detected mismatched sigma sequence lengths. Aligning...")
|
| 1538 |
+
self.validate_and_align_sigmas()
|
| 1539 |
+
|
| 1540 |
+
return {
|
| 1541 |
+
'blend_methods': self.blend_methods,
|
| 1542 |
+
'all_sigmas': self.all_sigmas,
|
| 1543 |
+
'sigs': self.sigs
|
| 1544 |
+
}
|
| 1545 |
+
'''
|
| 1546 |
+
# Now it's safe to compute sigs
|
| 1547 |
+
#start = math.log(self.sigma_max)
|
| 1548 |
+
#end = math.log(self.sigma_min)
|
| 1549 |
+
#self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
|
| 1550 |
+
|
| 1551 |
+
|
| 1552 |
+
'''
|
| 1553 |
+
if torch.any(self.sigs > 0):
|
| 1554 |
+
self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
|
| 1555 |
+
else:
|
| 1556 |
+
# If sigs are all invalid, set a safe fallback
|
| 1557 |
+
self.sigma_min = self.min_threshold
|
| 1558 |
+
self.sigma_max = self.min_threshold
|
| 1559 |
+
self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
| 1560 |
+
|
| 1561 |
+
return {
|
| 1562 |
+
'blend_methods': self.blend_methods,
|
| 1563 |
+
'all_sigmas': self.all_sigmas,
|
| 1564 |
+
'sigs': self.sigs
|
| 1565 |
+
}
|
| 1566 |
+
|
| 1567 |
+
return {
|
| 1568 |
+
'blend_methods': self.blend_methods,
|
| 1569 |
+
'all_sigmas': self.all_sigmas,
|
| 1570 |
+
'sigs': self.sigs
|
| 1571 |
+
}
|
| 1572 |
+
'''
|
| 1573 |
+
|
| 1574 |
+
if not torch.any(self.sigs > 0):
|
| 1575 |
+
self.sigma_min = self.min_threshold
|
| 1576 |
+
self.sigma_max = self.min_threshold
|
| 1577 |
+
self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
| 1578 |
+
else:
|
| 1579 |
+
self.sigma_min = self.sigs[self.sigs > 0].min()
|
| 1580 |
+
self.sigma_max = self.sigs.max()
|
| 1581 |
+
|
| 1582 |
+
return {
|
| 1583 |
+
'blend_methods': self.blend_methods,
|
| 1584 |
+
'all_sigmas': self.all_sigmas,
|
| 1585 |
+
'sigs': self.sigs
|
| 1586 |
+
}
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
|
| 1590 |
+
def call_scheduler_function(self, scheduler_func, **kwargs):
|
| 1591 |
+
"""
|
| 1592 |
+
Safely calls a scheduler function with dynamic argument filtering and flexible return handling.
|
| 1593 |
+
|
| 1594 |
+
This method ensures that only the parameters accepted by the scheduler function are passed.
|
| 1595 |
+
It automatically handles scheduler functions that may return:
|
| 1596 |
+
- Only the sigma sequence
|
| 1597 |
+
- A tuple with (tails, sigmas)
|
| 1598 |
+
- A tuple with (tails, decay, sigmas)
|
| 1599 |
+
- A tuple with additional items (extras) before sigmas
|
| 1600 |
+
|
| 1601 |
+
The method always assumes the last returned item is the sigma sequence, with optional
|
| 1602 |
+
items preceding it.
|
| 1603 |
+
|
| 1604 |
+
Parameters:
|
| 1605 |
+
----------
|
| 1606 |
+
scheduler_func : callable
|
| 1607 |
+
The scheduler function to be invoked. It may accept various arguments such as steps, sigma_min,
|
| 1608 |
+
sigma_max, rho, device, decay_pattern, etc.
|
| 1609 |
+
**kwargs : dict
|
| 1610 |
+
Arbitrary keyword arguments. Only those accepted by the scheduler function will be passed.
|
| 1611 |
+
|
| 1612 |
+
Returns:
|
| 1613 |
+
-------
|
| 1614 |
+
tuple
|
| 1615 |
+
A 4-tuple containing:
|
| 1616 |
+
- tails : Any (optional, can be None)
|
| 1617 |
+
The tail component of the sigma schedule, if provided.
|
| 1618 |
+
- decay : Any (optional, can be None)
|
| 1619 |
+
The decay component of the sigma schedule, if provided.
|
| 1620 |
+
- extras : list
|
| 1621 |
+
Any additional return values provided by the scheduler function, beyond tails and decay.
|
| 1622 |
+
- sigmas : Any
|
| 1623 |
+
The sigma sequence, always assumed to be the last item returned by the scheduler function.
|
| 1624 |
+
|
| 1625 |
+
Raises:
|
| 1626 |
+
------
|
| 1627 |
+
ValueError
|
| 1628 |
+
If the scheduler function returns an empty tuple.
|
| 1629 |
+
|
| 1630 |
+
Notes:
|
| 1631 |
+
-----
|
| 1632 |
+
This method allows future schedulers to return additional optional data without breaking the calling pattern.
|
| 1633 |
+
"""
|
| 1634 |
+
valid_params = inspect.signature(scheduler_func).parameters
|
| 1635 |
+
filtered_args = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 1636 |
+
|
| 1637 |
+
result = scheduler_func(**filtered_args)
|
| 1638 |
+
|
| 1639 |
+
if isinstance(result, dict):
|
| 1640 |
+
tails = result.get('tails', None)
|
| 1641 |
+
decay = result.get('decay', None)
|
| 1642 |
+
sigmas = result.get('sigmas')
|
| 1643 |
+
extras = result.get('extras', [])
|
| 1644 |
+
|
| 1645 |
+
if sigmas is None:
|
| 1646 |
+
raise ValueError("Scheduler function must return a 'sigmas' key.")
|
| 1647 |
+
|
| 1648 |
+
return tails, decay, extras, sigmas
|
| 1649 |
+
|
| 1650 |
+
# Legacy support: fallback to tuple unpacking if needed
|
| 1651 |
+
if not isinstance(result, tuple):
|
| 1652 |
+
return None, None, [], result
|
| 1653 |
+
|
| 1654 |
+
if len(result) == 0:
|
| 1655 |
+
raise ValueError(f"Scheduler function returned an empty tuple. This is not allowed.")
|
| 1656 |
+
|
| 1657 |
+
sigmas = result[-1]
|
| 1658 |
+
optional_returns = result[:-1]
|
| 1659 |
+
|
| 1660 |
+
tails = optional_returns[0] if len(optional_returns) > 0 else None
|
| 1661 |
+
decay = optional_returns[1] if len(optional_returns) > 1 else None
|
| 1662 |
+
extras = optional_returns[2:] if len(optional_returns) > 2 else []
|
| 1663 |
+
|
| 1664 |
+
return tails, decay, extras, sigmas
|
| 1665 |
+
|
| 1666 |
+
def config_values(self):
|
| 1667 |
+
#Ensures sigma_min is always less than sigma_max for edge cases
|
| 1668 |
+
if self.sigma_min >= self.sigma_max:
|
| 1669 |
+
correction_factor = random.uniform(0.01, 0.99)
|
| 1670 |
+
old_sigma_min = self.sigma_min
|
| 1671 |
+
self.sigma_min = self.sigma_max * correction_factor
|
| 1672 |
+
self.log(f"[Correction] sigma_min ({old_sigma_min}) was >= sigma_max ({self.sigma_max}). Adjusted sigma_min to {self.sigma_min} using correction factor {correction_factor}.")
|
| 1673 |
+
|
| 1674 |
+
self.log(f"Final sigmas: sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
| 1675 |
+
|
| 1676 |
+
|
| 1677 |
+
if self.sigma_auto_enabled:
|
| 1678 |
+
if self.sigma_auto_mode not in ["sigma_min", "sigma_max"]:
|
| 1679 |
+
raise ValueError(f"[Config Error] Invalid sigma_auto_mode: {self.sigma_auto_mode}. Must be 'sigma_min' or 'sigma_max'.")
|
| 1680 |
+
|
| 1681 |
+
if self.sigma_auto_mode == "sigma_min":
|
| 1682 |
+
self.sigma_min = self.sigma_max / self.sigma_scale_factor
|
| 1683 |
+
self.log(f"[Auto Sigma Min] sigma_min set to {self.sigma_min} using scale factor {self.sigma_scale_factor}")
|
| 1684 |
+
|
| 1685 |
+
elif self.sigma_auto_mode == "sigma_max":
|
| 1686 |
+
self.sigma_max = self.sigma_min * self.sigma_scale_factor
|
| 1687 |
+
self.log(f"[Auto Sigma Max] sigma_max set to {self.sigma_max} using scale factor {self.sigma_scale_factor} and using a multiplier of {sigma_max_multipier} to account for smoother transitions")
|
| 1688 |
+
|
| 1689 |
+
# Always apply min_threshold AFTER auto scaling
|
| 1690 |
+
self.min_threshold = random.uniform(1e-5, 5e-5)
|
| 1691 |
+
|
| 1692 |
+
if self.sigma_min < self.min_threshold:
|
| 1693 |
+
self.log(f"[Threshold Enforcement] sigma_min was too low: {self.sigma_min} < min_threshold {self.min_threshold}")
|
| 1694 |
+
self.sigma_min = self.min_threshold
|
| 1695 |
+
|
| 1696 |
+
if self.sigma_max < self.min_threshold:
|
| 1697 |
+
self.log(f"[Threshold Enforcement] sigma_max was too low: {self.sigma_max} < min_threshold {self.min_threshold}")
|
| 1698 |
+
self.sigma_max = self.min_threshold
|
| 1699 |
+
|
| 1700 |
+
|
| 1701 |
+
valid_methods = ['mean', 'max', 'sum']
|
| 1702 |
+
if self.early_stopping_method not in valid_methods:
|
| 1703 |
+
self.log(f"[Config Correction] Invalid early_stopping_method: {self.early_stopping_method}. Defaulting to 'mean'.")
|
| 1704 |
+
self.early_stopping_method = 'mean'
|
| 1705 |
+
|
| 1706 |
+
def prepass_compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None, cache_key = None, skip_prepass = False)->torch.Tensor:
|
| 1707 |
+
|
| 1708 |
+
'''
|
| 1709 |
+
if self.load_prepass_sigmas:
|
| 1710 |
+
if cache_key:
|
| 1711 |
+
self._safe_sigma_loader(cache_key)
|
| 1712 |
+
#self.load_or_regenerate_sigmas(cache_key)
|
| 1713 |
+
loaded_data = torch.load(self.cache_file.replace('.txt', '.pt'), map_location=self.device)
|
| 1714 |
+
|
| 1715 |
+
sigmas = loaded_data['sigma_values'].to(self.device)
|
| 1716 |
+
self.loaded_sigmas = sigmas # No need to call torch.tensor again
|
| 1717 |
+
|
| 1718 |
+
loaded_hash = loaded_data['sigma_hash']
|
| 1719 |
+
|
| 1720 |
+
steps = loaded_data['steps']
|
| 1721 |
+
sigma_min = loaded_data['sigma_min']
|
| 1722 |
+
sigma_max = loaded_data['sigma_max']
|
| 1723 |
+
rho = loaded_data['rho']
|
| 1724 |
+
device = loaded_data['device']
|
| 1725 |
+
schedule_type = loaded_data['schedule_type']
|
| 1726 |
+
decay_pattern = loaded_data['decay_pattern']
|
| 1727 |
+
|
| 1728 |
+
restored_config = loaded_data['full_config']
|
| 1729 |
+
|
| 1730 |
+
|
| 1731 |
+
# Optionally overwrite current settings with restored settings
|
| 1732 |
+
self.settings.update(restored_config)
|
| 1733 |
+
|
| 1734 |
+
#self.load_sigmas_with_hash_validation(self, loaded_data, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_key, suffix=None)
|
| 1735 |
+
self.log(f"[Cache Loaded] Sigma schedule, hash, and config loaded from: {self.cache_file.replace('.pt', '.txt')}")
|
| 1736 |
+
'''
|
| 1737 |
+
acceptable_keys = [
|
| 1738 |
+
"sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| 1739 |
+
"early_stopping_threshold", "initial_step_size",
|
| 1740 |
+
"final_step_size", "initial_noise_scale", "final_noise_scale",
|
| 1741 |
+
"smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| 1742 |
+
]
|
| 1743 |
+
|
| 1744 |
+
for key in acceptable_keys:
|
| 1745 |
+
default_val = self.settings[key]
|
| 1746 |
+
value = self.get_random_or_default(key, default_val)
|
| 1747 |
+
setattr(self, key, value)
|
| 1748 |
+
|
| 1749 |
+
if self.steps is None:
|
| 1750 |
+
raise ValueError("Number of steps must be provided.")
|
| 1751 |
+
if isinstance(self.device, str):
|
| 1752 |
+
self.device = torch.device(self.device)
|
| 1753 |
+
self.config_values()
|
| 1754 |
+
self.generate_sigmas_schedule(mode='prepass')
|
| 1755 |
+
|
| 1756 |
+
self.predicted_stop_step = self.steps if None else self.original_steps
|
| 1757 |
+
if self.sharpen_last_n_steps > len(self.sigs):
|
| 1758 |
+
self.sharpen_last_n_steps = len(self.sigs)
|
| 1759 |
+
self.log(f"[Sharpening Notice] Requested last {self.sharpen_last_n_steps} steps exceeds sequence length. Using entire sequence instead.")
|
| 1760 |
+
|
| 1761 |
+
self.visual_sigma = max(0.8, self.sigma_min * self.min_visual_sigma)
|
| 1762 |
+
|
| 1763 |
+
self.blend_sigma_sequence(
|
| 1764 |
+
sigmas_karras=None,
|
| 1765 |
+
sigmas_exponential=None,
|
| 1766 |
+
pre_pass = True,
|
| 1767 |
+
blend_methods=self.blend_methods,
|
| 1768 |
+
blend_weights = self.blend_weights
|
| 1769 |
+
)
|
| 1770 |
+
if torch.isnan(self.sigs).any() or torch.isinf(self.sigs).any():
|
| 1771 |
+
raise ValueError("Invalid sigma values detected (NaN or Inf).")
|
| 1772 |
+
final_steps = self.sigs[:self.predicted_stop_step].to(self.device)
|
| 1773 |
+
# Store the results for later use in compute_sigmas
|
| 1774 |
+
self.final_steps = final_steps
|
| 1775 |
+
if self.blending_mode == 'default':
|
| 1776 |
+
self.final_sigmas_karras = self.sigmas_karras
|
| 1777 |
+
self.final_sigmas_exponential = self.sigmas_exponential
|
| 1778 |
+
self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
|
| 1779 |
+
self.log(f"final sigmas karras: {self.final_sigmas_karras}")
|
| 1780 |
+
else:
|
| 1781 |
+
# For multi-method blending
|
| 1782 |
+
self.final_sigmas_blended = torch.tensor(self.blended_sigmas, device=self.device)
|
| 1783 |
+
|
| 1784 |
+
self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
|
| 1785 |
+
self.log(f"final blended sigmas: {self.final_sigmas_blended}")
|
| 1786 |
+
|
| 1787 |
+
# Optionally log the contributing sigma sequences for debugging
|
| 1788 |
+
for idx, (method, sigmas) in enumerate(zip(self.blend_methods, self.all_sigmas)):
|
| 1789 |
+
self.log(f"Method: {method}, Sigma sequence: {sigmas}")
|
| 1790 |
+
|
| 1791 |
+
|
| 1792 |
+
'''
|
| 1793 |
+
# Build cache key
|
| 1794 |
+
sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|
| 1795 |
+
|
| 1796 |
+
if self.save_prepass_sigmas:
|
| 1797 |
+
save_data = {
|
| 1798 |
+
'sigma_values': sigmas.cpu(), # Keep as tensor
|
| 1799 |
+
'sigma_hash': sigma_hash,
|
| 1800 |
+
'steps': steps,
|
| 1801 |
+
'sigma_min': sigma_min,
|
| 1802 |
+
'sigma_max': sigma_max,
|
| 1803 |
+
'rho': rho,
|
| 1804 |
+
'device': device,
|
| 1805 |
+
'schedule_type': schedule_type,
|
| 1806 |
+
'decay_pattern': decay_pattern,
|
| 1807 |
+
'full_config': self.settings # Save as raw dict
|
| 1808 |
+
}
|
| 1809 |
+
|
| 1810 |
+
# Save directly with torch.save in .pt format
|
| 1811 |
+
torch.save(save_data, self.cache_file) # Assuming self.cache_file has .pt extension
|
| 1812 |
+
self.log(f"[Sigma Saver] Final sigmas saved to: {self.cache_file}")
|
| 1813 |
+
'''
|
| 1814 |
+
def load_or_regenerate_sigmas(self, cache_key):
|
| 1815 |
+
if self.load_sigma_cache and cache_key:
|
| 1816 |
+
try:
|
| 1817 |
+
loaded_data = torch.load(self.cache_file, map_location=self.device)
|
| 1818 |
+
sigmas = loaded_data['sigma_values'].to(self.device)
|
| 1819 |
+
|
| 1820 |
+
except FileNotFoundError:
|
| 1821 |
+
self.log(f"[Cache Warning] Cache file not found: {self.cache_file}")
|
| 1822 |
+
self.log(f"[Cache Recovery] Automatically recomputing sigma schedule.")
|
| 1823 |
+
_, _, _, sigmas = self._generate_sigmas(
|
| 1824 |
+
self.steps,
|
| 1825 |
+
self.sigma_min,
|
| 1826 |
+
self.sigma_max,
|
| 1827 |
+
self.rho,
|
| 1828 |
+
self.device,
|
| 1829 |
+
self.schedule_type,
|
| 1830 |
+
self.decay_pattern
|
| 1831 |
+
)
|
| 1832 |
+
|
| 1833 |
+
# Recompute if cache is disabled or failed
|
| 1834 |
+
_, _, _, sigmas = self._generate_sigmas(
|
| 1835 |
+
self.steps,
|
| 1836 |
+
self.sigma_min,
|
| 1837 |
+
self.sigma_max,
|
| 1838 |
+
self.rho,
|
| 1839 |
+
self.device,
|
| 1840 |
+
self.schedule_type,
|
| 1841 |
+
self.decay_pattern
|
| 1842 |
+
)
|
| 1843 |
+
'''
|
| 1844 |
+
if self.save_prepass_sigmas:
|
| 1845 |
+
# Optional: Cache the recomputed sigma schedule
|
| 1846 |
+
torch.save(save_data, self.prepass_save_file)
|
| 1847 |
+
self.log(f"[Sigma Saver] Final sigmas saved to: {self.prepass_save_file}")
|
| 1848 |
+
|
| 1849 |
+
#self.sigma_cache[cache_key] = sigmas
|
| 1850 |
+
'''
|
| 1851 |
+
return sigmas
|
| 1852 |
+
|
| 1853 |
+
def compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type=None, decay_pattern=None, cache_key=None)->torch.Tensor:
|
| 1854 |
+
"""
|
| 1855 |
+
Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
|
| 1856 |
+
|
| 1857 |
+
Parameters:
|
| 1858 |
+
n (int): Number of steps.
|
| 1859 |
+
sigma_min (float): Minimum sigma value.
|
| 1860 |
+
sigma_max (float): Maximum sigma value.
|
| 1861 |
+
device (torch.device): The device on which to perform computations (e.g., 'cuda' or 'cpu').
|
| 1862 |
+
start_blend (float): Initial blend factor for dynamic blending.
|
| 1863 |
+
end_bend (float): Final blend factor for dynamic blending.
|
| 1864 |
+
sharpen_factor (float): Sharpening factor to be applied adaptively.
|
| 1865 |
+
early_stopping_threshold (float): Threshold to trigger early stopping.
|
| 1866 |
+
initial_step_size (float): Initial step size for adaptive step size calculation.
|
| 1867 |
+
final_step_size (float): Final step size for adaptive step size calculation.
|
| 1868 |
+
initial_noise_scale (float): Initial noise scale factor.
|
| 1869 |
+
final_noise_scale (float): Final noise scale factor.
|
| 1870 |
+
step_size_factor: Adjust to compensate for oversmoothing
|
| 1871 |
+
noise_scale_factor: Adjust to provide more variation
|
| 1872 |
+
|
| 1873 |
+
Returns:
|
| 1874 |
+
torch.Tensor: A tensor of blended sigma values.
|
| 1875 |
+
"""
|
| 1876 |
+
'''
|
| 1877 |
+
if self.load_sigma_cache and cache_key:
|
| 1878 |
+
sigmas = self._safe_sigma_loader(cache_key)
|
| 1879 |
+
if sigmas is None:
|
| 1880 |
+
self.log(f"[Cache Recovery] No valid cache found for key: {cache_key}. Recomputing sigma schedule.")
|
| 1881 |
+
_, _, _, sigmas = self._generate_sigmas(
|
| 1882 |
+
self.steps,
|
| 1883 |
+
self.sigma_min,
|
| 1884 |
+
self.sigma_max,
|
| 1885 |
+
self.rho,
|
| 1886 |
+
self.device,
|
| 1887 |
+
self.schedule_type,
|
| 1888 |
+
self.decay_pattern
|
| 1889 |
+
)
|
| 1890 |
+
else:
|
| 1891 |
+
self.log(f"[Cache Hit] Sigma schedule successfully loaded from cache.")
|
| 1892 |
+
self.sigs = sigmas
|
| 1893 |
+
'''
|
| 1894 |
+
acceptable_keys = [
|
| 1895 |
+
"sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| 1896 |
+
"early_stopping_threshold", "initial_step_size",
|
| 1897 |
+
"final_step_size", "initial_noise_scale", "final_noise_scale",
|
| 1898 |
+
"smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| 1899 |
+
]
|
| 1900 |
+
|
| 1901 |
+
if self.skip_prepass:
|
| 1902 |
+
for key in acceptable_keys:
|
| 1903 |
+
default_val = self.settings[key]
|
| 1904 |
+
value = self.get_random_or_default(key, default_val)
|
| 1905 |
+
setattr(self, key, value)
|
| 1906 |
+
|
| 1907 |
+
self.log(f"Using device: {self.device}")
|
| 1908 |
+
self.config_values()
|
| 1909 |
+
self.generate_sigmas_schedule(mode='final')
|
| 1910 |
+
if hasattr(self, 'final_sigmas_karras'):
|
| 1911 |
+
self.sigs = torch.zeros_like(self.final_sigmas_karras).to(self.device)
|
| 1912 |
+
else:
|
| 1913 |
+
self.sigs = torch.zeros_like(self.sigmas_karras).to(self.device)
|
| 1914 |
+
|
| 1915 |
+
self.blend_sigma_sequence(
|
| 1916 |
+
sigmas_karras=self.final_sigmas_karras if hasattr(self, 'final_sigmas_karras') else self.sigmas_karras,
|
| 1917 |
+
sigmas_exponential=self.final_sigmas_exponential if hasattr(self, 'final_sigmas_exponential') else self.sigmas_exponential,
|
| 1918 |
+
pre_pass=False,
|
| 1919 |
+
blend_methods=self.blend_methods,
|
| 1920 |
+
blend_weights = self.blend_weights
|
| 1921 |
+
|
| 1922 |
+
)
|
| 1923 |
+
self.sigma_variance = torch.var(self.sigs).item()
|
| 1924 |
+
if self.sharpen_mode in ['last_n', 'both']:
|
| 1925 |
+
if self.sigma_variance < self.sharpen_variance_threshold:
|
| 1926 |
+
# Apply full sharpening
|
| 1927 |
+
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1928 |
+
sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| 1929 |
+
self.sigs = self.sigs * self.sharpen_mask
|
| 1930 |
+
self.log(f"[Sharpen Mask] Full sharpening applied (low variance). Steps: {sharpen_indices}")
|
| 1931 |
+
else:
|
| 1932 |
+
# Apply sharpening only to the last N steps
|
| 1933 |
+
recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
|
| 1934 |
+
sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1935 |
+
sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
|
| 1936 |
+
self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
|
| 1937 |
+
|
| 1938 |
+
# Now loop per step if desired (safely inside this block)
|
| 1939 |
+
for j in range(len(self.sigs) - self.sharpen_last_n_steps, len(self.sigs)):
|
| 1940 |
+
if self.sigs[j] < self.sigma_min * 1.5:
|
| 1941 |
+
old_value = self.sigs[j].item()
|
| 1942 |
+
self.sigs[j] = self.sigs[j] * self.sharpness
|
| 1943 |
+
self.log(f"[Sharpening] Step {j+1}: Applied sharpening. Sigma changed from {old_value:.6f} to {self.sigs[j].item():.6f}")
|
| 1944 |
+
else:
|
| 1945 |
+
self.log(f"[Sharpening] Step {j+1}: No sharpening applied. Sigma: {self.sigs[j].item():.6f}")
|
| 1946 |
+
|
| 1947 |
+
if self.sharpen_mode in ['full', 'both']:
|
| 1948 |
+
# Optional: Additional full sharpening (if needed)
|
| 1949 |
+
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1950 |
+
sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| 1951 |
+
self.sigs = self.sigs * self.sharpen_mask
|
| 1952 |
+
self.log(f"[Sharpen Mask] Full sharpening applied at steps: {sharpen_indices}")
|
| 1953 |
+
|
| 1954 |
+
'''
|
| 1955 |
+
sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|
| 1956 |
+
if self.settings.get('save_sigma_cache', False):
|
| 1957 |
+
save_data = {
|
| 1958 |
+
'sigma_values': self.sigs.cpu().tolist(),
|
| 1959 |
+
'sigma_hash': sigma_hash,
|
| 1960 |
+
'steps': steps,
|
| 1961 |
+
'sigma_min': sigma_min,
|
| 1962 |
+
'sigma_max': sigma_max,
|
| 1963 |
+
'rho': rho,
|
| 1964 |
+
'device': device,
|
| 1965 |
+
'schedule_type': schedule_type,
|
| 1966 |
+
'decay_pattern': decay_pattern,
|
| 1967 |
+
'full_config': json.dumps(self.settings)
|
| 1968 |
+
}
|
| 1969 |
+
|
| 1970 |
+
|
| 1971 |
+
if self.save_sigma_cache:
|
| 1972 |
+
torch.save(save_data, self.final_save_file)
|
| 1973 |
+
self.log(f"[Sigma Saver] Final sigmas saved to: {self.final_save_file}")
|
| 1974 |
+
'''
|
| 1975 |
+
#self.log(f"[DEBUG]Final Output: Skip Prepass: {self.skip_prepass}. Original requested steps: {self.original_steps}. Self.steps = {self.steps} for tensor sigs: {self.sigs})")
|
| 1976 |
+
return self.sigs.to(self.device)
|
| 1977 |
+
|
| 1978 |
+
def get_sigma_from_cache(self, cache_key):
|
| 1979 |
+
"""
|
| 1980 |
+
Safely retrieves a sigma sequence from cache.
|
| 1981 |
+
Always returns a detached copy to prevent in-place modification of cached data.
|
| 1982 |
+
"""
|
| 1983 |
+
if cache_key in self.sigma_cache:
|
| 1984 |
+
cached_sigmas = self.sigma_cache[cache_key]
|
| 1985 |
+
self.log(f"[Cache Hit] Returning cached sigma sequence for key: {cache_key}")
|
| 1986 |
+
|
| 1987 |
+
# If tensor, clone and detach
|
| 1988 |
+
if isinstance(cached_sigmas, torch.Tensor):
|
| 1989 |
+
return cached_sigmas.clone().detach().to(self.device)
|
| 1990 |
+
|
| 1991 |
+
# If list, deep copy
|
| 1992 |
+
elif isinstance(cached_sigmas, list):
|
| 1993 |
+
import copy
|
| 1994 |
+
return copy.deepcopy(cached_sigmas)
|
| 1995 |
+
|
| 1996 |
+
# If other types, return as-is (optional: tighten later)
|
| 1997 |
+
else:
|
| 1998 |
+
return cached_sigmas
|
| 1999 |
+
|
| 2000 |
+
else:
|
| 2001 |
+
self.log(f"[Cache Miss] Cache key not found: {cache_key}")
|
| 2002 |
+
return None
|
sd_simple_kes_v3_fix?/validate_config.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, Any
|
| 2 |
+
|
| 3 |
+
RANDOMIZATION_TYPE_ALIASES = {
|
| 4 |
+
# Asymmetric
|
| 5 |
+
'asymmetric': 'asymmetric', 'assym': 'asymmetric', 'a': 'asymmetric', 'asym': 'asymmetric', 'A': 'asymmetric',
|
| 6 |
+
# Symmetric
|
| 7 |
+
'symmetric': 'symmetric', 'sym': 'symmetric', 's': 'symmetric', 'S': 'symmetric',
|
| 8 |
+
# Logarithmic
|
| 9 |
+
'logarithmic': 'logarithmic', 'log': 'logarithmic', 'l': 'logarithmic', 'L': 'logarithmic',
|
| 10 |
+
# Exponential
|
| 11 |
+
'exponential': 'exponential', 'exp': 'exponential', 'e': 'exponential', 'E': 'exponential',
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
DEFAULT_RANDOMIZATION_TYPE = 'asymmetric'
|
| 15 |
+
DEFAULT_RANDOMIZATION_PERCENT = 0.2
|
| 16 |
+
|
| 17 |
+
# Base default values
|
| 18 |
+
BASE_DEFAULTS = {
|
| 19 |
+
'sigma_min': 0.05,
|
| 20 |
+
'sigma_max': 27.5,
|
| 21 |
+
'start_blend': 0.1,
|
| 22 |
+
'end_blend': 0.5,
|
| 23 |
+
'sharpness': 1.0,
|
| 24 |
+
'early_stopping_threshold': 0.01,
|
| 25 |
+
'initial_step_size': 0.9,
|
| 26 |
+
'final_step_size': 0.2,
|
| 27 |
+
'initial_noise_scale': 1.25,
|
| 28 |
+
'final_noise_scale': 0.8,
|
| 29 |
+
'smooth_blend_factor': 9.0,
|
| 30 |
+
'step_size_factor': 0.8,
|
| 31 |
+
'noise_scale_factor': 0.8,
|
| 32 |
+
'rho': 8.0
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def validate_config(config: Dict[str, Any], logger: Optional[Any] = None) -> Dict[str, Any]:
|
| 36 |
+
updated_config = config.copy()
|
| 37 |
+
|
| 38 |
+
def log(message):
|
| 39 |
+
if logger:
|
| 40 |
+
logger.log(message)
|
| 41 |
+
else:
|
| 42 |
+
print(message)
|
| 43 |
+
|
| 44 |
+
# Step 1: Set all base defaults if missing
|
| 45 |
+
for key, base_value in BASE_DEFAULTS.items():
|
| 46 |
+
if key not in updated_config:
|
| 47 |
+
updated_config[key] = base_value
|
| 48 |
+
log(f"[Config Correction] {key} missing. Set to base default: {base_value}")
|
| 49 |
+
|
| 50 |
+
# Ensure _rand flag exists and is a boolean
|
| 51 |
+
rand_flag = f"{key}_rand"
|
| 52 |
+
if rand_flag not in updated_config or not isinstance(updated_config.get(rand_flag), bool):
|
| 53 |
+
updated_config[rand_flag] = False
|
| 54 |
+
log(f"[Config Correction] {rand_flag} missing or invalid. Set to False.")
|
| 55 |
+
|
| 56 |
+
# Ensure _enable_randomization_type flag exists and is a boolean
|
| 57 |
+
randomization_flag = f"{key}_enable_randomization_type"
|
| 58 |
+
if randomization_flag not in updated_config or not isinstance(updated_config.get(randomization_flag), bool):
|
| 59 |
+
updated_config[randomization_flag] = False
|
| 60 |
+
log(f"[Config Correction] {randomization_flag} missing or invalid. Set to False.")
|
| 61 |
+
|
| 62 |
+
# Ensure randomization_type exists
|
| 63 |
+
randomization_type_key = f"{key}_randomization_type"
|
| 64 |
+
if randomization_type_key not in updated_config:
|
| 65 |
+
updated_config[randomization_type_key] = DEFAULT_RANDOMIZATION_TYPE
|
| 66 |
+
log(f"[Config Correction] {randomization_type_key} missing. Set to '{DEFAULT_RANDOMIZATION_TYPE}'.")
|
| 67 |
+
|
| 68 |
+
# Ensure randomization_percent exists
|
| 69 |
+
randomization_percent_key = f"{key}_randomization_percent"
|
| 70 |
+
if randomization_percent_key not in updated_config:
|
| 71 |
+
updated_config[randomization_percent_key] = DEFAULT_RANDOMIZATION_PERCENT
|
| 72 |
+
log(f"[Config Correction] {randomization_percent_key} missing. Set to {DEFAULT_RANDOMIZATION_PERCENT}.")
|
| 73 |
+
|
| 74 |
+
# Ensure _rand_min and _rand_max exist
|
| 75 |
+
min_key = f"{key}_rand_min"
|
| 76 |
+
max_key = f"{key}_rand_max"
|
| 77 |
+
percent = updated_config[randomization_percent_key]
|
| 78 |
+
|
| 79 |
+
if min_key not in updated_config:
|
| 80 |
+
updated_config[min_key] = updated_config[key] * (1 - percent)
|
| 81 |
+
log(f"[Config Correction] {min_key} missing. Auto-calculated from base.")
|
| 82 |
+
|
| 83 |
+
if max_key not in updated_config:
|
| 84 |
+
updated_config[max_key] = updated_config[key] * (1 + percent)
|
| 85 |
+
log(f"[Config Correction] {max_key} missing. Auto-calculated from base.")
|
| 86 |
+
|
| 87 |
+
log("[Config Validation] Config validated and missing values filled successfully.")
|
| 88 |
+
return updated_config
|