dikdimon commited on
Commit
3b4fdab
·
verified ·
1 Parent(s): 688dd9c

Upload sd_simple_kes_v3_fix? using SD-Hub

Browse files
Files changed (35) hide show
  1. sd_simple_kes_v3_fix?/__pycache__/get_sigmas.cpython-310.pyc +0 -0
  2. sd_simple_kes_v3_fix?/__pycache__/plot_sigma_sequence.cpython-310.pyc +0 -0
  3. sd_simple_kes_v3_fix?/__pycache__/simple_kes_v3.cpython-310.pyc +0 -0
  4. sd_simple_kes_v3_fix?/__pycache__/validate_config.cpython-310.pyc +0 -0
  5. sd_simple_kes_v3_fix?/get_sigmas.py +18 -0
  6. sd_simple_kes_v3_fix?/kes_config/default_config.yaml +196 -0
  7. sd_simple_kes_v3_fix?/kes_config/good_configs.yaml +45 -0
  8. sd_simple_kes_v3_fix?/kes_config/how to use these files +20 -0
  9. sd_simple_kes_v3_fix?/kes_config/simple_kes_requirements.txt +5 -0
  10. sd_simple_kes_v3_fix?/kes_config/simple_kes_scheduler.yaml +146 -0
  11. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/alternating soft_hard decay.yaml +42 -0
  12. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/anime_1.yaml +18 -0
  13. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/cross_style_safe.yaml +18 -0
  14. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/front_loaded geometric.yaml +42 -0
  15. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/photo_realistic_1.yaml +18 -0
  16. sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/progressive.yaml +42 -0
  17. sd_simple_kes_v3_fix?/kes_config/user_config.yaml +81 -0
  18. sd_simple_kes_v3_fix?/plot_sigma_sequence.py +42 -0
  19. sd_simple_kes_v3_fix?/requirements.txt +4 -0
  20. sd_simple_kes_v3_fix?/schedulers/__pycache__/euler_advanced_scheduler.cpython-310.pyc +0 -0
  21. sd_simple_kes_v3_fix?/schedulers/__pycache__/exponential_advanced_scheduler.cpython-310.pyc +0 -0
  22. sd_simple_kes_v3_fix?/schedulers/__pycache__/geometric_advanced_scheduler.cpython-310.pyc +0 -0
  23. sd_simple_kes_v3_fix?/schedulers/__pycache__/harmonic_advanced_scheduler.cpython-310.pyc +0 -0
  24. sd_simple_kes_v3_fix?/schedulers/__pycache__/karras_advanced_scheduler.cpython-310.pyc +0 -0
  25. sd_simple_kes_v3_fix?/schedulers/__pycache__/logarithmic_advanced_scheduler.cpython-310.pyc +0 -0
  26. sd_simple_kes_v3_fix?/schedulers/__pycache__/shared.cpython-310.pyc +0 -0
  27. sd_simple_kes_v3_fix?/schedulers/euler_advanced_scheduler.py +53 -0
  28. sd_simple_kes_v3_fix?/schedulers/exponential_advanced_scheduler.py +41 -0
  29. sd_simple_kes_v3_fix?/schedulers/geometric_advanced_scheduler.py +60 -0
  30. sd_simple_kes_v3_fix?/schedulers/harmonic_advanced_scheduler.py +45 -0
  31. sd_simple_kes_v3_fix?/schedulers/karras_advanced_scheduler.py +40 -0
  32. sd_simple_kes_v3_fix?/schedulers/logarithmic_advanced_scheduler.py +46 -0
  33. sd_simple_kes_v3_fix?/schedulers/shared.py +184 -0
  34. sd_simple_kes_v3_fix?/simple_kes_v3.py +2002 -0
  35. sd_simple_kes_v3_fix?/validate_config.py +88 -0
sd_simple_kes_v3_fix?/__pycache__/get_sigmas.cpython-310.pyc ADDED
Binary file (964 Bytes). View file
 
sd_simple_kes_v3_fix?/__pycache__/plot_sigma_sequence.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
sd_simple_kes_v3_fix?/__pycache__/simple_kes_v3.cpython-310.pyc ADDED
Binary file (52.7 kB). View file
 
sd_simple_kes_v3_fix?/__pycache__/validate_config.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
sd_simple_kes_v3_fix?/get_sigmas.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.sd_simple_kes_v3.schedulers.karras_advanced_scheduler import get_sigmas_karras
2
+ from modules.sd_simple_kes_v3.schedulers.exponential_advanced_scheduler import get_sigmas_exponential
3
+ from modules.sd_simple_kes_v3.schedulers.geometric_advanced_scheduler import get_sigmas_geometric
4
+ from modules.sd_simple_kes_v3.schedulers.harmonic_advanced_scheduler import get_sigmas_harmonic
5
+ from modules.sd_simple_kes_v3.schedulers.logarithmic_advanced_scheduler import get_sigmas_logarithmic
6
+ from modules.sd_simple_kes_v3.schedulers.euler_advanced_scheduler import get_sigmas_euler, get_sigmas_euler_advanced
7
+
8
+
9
+ scheduler_registry = {
10
+ 'karras': get_sigmas_karras,
11
+ 'exponential': get_sigmas_exponential,
12
+ 'geometric': get_sigmas_geometric,
13
+ 'harmonic': get_sigmas_harmonic,
14
+ 'logarithmic': get_sigmas_logarithmic,
15
+ 'euler': get_sigmas_euler,
16
+ 'euler_advanced': get_sigmas_euler_advanced
17
+ # Add more here - ensure methods are added to get_sigmas then update imports #also update simple_kes
18
+ }
sd_simple_kes_v3_fix?/kes_config/default_config.yaml ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_save_directory: "modules/sd_simple_kes_v3/image_generation_data"
2
+ graph_save_directory: "modules/sd_simple_kes_v3/image_generation_data"
3
+ graph_save_enable: false
4
+ load_sigma_cache: false # Master toggle to enable/disable sigma caching
5
+ save_sigma_cache: false # Whether to save new sigma schedules to file
6
+ decay_pattern: "extrapolate" #Options: zero, soft_landing, extrapolate, and fractional
7
+ decay_mode: "append"
8
+ load_prepass_sigmas: false
9
+ save_prepass_sigmas: false
10
+
11
+ #blend methods #Key/Value - Options: "karras", "exponential", "geometric", "harmonic", "logarithmic"
12
+ #blend weights #Any number between 0 and infinity. Explicit will use that number in relation to the other numbers and give it more weight. Softmax will normalize all values in relation to each other and keep values between 0 and 1.
13
+ #if decay_pattern is empty or missing, decay_mode, tail_steps, and decay_rate do not get used.
14
+ #decay rate only affects geometric patterns
15
+ allow_step_expansion: false # Strict A1111 compatibility mode
16
+ apply_tail_steps: false # Append tails from schedulers to sigma sequence
17
+ apply_decay_tail: false # Append decay tails from schedulers to sigma sequence
18
+ apply_blended_tail: false # Blend multiple tails into a single tail, then append
19
+ apply_progressive_decay: false # Gradually apply decay to sigma sequence step-by-step
20
+
21
+ auto_tail_smoothing: true
22
+ auto_tail_threshold: 0.05
23
+ jaggedness_threshold: 0.01
24
+
25
+ auto_stabilization_sequence:
26
+ - smooth_interpolation
27
+ - append_tail
28
+ - blend_tail
29
+ - apply_decay
30
+ - progressive_decay
31
+
32
+ blending_style: 'softmax' # Options: 'explicit' or 'softmax'
33
+ #Valid decay_patterns: 'geometric', 'harmonic', 'extrapolate','fractional', 'logarithmic', 'exponential', 'linear', and 'zero'
34
+ #Valid decay_modes: 'append', 'blend', 'replace'
35
+ #Valid decay modes compatible with A1111: all if tail_steps is not greater than 1. If any methods add steps that increase steps higher than what was requested, it is not compatible
36
+ #decay modes have been tested and they work. However if they increase steps beyond the requested amount, it will not work in the A1111 pipeline. If a pipeline supports increasing steps to have a smoother transition for sigma/noise reduction, then this method would function as intended - to increase steps to have a smoother transition & no jaggedness between steps.
37
+
38
+ blend_methods:
39
+ euler: #if euler was a scheduler -
40
+ weight: 0.3
41
+ decay_pattern: 'harmonic'
42
+ decay_mode: 'blend'
43
+ tail_steps: 1
44
+
45
+
46
+ euler_advanced: #if euler advanced were a scheduler -
47
+ weight: 0.7
48
+ decay_pattern: 'harmonic'
49
+ decay_mode: 'blend'
50
+ tail_steps: 1
51
+
52
+
53
+
54
+
55
+ blending_mode: "default" # Options: "auto", "default" "smooth_blend", "weights"
56
+ #"auto" uses smart weights if more than 2 methods, or smooth blend if exactly two methods
57
+ #"default" is karras + exponential for the standard Simple_KES methods which uses smooth_blend_factor
58
+ #"smooth_blend" enforces use even with weights included
59
+ # "weights" enforces weights even with only 2 methods
60
+
61
+ smooth_blend_factor_rand: false
62
+ smooth_blend_factor_rand_min: 6
63
+ smooth_blend_factor_rand_max: 11
64
+ smooth_blend_factor: 3
65
+ smooth_blend_factor_enable_randomization_type: false
66
+ smooth_blend_factor_randomization_type: "asymmetric"
67
+ smooth_blend_factor_randomization_percent: 0.2
68
+
69
+
70
+
71
+ skip_prepass: true # has no change to image quality - not currently functioning as intended for early stop purposes
72
+ device: "cuda" #cpu or cuda
73
+ debug: true
74
+ global_randomize: false
75
+ #
76
+ sigma_scale_factor: 900
77
+ sigma_auto_enabled: true
78
+ sigma_auto_mode: sigma_min # Options: sigma_min, sigma_max
79
+ #
80
+ rho_rand: false
81
+ rho_rand_min: 3.00 # tested recommended settings threshold
82
+ rho_rand_max: 8.00 # tested recommended settings threshold
83
+ #rho: 7.571656624637901
84
+ rho: 7.959565031107985
85
+ rho_enable_randomization_type: false
86
+ rho_randomization_type: "log"
87
+ rho_randomization_percent: 0.1
88
+ #
89
+ sigma_min_rand: false
90
+ sigma_min_rand_min: 0.001 # tested recommended settings
91
+ sigma_min_rand_max: 0.02 # tested recommended settings threshold
92
+ sigma_min: 0.13757067353874633
93
+ sigma_min_enable_randomization_type: false
94
+ sigma_min_randomization_type: "asymmetric"
95
+ sigma_min_randomization_percent: 0.2
96
+ #
97
+ sigma_max_rand: false
98
+ sigma_max_rand_min: 25
99
+ sigma_max_rand_max: 60
100
+ sigma_max: 47.95768510805332
101
+ sigma_max_enable_randomization_type: false
102
+ sigma_max_randomization_type: "log"
103
+ sigma_max_randomization_percent: 0.25
104
+ #
105
+ start_blend_rand: false
106
+ start_blend_rand_min: 0.04 # tested recommended settings threshold
107
+ start_blend_rand_max: 0.11 # tested recommended settings threshold
108
+ start_blend: 0.05
109
+ start_blend_enable_randomization_type: false
110
+ start_blend_randomization_type: "asymmetric"
111
+ start_blend_randomization_percent: 0.1
112
+ #
113
+ end_blend_rand: false
114
+ end_blend_rand_min: 0.4 # tested recommended settings threshold
115
+ end_blend_rand_max: 0.6 # tested recommended settings threshold
116
+ end_blend: 0.4
117
+ end_blend_enable_randomization_type: false
118
+ end_blend_randomization_type: "asymmetric"
119
+ end_blend_randomization_percent: 0.2
120
+ #
121
+ sharpness_rand: false
122
+ sharpness_rand_min: 0.75 # tested recommended settings threshold
123
+ sharpness_rand_max: 0.95 # tested recommended settings threshold
124
+ sharpness: 0.85 # Note: Visible changes in image between 2-15. Above 15 - notable differences. At 50+ - poor image quality. sharpness not applied above 0.95
125
+ sharpen_variance_threshold: 0.01
126
+ sharpen_last_n_steps: 10
127
+ sharpen_mode: "full" # Options: last_n, full, both
128
+ sharpness_enable_randomization_type: false
129
+ sharpness_randomization_type: "asymmetric"
130
+ sharpness_randomization_percent: 0.2
131
+ #
132
+ step_progress_mode: "sigmoid" # Options supported (default = "linear"), "exponential", "logarithmic", or "sigmoid". If exponential, uses "exp_power"
133
+ exp_power: 2
134
+ #
135
+ initial_step_size_rand: false
136
+ initial_step_size_rand_min: 0.7
137
+ initial_step_size_rand_max: 1.0
138
+ initial_step_size: 0.9
139
+ initial_step_size_enable_randomization_type: false
140
+ initial_step_size_randomization_type: "asymmetric" #assym, symm, log, or exp A/S/L/E
141
+ initial_step_size_randomization_percent: 0.2
142
+ #
143
+ final_step_size_rand: false
144
+ final_step_size_rand_min: 0.1
145
+ final_step_size_rand_max: 0.3
146
+ final_step_size: 0.20
147
+ final_step_size_enable_randomization_type: false
148
+ final_step_size_randomization_type: "asymmetric"
149
+ final_step_size_randomization_percent: 0.2
150
+ #
151
+ step_size_factor_rand: false
152
+ step_size_factor_rand_min: 0.65
153
+ step_size_factor_rand_max: 0.85
154
+ step_size_factor: 0.80814932869181
155
+ step_size_factor_enable_randomization_type: false
156
+ step_size_factor_randomization_type: "asymmetric"
157
+ step_size_factor_randomization_percent: 0.2
158
+ #
159
+ initial_noise_scale_rand: false
160
+ initial_noise_scale_rand_min: 1.0
161
+ initial_noise_scale_rand_max: 1.5
162
+ initial_noise_scale: 1.25
163
+ initial_noise_scale_enable_randomization_type: false
164
+ initial_noise_scale_randomization_type: "asymmetric"
165
+ initial_noise_scale_randomization_percent: 0.2
166
+ #
167
+ final_noise_scale_rand: false
168
+ final_noise_scale_rand_min: 0.6
169
+ final_noise_scale_rand_max: 1.0
170
+ final_noise_scale: 0.80
171
+ final_noise_scale_enable_randomization_type: false
172
+ final_noise_scale_randomization_type: "asymmetric"
173
+ final_noise_scale_randomization_percent: 0.2
174
+
175
+ #
176
+ noise_scale_factor_rand: false
177
+ noise_scale_factor_rand_min: 0.75
178
+ noise_scale_factor_rand_max: 0.95
179
+ noise_scale_factor: 0.8113992828873163
180
+ noise_scale_factor_enable_randomization_type: false
181
+ noise_scale_factor_randomization_type: "asymmetric"
182
+ noise_scale_factor_randomization_percent: 0.2
183
+
184
+ # Experimental settings
185
+ early_stopping_threshold_rand: false
186
+ early_stopping_threshold_rand_min: 0.001
187
+ early_stopping_threshold_rand_max: 0.02
188
+ early_stopping_threshold: 0.06
189
+ early_stopping_method: max # Options: mean, max, sum
190
+ sigma_variance_scale: 0.1 # *100 = % of current sigma, increase to reduce false early stopping, try 0.07 or 0.10
191
+ safety_minimum_stop_step: 10 # means won't consider until past this step, consider increasing this to increase minimum steps to process the image
192
+ recent_change_convergence_delta: 0.6 # this is the change between mean/max variable changes between sigmas. Keep this relatively low. This contributes directly to when we stop.
193
+ #min_visual_sigma: 50 # Increase from 10 to push later into the denoising sequence
194
+ early_stopping_threshold_enable_randomization_type: false
195
+ early_stopping_threshold_randomization_type: "asymmetric"
196
+ early_stopping_threshold_randomization_percent: 0.2
sd_simple_kes_v3_fix?/kes_config/good_configs.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler:
3
+ weight: 0.0
4
+ decay_pattern: 'geometric'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ euler_advanced:
9
+ weight: 1.895
10
+ decay_pattern: 'harmonic'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ geometric:
15
+ weight: 0.0
16
+ decay_pattern: 'exponential'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
19
+
20
+ harmonic:
21
+ weight: 0.5
22
+ decay_pattern: 'logarithmic'
23
+ decay_mode: 'blend'
24
+ tail_steps: 1
25
+
26
+ logarithmic:
27
+ weight: 0.5
28
+ decay_pattern: 'fractional'
29
+ decay_mode: 'blend'
30
+ tail_steps: 1
31
+
32
+ karras:
33
+ weight: 0.0
34
+ decay_pattern: 'exponential'
35
+ decay_mode: 'blend'
36
+ tail_steps: 1
37
+
38
+ exponential:
39
+ weight: 0.0
40
+ decay_pattern: 'logarithmic'
41
+ decay_mode: 'blend'
42
+ tail_steps: 1
43
+
44
+ blending_style: 'softmax'
45
+ blending_mode: 'weights'
sd_simple_kes_v3_fix?/kes_config/how to use these files ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ These are only suggested configurations. To give you a starting point for your configs.
3
+
4
+ Do not make changes to the default_config file unless you know what you're doing. If you do decide to tweak it, please copy it so you can revert changes if needed.
5
+
6
+ The best way is to copy whatever configuration changes from the default_config into the user_config since any values in the user_config will override the default_config when the program runs. This means you should never need to change the default_config.
7
+
8
+ I've kept my user_config pretty basic and I've kept the default_config as is. Should you want to edit your user_config with advanced settings you can find those advanced settings inside the default_config.
9
+
10
+ The files in the "suggested_scheduling_configs" folder have suggested values for different scheduler combinations **ONLY**.
11
+
12
+ Simple replace in line or completely your user_config yaml with the suggested config. Tweak as needed.
13
+
14
+ If you decide to replace in line your current file, simply set the weights to 0.0 for any current ones that you're using, or you could comment out the lines that you don't use.
15
+
16
+ If you're happy with your current user_config, I'd recommend you simply make a copy of it, and then tweak the user_config by copying the files in whole. That way is the easiest and cleanest way to ensure you don't delete things that work and you can play around with other values.
17
+
18
+
19
+
20
+
sd_simple_kes_v3_fix?/kes_config/simple_kes_requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ k_diffusion
2
+ requests
3
+ torch
4
+ Pyyaml
5
+ watchdog
sd_simple_kes_v3_fix?/kes_config/simple_kes_scheduler.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+
3
+ #Optionally print to a log file for debugging. If false, debug is turned off, and no log file will be created.
4
+ #config options: true or false
5
+ debug: true
6
+
7
+ # The minimum value for the noise level (sigma) during image generation.
8
+ # Decreasing this value makes the image clearer but less detailed.
9
+ # Increasing it makes the image noisier but potentially more artistic or abstract.
10
+ sigma_min: 0.04 # Default: 0.01, Suggested range: 0.01 - 0.1
11
+
12
+ # The maximum value for the noise level (sigma) during image generation.
13
+ # Increasing this value can create more variation in the image details.
14
+ # Lower values keep the image more stable and less noisy.
15
+ sigma_max: 50 # Default: 50, Suggested range:10 - 60
16
+
17
+ # The device used for running the scheduler. If you have a GPU, set this to "cuda".
18
+ # Otherwise, use "cpu", but note that it will be significantly slower.
19
+ #device: "cuda" # Options: "cuda" (GPU) or "cpu" (processor)
20
+
21
+ # Initial blend factor between Karras and Exponential noise methods.
22
+ # A higher initial blend makes the image sharper at the start.
23
+ # A lower initial blend makes the image smoother early on.
24
+ start_blend: 0.15 # Default: 0.1, Suggested range: 0.05 - 0.2
25
+
26
+ # Final blend factor between Karras and Exponential noise methods.
27
+ # Higher values blend more noise at the end, possibly adding more detail.
28
+ # Lower values blend less noise for smoother, simpler images at the end.
29
+ end_blend: 0.485 # Default: 0.5, Suggested range: 0.4 - 0.6
30
+
31
+ # Sharpening factor applied to images during generation.
32
+ # Higher values increase sharpness but can add unwanted artifacts.
33
+ # Lower values reduce sharpness but may make the image look blurry.
34
+ sharpness: 0.9 # Default: 0.95, Suggested range: 0.8 - 1.0
35
+
36
+ # Early stopping threshold for stopping the image generation when changes between steps are minimal.
37
+ # Lower values stop early, saving time, but might produce incomplete images.
38
+ # Higher values take longer but may give more detailed results.
39
+ early_stopping_threshold: 0.05 # Default: 0.01, Suggested range: 0.01 - 0.05
40
+
41
+ # The number of steps between updates of the blend factor.
42
+ # Smaller values update the blend more frequently for smoother transitions.
43
+ # Larger values update the blend less frequently for faster processing.
44
+ update_interval: 10 # Default: 10, Suggested range: 5 - 15
45
+
46
+ # Initial step size, which controls how quickly the image evolves early on.
47
+ # Higher values make big changes at the start, possibly generating faster but less refined images.
48
+ # Lower values make smaller changes, giving more control over details.
49
+ initial_step_size: 0.5 # Default, 0.9, Suggested range: 0.5 - 1.0
50
+
51
+ # Final step size, which controls how much the image changes towards the end.
52
+ # Higher values keep details more flexible until the end, which may add complexity.
53
+ # Lower values lock the details earlier, making the image simpler.
54
+ final_step_size: 0.22 # Default: 0.2, Suggested range: 0.1 - 0.3
55
+
56
+ # Initial noise scaling applied to the image generation process.
57
+ # Higher values add more noise early on, making the initial image more random.
58
+ # Lower values reduce noise early on, leading to a smoother initial image.
59
+ initial_noise_scale: 1.1 # Default, 1.25, Suggested range: 1.0 - 1.5
60
+
61
+ # Final noise scaling applied at the end of the image generation.
62
+ # Higher values add noise towards the end, possibly adding fine detail.
63
+ # Lower values reduce noise towards the end, making the final image smoother.
64
+ final_noise_scale: 0.7 # Default, 0.8, Suggested range: 0.6 - 1.0
65
+
66
+
67
+ smooth_blend_factor: 11 #Default: 11, try 6 for more variation
68
+ step_size_factor: 0.75 #suggested value (0.8) to avoid oversmoothing
69
+ noise_scale_factor: 0.95 #suggested value (0.9) to add more variation
70
+
71
+
72
+ # Enables global randomization.
73
+ # If true, all parameters are randomized within specified min/max ranges.
74
+ # If false, individual parameters with _rand flags set to true will still be randomized.
75
+ randomize: false
76
+
77
+ #Sigma values typically start very small. Lowering this could allow more gradual noise reduction. Too large would overwhelm the process.
78
+ sigma_min_rand: false
79
+ sigma_min_rand_min: 0.001
80
+ sigma_min_rand_max: 0.05
81
+
82
+ #Sigma max controls the upper limit of the noise. A lower minimum could allow faster convergence, while a higher max gives more flexibility for noisier images.
83
+ sigma_max_rand: false
84
+ sigma_max_rand_min: 22
85
+ sigma_max_rand_max: 48
86
+
87
+ #Start blend controls how strongly Karras and Exponential are blended at the start. A slightly lower value introduces more variety in the blending at the beginning.
88
+ start_blend_rand: false
89
+ start_blend_rand_min: 0.05
90
+ start_blend_rand_max: 0.2
91
+
92
+ # End blend affects how much the blending changes towards the end. Increasing the upper limit would allow more variation.
93
+ end_blend_rand: false
94
+ end_blend_rand_min: 0.4
95
+ end_blend_rand_max: 0.6
96
+
97
+ # Sharpness controls detail retention. You wouldn’t want to lower it too much, as it might lose detail.
98
+ sharpness_rand: false
99
+ sharpness_rand_min: 0.85
100
+ sharpness_rand_max: 1.0
101
+
102
+ #A smaller early stopping threshold could lead to earlier stopping if the changes between sigma steps become too small, while the upper value would prevent early stopping until larger changes occur.
103
+ early_stopping_rand: false
104
+ early_stopping_rand_min: 0.001
105
+ early_stopping_rand_max: 0.02
106
+
107
+ #Update intervals affect how frequently blending factors are updated. More frequent updates allow more flexibility in blending.
108
+ update_interval_rand: false
109
+ update_interval_rand_min: 5
110
+ update_interval_rand_max: 10
111
+
112
+ # The initial step size defines how large the steps are at the start. A slightly smaller value introduces more gradual transitions.
113
+ initial_step_rand: false
114
+ initial_step_rand_min: 0.7
115
+ initial_step_rand_max: 1.0
116
+
117
+ # The final step size defines how small the steps become towards the end. A slightly larger range gives more control over the final convergence.
118
+ final_step_rand: false
119
+ final_step_rand_min: 0.1
120
+ final_step_rand_max: 0.3
121
+
122
+ #Initial noise scale defines how much noise to introduce initially. Larger values make the process start with more randomness, while smaller values keep it controlled.
123
+ initial_noise_rand: false
124
+ initial_noise_rand_min: 1.0
125
+ initial_noise_rand_max: 1.5
126
+
127
+ # Final noise scale affects how much noise is reduced at the end. A lower minimum allows more noise to persist, while a higher maximum ensures full convergence.
128
+ final_noise_rand: false
129
+ final_noise_rand_min: 0.6
130
+ final_noise_rand_max: 1.0
131
+
132
+ #The smooth blend factor controls how aggressively the blending is smoothed. Lower values allow more abrupt blending changes, while higher values give smoother transitions.
133
+ smooth_blend_factor_rand: false
134
+ smooth_blend_factor_rand_min: 6
135
+ smooth_blend_factor_rand_max: 11
136
+
137
+ #Step size factor adjusts the step size dynamically to avoid oversmoothing. A lower minimum increases variety, while a higher max provides smoother results.
138
+ step_size_factor_rand: false
139
+ step_size_factor_rand_min: 0.65
140
+ step_size_factor_rand_max: 0.85
141
+
142
+ # Noise scale factor controls how noise is scaled throughout the steps. A slightly lower minimum adds more variety, while keeping the maximum value near the suggested ensures more uniform results.
143
+ noise_scale_factor_rand: false
144
+ noise_scale_factor_rand_min: 0.75
145
+ noise_scale_factor_rand_max: 0.95
146
+
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/alternating soft_hard decay.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler:
3
+ weight: 0.5
4
+ decay_pattern: 'geometric'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ euler_advanced:
9
+ weight: 0.4
10
+ decay_pattern: 'harmonic'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ geometric:
15
+ weight: 0.7
16
+ decay_pattern: 'exponential'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
19
+
20
+ harmonic:
21
+ weight: 0.3
22
+ decay_pattern: 'linear'
23
+ decay_mode: 'blend'
24
+ tail_steps: 1
25
+
26
+ logarithmic:
27
+ weight: 0.6
28
+ decay_pattern: 'fractional'
29
+ decay_mode: 'blend'
30
+ tail_steps: 1
31
+
32
+ karras:
33
+ weight: 0.8
34
+ decay_pattern: 'exponential'
35
+ decay_mode: 'blend'
36
+ tail_steps: 1
37
+
38
+ exponential:
39
+ weight: 0.2
40
+ decay_pattern: 'logarithmic'
41
+ decay_mode: 'blend'
42
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/anime_1.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler_advanced:
3
+ weight: 0.5
4
+ decay_pattern: 'harmonic'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ harmonic:
9
+ weight: 0.7
10
+ decay_pattern: 'logarithmic'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ karras:
15
+ weight: 0.6
16
+ decay_pattern: 'fractional'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/cross_style_safe.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler_advanced:
3
+ weight: 0.4
4
+ decay_pattern: 'logarithmic'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ harmonic:
9
+ weight: 0.4
10
+ decay_pattern: 'linear'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ karras:
15
+ weight: 0.5
16
+ decay_pattern: 'exponential'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/front_loaded geometric.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler:
3
+ weight: 0.6
4
+ decay_pattern: 'geometric'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ euler_advanced:
9
+ weight: 0.7
10
+ decay_pattern: 'geometric'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ geometric:
15
+ weight: 0.8
16
+ decay_pattern: 'geometric'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
19
+
20
+ harmonic:
21
+ weight: 0.4
22
+ decay_pattern: 'harmonic'
23
+ decay_mode: 'blend'
24
+ tail_steps: 1
25
+
26
+ logarithmic:
27
+ weight: 0.3
28
+ decay_pattern: 'logarithmic'
29
+ decay_mode: 'blend'
30
+ tail_steps: 1
31
+
32
+ karras:
33
+ weight: 0.5
34
+ decay_pattern: 'fractional'
35
+ decay_mode: 'blend'
36
+ tail_steps: 1
37
+
38
+ exponential:
39
+ weight: 0.2
40
+ decay_pattern: 'linear'
41
+ decay_mode: 'blend'
42
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/photo_realistic_1.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler:
3
+ weight: 0.6
4
+ decay_pattern: 'geometric'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ geometric:
9
+ weight: 0.7
10
+ decay_pattern: 'exponential'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ karras:
15
+ weight: 0.8
16
+ decay_pattern: 'geometric'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/suggested_scheduling_configs/progressive.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blend_methods:
2
+ euler:
3
+ weight: 0.2
4
+ decay_pattern: 'harmonic'
5
+ decay_mode: 'blend'
6
+ tail_steps: 1
7
+
8
+ euler_advanced:
9
+ weight: 0.2
10
+ decay_pattern: 'logarithmic'
11
+ decay_mode: 'blend'
12
+ tail_steps: 1
13
+
14
+ geometric:
15
+ weight: 0.3
16
+ decay_pattern: 'linear'
17
+ decay_mode: 'blend'
18
+ tail_steps: 1
19
+
20
+ harmonic:
21
+ weight: 0.4
22
+ decay_pattern: 'fractional'
23
+ decay_mode: 'blend'
24
+ tail_steps: 1
25
+
26
+ logarithmic:
27
+ weight: 0.5
28
+ decay_pattern: 'geometric'
29
+ decay_mode: 'blend'
30
+ tail_steps: 1
31
+
32
+ karras:
33
+ weight: 0.6
34
+ decay_pattern: 'exponential'
35
+ decay_mode: 'blend'
36
+ tail_steps: 1
37
+
38
+ exponential:
39
+ weight: 0.7
40
+ decay_pattern: 'extrapolate'
41
+ decay_mode: 'blend'
42
+ tail_steps: 1
sd_simple_kes_v3_fix?/kes_config/user_config.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========================================================
2
+ # ПОЛНЫЙ КОНФИГ (БЕЗ ЗАВИСИМОСТИ ОТ DEFAULT)
3
+ # ========================================================
4
+
5
+ # --- 1. ГЛАВНЫЕ НАСТРОЙКИ СМЕШИВАНИЯ ---
6
+ # Используем "default", чтобы избежать деления на ноль.
7
+ # Это смешивает Karras и Exponential.
8
+ blending_mode: "default"
9
+
10
+ # Управление плавностью перехода между Karras и Exponential.
11
+ # Значение 3 делает переход очень мягким.
12
+ smooth_blend_factor: 9
13
+
14
+ # --- 2. ИСПРАВЛЕНИЕ АРТЕФАКТОВ (ГЛАВНОЕ) ---
15
+ # "linear" убирает "ступеньки" на линиях. (Было "sigmoid").
16
+ step_progress_mode: "exponential"
17
+
18
+ # --- 3. НАСТРОЙКИ ШУМА И ШАГОВ (ВАШ "ХАРАКТЕР") ---
19
+ # Эти числа взяты из вашего старого конфига, чтобы шедулер был "нескучным".
20
+
21
+ # Размеры шагов (динамика скорости)
22
+ initial_step_size: 0.9
23
+ final_step_size: 0.20
24
+ step_size_factor: 0.80814932869181
25
+
26
+ # Масштаб шума (динамика детализации)
27
+ initial_noise_scale: 1.25
28
+ final_noise_scale: 0.80
29
+ noise_scale_factor: 0.8113992828873163
30
+
31
+ # --- 4. НАСТРОЙКИ ГРАНИЦ ШУМА (SIGMA) ---
32
+ # Ваши точные значения.
33
+ sigma_min: 0.13757067353874633
34
+ sigma_max: 47.95768510805332
35
+ rho: 7.959565031107985
36
+
37
+ # Авто-масштабирование (можно оставить включенным)
38
+ sigma_auto_enabled: true
39
+ sigma_auto_mode: "sigma_min"
40
+ sigma_scale_factor: 900
41
+
42
+ # --- 5. РЕЗКОСТЬ (SHARPNESS) ---
43
+ # 0.85 = средняя резкость.
44
+ sharpness: 0.85
45
+
46
+ # ВАЖНО: Применяем только к последним 10 шагам, чтобы не ломать картинку.
47
+ sharpen_mode: "last_n"
48
+ sharpen_last_n_steps: 10
49
+ sharpen_variance_threshold: 0.01
50
+
51
+ # --- 6. СМЕШИВАНИЕ (ТОЛЬКО ДЛЯ СПРАВКИ) ---
52
+ # В режиме "default" эти веса игнорируются, но параметры хвостов важны.
53
+ blend_methods:
54
+ karras:
55
+ weight: 1.0
56
+ decay_pattern: 'zero'
57
+ decay_mode: 'blend'
58
+ tail_steps: 1
59
+ exponential:
60
+ weight: 1.0
61
+ decay_pattern: 'zero'
62
+ decay_mode: 'blend'
63
+ tail_steps: 1
64
+
65
+ # --- 7. СИСТЕМНЫЕ НАСТРОЙКИ ---
66
+ debug: true
67
+ device: "cuda"
68
+ skip_prepass: true # Отключаем пре-проход для скорости
69
+ global_randomize: false
70
+
71
+ # Отключаем рандомизацию отдельных параметров для стабильности
72
+ rho_rand: false
73
+ sigma_min_rand: false
74
+ sigma_max_rand: false
75
+ start_blend_rand: false
76
+ end_blend_rand: false
77
+ sharpness_rand: false
78
+ initial_step_size_rand: false
79
+ final_step_size_rand: false
80
+ initial_noise_scale_rand: false
81
+ final_noise_scale_rand: false
sd_simple_kes_v3_fix?/plot_sigma_sequence.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import os
4
+
5
+ def plot_sigma_sequence(sigs, stopping_index, log_filename, save_directory="modules/sd_simple_kes_v3/image_generation_data", show_plot=False):
6
+ """
7
+ Plot the sigma sequence and mark the early stopping point.
8
+
9
+ Parameters:
10
+ - sigs: The sigma tensor or numpy array (can be truncated if stopping early).
11
+ - stopping_index: The step index where early stopping was triggered.
12
+ - log_filename: The filename of the generation log (used to match the graph name).
13
+ - save_directory: The folder where the plot should be saved.
14
+ - show_plot: Set to True to display the plot interactively.
15
+ """
16
+
17
+ # Extract base name to match log filename
18
+ base_filename = os.path.splitext(os.path.basename(log_filename))[0]
19
+ graph_filename = f"{base_filename}_sigma_plot.png"
20
+ graph_path = os.path.join(save_directory, graph_filename)
21
+
22
+ # Prepare sigma sequence for plotting
23
+ sigs_np = sigs.cpu().numpy() if hasattr(sigs, 'cpu') else np.array(sigs)
24
+ x = np.arange(len(sigs_np))
25
+
26
+ # Plotting
27
+ plt.figure(figsize=(10, 6))
28
+ plt.plot(x, sigs_np, label='Sigma Sequence', marker='o')
29
+ plt.axvline(x=stopping_index, color='red', linestyle='--', label=f'Stopping Point: {stopping_index}')
30
+ plt.xlabel('Step Index')
31
+ plt.ylabel('Sigma Value')
32
+ plt.title('Sigma Sequence with Early Stopping Point')
33
+ plt.legend()
34
+ plt.grid(True)
35
+ plt.tight_layout()
36
+ plt.savefig(graph_path)
37
+
38
+ if show_plot:
39
+ plt.show()
40
+
41
+ plt.close()
42
+ return graph_path
sd_simple_kes_v3_fix?/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ pyyaml
3
+ numpy
4
+ matplotlib
sd_simple_kes_v3_fix?/schedulers/__pycache__/euler_advanced_scheduler.cpython-310.pyc ADDED
Binary file (2.02 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/exponential_advanced_scheduler.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/geometric_advanced_scheduler.cpython-310.pyc ADDED
Binary file (1.65 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/harmonic_advanced_scheduler.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/karras_advanced_scheduler.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/logarithmic_advanced_scheduler.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/__pycache__/shared.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
sd_simple_kes_v3_fix?/schedulers/euler_advanced_scheduler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
3
+ import math
4
+
5
+ def get_sigmas_euler(steps, sigma_min, sigma_max, device='cpu'):
6
+ """
7
+ Sigma schedule designed to work well with Euler sampling.
8
+ Logarithmic spacing with a smooth transition.
9
+ """
10
+ def _to_tensor(val, device):
11
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
12
+
13
+ # Convert sigma_min and sigma_max to tensors safely
14
+ sigma_min = _to_tensor(sigma_min, device)
15
+ sigma_max = _to_tensor(sigma_max, device)
16
+
17
+ sigmas = torch.exp(torch.linspace(math.log(sigma_max.item()), math.log(sigma_min.item()), steps, device=device))
18
+
19
+ tails = None
20
+ decay = None
21
+
22
+
23
+ return tails, decay, sigmas
24
+
25
+ def get_sigmas_euler_advanced(steps, sigma_min, sigma_max, device='cpu', blend_factor=0.5, decay_pattern=None, decay_mode=None, tail_steps=None):
26
+ def _to_tensor(val, device):
27
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
28
+ ramp = torch.linspace(0, 1, steps, device=device)
29
+ sigmas_exp = torch.exp(torch.linspace(math.log(sigma_max), math.log(sigma_min), steps, device=device))
30
+ sigmas_karras = (sigma_max ** (1 / 7) + ramp * (sigma_min ** (1 / 7) - sigma_max ** (1 / 7))) ** 7
31
+
32
+ # Blend Karras and Exponential schedules for a hybrid Euler-friendly progression
33
+ sigmas = (1 - blend_factor) * sigmas_exp + blend_factor * sigmas_karras
34
+ tails = None
35
+ decay = None
36
+
37
+ if decay_pattern:
38
+ if decay_pattern in valid_decay_patterns:
39
+ tails = apply_last_tail(sigmas, device, decay_pattern)
40
+ elif decay_pattern not in valid_decay_patterns:
41
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
42
+ if decay_mode:
43
+ if decay_mode in valid_decay_modes:
44
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
45
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
46
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
47
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
48
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
49
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
50
+ elif decay_mode not in valid_decay_modes:
51
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
52
+
53
+ return tails, decay, sigmas
sd_simple_kes_v3_fix?/schedulers/exponential_advanced_scheduler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
4
+
5
+
6
+ def get_sigmas_exponential(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
7
+ """Constructs an exponential noise schedule."""
8
+
9
+ def _to_tensor(val, device):
10
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
11
+
12
+ # Convert sigma_min and sigma_max to tensors safely
13
+ sigma_min = _to_tensor(sigma_min, device)
14
+ sigma_max = _to_tensor(sigma_max, device)
15
+
16
+ tail_steps = tail_steps or 5
17
+
18
+
19
+ # Exponential progression (correct)
20
+ sigmas = torch.linspace(math.log(sigma_max.item()), math.log(sigma_min.item()), steps, device=device).exp()
21
+
22
+ tails = None
23
+ decay = None
24
+
25
+ if decay_pattern:
26
+ if decay_pattern in valid_decay_patterns:
27
+ tails = apply_last_tail(sigmas, device, decay_pattern)
28
+ elif decay_pattern not in valid_decay_patterns:
29
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
30
+ if decay_mode:
31
+ if decay_mode in valid_decay_modes:
32
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
33
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
34
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
35
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
36
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
37
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
38
+ elif decay_mode not in valid_decay_modes:
39
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
40
+
41
+ return tails, decay, sigmas
sd_simple_kes_v3_fix?/schedulers/geometric_advanced_scheduler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
3
+
4
+
5
+
6
+ def get_sigmas_geometric(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
7
+ def _to_tensor(val, device):
8
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
9
+
10
+ # Convert sigma_min and sigma_max to tensors safely
11
+ sigma_min = _to_tensor(sigma_min, device)
12
+ sigma_max = _to_tensor(sigma_max, device)
13
+
14
+ sigmas = [sigma_max]
15
+ tail_steps = tail_steps or 5
16
+
17
+ for step in range(1, steps):
18
+ if len(sigmas) >= 2:
19
+ deltas = torch.abs(torch.tensor(sigmas[:-1]) - torch.tensor(sigmas[1:]))
20
+ avg_delta = torch.mean(deltas).item()
21
+ else:
22
+ avg_delta = sigmas[-1].item() * 0.1
23
+
24
+ last_sigma = sigmas[-1]
25
+ dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85) # Clamp for stability
26
+
27
+ next_sigma = max(last_sigma * dynamic_decay_rate, sigma_min.item())
28
+ sigmas.append(next_sigma)
29
+
30
+ sigmas = torch.tensor(sigmas, device=device)
31
+
32
+ tails = None
33
+ decay = None
34
+
35
+ if decay_pattern:
36
+ if decay_pattern in valid_decay_patterns:
37
+ tails = apply_last_tail(sigmas, device, decay_pattern)
38
+ elif decay_pattern not in valid_decay_patterns:
39
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
40
+ if decay_mode:
41
+ if decay_mode in valid_decay_modes:
42
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
43
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
44
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
45
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
46
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
47
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
48
+ elif decay_mode not in valid_decay_modes:
49
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
50
+
51
+ return tails, decay, sigmas
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
sd_simple_kes_v3_fix?/schedulers/harmonic_advanced_scheduler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
3
+
4
+
5
+ def get_sigmas_harmonic(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
6
+ def _to_tensor(val, device):
7
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
8
+
9
+ # Convert sigma_min and sigma_max to tensors safely
10
+ sigma_min = _to_tensor(sigma_min, device)
11
+ sigma_max = _to_tensor(sigma_max, device)
12
+
13
+ sigmas = []
14
+ tail_steps = tail_steps or 5
15
+
16
+
17
+ # Harmonic decay calculation
18
+ for i in range(1, steps + 1):
19
+ next_sigma = sigma_max.item() / (i+1) # Harmonic sequence
20
+ next_sigma = max(next_sigma, sigma_min.item())
21
+ sigmas.append(next_sigma)
22
+
23
+ # Convert list to tensor
24
+ sigmas = torch.tensor(sigmas, device=device)
25
+
26
+ tails = None
27
+ decay = None
28
+
29
+ if decay_pattern:
30
+ if decay_pattern in valid_decay_patterns:
31
+ tails = apply_last_tail(sigmas, device, decay_pattern)
32
+ elif decay_pattern not in valid_decay_patterns:
33
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
34
+ if decay_mode:
35
+ if decay_mode in valid_decay_modes:
36
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
37
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
38
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
39
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
40
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
41
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
42
+ elif decay_mode not in valid_decay_modes:
43
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
44
+
45
+ return tails, decay, sigmas
sd_simple_kes_v3_fix?/schedulers/karras_advanced_scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import linspace, tensor
3
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
4
+
5
+
6
+ def get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device='cpu', decay_pattern=None, decay_mode = None, tail_steps=None):
7
+ """Constructs the noise schedule of Karras et al. (2022)."""
8
+ ramp = linspace(0, 1, steps, device=device)
9
+ tail_steps = tail_steps or 5
10
+ def _to_tensor(val, device):
11
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
12
+
13
+ sigma_min = _to_tensor(sigma_min, device)
14
+ sigma_max = _to_tensor(sigma_max, device)
15
+
16
+ min_inv_rho = sigma_min.item() ** (1 / rho)
17
+ max_inv_rho = sigma_max.item() ** (1 / rho)
18
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
19
+
20
+ tails = None
21
+ decay = None
22
+
23
+
24
+ if decay_pattern:
25
+ if decay_pattern in valid_decay_patterns:
26
+ tails = apply_last_tail(sigmas, device, decay_pattern)
27
+ elif decay_pattern not in valid_decay_patterns:
28
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
29
+ if decay_mode:
30
+ if decay_mode in valid_decay_modes:
31
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
32
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
33
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
34
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
35
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
36
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
37
+ elif decay_mode not in valid_decay_modes:
38
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
39
+
40
+ return tails, decay, sigmas
sd_simple_kes_v3_fix?/schedulers/logarithmic_advanced_scheduler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from modules.sd_simple_kes_v3.schedulers.shared import apply_last_tail, apply_decay_tail, valid_decay_patterns, valid_decay_modes, blend_decay_tail, replace_tail
4
+
5
+
6
+
7
+ def get_sigmas_logarithmic(steps, sigma_min, sigma_max, device='cpu', decay_pattern=None, decay_mode=None, tail_steps=None):
8
+ def _to_tensor(val, device):
9
+ return val.to(device) if isinstance(val, torch.Tensor) else torch.tensor(val, device=device)
10
+
11
+ # Convert sigma_min and sigma_max to tensors safely
12
+ sigma_min = _to_tensor(sigma_min, device)
13
+ sigma_max = _to_tensor(sigma_max, device)
14
+
15
+ sigmas = []
16
+ tail_steps = tail_steps or 5
17
+
18
+ # Build the sigma list
19
+ for i in range(1, steps + 1):
20
+ next_sigma = sigma_max.item() - (math.log(i + 1) / math.log(steps + 1)) * (sigma_max.item() - sigma_min.item())
21
+ next_sigma = max(next_sigma, sigma_min.item())
22
+ sigmas.append(next_sigma)
23
+
24
+ # Convert list to tensor on the correct device
25
+ sigmas = torch.tensor(sigmas, device=device)
26
+
27
+ tails = None
28
+ decay = None
29
+
30
+ if decay_pattern:
31
+ if decay_pattern in valid_decay_patterns:
32
+ tails = apply_last_tail(sigmas, device, decay_pattern)
33
+ elif decay_pattern not in valid_decay_patterns:
34
+ print(f"[Warning] decay_pattern: {decay_pattern} not in valid decay patterns: {valid_decay_patterns}")
35
+ if decay_mode:
36
+ if decay_mode in valid_decay_modes:
37
+ if decay_mode == 'append': # <--- ИСПРАВЛЕНО
38
+ sigmas = apply_decay_tail(sigmas, device, decay_pattern)
39
+ elif decay_mode == 'blend': # <--- ИСПРАВЛЕНО
40
+ sigmas = blend_decay_tail(sigmas, device, decay_pattern, tail_steps)
41
+ elif decay_mode == 'replace': # <--- ИСПРАВЛЕНО
42
+ sigmas = replace_tail(sigmas, device, decay_pattern, tail_steps)
43
+ elif decay_mode not in valid_decay_modes:
44
+ print(f"[Warning] decay_mode: {decay_mode} not in valid decay modes: {valid_decay_modes}")
45
+
46
+ return tails, decay, sigmas
sd_simple_kes_v3_fix?/schedulers/shared.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import linspace, tensor
4
+
5
+
6
+ valid_decay_patterns = [
7
+ 'zero', 'geometric', 'harmonic', 'logarithmic',
8
+ 'extrapolate', 'fractional', 'exponential', 'linear'
9
+ ]
10
+ valid_decay_modes = [
11
+ 'append',
12
+ 'blend',
13
+ 'replace'
14
+ ]
15
+
16
+ def append_zero(x):
17
+ return torch.cat([x, x.new_zeros([1])])
18
+
19
+
20
+ def apply_last_tail(sigmas, device, decay_pattern='zero'):
21
+ """
22
+ Applies a single final sigma step using context-aware dynamic decay based on the sequence trend.
23
+ """
24
+ if decay_pattern == 'zero':
25
+ return append_zero(sigmas)
26
+
27
+ last_sigma = sigmas[-1]
28
+
29
+ if len(sigmas) >= 2:
30
+ deltas = torch.abs(sigmas[:-1] - sigmas[1:])
31
+ avg_delta = torch.mean(deltas).item()
32
+ else:
33
+ avg_delta = last_sigma.item() * 0.1
34
+
35
+ if decay_pattern == 'geometric':
36
+ dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85)
37
+ next_sigma = max(last_sigma * dynamic_decay_rate, 1e-5)
38
+
39
+ elif decay_pattern == 'harmonic':
40
+ next_sigma = max(last_sigma - avg_delta, 1e-5)
41
+
42
+ elif decay_pattern == 'logarithmic':
43
+ next_sigma = max(last_sigma - (avg_delta / math.log(len(sigmas) + 2)), 1e-5)
44
+
45
+ elif decay_pattern == 'extrapolate':
46
+ if len(sigmas) >= 2:
47
+ last_delta = sigmas[-2] - sigmas[-1]
48
+ else:
49
+ last_delta = avg_delta
50
+ next_sigma = max(last_sigma - last_delta, 1e-5)
51
+
52
+ elif decay_pattern == 'fractional':
53
+ next_sigma = max(last_sigma * 0.1, 1e-5)
54
+
55
+ elif decay_pattern == 'exponential':
56
+ next_sigma = max(last_sigma * math.exp(-avg_delta), 1e-5)
57
+
58
+ elif decay_pattern == 'linear':
59
+ next_sigma = max(last_sigma - (avg_delta * 0.5), 1e-5)
60
+
61
+ else:
62
+ raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
63
+
64
+ tail_tensor = torch.tensor([next_sigma], device=device)
65
+ return torch.cat([sigmas, tail_tensor])
66
+
67
+
68
+ def apply_decay_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
69
+ """
70
+ Applies a context-aware multi-step decay tail based on the progression of the entire sigma sequence.
71
+ """
72
+ tail = []
73
+
74
+ if decay_pattern == 'zero':
75
+ return append_zero(sigmas)
76
+
77
+ if len(sigmas) >= 2:
78
+ deltas = torch.abs(sigmas[:-1] - sigmas[1:])
79
+ avg_delta = torch.mean(deltas).item()
80
+ else:
81
+ avg_delta = sigmas[-1].item() * 0.1
82
+
83
+ last_sigma = sigmas[-1]
84
+
85
+ for step in range(tail_steps):
86
+ if decay_pattern == 'geometric':
87
+ dynamic_decay_rate = max(1 - (avg_delta / (last_sigma + 1e-5)), 0.85)
88
+ next_sigma = max(last_sigma * dynamic_decay_rate, 1e-5)
89
+
90
+ elif decay_pattern == 'harmonic':
91
+ next_sigma = max(last_sigma - (avg_delta / (step + 1)), 1e-5)
92
+
93
+ elif decay_pattern == 'logarithmic':
94
+ next_sigma = max(last_sigma - (avg_delta / math.log(len(sigmas) + step + 2)), 1e-5)
95
+
96
+ elif decay_pattern == 'extrapolate':
97
+ if len(sigmas) >= 2:
98
+ last_delta = sigmas[-2] - sigmas[-1]
99
+ else:
100
+ last_delta = avg_delta
101
+ next_sigma = max(last_sigma - last_delta, 1e-5)
102
+
103
+ elif decay_pattern == 'fractional':
104
+ next_sigma = max(last_sigma * 0.1, 1e-5)
105
+
106
+ elif decay_pattern == 'exponential':
107
+ next_sigma = max(last_sigma * math.exp(-avg_delta * (step + 1)), 1e-5)
108
+
109
+ elif decay_pattern == 'linear':
110
+ next_sigma = max(last_sigma - (avg_delta * 0.5 * (step + 1)), 1e-5)
111
+
112
+ else:
113
+ raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
114
+
115
+ tail.append(next_sigma)
116
+ last_sigma = next_sigma
117
+
118
+ tail_tensor = torch.tensor(tail, device=device)
119
+ return torch.cat([sigmas, tail_tensor])
120
+
121
+
122
+
123
+
124
+ def blend_decay_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
125
+ """
126
+ Applies in-place blending on the last N steps using decay patterns.
127
+ """
128
+ for i in range(1, tail_steps + 1):
129
+ idx = -i
130
+ base_sigma = sigmas[idx]
131
+
132
+ if len(sigmas) >= 2:
133
+ deltas = torch.abs(sigmas[:-1] - sigmas[1:])
134
+ avg_delta = torch.mean(deltas).item()
135
+ else:
136
+ avg_delta = base_sigma.item() * 0.1
137
+ if decay_pattern == 'zero':
138
+ sigmas[idx] = 0.0
139
+ continue
140
+ if decay_pattern == 'geometric':
141
+ dynamic_decay_rate = max(1 - (avg_delta / (base_sigma + 1e-5)), 0.85)
142
+ new_sigma = max(base_sigma * dynamic_decay_rate, 1e-5)
143
+
144
+ elif decay_pattern == 'harmonic':
145
+ new_sigma = max(base_sigma - (avg_delta / i), 1e-5)
146
+
147
+ elif decay_pattern == 'logarithmic':
148
+ new_sigma = max(base_sigma - (avg_delta / math.log(len(sigmas) + 2 - i)), 1e-5)
149
+
150
+ elif decay_pattern == 'extrapolate':
151
+ if len(sigmas) >= 2:
152
+ last_delta = sigmas[-2] - sigmas[-1]
153
+ else:
154
+ last_delta = avg_delta
155
+ new_sigma = max(base_sigma - last_delta, 1e-5)
156
+
157
+ elif decay_pattern == 'fractional':
158
+ new_sigma = max(base_sigma * 0.1, 1e-5)
159
+
160
+ elif decay_pattern == 'exponential':
161
+ new_sigma = max(base_sigma * math.exp(-avg_delta), 1e-5)
162
+
163
+ elif decay_pattern == 'linear':
164
+ new_sigma = max(base_sigma - (avg_delta * 0.5), 1e-5)
165
+
166
+ else:
167
+ raise ValueError(f"Unknown decay pattern: {decay_pattern}. Valid decay patterns are: {valid_decay_patterns}")
168
+
169
+ sigmas[idx] = (base_sigma + new_sigma) / 2
170
+
171
+ return sigmas
172
+
173
+
174
+ def replace_tail(sigmas, device, decay_pattern='geometric', tail_steps=5):
175
+ available_steps = len(sigmas)
176
+
177
+ # Clamp tail_steps to the number of available steps
178
+ if tail_steps >= available_steps:
179
+ print(f"[Replace Tail] Requested {tail_steps} steps, but only {available_steps} available. Adjusting to {available_steps - 1} steps.")
180
+ tail_steps = available_steps - 1 # Ensure we leave at least one sigma
181
+
182
+ sigmas = sigmas[:-tail_steps] # Remove the last N steps
183
+ return apply_decay_tail(sigmas, device, decay_pattern, tail_steps)
184
+
sd_simple_kes_v3_fix?/simple_kes_v3.py ADDED
@@ -0,0 +1,2002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.sd_simple_kes_v3.get_sigmas import scheduler_registry
2
+ from modules.sd_simple_kes_v3.validate_config import validate_config
3
+ from modules.sd_simple_kes_v3.plot_sigma_sequence import plot_sigma_sequence
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import logging
7
+ import os
8
+ import yaml
9
+ import random
10
+ from datetime import datetime
11
+ import warnings
12
+ import math
13
+ from typing import Optional
14
+ import json
15
+ import numpy as np
16
+ import hashlib
17
+ import glob
18
+ import re
19
+ import inspect
20
+ import copy
21
+
22
+
23
+ def simple_kes_scheduler_v3(n: int, sigma_min: float, sigma_max: float, device: torch.device) -> torch.Tensor:
24
+ scheduler = SimpleKEScheduler(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
25
+ return scheduler()
26
+
27
+ class SharedLogger:
28
+ def __init__(self, debug=False):
29
+ self.debug = debug
30
+ self.log_buffer = []
31
+ self.prepass_log_buffer=[]
32
+
33
+ def log(self, message):
34
+ if self.debug:
35
+ self.log_buffer.append(message)
36
+ def prepass_log(self, message):
37
+ if self.debug:
38
+ self.prepass_log_buffer.append(message)
39
+
40
+ class SimpleKEScheduler:
41
+ """
42
+ SimpleKEScheduler
43
+ ------------------
44
+ A hybrid scheduler that combines Karras-style sigma sampling
45
+ with exponential decay and blending controls. Supports parameterized
46
+ customization for use in advanced diffusion pipelines.
47
+
48
+ Parameters:
49
+ - steps (int): Number of inference steps.
50
+ - device (torch.device): Target device (e.g. 'cuda').
51
+ - config (dict): Scheduler-specific configuration options.
52
+
53
+ Usage:
54
+ scheduler = SimpleKEScheduler(steps=30, device='cuda', config=config_dict)
55
+ sigmas = scheduler.get_sigmas()
56
+ """
57
+
58
+ def __init__(self, n: int, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, device: torch.device = "cpu", logger=None, **kwargs)->torch.Tensor:
59
+ self.steps = n if n is not None else 10
60
+ self.original_steps = n
61
+ self.device = torch.device(device if isinstance(device, str) else device)
62
+ self.sigma_min = sigma_min
63
+ self.sigma_max = sigma_max
64
+ self.scheduler_registry = scheduler_registry
65
+ self.RANDOMIZATION_TYPE_ALIASES = {
66
+ 'symmetric': 'symmetric', 'sym': 'symmetric', 's': 'symmetric',
67
+ 'asymmetric': 'asymmetric', 'assym': 'asymmetric', 'a': 'asymmetric',
68
+ 'logarithmic': 'logarithmic', 'log': 'logarithmic', 'l': 'logarithmic',
69
+ 'exponential': 'exponential', 'exp': 'exponential', 'e': 'exponential'
70
+ }
71
+ self._config_schema = {
72
+ 'min_visual_sigma': (int, 10),
73
+ 'safety_minimum_stop_step': (int, 10),
74
+ 'auto_tail_smoothing': (bool, False),
75
+ 'auto_stabilization_sequence': (list, [
76
+ 'smooth_interpolation', 'append_tail', 'blend_tail', 'apply_decay', 'progressive_decay'
77
+ ]),
78
+ 'sharpen_variance_threshold': (float, 0.01),
79
+ 'sharpen_last_n_steps': (int, 10),
80
+ 'decay_pattern': (str, 'zero'),
81
+ 'sigma_save_subfolder': (str, 'saved_sigmas'),
82
+ 'load_sigma_cache': (bool, False),
83
+ 'save_sigma_cache': (bool, False),
84
+ 'graph_save_directory': (str, 'modules/sd_simple_kes_v3/image_generation_data'),
85
+ 'graph_save_enable': (bool, False),
86
+ 'exp_power': (int, 2),
87
+ 'recent_change_convergence_delta': (float, 0.02),
88
+ 'sigma_variance_scale': (float, 0.05),
89
+ 'allow_step_expansion': (bool, False),
90
+ 'sharpen_mode': (str, 'full'),
91
+ 'blend_midpoint': (float, 0.5),
92
+ 'early_stopping_method': (str, 'mean'),
93
+ 'save_prepass_sigmas': (bool, False),
94
+ 'global_randomize': (bool, False),
95
+ 'skip_prepass': (bool, False),
96
+ 'load_prepass_sigmas': (bool, False)
97
+ }
98
+ self._overrides = kwargs.copy() #Temporarily hold overrides from kwargs
99
+ default_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes_v3", "kes_config", "default_config.yaml")))
100
+ self.default_config = self._load_config(default_config_path)
101
+ user_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes_v3", "kes_config", "user_config.yaml")))
102
+ self.user_config = self._load_config(user_config_path)
103
+ self.config_data = {**self.default_config, **self.user_config}
104
+ self.config = self.config_data.copy()
105
+ self.settings = self.config.copy()
106
+ for key, value in self.settings.items():
107
+ setattr(self, key, value)
108
+ if self.global_randomize:
109
+ self.apply_global_randomization()
110
+ self.re_randomizable_keys = [
111
+ "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
112
+ "early_stopping_threshold",
113
+ "initial_step_size", "final_step_size",
114
+ "initial_noise_scale", "final_noise_scale",
115
+ "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
116
+ ]
117
+ for key in self.re_randomizable_keys:
118
+ value = self.settings.get(key)
119
+ if value is None:
120
+ raise KeyError(f"[KEScheduler] Missing required setting: {key}")
121
+ setattr(self, key, value)
122
+ self.debug = self.settings.get('debug', False)
123
+ #setup logging
124
+ logger = SharedLogger(debug=kwargs.get('debug', False))
125
+ self.logger=logger
126
+ self.log = self.logger.log
127
+ self.prepass_log = self.logger.prepass_log
128
+ self._validate_config_types()
129
+ validate_config(self.config, logger=self.logger)
130
+ # Apply overrides from kwargs if present
131
+ for k, v in self._overrides.items():
132
+ if k in self.settings:
133
+ self.settings[k] = v
134
+ setattr(self, k, v)
135
+ self.auto_mode_enabled = self.settings.get('auto_tail_smoothing', False)
136
+ self.initialize_generation_filename()
137
+ self.relative_converged = False
138
+ self.max_converged = False
139
+ self.delta_converged = False
140
+ self.early_stop_triggered = False
141
+ self.sigma_cache = {}
142
+ self.BASE_DIR = os.path.dirname(os.path.abspath(__file__))
143
+ self.cache_dir = os.path.join(self.BASE_DIR, 'cache')
144
+ self.sigma_save_folder = os.path.join(self.cache_dir, self.sigma_save_subfolder)
145
+ self.blend_method_dict = self.settings.get('blend_methods', {
146
+ 'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
147
+ 'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
148
+ })
149
+ self.blend_methods = list(self.blend_method_dict.keys())
150
+ self.blend_weights = [self.blend_method_dict[method]['weight'] for method in self.blend_methods]
151
+ self.loaded_sigmas = None
152
+ self.sigma_sequences = {}
153
+
154
+ self.schedule_type = None
155
+ self.suffix = None
156
+ self.ext = None
157
+ self._create_directories()
158
+ self._finalize_init()
159
+
160
+ def _create_directories(self):
161
+
162
+ os.makedirs(self.cache_dir, exist_ok=True)
163
+ os.makedirs(self.sigma_save_folder, exist_ok=True)
164
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
165
+ self.extras_log_filename = os.path.join(
166
+ self.settings.get('log_save_directory', 'modules/sd_simple_kes_v3/image_generation_data'),
167
+ f'all_extras_log_{timestamp}.txt'
168
+ )
169
+
170
+
171
+
172
+ def _finalize_init(self):
173
+
174
+ self.prepass_save_file = self.build_sigma_cache_filename(
175
+ steps=self.steps,
176
+ sigma_min=self.sigma_min,
177
+ sigma_max=self.sigma_max,
178
+ rho=self.rho,
179
+ schedule_type='karras',
180
+ decay_pattern=self.decay_pattern,
181
+ cache_dir=self.sigma_save_folder,
182
+ suffix='prepass',
183
+ ext = 'pt'
184
+ )
185
+ self.final_save_file = self.build_sigma_cache_filename(
186
+ steps=self.steps,
187
+ sigma_min=self.sigma_min,
188
+ sigma_max=self.sigma_max,
189
+ rho=self.rho,
190
+ schedule_type='karras',
191
+ decay_pattern=self.decay_pattern,
192
+ cache_dir=self.sigma_save_folder,
193
+ suffix='final',
194
+ ext = 'pt'
195
+ )
196
+ self.load_blend_method_sigmas()
197
+ def _load_config(self, config_path, **kwargs):
198
+ self.logger = SharedLogger(debug=kwargs.get('debug', False))
199
+
200
+
201
+ try:
202
+ with open(config_path, 'r', encoding='utf-8') as f:
203
+ user_config = yaml.safe_load(f)
204
+ return user_config or {} # Always return a dict, even if empty
205
+ except FileNotFoundError:
206
+ self.logger.log(f"Config file not found: {config_path}. Using empty config.")
207
+ return {}
208
+ except yaml.YAMLError as e:
209
+ self.logger.log(f"Error loading config file: {e}")
210
+ return {}
211
+
212
+
213
+ def _validate_config_types(self):
214
+ '''
215
+ Both corrects self.settings with an updated validated config, and also writes a corrected_user_config file with the correct types
216
+ '''
217
+ validated_settings = {}
218
+ corrected_lines = []
219
+
220
+ corrected_lines.append("# Corrected User Config (Invalid entries auto-corrected)\n")
221
+
222
+ for key, (expected_type, default_value) in self._config_schema.items():
223
+ value = self.settings.get(key, default_value)
224
+ if isinstance(value, expected_type):
225
+ validated_settings[key] = value
226
+ corrected_lines.append(f"{key}: {value}")
227
+
228
+ else:
229
+ self.log(f"[Config Warning] Invalid type for '{key}': Expected {expected_type.__name__}, got {type(value).__name__}. Using default: {default_value}")
230
+ validated_settings[key] = default_value
231
+ corrected_lines.append(f"{key}: {default_value} # Invalid type: {type(value).__name__}, replaced with default")
232
+
233
+ # Save the corrected config with comments
234
+ with open('corrected_user_config.yaml', 'w', encoding='utf-8') as f:
235
+ f.write('\n'.join(corrected_lines))
236
+
237
+ self.settings.update(validated_settings)
238
+
239
+ for key, value in self.settings.items():
240
+ setattr(self, key, value)
241
+
242
+
243
+ def _log_extras_to_file(self, all_extras):
244
+ """
245
+ Logs the extras returned by each scheduler to a dedicated 'all_extras' log file.
246
+
247
+ This method iterates through the list of extras for each blend method and writes them
248
+ to a separate log file for easier tracking, debugging, and future analysis.
249
+
250
+ Parameters:
251
+ ----------
252
+ all_extras : list
253
+ A list of extras returned by each scheduler, aligned with the blend_methods list.
254
+ Each item in the list corresponds to the extras provided by a specific scheduler.
255
+
256
+ Notes:
257
+ -----
258
+ - If extras are present, they are logged under their respective scheduler names.
259
+ - If extras contain complex objects, the method attempts to serialize them using JSON.
260
+ - Non-serializable extras are logged as raw text.
261
+
262
+ Purpose:
263
+ -------
264
+ This log file is intended for developers to track additional outputs that are not
265
+ directly part of the sigma, tails, or decay sequences but may be useful for diagnostics,
266
+ metadata, or advanced scheduler behaviors.
267
+ """
268
+ with open(self.extras_log_filename, 'a', encoding='utf-8') as f:
269
+ f.write("\n=== New Scheduler Extras ===\n")
270
+ for method, extras in zip(self.blend_methods, all_extras):
271
+ if extras:
272
+ try:
273
+ f.write(f"\nScheduler: {method}\n")
274
+ f.write(json.dumps(extras, indent=2))
275
+ f.write("\n")
276
+ except TypeError:
277
+ f.write(f"\nScheduler: {method}\n")
278
+ f.write(f"Extras (non-serializable): {extras}\n")
279
+ f.write("\n============================\n")
280
+ def __call__(self):
281
+ # First pass: Run prepass to determine predicted_stop_step
282
+ if not self.skip_prepass:
283
+ self.prepass_compute_sigmas(
284
+ steps=self.steps,
285
+ sigma_min=self.sigma_min,
286
+ sigma_max=self.sigma_max,
287
+ rho=self.rho,
288
+ device=self.device,
289
+ skip_prepass=self.skip_prepass
290
+ )
291
+
292
+
293
+ if self.load_prepass_sigmas:
294
+ self.generate_sigmas_schedule(mode='prepass')
295
+
296
+ if self.load_sigma_cache:
297
+ self.generate_sigmas_schedule(mode='final')
298
+
299
+ else:
300
+ # Build sigma sequence directly (without prepass)
301
+ self.config_values()
302
+ self.generate_sigmas_schedule()
303
+
304
+ if self.blending_mode == 'default':
305
+ self.blend_sigma_sequence(
306
+ sigmas_karras= self.scheduler_registry.get('karras')(
307
+ steps=self.steps,
308
+ sigma_min=self.sigma_min,
309
+ sigma_max=self.sigma_max,
310
+ device=self.device,
311
+ decay_pattern=self.decay_pattern
312
+ )[2],
313
+ sigmas_exponential=self.scheduler_registry.get('exponential')(
314
+ steps=self.steps,
315
+ sigma_min=self.sigma_min,
316
+ sigma_max=self.sigma_max,
317
+ device=self.device,
318
+ decay_pattern=self.decay_pattern
319
+ )[2],
320
+ pre_pass=False,
321
+ blend_methods=self.blend_methods,
322
+ blend_weights=self.blend_weights
323
+ )
324
+
325
+ else:
326
+ # For multi-method blending
327
+ self.blend_sigma_sequence(
328
+ sigmas_karras=None, # Not used in non-default mode
329
+ sigmas_exponential=None, # Not used in non-default mode
330
+ pre_pass=False,
331
+ blend_methods=self.blend_methods,
332
+ blend_weights=self.blend_weights
333
+ )
334
+
335
+ sigmas = self.compute_sigmas(steps=self.steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, rho=self.rho, device=self.device)
336
+
337
+
338
+ # Safety checks
339
+ if torch.isnan(sigmas).any():
340
+ raise ValueError("[SimpleKEScheduler] NaN detected in sigmas")
341
+ if torch.isinf(sigmas).any():
342
+ raise ValueError("[SimpleKEScheduler] Inf detected in sigmas")
343
+ if (sigmas <= 0).all():
344
+ raise ValueError("[SimpleKEScheduler] All sigma values are <= 0")
345
+ if (sigmas > 1000).all():
346
+ raise ValueError("[SimpleKEScheduler] Sigma values are extremely large — might explode the model")
347
+
348
+ # Save logs to file
349
+ if self.debug:
350
+ self.save_generation_settings()
351
+
352
+ return sigmas
353
+
354
+ def _safe_sigma_loader(self, cache_key):
355
+ cache_folder = self.sigma_save_folder
356
+
357
+ # Check if the folder exists and has files
358
+ if not os.path.exists(cache_folder) or not os.listdir(cache_folder):
359
+ self.log(f"[Cache Check] Cache folder {cache_folder} is empty or missing. Skipping load.")
360
+ return None # Signal to recompute
361
+
362
+ # Check if matching file exists
363
+ matching_files = [f for f in os.listdir(cache_folder) if cache_key in f and f.endswith('.pt')]
364
+
365
+ if not matching_files:
366
+ self.log(f"[Cache Check] No matching cache file found for key: {cache_key}. Skipping load.")
367
+ return None # Signal to recompute
368
+
369
+ # If matching file found, load it
370
+ filename = os.path.join(cache_folder, matching_files[0])
371
+ self.log(f"[Cache Hit] Loading sigma cache from: {filename}")
372
+ loaded_data = torch.load(filename, map_location=self.device)
373
+ return loaded_data['sigma_values'].to(self.device)
374
+
375
+ def call_scheduler(self, method_name, *args, **kwargs):
376
+ sigma_sequence = getattr(self, f"sigmas_{method_name}")
377
+ if sigma_sequence is None:
378
+ self.log(f"No sigma sequence found for method: {method_name}")
379
+ return None
380
+ return sigma_sequence
381
+
382
+ def is_sigma_randomized(self):
383
+ return (
384
+ self.settings.get('sigma_min_rand', False) or
385
+ self.settings.get('sigma_max_rand', False) or
386
+ self.settings.get('rho_rand', False) or
387
+ self.settings.get('sigma_max_enable_randomization_type', False) or
388
+ self.settings.get('sigma_min_enable_randomization_type', False) or
389
+ self.settings.get('rho_enable_randomization_type', False)
390
+ )
391
+
392
+
393
+ def save_sigmas_as_csv(self, sigmas, filename):
394
+ np.savetxt(filename, sigmas.cpu().numpy(), delimiter=",")
395
+
396
+ def build_sigma_cache_filename(self, steps, sigma_min, sigma_max, rho=None, schedule_type='karras', decay_pattern='zero', cache_dir=r'modules\sd_simple_kes_v3\cache', suffix=None, ext = None or 'txt'):
397
+ if cache_dir is None:
398
+ cache_dir = r'modules\sd_simple_kes_v3\cache'
399
+ if schedule_type == 'karras':
400
+ base_filename = f'sigma_{schedule_type}_{steps}steps_rho{rho}_min{sigma_min}_max{sigma_max}_{decay_pattern}'
401
+ else:
402
+ base_filename = f'sigma_{schedule_type}_{steps}steps_min{sigma_min}_max{sigma_max}_{decay_pattern}'
403
+
404
+ # If a suffix is provided, versioning applies
405
+ if suffix:
406
+ base_filename += f'_{suffix}'
407
+ version = self.get_next_version_number(cache_dir, base_filename)
408
+ if ext:
409
+ version = self.get_next_version_number(cache_dir, base_filename, ext)
410
+ filename = f'{version:03d}_{base_filename}.{ext}'
411
+ else:
412
+ # No versioning if suffix is not provided
413
+ filename = f'{base_filename}.{ext}'
414
+
415
+ return os.path.join(cache_dir, filename)
416
+
417
+ def get_next_version_number(self, cache_dir, base_filename,ext=None):
418
+ pattern = os.path.join(cache_dir, f'*_{base_filename}')
419
+ if ext:
420
+ pattern= os.path.join(cache_dir, f'*_{base_filename}.{ext}')
421
+ existing_files = glob.glob(pattern)
422
+
423
+ version_numbers = []
424
+ for file in existing_files:
425
+ match = re.search(r'(\d{3})_' + re.escape(base_filename), os.path.basename(file))
426
+ if match:
427
+ version_numbers.append(int(match.group(1)))
428
+
429
+ if version_numbers:
430
+ return max(version_numbers) + 1
431
+ else:
432
+ return 1
433
+
434
+ def get_sigma_with_cache(self, steps, sigma_min, sigma_max, rho=7.0, device='cpu',
435
+ schedule_type='karras', decay_pattern=None, cache_dir=None, cache_file=None,
436
+ suffix=None, ext=None, mode=None, cache_key = None):
437
+ self.steps = steps
438
+ self.sigma_min = sigma_min
439
+ self.sigma_max = sigma_max
440
+ self.rho = rho
441
+ self.device = device
442
+ self.schedule_type = schedule_type
443
+ self.decay_pattern = decay_pattern
444
+ self.cache_dir = cache_dir
445
+ self.cache_file = cache_file
446
+ self.suffix = suffix
447
+ self.ext = ext
448
+ self.mode = mode
449
+ self.cache_key = cache_key
450
+
451
+ # Try to retrieve from in-memory cache first
452
+ cached_sigmas = self.get_sigma_from_cache(cache_key)
453
+
454
+ if cached_sigmas is not None:
455
+ return cached_sigmas
456
+
457
+ # If sigma is randomized → always generate new
458
+ if self.is_sigma_randomized():
459
+ _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
460
+ self.sigma_cache[cache_key] = sigmas
461
+ return sigmas
462
+
463
+ # If nothing is loaded yet → generate and cache
464
+ if self.loaded_sigmas is None:
465
+ _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
466
+ self.loaded_sigmas = sigmas
467
+ self.sigma_cache[cache_key] = sigmas
468
+ return sigmas
469
+
470
+ # Handle file cache modes
471
+ if mode == 'prepass':
472
+ self.cache_file = self.prepass_save_file
473
+ elif mode == 'final':
474
+ self.cache_file = self.final_save_file
475
+ else:
476
+ self.cache_file = self.build_sigma_cache_filename(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_dir)
477
+
478
+ # Load from file cache if enabled
479
+ if mode in ['prepass', 'final'] and self.load_prepass_sigmas:
480
+ loaded_sigmas = self.load_sigmas_with_hash_validation(
481
+ filename=self.cache_file,
482
+ steps=steps,
483
+ sigma_min=sigma_min,
484
+ sigma_max=sigma_max,
485
+ rho=rho,
486
+ device=device,
487
+ schedule_type=schedule_type,
488
+ decay_pattern=decay_pattern,
489
+ cache_key = cache_key
490
+ )
491
+
492
+ if loaded_sigmas is not None:
493
+ self.loaded_sigmas = loaded_sigmas
494
+ self.sigma_cache[cache_key] = loaded_sigmas
495
+ return loaded_sigmas.to(device)
496
+ else:
497
+ self.log("[Cache Recovery] Cache load failed. Recalculating sigma schedule.")
498
+ '''
499
+ # Cache miss → recalculate
500
+ self.log(f"[Cache Miss] Recalculating sigma schedule for: {self.cache_file}")
501
+ _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
502
+ self.sigma_cache[cache_key] = sigmas
503
+ '''
504
+ return sigmas
505
+
506
+
507
+ def load_sigmas_with_hash_validation(self, filename, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, cache_key = None, suffix=None):
508
+ if self.load_prepass_sigmas:
509
+ if cache_key:
510
+ try:
511
+ loaded_data = torch.load(filename, map_location=self.device)
512
+ self.loaded_sigmas = loaded_data['sigma_values'].to(self.device)
513
+ loaded_hash = loaded_data['sigma_hash']
514
+
515
+ expected_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data, suffix)
516
+
517
+ if loaded_hash != expected_hash:
518
+ self.log(f"[Sigma Validator] Hash mismatch. Expected: {expected_hash}, Found: {loaded_hash}. Recalculating.")
519
+ return None # Return None to signal the scheduler to recalculate
520
+ else:
521
+ self.log(f"[Sigma Validator] Hash validated successfully for file: {filename}")
522
+ return self.loaded_sigmas
523
+
524
+ except Exception as e:
525
+ self.log("[Cache Recovery] Sigma cache invalid or missing. Recalculating sigmas.")
526
+ _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
527
+ return sigmas
528
+
529
+
530
+ def generate_sigma_hash(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, suffix=None):
531
+ data_string = f'{steps}_{sigma_min}_{sigma_max}_{rho}_{device}_{schedule_type}_{decay_pattern}_{suffix}'
532
+ hash_object = hashlib.sha256(data_string.encode())
533
+ return hash_object.hexdigest()[:12] # Use first 12 characters for compact ID
534
+
535
+ def _generate_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern=None, decay_mode=None, tail_steps=None):
536
+ scheduler_func = self.scheduler_registry.get(schedule_type)
537
+
538
+ if scheduler_func is None:
539
+ raise ValueError(f"Unknown schedule type: {schedule_type}")
540
+
541
+ tails, decay, extras, sigmas = self.call_scheduler_function(
542
+ scheduler_func,
543
+ steps=steps,
544
+ sigma_min=sigma_min,
545
+ sigma_max=sigma_max,
546
+ rho=rho,
547
+ device=device,
548
+ decay_pattern=decay_pattern,
549
+ decay_mode=decay_mode,
550
+ tail_steps=tail_steps
551
+ )
552
+
553
+ return tails, decay, extras, sigmas
554
+
555
+
556
+
557
+ def initialize_generation_filename(self, folder=None, base_name="generation_log", ext="txt"):
558
+ """
559
+ Initialize the log filename early so it can be used throughout the process.
560
+ """
561
+ if folder is None:
562
+ folder = self.settings.get('log_save_directory', 'modules/sd_simple_kes_v3/image_generation_data')
563
+ folder = os.path.abspath(os.path.normpath(folder))
564
+
565
+ os.makedirs(folder, exist_ok=True)
566
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
567
+
568
+ self.log_filename = os.path.join(folder, f"{base_name}_{timestamp}.{ext}")
569
+
570
+ def save_generation_settings(self):
571
+ """
572
+ Save the generation log with configurable directory, base name, and extension.
573
+
574
+ Parameters:
575
+ - folder (str): Optional custom directory to save the log file.
576
+ - base_name (str): The base name for the file (default is 'generation_log').
577
+ - ext (str): The file extension to use (default is 'txt').
578
+ """
579
+ with open(self.log_filename, "w", encoding = 'utf-8') as f:
580
+ for line in self.logger.log_buffer:
581
+ f.write(f"{line}\n")
582
+ for line in self.logger.prepass_log_buffer:
583
+ f.write(f"{line}\n")
584
+ self.log(f"[SimpleKEScheduler] Generation settings saved to {self.log_filename}")
585
+
586
+ self.logger.log_buffer.clear()
587
+ self.logger.prepass_log_buffer.clear()
588
+
589
+ def save_image_plot(self, sigs, i):
590
+ graph_plot = plot_sigma_sequence(
591
+ self.sigs[:i + 1],
592
+ i,
593
+ self.log_filename,
594
+ self.graph_save_directory,
595
+ self.graph_save_enable
596
+ )
597
+ self.log(f"Sigma sequence plot saved to {graph_plot}")
598
+
599
+
600
+
601
+
602
+ def apply_global_randomization(self):
603
+ """Force randomization for all eligible settings by enabling _rand flags and re-randomizing values."""
604
+ # First pass: turn on all _rand flags if corresponding _rand_min/_rand_max exists
605
+ for key in list(self.settings.keys()):
606
+ if key.endswith("_rand_min") or key.endswith("_rand_max"):
607
+ base_key = key.rsplit("_rand_", 1)[0]
608
+ rand_flag_key = f"{base_key}_rand"
609
+ self.settings[rand_flag_key] = True
610
+ # Step 2: If global_randomize is active, re-randomize all eligible keys
611
+ if self.global_randomize:
612
+ if key not in self.settings:
613
+ raise KeyError(f"[apply_global_randomization] Missing required key: {key}")
614
+
615
+ default_val = self.settings[key]
616
+ randomized_val = self.get_random_or_default(key, default_val)
617
+ self.settings[key] = randomized_val
618
+ setattr(self, key, randomized_val)
619
+
620
+ def get_randomization_type(self, key_prefix):
621
+ """
622
+ Retrieves the randomization type for a given key, with fallback to 'asymmetric' if missing.
623
+ """
624
+ randomization_type_raw = self.settings.get(f'{key_prefix}_randomization_type', 'asymmetric')
625
+ randomization_type = self.RANDOMIZATION_TYPE_ALIASES.get(randomization_type_raw.lower(), 'asymmetric')
626
+ return randomization_type
627
+
628
+ def get_randomization_percent(self, key_prefix):
629
+ """
630
+ Retrieves the randomization percent for a given key, with fallback to 0.2 if missing.
631
+ """
632
+ return self.settings.get(f'{key_prefix}_randomization_percent', 0.2)
633
+
634
+
635
+ def get_random_between_min_max(self, key_prefix, default_value):
636
+ """
637
+ Picks a random value between _rand_min and _rand_max if _rand is True.
638
+ Otherwise, returns the base value.
639
+ """
640
+ randomize_flag = self.settings.get(f'{key_prefix}_rand', False)
641
+
642
+ if randomize_flag:
643
+ rand_min = self.settings.get(f'{key_prefix}_rand_min', default_value)
644
+ rand_max = self.settings.get(f'{key_prefix}_rand_max', default_value)
645
+
646
+ if rand_min == rand_max:
647
+ self.log(f"[Random Range] {key_prefix}: min and max are equal ({rand_min}). Using single value.")
648
+ return rand_min
649
+
650
+ value = random.uniform(rand_min, rand_max)
651
+ self.log(f"[Random Range] {key_prefix}: Picked random value {value} between {rand_min} and {rand_max}")
652
+ return value
653
+ else:
654
+ self.log(f"[Random Range] {key_prefix}: Randomization is OFF. Using base value {default_value}")
655
+ return default_value
656
+
657
+ def get_random_by_type(self, key_prefix, default_value):
658
+ randomization_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
659
+
660
+ if not randomization_enabled:
661
+ self.log(f"[Randomization Type] {key_prefix}: Randomization type is OFF. Using base value {default_value}")
662
+ return default_value
663
+
664
+ randomization_type = self.get_randomization_type(key_prefix)
665
+ randomization_percent = self.get_randomization_percent(key_prefix)
666
+
667
+ if randomization_type == 'symmetric':
668
+ rand_min = default_value * (1 - randomization_percent)
669
+ rand_max = default_value * (1 + randomization_percent)
670
+ self.log(f"[Symmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
671
+
672
+ elif randomization_type == 'asymmetric':
673
+ rand_min = default_value * (1 - randomization_percent)
674
+ rand_max = default_value * (1 + (randomization_percent * 2))
675
+ self.log(f"[Asymmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
676
+
677
+ elif randomization_type == 'logarithmic':
678
+ rand_min = math.log(default_value * (1 - randomization_percent))
679
+ rand_max = math.log(default_value * (1 + randomization_percent))
680
+ value = math.exp(random.uniform(rand_min, rand_max))
681
+ self.log(f"[Logarithmic Randomization] {key_prefix}: Log-space randomization resulted in {value}")
682
+ return value
683
+
684
+ elif randomization_type == 'exponential':
685
+ rand_min = default_value * (1 - randomization_percent)
686
+ rand_max = default_value * (1 + randomization_percent)
687
+ base_value = random.uniform(rand_min, rand_max)
688
+ value = math.exp(base_value)
689
+ self.log(f"[Exponential Randomization] {key_prefix}: Randomized exponential value {value}")
690
+ return value
691
+
692
+ else:
693
+ self.log(f"[Randomization Type] {key_prefix}: Invalid randomization type {randomization_type}. Using base value.")
694
+ return default_value
695
+
696
+ value = random.uniform(rand_min, rand_max)
697
+
698
+ self.log(f"[Randomization Type] {key_prefix}: Randomized value {value}")
699
+ return value
700
+
701
+ def get_random_or_default(self, key_prefix, default_value):
702
+ """
703
+ Selects randomization method based on active flags:
704
+ - If both enabled → prioritize randomization type (or min/max if you prefer).
705
+ - If only one enabled → apply that one.
706
+ - If neither → return default value.
707
+ """
708
+ rand_type_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
709
+ min_max_enabled = self.settings.get(f'{key_prefix}_rand', False)
710
+
711
+ if rand_type_enabled and min_max_enabled:
712
+ self.log(f"[Randomization Policy] Both min/max and randomization type enabled for {key_prefix}. System will prioritize randomization type.")
713
+ result_value = self.get_random_by_type(key_prefix, default_value)
714
+
715
+ elif rand_type_enabled:
716
+ result_value = self.get_random_by_type(key_prefix, default_value)
717
+ self.log(f"[Randomization] {key_prefix}: Applied randomization type. Final value: {result_value}")
718
+
719
+ elif min_max_enabled:
720
+ result_value = self.get_random_between_min_max(key_prefix, default_value)
721
+ self.log(f"[Randomization] {key_prefix}: Applied min/max randomization. Final value: {result_value}")
722
+
723
+ else:
724
+ result_value = default_value
725
+ self.log(f"[Randomization] {key_prefix}: No randomization applied. Using default value: {result_value}")
726
+
727
+ return result_value
728
+
729
+
730
+ def resolve_blend_weights(self, blend_weights, blending_style):
731
+ if blending_style == 'softmax':
732
+ # Softmax automatically normalizes weights per step
733
+ blend_weights = torch.tensor(blend_weights)
734
+ normalized_weights = torch.softmax(blend_weights, dim=0)
735
+ return normalized_weights.tolist()
736
+
737
+ elif blending_style == 'explicit':
738
+ # Return raw weights, will manually normalize in blending step
739
+ return blend_weights
740
+
741
+ else:
742
+ raise ValueError(f"Unknown blending_style: {blending_style}")
743
+
744
+ def extract_scalar(self, value):
745
+ if isinstance(value, torch.Tensor):
746
+ if value.numel() > 1:
747
+ return value.mean().item() # or first element
748
+ else:
749
+ return value.item()
750
+ return value # Already a float
751
+
752
+ def _call_legacy_mode(self, schedule_type):
753
+ # Validate schedule_type
754
+ if schedule_type not in ['karras', 'exponential']:
755
+ self.log(f"[Legacy Mode] Unsupported schedule_type: {schedule_type}")
756
+ return
757
+
758
+ # Dynamically set target attribute
759
+ target_attr = f"sigmas_{schedule_type}"
760
+
761
+ scheduler_func = self.scheduler_registry.get(schedule_type)
762
+
763
+ tails, decay, extras, sigmas = self.call_scheduler_function(
764
+ scheduler_func,
765
+ steps=self.steps,
766
+ sigma_min=self.sigma_min,
767
+ sigma_max=self.sigma_max,
768
+ rho=self.rho,
769
+ device=self.device,
770
+ decay_pattern=self.decay_pattern
771
+ )
772
+
773
+ # Assign to self.sigmas_karras or self.sigmas_exponential dynamically
774
+ setattr(self, target_attr, sigmas)
775
+
776
+ self.log(f"[Legacy Mode] Loaded sigma sequence for {schedule_type}. Assigned to self.{target_attr}")
777
+
778
+
779
+ def blend_sigma_sequence(self, sigmas_karras=None, sigmas_exponential=None, pre_pass=False, blend_methods=None, blend_weights=None):
780
+ # Filter out schedulers with zero weight
781
+ active_methods = [
782
+ method for method, config in self.blend_method_dict.items() if config.get('weight', 1.0) > 0.0
783
+ ]
784
+ # Fallback if all weights are zero
785
+ if not active_methods:
786
+ self.log("[Blend Config] All weights are zero. Falling back to default blend (karras + exponential).")
787
+ '''
788
+ # Values set in init, placed here for reference
789
+ self.blend_method_dict = {
790
+ 'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
791
+ 'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
792
+ }
793
+ '''
794
+ active_methods = list(self.blend_method_dict.keys())
795
+
796
+ self.blend_methods = active_methods
797
+
798
+ # Rebuild sigma lists
799
+ self.blend_weights = [self.blend_method_dict[m]['weight'] for m in self.blend_methods]
800
+
801
+ # Edge Case: No active schedulers
802
+ if len(self.blend_methods) == 0:
803
+ raise ValueError("[SimpleKEScheduler] No active schedulers selected. Please check your blend configuration.")
804
+ #Should never happen because we default to init value for default if 0.
805
+
806
+ # Edge Case: Only one scheduler (direct usage)
807
+ if len(self.blend_methods) == 1:
808
+ self.log(f"[Blend] Only one active scheduler: {self.blend_methods[0]}. Skipping blending, using it directly.")
809
+ self.sigs = self.sigma_sequences[self.blend_methods[0]]['sigmas']
810
+ #return self.sigs # Do not Exit early!! Continue function
811
+
812
+ if len(self.blend_methods) == 2:
813
+ self.blending_mode = 'smooth_blend'
814
+ elif len(self.blend_methods) > 2:
815
+ self.blending_mode = 'weights'
816
+
817
+
818
+ if not self.allow_step_expansion and self.auto_mode_enabled:
819
+ self.auto_mode_enabled = False
820
+ self.log("[Auto Mode] Step expansion disallowed. Auto mode forcibly disabled.")
821
+
822
+ self.progress = torch.linspace(0, 1, len(self.sigs)).to(self.device)
823
+ self.blended_sigmas = []
824
+ self.change_log = []
825
+ self.relative_converged = False
826
+ self.max_converged = False
827
+ self.delta_converged = False
828
+ self.early_stop_triggered = False
829
+
830
+ """
831
+ Computes the blended sigma sequence using adaptive step sizes, dynamic blend factors,
832
+ and noise scaling across the progress of the diffusion process.
833
+
834
+ This method blends sigma values from the Karras and Exponential schedules using
835
+ a smooth, progress-dependent interpolation. It applies adaptive scaling based on
836
+ step size and noise scale factors to each sigma in the sequence.
837
+
838
+ Parameters:
839
+ -----------
840
+ sigs : torch.Tensor
841
+ A pre-allocated tensor where the computed sigma sequence will be stored.
842
+ This tensor must match the shape of the sigma schedules.
843
+
844
+ sigmas_karras : torch.Tensor
845
+ The sigma sequence generated using the Karras schedule.
846
+
847
+ sigmas_exponential : torch.Tensor
848
+ The sigma sequence generated using the Exponential schedule.
849
+
850
+ Returns:
851
+ --------
852
+ sigs : torch.Tensor
853
+ The final blended and scaled sigma sequence.
854
+
855
+ Notes:
856
+ ------
857
+ - This method is used in both the prepass and final pass of the scheduler.
858
+ - The progress tensor is computed linearly from 0 to 1 over the length of the sequence.
859
+ - The method uses class attributes for step size factors, blend factors, and noise scaling.
860
+ - This method modifies `sigs` in place.
861
+ """
862
+ if self.sigmas_exponential is None:
863
+ self._call_legacy_mode(schedule_type='exponential')
864
+
865
+ if self.sigmas_karras is None:
866
+ self._call_legacy_mode(schedule_type='karras')
867
+
868
+ self.prepass_blended_sigmas = []
869
+ self.blended_sigma = None
870
+ self.blended_sigmas=[]
871
+ for i in range(len(self.sigs)):
872
+ if self.step_progress_mode == "linear":
873
+ progress_value = self.progress[i]
874
+ elif self.step_progress_mode == "exponential":
875
+ progress_value = self.progress[i] ** self.exp_power
876
+ elif self.step_progress_mode == "logarithmic":
877
+ progress_value = torch.log1p(self.progress[i] * (torch.exp(torch.tensor(1.0)) - 1))
878
+ elif self.step_progress_mode == "sigmoid":
879
+ progress_value = 1 / (1 + torch.exp(-12 * (self.progress[i] - 0.5)))
880
+ else:
881
+ progress_value = self.progress[i] # Fallback to linear (previous version used)
882
+
883
+ self.dynamic_blend_factor = self.start_blend * (1 - self.progress[i]) + self.end_blend * self.progress[i]
884
+ self.smooth_blend = torch.sigmoid((self.dynamic_blend_factor - self.blend_midpoint) * self.smooth_blend_factor)
885
+ self.noise_scale = self.initial_noise_scale * (1 - self.progress[i]) + self.final_noise_scale * self.progress[i] * self.noise_scale_factor
886
+ self.step_size = self.initial_step_size * (1 - progress_value) + self.final_step_size * progress_value * self.step_size_factor
887
+ if self.blending_mode == 'default':
888
+ # Classic default: Karras + Exponential only
889
+ self.blended_sigma = self.sigmas_karras[i] * (1 - self.smooth_blend) + self.sigmas_exponential[i] * self.smooth_blend
890
+
891
+ if self.blending_mode == 'smooth_blend' or (self.blending_mode == 'auto' and len(self.blend_methods) == 2):
892
+ # Smooth blend between exactly two methods
893
+ sigma_seq_a = self.sigma_sequences[self.blend_methods[0]]['sigmas']
894
+ sigma_seq_b = self.sigma_sequences[self.blend_methods[1]]['sigmas']
895
+
896
+
897
+ self.blended_sigma = sigma_seq_a[i] * (1 - self.smooth_blend) + sigma_seq_b[i] * self.smooth_blend
898
+
899
+
900
+ elif self.blending_mode == 'weights' or (self.blending_mode == 'auto' and len(self.blend_methods) > 2):
901
+
902
+
903
+ if self.blend_weights is None:
904
+ self.blend_weights = [1.0] * len(self.all_sigmas)
905
+ if self.blending_style is None:
906
+ self.blending_style = 'soft_max'
907
+
908
+ # Resolve weights based on blending style
909
+ resolved_blend_weights = self.resolve_blend_weights(self.blend_weights, self.blending_style)
910
+
911
+ weighted_sum = sum(w * self.extract_scalar(s[i]) for w, s in zip(resolved_blend_weights, self.all_sigmas))
912
+
913
+
914
+ total_weight = sum(resolved_blend_weights)
915
+ self.blended_sigma = weighted_sum / total_weight
916
+
917
+ for s in self.all_sigmas:
918
+ self.log(f"[DEBUG]sigma sequence shape: {s.shape}")
919
+
920
+
921
+ self.sigs[i] = self.blended_sigma * self.step_size * self.noise_scale
922
+ self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
923
+ # Safely extract scalar for both tensor and float
924
+ self.change_log.append(self.extract_scalar(self.change))
925
+ relative_sigma_progress = (self.blended_sigma - self.sigs[-1].item()) / self.blended_sigma
926
+ recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
927
+ max_change = torch.max(recent_changes).item()
928
+ mean_change = torch.mean(recent_changes).item()
929
+ #percent_of_threshold = (max_change / self.early_stopping_threshold) * 100
930
+ self.delta_change = abs(max_change - mean_change)
931
+ #self.blended_sigmas.append(self.blended_sigma.item())
932
+ self.blended_sigmas.append(self.extract_scalar(self.blended_sigma))
933
+
934
+ # Check 1: Relative sigma progress
935
+ self.relative_converged = relative_sigma_progress < 0.05
936
+ # Check 2: Max recent sigma change
937
+ self.max_converged = max_change < self.early_stopping_threshold
938
+ # Check 3: Max-mean difference converged
939
+ self.delta_converged = self.delta_change < self.recent_change_convergence_delta
940
+
941
+ if pre_pass:
942
+ self.prepass_blended_sigmas=self.blended_sigmas.copy()
943
+ self.prepass_blended_sigma = self.blended_sigma
944
+ if i >= 2:
945
+
946
+ sigma_rate = abs(self.prepass_blended_sigmas[i] - self.prepass_blended_sigmas[i - 1])
947
+ previous_sigma_rate = abs(self.prepass_blended_sigmas[i - 1] - self.prepass_blended_sigmas[i - 2])
948
+ if sigma_rate > previous_sigma_rate:
949
+ self.prepass_log(f"Sigma decline is slowing down → possible plateau at step {i+1}.")
950
+
951
+ if i == 0:
952
+ self.prepass_log("\n--- Starting Pre-Pass Blending ---\n")
953
+ step_label = "Prepass First Step"
954
+ elif i == len(self.sigs) - 1:
955
+ step_label = "Prepass Last Step"
956
+ else:
957
+ step_label = None
958
+
959
+ if step_label:
960
+ self.prepass_log(f"[{step_label} - Step {i}/{len(self.sigs)}] Prepass Blended Sigma: {self.prepass_blended_sigma:.6f}, Final Sigma: {self.sigs[i]:.6f}")
961
+ self.prepass_log(f"{step_label} Delta Converged: {self.delta_converged} delta_change: {self.delta_change:.6f}, Target Default Settings:{self.recent_change_convergence_delta}")
962
+
963
+ # Start checking for early stopping after minimum steps
964
+ if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
965
+ # Calculate variance and dynamic threshold
966
+ self.blended_tensor = torch.tensor(self.prepass_blended_sigmas)
967
+ if self.device == 'cpu':
968
+ self.sigma_variance = np.var(self.prepass_blended_sigmas)
969
+ else:
970
+ self.sigma_variance = torch.var(self.sigs).item()
971
+
972
+ self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale # scale factor can be tuned
973
+ self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
974
+ self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
975
+ self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
976
+ self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
977
+ self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
978
+ self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
979
+
980
+
981
+ # Reason for continuing (sigma still too high)
982
+ if self.prepass_blended_sigma > self.min_sigma_threshold:
983
+ self.prepass_log(f"Prepass Blended Sigma {self.prepass_blended_sigma:.6f} exceeds min sigma threshold {self.min_sigma_threshold:.6f} → Continuing.\n")
984
+
985
+ # Start Early Stopping Checks
986
+ if self.early_stopping_method == "mean":
987
+ mean_change = sum(self.change_log) / len(self.change_log)
988
+ if mean_change < self.early_stopping_threshold:
989
+ skipped_steps = len(self.sigs) - (i)
990
+ self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {mean_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
991
+
992
+ elif self.early_stopping_method == "max":
993
+ #max_change = max(self.change_log)
994
+ if max_change < self.early_stopping_threshold:
995
+ skipped_steps = len(self.sigs) - (i)
996
+ self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {max_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
997
+
998
+ elif self.early_stopping_method == "sum":
999
+ stable_steps = sum(
1000
+ 1 for j in range(1, len(self.change_log))
1001
+ if abs(self.change_log[j]) < self.early_stopping_threshold * abs(self.sigs[j])
1002
+ )
1003
+ if stable_steps >= 0.8 * len(self.change_log):
1004
+ skipped_steps = len(self.sigs) - (i)
1005
+ self.prepass_log(f"Early stopping triggered by sum at step {i}. Stable steps: {stable_steps}/{len(self.change_log)}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
1006
+
1007
+ if self.relative_converged and self.max_converged and self.delta_converged:
1008
+ self.early_stop_triggered = True
1009
+ self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i+1} ---")
1010
+ self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
1011
+ self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
1012
+ self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
1013
+ self.prepass_log(f"Delta Change: {delta_change:.6f} (Target: {self.recent_change_convergence_delta})")
1014
+ self.prepass_log(f"Early stopping criteria met at step {i+1} based on all convergence checks.")
1015
+ self.predicted_stop_step = i
1016
+ #self.steps = self.predicted_stop_step
1017
+ self.save_image_plot(self.sigs, i)
1018
+ break
1019
+
1020
+
1021
+ # === Final Pass ===
1022
+ if not pre_pass:
1023
+
1024
+ if i == 0:
1025
+ step_label = "First Step"
1026
+ self.log("\n" + "=" * 10 + "\n[Start of Sigma Sequence Logging]\n" + "=" * 10)
1027
+ self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
1028
+ f"\nStep Size: {self.step_size:.6f}"
1029
+ f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
1030
+ f"\nNoise Scale: {self.noise_scale:.6f}"
1031
+ f"\nSmooth Blend: {self.smooth_blend:.6f}"
1032
+ f"\nBlended Sigma: {self.blended_sigma:.6f}"
1033
+ f"\nFinal Sigma: {self.sigs[i]:.6f}")
1034
+ elif i == len(self.sigs) // 2:
1035
+ step_label = "Middle Step"
1036
+ self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
1037
+ f"\nStep Size: {self.step_size:.6f}"
1038
+ f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
1039
+ f"\nNoise Scale: {self.noise_scale:.6f}"
1040
+ f"\nSmooth Blend: {self.smooth_blend:.6f}"
1041
+ f"\nBlended Sigma: {self.blended_sigma:.6f}"
1042
+ f"\nFinal Sigma: {self.sigs[i]:.6f}")
1043
+ elif i == len(self.sigs) - 1:
1044
+ step_label = "Last Step"
1045
+ self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
1046
+ f"\nStep Size: {self.step_size:.6f}"
1047
+ f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
1048
+ f"\nNoise Scale: {self.noise_scale:.6f}"
1049
+ f"\nSmooth Blend: {self.smooth_blend:.6f}"
1050
+ f"\nBlended Sigma: {self.blended_sigma:.6f}"
1051
+ f"\nFinal Sigma: {self.sigs[i]:.6f}")
1052
+ self.log("\n" + "=" * 10 + "\n[End of Sigma Sequence Logging]\n" + "=" * 10)
1053
+ else:
1054
+ step_label = None
1055
+
1056
+ if i > 0:
1057
+ self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
1058
+ #self.change_log.append(self.change.item())
1059
+ self.change_log.append(self.extract_scalar(self.change))
1060
+
1061
+ # Early Stopping Evaluation
1062
+ if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
1063
+ final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
1064
+ if self.blended_sigma != 0:
1065
+ relative_sigma_progress = (self.blended_sigma - final_target_sigma) / self.blended_sigma
1066
+ else:
1067
+ relative_sigma_progress = 0 # Assume fully converged if blended_sigma is 0
1068
+ # Optional: Show variance but no need to stop on it
1069
+ self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
1070
+ self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
1071
+ if self.graph_save_enable:
1072
+ self.save_image_plot(self.sigs, i)
1073
+
1074
+ #apply tails and decay after the loop finishes
1075
+ # Finished core sigma blending
1076
+ if not self.auto_mode_enabled:
1077
+ if not pre_pass: # Only extend in the final pass
1078
+ if self.apply_tail_steps:
1079
+ for i, tail in enumerate(self.all_tails):
1080
+ if tail is not None:
1081
+ self.log(f"Appending tail from method: {self.blend_methods[i]}")
1082
+ self.sigs = torch.cat([self.sigs, tail])
1083
+
1084
+ if self.apply_decay_tail:
1085
+ for i, decay in enumerate(self.all_decays):
1086
+ if decay is not None:
1087
+ self.log(f"Appending decay from method: {self.blend_methods[i]}")
1088
+ self.sigs = torch.cat([self.sigs, decay])
1089
+
1090
+ if self.apply_progressive_decay:
1091
+ progressive_decay = None
1092
+ total_weight = 0
1093
+
1094
+ for w, decay in zip(resolved_blend_weights, self.all_decays):
1095
+ if decay is not None:
1096
+ decay = decay[:len(self.sigs)] # Ensure matching length
1097
+ if progressive_decay is None:
1098
+ progressive_decay = w * decay
1099
+ else:
1100
+ progressive_decay += w * decay
1101
+ total_weight += w
1102
+
1103
+ if progressive_decay is not None and total_weight > 0:
1104
+ progressive_decay /= total_weight
1105
+ self.log("Applying progressive decay to sigma sequence.")
1106
+ self.sigs = self.sigs * progressive_decay
1107
+
1108
+ if self.apply_blended_tail:
1109
+ blended_tail = None
1110
+ total_weight = 0
1111
+
1112
+ for w, tail in zip(resolved_blend_weights, self.all_tails):
1113
+ if tail is not None:
1114
+ if blended_tail is None:
1115
+ blended_tail = w * tail
1116
+ else:
1117
+ blended_tail += w * tail
1118
+ total_weight += w
1119
+
1120
+ if blended_tail is not None and total_weight > 0:
1121
+ blended_tail /= total_weight
1122
+ self.log("Appending blended tail to sigma sequence.")
1123
+ self.sigs = torch.cat([self.sigs, blended_tail])
1124
+
1125
+ else:
1126
+ # Run Auto Mode stabilization sequence
1127
+ if len(self.sigs) > self.steps:
1128
+ self.auto_stabilization_sequence = []
1129
+ self.log(f"[Auto Mode] Sigma sequence length {len(self.sigs)} exceeds requested steps {self.steps}. Disabling auto stabilization.")
1130
+ self.auto_mode_enabled = False
1131
+ self.sigs = self.sigs[:self.steps] # Force truncate to requested step count
1132
+ return self.sigs
1133
+ self.run_auto_stabilization(self.sigs)
1134
+
1135
+ if pre_pass and self.early_stop_triggered:
1136
+ return self.sigs[:self.predicted_stop_step] # Return only the usable sequence
1137
+ else:
1138
+ return self.sigs
1139
+ def run_auto_stabilization(self):
1140
+ #This function works as intended, but is blocked by default if programs don't let schedulers create a sigma schedule longer than requested steps.
1141
+ if not self.allow_step_expansion:
1142
+ self.log("[Auto Mode] Step expansion is disabled by configuration. Skipping auto stabilization.")
1143
+ return self.sigs
1144
+ if self.allow_step_expansion:
1145
+ unstable = self.detect_sequence_instability()
1146
+
1147
+ if not unstable:
1148
+ self.log("[Auto Mode] Sigma sequence is already stable.")
1149
+ return
1150
+
1151
+ self.log("[Auto Mode] Detected instability in sigma sequence. Starting stabilization sequence.")
1152
+
1153
+ for method in self.auto_stabilization_sequence:
1154
+ if not unstable:
1155
+ self.log(f"[Auto Mode] Sequence stabilized after {method}. Stopping further corrections.")
1156
+ break
1157
+
1158
+ if method == 'smooth_interpolation':
1159
+ unstable = self.smooth_interpolation()
1160
+
1161
+ elif method == 'append_tail':
1162
+ unstable = self.append_tail()
1163
+
1164
+ elif method == 'blend_tail':
1165
+ unstable = self.blend_tail()
1166
+
1167
+ elif method == 'apply_decay':
1168
+ unstable = self.apply_decay()
1169
+
1170
+ elif method == 'progressive_decay':
1171
+ unstable = self.progressive_decay()
1172
+
1173
+ else:
1174
+ self.log(f"[Auto Mode] Unknown stabilization method: {method}")
1175
+ def detect_sequence_instability(self):
1176
+ delta_sigmas = self.sigs[:-1] - self.sigs[1:]
1177
+ second_deltas = torch.diff(delta_sigmas)
1178
+
1179
+ steep_drop_detected = torch.any(delta_sigmas > self.auto_tail_threshold)
1180
+ jaggedness_score = torch.var(second_deltas[-5:]) if len(second_deltas) >= 5 else 0
1181
+ jagged_transition_detected = jaggedness_score > self.jaggedness_threshold
1182
+
1183
+ if steep_drop_detected:
1184
+ self.log(f"[Auto Mode] Steep drop detected. Max drop: {torch.max(delta_sigmas).item():.6f}")
1185
+ if jagged_transition_detected:
1186
+ self.log(f"[Auto Mode] Jagged transition detected. Jaggedness score: {jaggedness_score:.6f}")
1187
+
1188
+ return steep_drop_detected or jagged_transition_detected
1189
+ def smooth_interpolation(self):
1190
+ self.log("[Auto Mode] Applying smooth interpolation to last 5 steps.")
1191
+ if len(self.sigs) >= 5:
1192
+ start = self.sigs[-6].item()
1193
+ end = self.sigs[-1].item()
1194
+ interpolated = torch.linspace(start, end, steps=6, device=self.device)[1:]
1195
+ self.sigs[-5:] = interpolated
1196
+
1197
+ return self.detect_sequence_instability()
1198
+
1199
+ def append_tail(self):
1200
+ self.log("[Auto Mode] Attempting to append available tail.")
1201
+ if hasattr(self, 'all_tails') and self.all_tails:
1202
+ for tail in self.all_tails:
1203
+ if tail is not None:
1204
+ tail = tail.to(self.device)
1205
+ # If tail is longer than remaining sequence, trim
1206
+ if tail.shape[0] > self.sigs.shape[0]:
1207
+ tail = tail[:len(self.sigs)]
1208
+ self.sigs = torch.cat([self.sigs, tail])
1209
+ self.log("[Auto Mode] Appended tail to sigma sequence.")
1210
+ break
1211
+
1212
+ return self.detect_sequence_instability()
1213
+
1214
+
1215
+ def blend_tail(self):
1216
+ if not hasattr(self, 'all_tails') or not self.all_tails:
1217
+ self.log("[Auto Mode] No available tails to blend.")
1218
+ return self.detect_sequence_instability()
1219
+
1220
+ self.log("[Auto Mode] Attempting to blend multiple tails.")
1221
+ blended_tail = None
1222
+ total_weight = 0
1223
+
1224
+ for w, tail in zip(self.blend_weights, self.all_tails):
1225
+ if tail is not None:
1226
+ tail = tail.to(self.device)
1227
+
1228
+ # Align length if needed
1229
+ if tail.shape[0] > self.sigs.shape[0]:
1230
+ tail = tail[:len(self.sigs)]
1231
+
1232
+ if blended_tail is None:
1233
+ blended_tail = w * tail
1234
+ else:
1235
+ blended_tail += w * tail
1236
+ total_weight += w
1237
+
1238
+ if blended_tail is not None and total_weight > 0:
1239
+ blended_tail /= total_weight
1240
+ self.sigs = torch.cat([self.sigs, blended_tail])
1241
+ self.log("[Auto Mode] Appended blended tail to sigma sequence.")
1242
+
1243
+ return self.detect_sequence_instability()
1244
+
1245
+
1246
+ def apply_decay(self):
1247
+ self.log("[Auto Mode] Attempting to append decay tails.")
1248
+ if hasattr(self, 'all_decays') and self.all_decays:
1249
+ for decay in self.all_decays:
1250
+ if decay is not None:
1251
+ decay = decay.to(self.device)
1252
+
1253
+ # Align length if needed
1254
+ if decay.shape[0] > self.sigs.shape[0]:
1255
+ decay = decay[:len(self.sigs)]
1256
+
1257
+ self.sigs = torch.cat([self.sigs, decay])
1258
+ self.log("[Auto Mode] Appended decay tail to sigma sequence.")
1259
+ break
1260
+
1261
+ return self.detect_sequence_instability()
1262
+
1263
+
1264
+ def progressive_decay(self):
1265
+ self.log("[Auto Mode] Applying progressive decay to sigma sequence.")
1266
+ progressive_decay = None
1267
+ total_weight = 0
1268
+
1269
+ for w, decay in zip(self.blend_weights, self.all_decays):
1270
+ if decay is not None:
1271
+ decay = decay.to(self.device)
1272
+
1273
+ # If the decay is too short, interpolate to match self.sigs length
1274
+ if decay.shape[0] != self.sigs.shape[0]:
1275
+ decay = decay.view(1, 1, -1) # Shape for interpolation
1276
+ decay = F.interpolate(decay, size=self.sigs.shape[0], mode='linear', align_corners=False)
1277
+ decay = decay.view(-1)
1278
+
1279
+ if progressive_decay is None:
1280
+ progressive_decay = w * decay
1281
+ else:
1282
+ progressive_decay += w * decay
1283
+
1284
+ total_weight += w
1285
+
1286
+ if progressive_decay is not None and total_weight > 0:
1287
+ progressive_decay /= total_weight
1288
+ self.sigs = self.sigs * progressive_decay
1289
+ self.log("[Auto Mode] Applied progressive decay to sigma sequence.")
1290
+
1291
+ return self.detect_sequence_instability()
1292
+
1293
+
1294
+ def load_blend_method_sigmas(self, mode=None):
1295
+ """Loads all sigma sequences for the blend_methods list based on current settings and mode."""
1296
+ self.all_sigmas = []
1297
+
1298
+
1299
+ for method in self.blend_methods:
1300
+ self.method_config = self.blend_method_dict[method]
1301
+ self.method_config[method] = {
1302
+ 'decay_pattern': self.method_config.get('decay_pattern', 'zero'),
1303
+ 'decay_mode': self.method_config.get('decay_mode', 'blend'),
1304
+ 'tail_steps': self.method_config.get('tail_steps', 1)
1305
+ }
1306
+ self.current_config = self.method_config[method]
1307
+
1308
+
1309
+ sigma_func = self.scheduler_registry[method]
1310
+ tails, decay, extras, sigmas = self.call_scheduler_function(
1311
+ self.scheduler_registry.get(method),
1312
+ steps=self.steps,
1313
+ sigma_min=self.sigma_min,
1314
+ sigma_max=self.sigma_max,
1315
+ rho=self.rho, # Only passed if the scheduler accepts it
1316
+ device=self.device,
1317
+ decay_pattern=self.current_config['decay_pattern'], # Method-specific
1318
+ decay_mode=self.current_config['decay_mode'], # Method-specific
1319
+ tail_steps=self.current_config['tail_steps'] # Method-specific
1320
+ )
1321
+ self.sigma_sequences[method] = {
1322
+ 'sigmas': sigmas,
1323
+ 'tails': tails,
1324
+ 'decay': decay,
1325
+ 'extras': extras
1326
+ }
1327
+ setattr(self, f"sigmas_{method}", sigmas)
1328
+
1329
+ self.all_sigmas = [self.sigma_sequences[method]['sigmas'] for method in self.blend_methods]
1330
+ self.all_tails = [self.sigma_sequences[method]['tails'] for method in self.blend_methods]
1331
+ self.all_decays = [self.sigma_sequences[method]['decay'] for method in self.blend_methods]
1332
+ self.all_extras = [self.sigma_sequences[method].get('extras', []) for method in self.blend_methods]
1333
+ self._log_extras_to_file(self.all_extras)
1334
+ self.all_sigmas.append(sigmas)
1335
+
1336
+ # Optionally log which schedules were loaded
1337
+ self.log(f"Loaded sigma schedules for blend methods: {self.blend_methods} using mode: {mode}")
1338
+ def validate_and_align_sigmas(self):
1339
+ """
1340
+ Ensures all sigma sequences in self.all_sigmas are valid and have the same length.
1341
+ Pads shorter sequences with their last sigma.
1342
+ """
1343
+ if not self.all_sigmas or len(self.all_sigmas) == 0:
1344
+ raise ValueError("No sigma sequences were loaded for blending.")
1345
+
1346
+ target_length = max(len(s) for s in self.all_sigmas)
1347
+
1348
+ for idx, sigmas in enumerate(self.all_sigmas):
1349
+ if sigmas is None or len(sigmas) == 0:
1350
+ raise ValueError(f"Sigma sequence at index {idx} is invalid or empty: {sigmas}")
1351
+
1352
+ if len(sigmas) < target_length:
1353
+ padding = torch.full((target_length - len(sigmas),), sigmas[-1]).to(sigmas.device)
1354
+ self.all_sigmas[idx] = torch.cat([sigmas, padding])
1355
+
1356
+ self.log(f"Validated and aligned all sigma sequences to length {target_length}.")
1357
+
1358
+ def generate_sigmas_schedule(self, mode=None):
1359
+ """
1360
+ Generates the sigma schedules required for the hybrid blending process.
1361
+
1362
+ The Karras and Exponential sigma sequences are created to provide two distinct
1363
+ noise scaling strategies:
1364
+ - The Karras sequence offers a more aggressive noise decay, commonly used in
1365
+ modern schedulers for improved image quality and denoising stability.
1366
+ - The Exponential sequence provides a traditional log-space noise schedule.
1367
+
1368
+ These two sequences are dynamically blended in later steps using progress-dependent
1369
+ weights to produce a custom sigma path that combines the advantages of both approaches.
1370
+
1371
+ This blending process is critical to the scheduler's ability to:
1372
+ - Adapt noise scaling across steps.
1373
+ - Control the sharpness and smoothness of transitions.
1374
+ - Support early stopping based on sigma convergence patterns.
1375
+
1376
+ These sigma sequences must be regenerated in both the prepass (for early stopping detection)
1377
+ and the final pass (for polished sigma application), ensuring both passes are synchronized
1378
+ with the current step count and randomization settings.
1379
+ """
1380
+ #self.cache_key= self.generate_sigma_hash(self.steps, self.sigma_min, self.sigma_max, self.rho, self.device, self.schedule_type, self.decay_pattern, self.suffix or None)
1381
+ # ✅ Clean Mode Selection
1382
+ if mode == 'prepass':
1383
+ if self.load_prepass_sigmas:
1384
+ self.cache_file = self.prepass_save_file
1385
+ self.mode = 'prepass'
1386
+
1387
+ elif mode == 'final':
1388
+ if self.load_sigma_cache:
1389
+ self.cache_file = self.final_save_file
1390
+ self.mode = 'final'
1391
+
1392
+ else:
1393
+ self.mode = None
1394
+ self.cache_file = None # Optional, for safety
1395
+
1396
+ '''
1397
+ if self.cache_file:
1398
+ sigmas = self.get_sigma_with_cache(
1399
+ steps=self.steps,
1400
+ sigma_min=self.sigma_min,
1401
+ sigma_max=self.sigma_max,
1402
+ rho=self.rho,
1403
+ device=self.device,
1404
+ decay_pattern=self.decay_pattern,
1405
+ cache_file=self.cache_file,
1406
+ mode=self.mode
1407
+ #cache_key = self.cache_key
1408
+ )
1409
+ return sigmas
1410
+ '''
1411
+
1412
+ #else:
1413
+ #logic for multiple schedulers
1414
+ #self.load_blend_method_sigmas(mode=self.mode)
1415
+ self.load_blend_method_sigmas(mode=self.mode)
1416
+ self.blend_pairs = []
1417
+ self.active_methods = [method for method in self.blend_methods if self.blend_method_dict[method].get('weight', 1.0) > 0]
1418
+
1419
+ if self.blending_mode == 'default':
1420
+ self._call_legacy_mode(schedule_type='exponential')
1421
+ self._call_legacy_mode(schedule_type='karras')
1422
+
1423
+ self.blend_pairs = []
1424
+ self.blend_pairs.append({
1425
+ 'method_label': 'method_a',
1426
+ 'method': 'karras',
1427
+ 'sigmas': self.sigmas_karras
1428
+ })
1429
+ self.blend_pairs.append({
1430
+ 'method_label': 'method_b',
1431
+ 'method': 'exponential',
1432
+ 'sigmas': self.sigmas_exponential
1433
+ })
1434
+
1435
+ # ✅ Optional: Pad if needed (if you know some might be misaligned)
1436
+ max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
1437
+
1438
+ for pair in self.blend_pairs:
1439
+ if len(pair['sigmas']) < max_length:
1440
+ padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
1441
+ pair['sigmas'] = torch.cat([pair['sigmas'], padding])
1442
+
1443
+ self.log(f"All sigma sequences aligned to length: {max_length}")
1444
+
1445
+ # ✅ For legacy compatibility
1446
+ sigmas_a = self.blend_pairs[0]['sigmas']
1447
+ sigmas_b = self.blend_pairs[1]['sigmas']
1448
+ label_a = self.blend_pairs[0]['method']
1449
+ label_b = self.blend_pairs[1]['method']
1450
+ if sigmas_a is None:
1451
+ raise ValueError(f"Sigmas {label_a} failed to generate or assign correctly.")
1452
+ if sigmas_b is None:
1453
+ raise ValueError(f"Sigmas {label_b} failed to generate or assign correctly.")
1454
+ else:
1455
+ if len(self.active_methods) == 1:
1456
+ # Only one method, assign both to the same method
1457
+ self.blend_pairs = []
1458
+ method = self.active_methods[0]
1459
+ self.blend_pairs.append({
1460
+ 'method_label': 'method_a',
1461
+ 'method': method,
1462
+ 'sigmas': self.sigma_sequences[method]['sigmas']
1463
+ })
1464
+
1465
+ elif len(self.active_methods) >= 2:
1466
+ # Build blend_pairs dynamically for all active methods
1467
+ self.blend_pairs = []
1468
+ for idx, method in enumerate(self.active_methods):
1469
+ self.blend_pairs.append({
1470
+ 'method_label': f'method_{chr(97 + idx)}', # method_a, method_b, etc.
1471
+ 'method': method,
1472
+ 'sigmas': self.sigma_sequences[method]['sigmas']
1473
+ })
1474
+
1475
+ # Validation checks (loop version)
1476
+ for pair in self.blend_pairs:
1477
+ if pair['sigmas'] is None:
1478
+ raise ValueError(f"Sigmas {pair['method']} failed to generate or assign correctly.")
1479
+
1480
+
1481
+ # Skip length matching if only 1 method is enabled
1482
+ if len(self.blend_pairs) > 1:
1483
+ target_length = min(len(pair['sigmas']) for pair in self.blend_pairs)
1484
+
1485
+ # Trim all to target length
1486
+ for pair in self.blend_pairs:
1487
+ pair['sigmas'] = pair['sigmas'][:target_length]
1488
+
1489
+ # Find the max length (in case any sequence is longer after trimming)
1490
+ max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
1491
+
1492
+ for pair in self.blend_pairs:
1493
+ if len(pair['sigmas']) < max_length:
1494
+ padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
1495
+ pair['sigmas'] = torch.cat([pair['sigmas'], padding])
1496
+
1497
+ self.log(f"All sigma sequences aligned to length: {max_length}")
1498
+
1499
+ self.sigs = torch.zeros(target_length, device=self.blend_pairs[0]['sigmas'].device)
1500
+ else:
1501
+ # Only one method, set self.sigs directly
1502
+ self.sigs = self.blend_pairs[0]['sigmas'].clone()
1503
+
1504
+
1505
+ '''
1506
+ # Now it's safe to compute sigs
1507
+ start = math.log(self.sigma_max)
1508
+ end = math.log(self.sigma_min)
1509
+ #self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
1510
+ if self.sigs is None or self.force_rebuild_sigs:
1511
+ self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
1512
+
1513
+
1514
+
1515
+ # Ensure sigs contain valid values before using them
1516
+ if torch.any(self.sigs > 0):
1517
+ self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
1518
+ else:
1519
+ # If sigs are all invalid, set a safe fallback
1520
+ self.sigma_min, self.sigma_max = self.min_threshold, self.min_threshold
1521
+ self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
1522
+
1523
+ return {
1524
+ 'karras': self.sigmas_karras,
1525
+ 'exponential': self.sigmas_exponential,
1526
+ 'blend_methods': self.blend_methods,
1527
+ 'all_sigmas': self.all_sigmas,
1528
+ 'sigs': self.sigs
1529
+ }
1530
+
1531
+ #sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
1532
+ #if len(set(sigma_lengths)) > 1: # There are mismatched lengths
1533
+ #self.validate_and_align_sigmas()
1534
+ #self.sigs = torch.zeros(self.steps, device=self.device)
1535
+ sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
1536
+ if len(set(sigma_lengths)) > 1:
1537
+ self.log("[Sigma Alignment] Detected mismatched sigma sequence lengths. Aligning...")
1538
+ self.validate_and_align_sigmas()
1539
+
1540
+ return {
1541
+ 'blend_methods': self.blend_methods,
1542
+ 'all_sigmas': self.all_sigmas,
1543
+ 'sigs': self.sigs
1544
+ }
1545
+ '''
1546
+ # Now it's safe to compute sigs
1547
+ #start = math.log(self.sigma_max)
1548
+ #end = math.log(self.sigma_min)
1549
+ #self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
1550
+
1551
+
1552
+ '''
1553
+ if torch.any(self.sigs > 0):
1554
+ self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
1555
+ else:
1556
+ # If sigs are all invalid, set a safe fallback
1557
+ self.sigma_min = self.min_threshold
1558
+ self.sigma_max = self.min_threshold
1559
+ self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
1560
+
1561
+ return {
1562
+ 'blend_methods': self.blend_methods,
1563
+ 'all_sigmas': self.all_sigmas,
1564
+ 'sigs': self.sigs
1565
+ }
1566
+
1567
+ return {
1568
+ 'blend_methods': self.blend_methods,
1569
+ 'all_sigmas': self.all_sigmas,
1570
+ 'sigs': self.sigs
1571
+ }
1572
+ '''
1573
+
1574
+ if not torch.any(self.sigs > 0):
1575
+ self.sigma_min = self.min_threshold
1576
+ self.sigma_max = self.min_threshold
1577
+ self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
1578
+ else:
1579
+ self.sigma_min = self.sigs[self.sigs > 0].min()
1580
+ self.sigma_max = self.sigs.max()
1581
+
1582
+ return {
1583
+ 'blend_methods': self.blend_methods,
1584
+ 'all_sigmas': self.all_sigmas,
1585
+ 'sigs': self.sigs
1586
+ }
1587
+
1588
+
1589
+
1590
+ def call_scheduler_function(self, scheduler_func, **kwargs):
1591
+ """
1592
+ Safely calls a scheduler function with dynamic argument filtering and flexible return handling.
1593
+
1594
+ This method ensures that only the parameters accepted by the scheduler function are passed.
1595
+ It automatically handles scheduler functions that may return:
1596
+ - Only the sigma sequence
1597
+ - A tuple with (tails, sigmas)
1598
+ - A tuple with (tails, decay, sigmas)
1599
+ - A tuple with additional items (extras) before sigmas
1600
+
1601
+ The method always assumes the last returned item is the sigma sequence, with optional
1602
+ items preceding it.
1603
+
1604
+ Parameters:
1605
+ ----------
1606
+ scheduler_func : callable
1607
+ The scheduler function to be invoked. It may accept various arguments such as steps, sigma_min,
1608
+ sigma_max, rho, device, decay_pattern, etc.
1609
+ **kwargs : dict
1610
+ Arbitrary keyword arguments. Only those accepted by the scheduler function will be passed.
1611
+
1612
+ Returns:
1613
+ -------
1614
+ tuple
1615
+ A 4-tuple containing:
1616
+ - tails : Any (optional, can be None)
1617
+ The tail component of the sigma schedule, if provided.
1618
+ - decay : Any (optional, can be None)
1619
+ The decay component of the sigma schedule, if provided.
1620
+ - extras : list
1621
+ Any additional return values provided by the scheduler function, beyond tails and decay.
1622
+ - sigmas : Any
1623
+ The sigma sequence, always assumed to be the last item returned by the scheduler function.
1624
+
1625
+ Raises:
1626
+ ------
1627
+ ValueError
1628
+ If the scheduler function returns an empty tuple.
1629
+
1630
+ Notes:
1631
+ -----
1632
+ This method allows future schedulers to return additional optional data without breaking the calling pattern.
1633
+ """
1634
+ valid_params = inspect.signature(scheduler_func).parameters
1635
+ filtered_args = {k: v for k, v in kwargs.items() if k in valid_params}
1636
+
1637
+ result = scheduler_func(**filtered_args)
1638
+
1639
+ if isinstance(result, dict):
1640
+ tails = result.get('tails', None)
1641
+ decay = result.get('decay', None)
1642
+ sigmas = result.get('sigmas')
1643
+ extras = result.get('extras', [])
1644
+
1645
+ if sigmas is None:
1646
+ raise ValueError("Scheduler function must return a 'sigmas' key.")
1647
+
1648
+ return tails, decay, extras, sigmas
1649
+
1650
+ # Legacy support: fallback to tuple unpacking if needed
1651
+ if not isinstance(result, tuple):
1652
+ return None, None, [], result
1653
+
1654
+ if len(result) == 0:
1655
+ raise ValueError(f"Scheduler function returned an empty tuple. This is not allowed.")
1656
+
1657
+ sigmas = result[-1]
1658
+ optional_returns = result[:-1]
1659
+
1660
+ tails = optional_returns[0] if len(optional_returns) > 0 else None
1661
+ decay = optional_returns[1] if len(optional_returns) > 1 else None
1662
+ extras = optional_returns[2:] if len(optional_returns) > 2 else []
1663
+
1664
+ return tails, decay, extras, sigmas
1665
+
1666
+ def config_values(self):
1667
+ #Ensures sigma_min is always less than sigma_max for edge cases
1668
+ if self.sigma_min >= self.sigma_max:
1669
+ correction_factor = random.uniform(0.01, 0.99)
1670
+ old_sigma_min = self.sigma_min
1671
+ self.sigma_min = self.sigma_max * correction_factor
1672
+ self.log(f"[Correction] sigma_min ({old_sigma_min}) was >= sigma_max ({self.sigma_max}). Adjusted sigma_min to {self.sigma_min} using correction factor {correction_factor}.")
1673
+
1674
+ self.log(f"Final sigmas: sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
1675
+
1676
+
1677
+ if self.sigma_auto_enabled:
1678
+ if self.sigma_auto_mode not in ["sigma_min", "sigma_max"]:
1679
+ raise ValueError(f"[Config Error] Invalid sigma_auto_mode: {self.sigma_auto_mode}. Must be 'sigma_min' or 'sigma_max'.")
1680
+
1681
+ if self.sigma_auto_mode == "sigma_min":
1682
+ self.sigma_min = self.sigma_max / self.sigma_scale_factor
1683
+ self.log(f"[Auto Sigma Min] sigma_min set to {self.sigma_min} using scale factor {self.sigma_scale_factor}")
1684
+
1685
+ elif self.sigma_auto_mode == "sigma_max":
1686
+ self.sigma_max = self.sigma_min * self.sigma_scale_factor
1687
+ self.log(f"[Auto Sigma Max] sigma_max set to {self.sigma_max} using scale factor {self.sigma_scale_factor} and using a multiplier of {sigma_max_multipier} to account for smoother transitions")
1688
+
1689
+ # Always apply min_threshold AFTER auto scaling
1690
+ self.min_threshold = random.uniform(1e-5, 5e-5)
1691
+
1692
+ if self.sigma_min < self.min_threshold:
1693
+ self.log(f"[Threshold Enforcement] sigma_min was too low: {self.sigma_min} < min_threshold {self.min_threshold}")
1694
+ self.sigma_min = self.min_threshold
1695
+
1696
+ if self.sigma_max < self.min_threshold:
1697
+ self.log(f"[Threshold Enforcement] sigma_max was too low: {self.sigma_max} < min_threshold {self.min_threshold}")
1698
+ self.sigma_max = self.min_threshold
1699
+
1700
+
1701
+ valid_methods = ['mean', 'max', 'sum']
1702
+ if self.early_stopping_method not in valid_methods:
1703
+ self.log(f"[Config Correction] Invalid early_stopping_method: {self.early_stopping_method}. Defaulting to 'mean'.")
1704
+ self.early_stopping_method = 'mean'
1705
+
1706
+ def prepass_compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None, cache_key = None, skip_prepass = False)->torch.Tensor:
1707
+
1708
+ '''
1709
+ if self.load_prepass_sigmas:
1710
+ if cache_key:
1711
+ self._safe_sigma_loader(cache_key)
1712
+ #self.load_or_regenerate_sigmas(cache_key)
1713
+ loaded_data = torch.load(self.cache_file.replace('.txt', '.pt'), map_location=self.device)
1714
+
1715
+ sigmas = loaded_data['sigma_values'].to(self.device)
1716
+ self.loaded_sigmas = sigmas # No need to call torch.tensor again
1717
+
1718
+ loaded_hash = loaded_data['sigma_hash']
1719
+
1720
+ steps = loaded_data['steps']
1721
+ sigma_min = loaded_data['sigma_min']
1722
+ sigma_max = loaded_data['sigma_max']
1723
+ rho = loaded_data['rho']
1724
+ device = loaded_data['device']
1725
+ schedule_type = loaded_data['schedule_type']
1726
+ decay_pattern = loaded_data['decay_pattern']
1727
+
1728
+ restored_config = loaded_data['full_config']
1729
+
1730
+
1731
+ # Optionally overwrite current settings with restored settings
1732
+ self.settings.update(restored_config)
1733
+
1734
+ #self.load_sigmas_with_hash_validation(self, loaded_data, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_key, suffix=None)
1735
+ self.log(f"[Cache Loaded] Sigma schedule, hash, and config loaded from: {self.cache_file.replace('.pt', '.txt')}")
1736
+ '''
1737
+ acceptable_keys = [
1738
+ "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
1739
+ "early_stopping_threshold", "initial_step_size",
1740
+ "final_step_size", "initial_noise_scale", "final_noise_scale",
1741
+ "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
1742
+ ]
1743
+
1744
+ for key in acceptable_keys:
1745
+ default_val = self.settings[key]
1746
+ value = self.get_random_or_default(key, default_val)
1747
+ setattr(self, key, value)
1748
+
1749
+ if self.steps is None:
1750
+ raise ValueError("Number of steps must be provided.")
1751
+ if isinstance(self.device, str):
1752
+ self.device = torch.device(self.device)
1753
+ self.config_values()
1754
+ self.generate_sigmas_schedule(mode='prepass')
1755
+
1756
+ self.predicted_stop_step = self.steps if None else self.original_steps
1757
+ if self.sharpen_last_n_steps > len(self.sigs):
1758
+ self.sharpen_last_n_steps = len(self.sigs)
1759
+ self.log(f"[Sharpening Notice] Requested last {self.sharpen_last_n_steps} steps exceeds sequence length. Using entire sequence instead.")
1760
+
1761
+ self.visual_sigma = max(0.8, self.sigma_min * self.min_visual_sigma)
1762
+
1763
+ self.blend_sigma_sequence(
1764
+ sigmas_karras=None,
1765
+ sigmas_exponential=None,
1766
+ pre_pass = True,
1767
+ blend_methods=self.blend_methods,
1768
+ blend_weights = self.blend_weights
1769
+ )
1770
+ if torch.isnan(self.sigs).any() or torch.isinf(self.sigs).any():
1771
+ raise ValueError("Invalid sigma values detected (NaN or Inf).")
1772
+ final_steps = self.sigs[:self.predicted_stop_step].to(self.device)
1773
+ # Store the results for later use in compute_sigmas
1774
+ self.final_steps = final_steps
1775
+ if self.blending_mode == 'default':
1776
+ self.final_sigmas_karras = self.sigmas_karras
1777
+ self.final_sigmas_exponential = self.sigmas_exponential
1778
+ self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
1779
+ self.log(f"final sigmas karras: {self.final_sigmas_karras}")
1780
+ else:
1781
+ # For multi-method blending
1782
+ self.final_sigmas_blended = torch.tensor(self.blended_sigmas, device=self.device)
1783
+
1784
+ self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
1785
+ self.log(f"final blended sigmas: {self.final_sigmas_blended}")
1786
+
1787
+ # Optionally log the contributing sigma sequences for debugging
1788
+ for idx, (method, sigmas) in enumerate(zip(self.blend_methods, self.all_sigmas)):
1789
+ self.log(f"Method: {method}, Sigma sequence: {sigmas}")
1790
+
1791
+
1792
+ '''
1793
+ # Build cache key
1794
+ sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
1795
+
1796
+ if self.save_prepass_sigmas:
1797
+ save_data = {
1798
+ 'sigma_values': sigmas.cpu(), # Keep as tensor
1799
+ 'sigma_hash': sigma_hash,
1800
+ 'steps': steps,
1801
+ 'sigma_min': sigma_min,
1802
+ 'sigma_max': sigma_max,
1803
+ 'rho': rho,
1804
+ 'device': device,
1805
+ 'schedule_type': schedule_type,
1806
+ 'decay_pattern': decay_pattern,
1807
+ 'full_config': self.settings # Save as raw dict
1808
+ }
1809
+
1810
+ # Save directly with torch.save in .pt format
1811
+ torch.save(save_data, self.cache_file) # Assuming self.cache_file has .pt extension
1812
+ self.log(f"[Sigma Saver] Final sigmas saved to: {self.cache_file}")
1813
+ '''
1814
+ def load_or_regenerate_sigmas(self, cache_key):
1815
+ if self.load_sigma_cache and cache_key:
1816
+ try:
1817
+ loaded_data = torch.load(self.cache_file, map_location=self.device)
1818
+ sigmas = loaded_data['sigma_values'].to(self.device)
1819
+
1820
+ except FileNotFoundError:
1821
+ self.log(f"[Cache Warning] Cache file not found: {self.cache_file}")
1822
+ self.log(f"[Cache Recovery] Automatically recomputing sigma schedule.")
1823
+ _, _, _, sigmas = self._generate_sigmas(
1824
+ self.steps,
1825
+ self.sigma_min,
1826
+ self.sigma_max,
1827
+ self.rho,
1828
+ self.device,
1829
+ self.schedule_type,
1830
+ self.decay_pattern
1831
+ )
1832
+
1833
+ # Recompute if cache is disabled or failed
1834
+ _, _, _, sigmas = self._generate_sigmas(
1835
+ self.steps,
1836
+ self.sigma_min,
1837
+ self.sigma_max,
1838
+ self.rho,
1839
+ self.device,
1840
+ self.schedule_type,
1841
+ self.decay_pattern
1842
+ )
1843
+ '''
1844
+ if self.save_prepass_sigmas:
1845
+ # Optional: Cache the recomputed sigma schedule
1846
+ torch.save(save_data, self.prepass_save_file)
1847
+ self.log(f"[Sigma Saver] Final sigmas saved to: {self.prepass_save_file}")
1848
+
1849
+ #self.sigma_cache[cache_key] = sigmas
1850
+ '''
1851
+ return sigmas
1852
+
1853
+ def compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type=None, decay_pattern=None, cache_key=None)->torch.Tensor:
1854
+ """
1855
+ Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
1856
+
1857
+ Parameters:
1858
+ n (int): Number of steps.
1859
+ sigma_min (float): Minimum sigma value.
1860
+ sigma_max (float): Maximum sigma value.
1861
+ device (torch.device): The device on which to perform computations (e.g., 'cuda' or 'cpu').
1862
+ start_blend (float): Initial blend factor for dynamic blending.
1863
+ end_bend (float): Final blend factor for dynamic blending.
1864
+ sharpen_factor (float): Sharpening factor to be applied adaptively.
1865
+ early_stopping_threshold (float): Threshold to trigger early stopping.
1866
+ initial_step_size (float): Initial step size for adaptive step size calculation.
1867
+ final_step_size (float): Final step size for adaptive step size calculation.
1868
+ initial_noise_scale (float): Initial noise scale factor.
1869
+ final_noise_scale (float): Final noise scale factor.
1870
+ step_size_factor: Adjust to compensate for oversmoothing
1871
+ noise_scale_factor: Adjust to provide more variation
1872
+
1873
+ Returns:
1874
+ torch.Tensor: A tensor of blended sigma values.
1875
+ """
1876
+ '''
1877
+ if self.load_sigma_cache and cache_key:
1878
+ sigmas = self._safe_sigma_loader(cache_key)
1879
+ if sigmas is None:
1880
+ self.log(f"[Cache Recovery] No valid cache found for key: {cache_key}. Recomputing sigma schedule.")
1881
+ _, _, _, sigmas = self._generate_sigmas(
1882
+ self.steps,
1883
+ self.sigma_min,
1884
+ self.sigma_max,
1885
+ self.rho,
1886
+ self.device,
1887
+ self.schedule_type,
1888
+ self.decay_pattern
1889
+ )
1890
+ else:
1891
+ self.log(f"[Cache Hit] Sigma schedule successfully loaded from cache.")
1892
+ self.sigs = sigmas
1893
+ '''
1894
+ acceptable_keys = [
1895
+ "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
1896
+ "early_stopping_threshold", "initial_step_size",
1897
+ "final_step_size", "initial_noise_scale", "final_noise_scale",
1898
+ "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
1899
+ ]
1900
+
1901
+ if self.skip_prepass:
1902
+ for key in acceptable_keys:
1903
+ default_val = self.settings[key]
1904
+ value = self.get_random_or_default(key, default_val)
1905
+ setattr(self, key, value)
1906
+
1907
+ self.log(f"Using device: {self.device}")
1908
+ self.config_values()
1909
+ self.generate_sigmas_schedule(mode='final')
1910
+ if hasattr(self, 'final_sigmas_karras'):
1911
+ self.sigs = torch.zeros_like(self.final_sigmas_karras).to(self.device)
1912
+ else:
1913
+ self.sigs = torch.zeros_like(self.sigmas_karras).to(self.device)
1914
+
1915
+ self.blend_sigma_sequence(
1916
+ sigmas_karras=self.final_sigmas_karras if hasattr(self, 'final_sigmas_karras') else self.sigmas_karras,
1917
+ sigmas_exponential=self.final_sigmas_exponential if hasattr(self, 'final_sigmas_exponential') else self.sigmas_exponential,
1918
+ pre_pass=False,
1919
+ blend_methods=self.blend_methods,
1920
+ blend_weights = self.blend_weights
1921
+
1922
+ )
1923
+ self.sigma_variance = torch.var(self.sigs).item()
1924
+ if self.sharpen_mode in ['last_n', 'both']:
1925
+ if self.sigma_variance < self.sharpen_variance_threshold:
1926
+ # Apply full sharpening
1927
+ self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1928
+ sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1929
+ self.sigs = self.sigs * self.sharpen_mask
1930
+ self.log(f"[Sharpen Mask] Full sharpening applied (low variance). Steps: {sharpen_indices}")
1931
+ else:
1932
+ # Apply sharpening only to the last N steps
1933
+ recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
1934
+ sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1935
+ sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
1936
+ self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
1937
+
1938
+ # Now loop per step if desired (safely inside this block)
1939
+ for j in range(len(self.sigs) - self.sharpen_last_n_steps, len(self.sigs)):
1940
+ if self.sigs[j] < self.sigma_min * 1.5:
1941
+ old_value = self.sigs[j].item()
1942
+ self.sigs[j] = self.sigs[j] * self.sharpness
1943
+ self.log(f"[Sharpening] Step {j+1}: Applied sharpening. Sigma changed from {old_value:.6f} to {self.sigs[j].item():.6f}")
1944
+ else:
1945
+ self.log(f"[Sharpening] Step {j+1}: No sharpening applied. Sigma: {self.sigs[j].item():.6f}")
1946
+
1947
+ if self.sharpen_mode in ['full', 'both']:
1948
+ # Optional: Additional full sharpening (if needed)
1949
+ self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1950
+ sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1951
+ self.sigs = self.sigs * self.sharpen_mask
1952
+ self.log(f"[Sharpen Mask] Full sharpening applied at steps: {sharpen_indices}")
1953
+
1954
+ '''
1955
+ sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
1956
+ if self.settings.get('save_sigma_cache', False):
1957
+ save_data = {
1958
+ 'sigma_values': self.sigs.cpu().tolist(),
1959
+ 'sigma_hash': sigma_hash,
1960
+ 'steps': steps,
1961
+ 'sigma_min': sigma_min,
1962
+ 'sigma_max': sigma_max,
1963
+ 'rho': rho,
1964
+ 'device': device,
1965
+ 'schedule_type': schedule_type,
1966
+ 'decay_pattern': decay_pattern,
1967
+ 'full_config': json.dumps(self.settings)
1968
+ }
1969
+
1970
+
1971
+ if self.save_sigma_cache:
1972
+ torch.save(save_data, self.final_save_file)
1973
+ self.log(f"[Sigma Saver] Final sigmas saved to: {self.final_save_file}")
1974
+ '''
1975
+ #self.log(f"[DEBUG]Final Output: Skip Prepass: {self.skip_prepass}. Original requested steps: {self.original_steps}. Self.steps = {self.steps} for tensor sigs: {self.sigs})")
1976
+ return self.sigs.to(self.device)
1977
+
1978
+ def get_sigma_from_cache(self, cache_key):
1979
+ """
1980
+ Safely retrieves a sigma sequence from cache.
1981
+ Always returns a detached copy to prevent in-place modification of cached data.
1982
+ """
1983
+ if cache_key in self.sigma_cache:
1984
+ cached_sigmas = self.sigma_cache[cache_key]
1985
+ self.log(f"[Cache Hit] Returning cached sigma sequence for key: {cache_key}")
1986
+
1987
+ # If tensor, clone and detach
1988
+ if isinstance(cached_sigmas, torch.Tensor):
1989
+ return cached_sigmas.clone().detach().to(self.device)
1990
+
1991
+ # If list, deep copy
1992
+ elif isinstance(cached_sigmas, list):
1993
+ import copy
1994
+ return copy.deepcopy(cached_sigmas)
1995
+
1996
+ # If other types, return as-is (optional: tighten later)
1997
+ else:
1998
+ return cached_sigmas
1999
+
2000
+ else:
2001
+ self.log(f"[Cache Miss] Cache key not found: {cache_key}")
2002
+ return None
sd_simple_kes_v3_fix?/validate_config.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+
3
+ RANDOMIZATION_TYPE_ALIASES = {
4
+ # Asymmetric
5
+ 'asymmetric': 'asymmetric', 'assym': 'asymmetric', 'a': 'asymmetric', 'asym': 'asymmetric', 'A': 'asymmetric',
6
+ # Symmetric
7
+ 'symmetric': 'symmetric', 'sym': 'symmetric', 's': 'symmetric', 'S': 'symmetric',
8
+ # Logarithmic
9
+ 'logarithmic': 'logarithmic', 'log': 'logarithmic', 'l': 'logarithmic', 'L': 'logarithmic',
10
+ # Exponential
11
+ 'exponential': 'exponential', 'exp': 'exponential', 'e': 'exponential', 'E': 'exponential',
12
+ }
13
+
14
+ DEFAULT_RANDOMIZATION_TYPE = 'asymmetric'
15
+ DEFAULT_RANDOMIZATION_PERCENT = 0.2
16
+
17
+ # Base default values
18
+ BASE_DEFAULTS = {
19
+ 'sigma_min': 0.05,
20
+ 'sigma_max': 27.5,
21
+ 'start_blend': 0.1,
22
+ 'end_blend': 0.5,
23
+ 'sharpness': 1.0,
24
+ 'early_stopping_threshold': 0.01,
25
+ 'initial_step_size': 0.9,
26
+ 'final_step_size': 0.2,
27
+ 'initial_noise_scale': 1.25,
28
+ 'final_noise_scale': 0.8,
29
+ 'smooth_blend_factor': 9.0,
30
+ 'step_size_factor': 0.8,
31
+ 'noise_scale_factor': 0.8,
32
+ 'rho': 8.0
33
+ }
34
+
35
+ def validate_config(config: Dict[str, Any], logger: Optional[Any] = None) -> Dict[str, Any]:
36
+ updated_config = config.copy()
37
+
38
+ def log(message):
39
+ if logger:
40
+ logger.log(message)
41
+ else:
42
+ print(message)
43
+
44
+ # Step 1: Set all base defaults if missing
45
+ for key, base_value in BASE_DEFAULTS.items():
46
+ if key not in updated_config:
47
+ updated_config[key] = base_value
48
+ log(f"[Config Correction] {key} missing. Set to base default: {base_value}")
49
+
50
+ # Ensure _rand flag exists and is a boolean
51
+ rand_flag = f"{key}_rand"
52
+ if rand_flag not in updated_config or not isinstance(updated_config.get(rand_flag), bool):
53
+ updated_config[rand_flag] = False
54
+ log(f"[Config Correction] {rand_flag} missing or invalid. Set to False.")
55
+
56
+ # Ensure _enable_randomization_type flag exists and is a boolean
57
+ randomization_flag = f"{key}_enable_randomization_type"
58
+ if randomization_flag not in updated_config or not isinstance(updated_config.get(randomization_flag), bool):
59
+ updated_config[randomization_flag] = False
60
+ log(f"[Config Correction] {randomization_flag} missing or invalid. Set to False.")
61
+
62
+ # Ensure randomization_type exists
63
+ randomization_type_key = f"{key}_randomization_type"
64
+ if randomization_type_key not in updated_config:
65
+ updated_config[randomization_type_key] = DEFAULT_RANDOMIZATION_TYPE
66
+ log(f"[Config Correction] {randomization_type_key} missing. Set to '{DEFAULT_RANDOMIZATION_TYPE}'.")
67
+
68
+ # Ensure randomization_percent exists
69
+ randomization_percent_key = f"{key}_randomization_percent"
70
+ if randomization_percent_key not in updated_config:
71
+ updated_config[randomization_percent_key] = DEFAULT_RANDOMIZATION_PERCENT
72
+ log(f"[Config Correction] {randomization_percent_key} missing. Set to {DEFAULT_RANDOMIZATION_PERCENT}.")
73
+
74
+ # Ensure _rand_min and _rand_max exist
75
+ min_key = f"{key}_rand_min"
76
+ max_key = f"{key}_rand_max"
77
+ percent = updated_config[randomization_percent_key]
78
+
79
+ if min_key not in updated_config:
80
+ updated_config[min_key] = updated_config[key] * (1 - percent)
81
+ log(f"[Config Correction] {min_key} missing. Auto-calculated from base.")
82
+
83
+ if max_key not in updated_config:
84
+ updated_config[max_key] = updated_config[key] * (1 + percent)
85
+ log(f"[Config Correction] {max_key} missing. Auto-calculated from base.")
86
+
87
+ log("[Config Validation] Config validated and missing values filled successfully.")
88
+ return updated_config