| from modules.sd_simple_kes.get_sigmas import scheduler_registry
|
| from modules.sd_simple_kes.validate_config import validate_config
|
| from modules.sd_simple_kes.plot_sigma_sequence import plot_sigma_sequence
|
| import torch
|
| import torch.nn.functional as F
|
| import logging
|
| import os
|
| import yaml
|
| import random
|
| from datetime import datetime
|
| import warnings
|
| import math
|
| from typing import Optional
|
| import json
|
| import numpy as np
|
| import hashlib
|
| import glob
|
| import re
|
| import inspect
|
| import copy
|
|
|
|
|
| def simple_kes_scheduler(n: int, sigma_min: float, sigma_max: float, device: torch.device) -> torch.Tensor:
|
| scheduler = SimpleKEScheduler(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
|
| return scheduler()
|
|
|
| class SharedLogger:
|
| def __init__(self, debug=False):
|
| self.debug = debug
|
| self.log_buffer = []
|
| self.prepass_log_buffer=[]
|
|
|
| def log(self, message):
|
| if self.debug:
|
| self.log_buffer.append(message)
|
| def prepass_log(self, message):
|
| if self.debug:
|
| self.prepass_log_buffer.append(message)
|
|
|
| class SimpleKEScheduler:
|
| """
|
| SimpleKEScheduler
|
| ------------------
|
| A hybrid scheduler that combines Karras-style sigma sampling
|
| with exponential decay and blending controls. Supports parameterized
|
| customization for use in advanced diffusion pipelines.
|
|
|
| Parameters:
|
| - steps (int): Number of inference steps.
|
| - device (torch.device): Target device (e.g. 'cuda').
|
| - config (dict): Scheduler-specific configuration options.
|
|
|
| Usage:
|
| scheduler = SimpleKEScheduler(steps=30, device='cuda', config=config_dict)
|
| sigmas = scheduler.get_sigmas()
|
| """
|
|
|
| def __init__(self, n: int, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, device: torch.device = "cpu", logger=None, **kwargs)->torch.Tensor:
|
| self.steps = n if n is not None else 10
|
| self.original_steps = n
|
| self.device = torch.device(device if isinstance(device, str) else device)
|
| self.sigma_min = sigma_min
|
| self.sigma_max = sigma_max
|
| self.scheduler_registry = scheduler_registry
|
| self.RANDOMIZATION_TYPE_ALIASES = {
|
| 'symmetric': 'symmetric', 'sym': 'symmetric', 's': 'symmetric',
|
| 'asymmetric': 'asymmetric', 'assym': 'asymmetric', 'a': 'asymmetric',
|
| 'logarithmic': 'logarithmic', 'log': 'logarithmic', 'l': 'logarithmic',
|
| 'exponential': 'exponential', 'exp': 'exponential', 'e': 'exponential'
|
| }
|
| self._config_schema = {
|
| 'min_visual_sigma': (int, 10),
|
| 'safety_minimum_stop_step': (int, 10),
|
| 'auto_tail_smoothing': (bool, False),
|
| 'auto_stabilization_sequence': (list, [
|
| 'smooth_interpolation', 'append_tail', 'blend_tail', 'apply_decay', 'progressive_decay'
|
| ]),
|
| 'sharpen_variance_threshold': (float, 0.01),
|
| 'sharpen_last_n_steps': (int, 10),
|
| 'decay_pattern': (str, 'zero'),
|
| 'sigma_save_subfolder': (str, 'saved_sigmas'),
|
| 'load_sigma_cache': (bool, False),
|
| 'save_sigma_cache': (bool, False),
|
| 'graph_save_directory': (str, 'modules/sd_simple_kes/image_generation_data'),
|
| 'graph_save_enable': (bool, False),
|
| 'exp_power': (int, 2),
|
| 'recent_change_convergence_delta': (float, 0.02),
|
| 'sigma_variance_scale': (float, 0.05),
|
| 'allow_step_expansion': (bool, False),
|
| 'sharpen_mode': (str, 'full'),
|
| 'blend_midpoint': (float, 0.5),
|
| 'early_stopping_method': (str, 'mean'),
|
| 'save_prepass_sigmas': (bool, False),
|
| 'global_randomize': (bool, False),
|
| 'skip_prepass': (bool, False),
|
| 'load_prepass_sigmas': (bool, False)
|
| }
|
| self._overrides = kwargs.copy()
|
| default_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes", "kes_config", "default_config.yaml")))
|
| self.default_config = self._load_config(default_config_path)
|
| user_config_path = os.path.abspath(os.path.normpath(os.path.join("modules", "sd_simple_kes", "kes_config", "user_config.yaml")))
|
| self.user_config = self._load_config(user_config_path)
|
| self.config_data = {**self.default_config, **self.user_config}
|
| self.config = self.config_data.copy()
|
| self.settings = self.config.copy()
|
| for key, value in self.settings.items():
|
| setattr(self, key, value)
|
| if self.global_randomize:
|
| self.apply_global_randomization()
|
| self.re_randomizable_keys = [
|
| "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| "early_stopping_threshold",
|
| "initial_step_size", "final_step_size",
|
| "initial_noise_scale", "final_noise_scale",
|
| "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| ]
|
| for key in self.re_randomizable_keys:
|
| value = self.settings.get(key)
|
| if value is None:
|
| raise KeyError(f"[KEScheduler] Missing required setting: {key}")
|
| setattr(self, key, value)
|
| self.debug = self.settings.get('debug', False)
|
|
|
| logger = SharedLogger(debug=kwargs.get('debug', False))
|
| self.logger=logger
|
| self.log = self.logger.log
|
| self.prepass_log = self.logger.prepass_log
|
| self._validate_config_types()
|
| validate_config(self.config, logger=self.logger)
|
|
|
| for k, v in self._overrides.items():
|
| if k in self.settings:
|
| self.settings[k] = v
|
| setattr(self, k, v)
|
| self.auto_mode_enabled = self.settings.get('auto_tail_smoothing', False)
|
| self.initialize_generation_filename()
|
| self.relative_converged = False
|
| self.max_converged = False
|
| self.delta_converged = False
|
| self.early_stop_triggered = False
|
| self.sigma_cache = {}
|
| self.BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| self.cache_dir = os.path.join(self.BASE_DIR, 'cache')
|
| self.sigma_save_folder = os.path.join(self.cache_dir, self.sigma_save_subfolder)
|
| self.blend_method_dict = self.settings.get('blend_methods', {
|
| 'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
|
| 'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
|
| })
|
| self.blend_methods = list(self.blend_method_dict.keys())
|
| self.blend_weights = [self.blend_method_dict[method]['weight'] for method in self.blend_methods]
|
| self.loaded_sigmas = None
|
| self.sigma_sequences = {}
|
|
|
| self.schedule_type = None
|
| self.suffix = None
|
| self.ext = None
|
| self._create_directories()
|
| self._finalize_init()
|
|
|
| def _create_directories(self):
|
|
|
| os.makedirs(self.cache_dir, exist_ok=True)
|
| os.makedirs(self.sigma_save_folder, exist_ok=True)
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| self.extras_log_filename = os.path.join(
|
| self.settings.get('log_save_directory', 'modules/sd_simple_kes/image_generation_data'),
|
| f'all_extras_log_{timestamp}.txt'
|
| )
|
|
|
|
|
|
|
| def _finalize_init(self):
|
|
|
| self.prepass_save_file = self.build_sigma_cache_filename(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| schedule_type='karras',
|
| decay_pattern=self.decay_pattern,
|
| cache_dir=self.sigma_save_folder,
|
| suffix='prepass',
|
| ext = 'pt'
|
| )
|
| self.final_save_file = self.build_sigma_cache_filename(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| schedule_type='karras',
|
| decay_pattern=self.decay_pattern,
|
| cache_dir=self.sigma_save_folder,
|
| suffix='final',
|
| ext = 'pt'
|
| )
|
| self.load_blend_method_sigmas()
|
| def _load_config(self, config_path, **kwargs):
|
| self.logger = SharedLogger(debug=kwargs.get('debug', False))
|
|
|
|
|
| try:
|
| with open(config_path, 'r', encoding='utf-8') as f:
|
| user_config = yaml.safe_load(f)
|
| return user_config or {}
|
| except FileNotFoundError:
|
| self.logger.log(f"Config file not found: {config_path}. Using empty config.")
|
| return {}
|
| except yaml.YAMLError as e:
|
| self.logger.log(f"Error loading config file: {e}")
|
| return {}
|
|
|
|
|
| def _validate_config_types(self):
|
| '''
|
| Both corrects self.settings with an updated validated config, and also writes a corrected_user_config file with the correct types
|
| '''
|
| validated_settings = {}
|
| corrected_lines = []
|
|
|
| corrected_lines.append("# Corrected User Config (Invalid entries auto-corrected)\n")
|
|
|
| for key, (expected_type, default_value) in self._config_schema.items():
|
| value = self.settings.get(key, default_value)
|
| if isinstance(value, expected_type):
|
| validated_settings[key] = value
|
| corrected_lines.append(f"{key}: {value}")
|
|
|
| else:
|
| self.log(f"[Config Warning] Invalid type for '{key}': Expected {expected_type.__name__}, got {type(value).__name__}. Using default: {default_value}")
|
| validated_settings[key] = default_value
|
| corrected_lines.append(f"{key}: {default_value} # Invalid type: {type(value).__name__}, replaced with default")
|
|
|
|
|
| with open('corrected_user_config.yaml', 'w', encoding='utf-8') as f:
|
| f.write('\n'.join(corrected_lines))
|
|
|
| self.settings.update(validated_settings)
|
|
|
| for key, value in self.settings.items():
|
| setattr(self, key, value)
|
|
|
|
|
| def _log_extras_to_file(self, all_extras):
|
| """
|
| Logs the extras returned by each scheduler to a dedicated 'all_extras' log file.
|
|
|
| This method iterates through the list of extras for each blend method and writes them
|
| to a separate log file for easier tracking, debugging, and future analysis.
|
|
|
| Parameters:
|
| ----------
|
| all_extras : list
|
| A list of extras returned by each scheduler, aligned with the blend_methods list.
|
| Each item in the list corresponds to the extras provided by a specific scheduler.
|
|
|
| Notes:
|
| -----
|
| - If extras are present, they are logged under their respective scheduler names.
|
| - If extras contain complex objects, the method attempts to serialize them using JSON.
|
| - Non-serializable extras are logged as raw text.
|
|
|
| Purpose:
|
| -------
|
| This log file is intended for developers to track additional outputs that are not
|
| directly part of the sigma, tails, or decay sequences but may be useful for diagnostics,
|
| metadata, or advanced scheduler behaviors.
|
| """
|
| with open(self.extras_log_filename, 'a', encoding='utf-8') as f:
|
| f.write("\n=== New Scheduler Extras ===\n")
|
| for method, extras in zip(self.blend_methods, all_extras):
|
| if extras:
|
| try:
|
| f.write(f"\nScheduler: {method}\n")
|
| f.write(json.dumps(extras, indent=2))
|
| f.write("\n")
|
| except TypeError:
|
| f.write(f"\nScheduler: {method}\n")
|
| f.write(f"Extras (non-serializable): {extras}\n")
|
| f.write("\n============================\n")
|
| def __call__(self):
|
|
|
| if not self.skip_prepass:
|
| self.prepass_compute_sigmas(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| device=self.device,
|
| skip_prepass=self.skip_prepass
|
| )
|
|
|
|
|
| if self.load_prepass_sigmas:
|
| self.generate_sigmas_schedule(mode='prepass')
|
|
|
| if self.load_sigma_cache:
|
| self.generate_sigmas_schedule(mode='final')
|
|
|
| else:
|
|
|
| self.config_values()
|
| self.generate_sigmas_schedule()
|
|
|
| if self.blending_mode == 'default':
|
| self.blend_sigma_sequence(
|
| sigmas_karras= self.scheduler_registry.get('karras')(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| device=self.device,
|
| decay_pattern=self.decay_pattern
|
| )[2],
|
| sigmas_exponential=self.scheduler_registry.get('exponential')(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| device=self.device,
|
| decay_pattern=self.decay_pattern
|
| )[2],
|
| pre_pass=False,
|
| blend_methods=self.blend_methods,
|
| blend_weights=self.blend_weights
|
| )
|
|
|
| else:
|
|
|
| self.blend_sigma_sequence(
|
| sigmas_karras=None,
|
| sigmas_exponential=None,
|
| pre_pass=False,
|
| blend_methods=self.blend_methods,
|
| blend_weights=self.blend_weights
|
| )
|
|
|
| sigmas = self.compute_sigmas(steps=self.steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, rho=self.rho, device=self.device)
|
|
|
|
|
|
|
| if torch.isnan(sigmas).any():
|
| raise ValueError("[SimpleKEScheduler] NaN detected in sigmas")
|
| if torch.isinf(sigmas).any():
|
| raise ValueError("[SimpleKEScheduler] Inf detected in sigmas")
|
| if (sigmas <= 0).all():
|
| raise ValueError("[SimpleKEScheduler] All sigma values are <= 0")
|
| if (sigmas > 1000).all():
|
| raise ValueError("[SimpleKEScheduler] Sigma values are extremely large — might explode the model")
|
|
|
|
|
| if self.debug:
|
| self.save_generation_settings()
|
|
|
| return sigmas
|
|
|
| def _safe_sigma_loader(self, cache_key):
|
| cache_folder = self.sigma_save_folder
|
|
|
|
|
| if not os.path.exists(cache_folder) or not os.listdir(cache_folder):
|
| self.log(f"[Cache Check] Cache folder {cache_folder} is empty or missing. Skipping load.")
|
| return None
|
|
|
|
|
| matching_files = [f for f in os.listdir(cache_folder) if cache_key in f and f.endswith('.pt')]
|
|
|
| if not matching_files:
|
| self.log(f"[Cache Check] No matching cache file found for key: {cache_key}. Skipping load.")
|
| return None
|
|
|
|
|
| filename = os.path.join(cache_folder, matching_files[0])
|
| self.log(f"[Cache Hit] Loading sigma cache from: {filename}")
|
| loaded_data = torch.load(filename, map_location=self.device)
|
| return loaded_data['sigma_values'].to(self.device)
|
|
|
| def call_scheduler(self, method_name, *args, **kwargs):
|
| sigma_sequence = getattr(self, f"sigmas_{method_name}")
|
| if sigma_sequence is None:
|
| self.log(f"No sigma sequence found for method: {method_name}")
|
| return None
|
| return sigma_sequence
|
|
|
| def is_sigma_randomized(self):
|
| return (
|
| self.settings.get('sigma_min_rand', False) or
|
| self.settings.get('sigma_max_rand', False) or
|
| self.settings.get('rho_rand', False) or
|
| self.settings.get('sigma_max_enable_randomization_type', False) or
|
| self.settings.get('sigma_min_enable_randomization_type', False) or
|
| self.settings.get('rho_enable_randomization_type', False)
|
| )
|
|
|
|
|
| def save_sigmas_as_csv(self, sigmas, filename):
|
| np.savetxt(filename, sigmas.cpu().numpy(), delimiter=",")
|
|
|
| def build_sigma_cache_filename(self, steps, sigma_min, sigma_max, rho=None, schedule_type='karras', decay_pattern='zero', cache_dir=r'modules\sd_simple_kes\cache', suffix=None, ext = None or 'txt'):
|
| if cache_dir is None:
|
| cache_dir = r'modules\sd_simple_kes\cache'
|
| if schedule_type == 'karras':
|
| base_filename = f'sigma_{schedule_type}_{steps}steps_rho{rho}_min{sigma_min}_max{sigma_max}_{decay_pattern}'
|
| else:
|
| base_filename = f'sigma_{schedule_type}_{steps}steps_min{sigma_min}_max{sigma_max}_{decay_pattern}'
|
|
|
|
|
| if suffix:
|
| base_filename += f'_{suffix}'
|
| version = self.get_next_version_number(cache_dir, base_filename)
|
| if ext:
|
| version = self.get_next_version_number(cache_dir, base_filename, ext)
|
| filename = f'{version:03d}_{base_filename}.{ext}'
|
| else:
|
|
|
| filename = f'{base_filename}.{ext}'
|
|
|
| return os.path.join(cache_dir, filename)
|
|
|
| def get_next_version_number(self, cache_dir, base_filename,ext=None):
|
| pattern = os.path.join(cache_dir, f'*_{base_filename}')
|
| if ext:
|
| pattern= os.path.join(cache_dir, f'*_{base_filename}.{ext}')
|
| existing_files = glob.glob(pattern)
|
|
|
| version_numbers = []
|
| for file in existing_files:
|
| match = re.search(r'(\d{3})_' + re.escape(base_filename), os.path.basename(file))
|
| if match:
|
| version_numbers.append(int(match.group(1)))
|
|
|
| if version_numbers:
|
| return max(version_numbers) + 1
|
| else:
|
| return 1
|
|
|
| def get_sigma_with_cache(self, steps, sigma_min, sigma_max, rho=7.0, device='cpu',
|
| schedule_type='karras', decay_pattern=None, cache_dir=None, cache_file=None,
|
| suffix=None, ext=None, mode=None, cache_key = None):
|
| self.steps = steps
|
| self.sigma_min = sigma_min
|
| self.sigma_max = sigma_max
|
| self.rho = rho
|
| self.device = device
|
| self.schedule_type = schedule_type
|
| self.decay_pattern = decay_pattern
|
| self.cache_dir = cache_dir
|
| self.cache_file = cache_file
|
| self.suffix = suffix
|
| self.ext = ext
|
| self.mode = mode
|
| self.cache_key = cache_key
|
|
|
|
|
| cached_sigmas = self.get_sigma_from_cache(cache_key)
|
|
|
| if cached_sigmas is not None:
|
| return cached_sigmas
|
|
|
|
|
| if self.is_sigma_randomized():
|
| _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| self.sigma_cache[cache_key] = sigmas
|
| return sigmas
|
|
|
|
|
| if self.loaded_sigmas is None:
|
| _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| self.loaded_sigmas = sigmas
|
| self.sigma_cache[cache_key] = sigmas
|
| return sigmas
|
|
|
|
|
| if mode == 'prepass':
|
| self.cache_file = self.prepass_save_file
|
| elif mode == 'final':
|
| self.cache_file = self.final_save_file
|
| else:
|
| self.cache_file = self.build_sigma_cache_filename(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_dir)
|
|
|
|
|
| if mode in ['prepass', 'final'] and self.load_prepass_sigmas:
|
| loaded_sigmas = self.load_sigmas_with_hash_validation(
|
| filename=self.cache_file,
|
| steps=steps,
|
| sigma_min=sigma_min,
|
| sigma_max=sigma_max,
|
| rho=rho,
|
| device=device,
|
| schedule_type=schedule_type,
|
| decay_pattern=decay_pattern,
|
| cache_key = cache_key
|
| )
|
|
|
| if loaded_sigmas is not None:
|
| self.loaded_sigmas = loaded_sigmas
|
| self.sigma_cache[cache_key] = loaded_sigmas
|
| return loaded_sigmas.to(device)
|
| else:
|
| self.log("[Cache Recovery] Cache load failed. Recalculating sigma schedule.")
|
| '''
|
| # Cache miss → recalculate
|
| self.log(f"[Cache Miss] Recalculating sigma schedule for: {self.cache_file}")
|
| _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| self.sigma_cache[cache_key] = sigmas
|
| '''
|
| return sigmas
|
|
|
|
|
| def load_sigmas_with_hash_validation(self, filename, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, cache_key = None, suffix=None):
|
| if self.load_prepass_sigmas:
|
| if cache_key:
|
| try:
|
| loaded_data = torch.load(filename, map_location=self.device)
|
| self.loaded_sigmas = loaded_data['sigma_values'].to(self.device)
|
| loaded_hash = loaded_data['sigma_hash']
|
|
|
| expected_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data, suffix)
|
|
|
| if loaded_hash != expected_hash:
|
| self.log(f"[Sigma Validator] Hash mismatch. Expected: {expected_hash}, Found: {loaded_hash}. Recalculating.")
|
| return None
|
| else:
|
| self.log(f"[Sigma Validator] Hash validated successfully for file: {filename}")
|
| return self.loaded_sigmas
|
|
|
| except Exception as e:
|
| self.log("[Cache Recovery] Sigma cache invalid or missing. Recalculating sigmas.")
|
| _, _, _, sigmas = self._generate_sigmas(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern)
|
| return sigmas
|
|
|
|
|
| def generate_sigma_hash(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, save_data=None, suffix=None):
|
| data_string = f'{steps}_{sigma_min}_{sigma_max}_{rho}_{device}_{schedule_type}_{decay_pattern}_{suffix}'
|
| hash_object = hashlib.sha256(data_string.encode())
|
| return hash_object.hexdigest()[:12]
|
|
|
| def _generate_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern=None, decay_mode=None, tail_steps=None):
|
| scheduler_func = self.scheduler_registry.get(schedule_type)
|
|
|
| if scheduler_func is None:
|
| raise ValueError(f"Unknown schedule type: {schedule_type}")
|
|
|
| tails, decay, extras, sigmas = self.call_scheduler_function(
|
| scheduler_func,
|
| steps=steps,
|
| sigma_min=sigma_min,
|
| sigma_max=sigma_max,
|
| rho=rho,
|
| device=device,
|
| decay_pattern=decay_pattern,
|
| decay_mode=decay_mode,
|
| tail_steps=tail_steps
|
| )
|
|
|
| return tails, decay, extras, sigmas
|
|
|
|
|
|
|
| def initialize_generation_filename(self, folder=None, base_name="generation_log", ext="txt"):
|
| """
|
| Initialize the log filename early so it can be used throughout the process.
|
| """
|
| if folder is None:
|
| folder = self.settings.get('log_save_directory', 'modules/sd_simple_kes/image_generation_data')
|
| folder = os.path.abspath(os.path.normpath(folder))
|
|
|
| os.makedirs(folder, exist_ok=True)
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
| self.log_filename = os.path.join(folder, f"{base_name}_{timestamp}.{ext}")
|
|
|
| def save_generation_settings(self):
|
| """
|
| Save the generation log with configurable directory, base name, and extension.
|
|
|
| Parameters:
|
| - folder (str): Optional custom directory to save the log file.
|
| - base_name (str): The base name for the file (default is 'generation_log').
|
| - ext (str): The file extension to use (default is 'txt').
|
| """
|
| with open(self.log_filename, "w", encoding = 'utf-8') as f:
|
| for line in self.logger.log_buffer:
|
| f.write(f"{line}\n")
|
| for line in self.logger.prepass_log_buffer:
|
| f.write(f"{line}\n")
|
| self.log(f"[SimpleKEScheduler] Generation settings saved to {self.log_filename}")
|
|
|
| self.logger.log_buffer.clear()
|
| self.logger.prepass_log_buffer.clear()
|
|
|
| def save_image_plot(self, sigs, i):
|
| graph_plot = plot_sigma_sequence(
|
| self.sigs[:i + 1],
|
| i,
|
| self.log_filename,
|
| self.graph_save_directory,
|
| self.graph_save_enable
|
| )
|
| self.log(f"Sigma sequence plot saved to {graph_plot}")
|
|
|
|
|
|
|
|
|
| def apply_global_randomization(self):
|
| """Force randomization for all eligible settings by enabling _rand flags and re-randomizing values."""
|
|
|
| for key in list(self.settings.keys()):
|
| if key.endswith("_rand_min") or key.endswith("_rand_max"):
|
| base_key = key.rsplit("_rand_", 1)[0]
|
| rand_flag_key = f"{base_key}_rand"
|
| self.settings[rand_flag_key] = True
|
|
|
| if self.global_randomize:
|
| if key not in self.settings:
|
| raise KeyError(f"[apply_global_randomization] Missing required key: {key}")
|
|
|
| default_val = self.settings[key]
|
| randomized_val = self.get_random_or_default(key, default_val)
|
| self.settings[key] = randomized_val
|
| setattr(self, key, randomized_val)
|
|
|
| def get_randomization_type(self, key_prefix):
|
| """
|
| Retrieves the randomization type for a given key, with fallback to 'asymmetric' if missing.
|
| """
|
| randomization_type_raw = self.settings.get(f'{key_prefix}_randomization_type', 'asymmetric')
|
| randomization_type = self.RANDOMIZATION_TYPE_ALIASES.get(randomization_type_raw.lower(), 'asymmetric')
|
| return randomization_type
|
|
|
| def get_randomization_percent(self, key_prefix):
|
| """
|
| Retrieves the randomization percent for a given key, with fallback to 0.2 if missing.
|
| """
|
| return self.settings.get(f'{key_prefix}_randomization_percent', 0.2)
|
|
|
|
|
| def get_random_between_min_max(self, key_prefix, default_value):
|
| """
|
| Picks a random value between _rand_min and _rand_max if _rand is True.
|
| Otherwise, returns the base value.
|
| """
|
| randomize_flag = self.settings.get(f'{key_prefix}_rand', False)
|
|
|
| if randomize_flag:
|
| rand_min = self.settings.get(f'{key_prefix}_rand_min', default_value)
|
| rand_max = self.settings.get(f'{key_prefix}_rand_max', default_value)
|
|
|
| if rand_min == rand_max:
|
| self.log(f"[Random Range] {key_prefix}: min and max are equal ({rand_min}). Using single value.")
|
| return rand_min
|
|
|
| value = random.uniform(rand_min, rand_max)
|
| self.log(f"[Random Range] {key_prefix}: Picked random value {value} between {rand_min} and {rand_max}")
|
| return value
|
| else:
|
| self.log(f"[Random Range] {key_prefix}: Randomization is OFF. Using base value {default_value}")
|
| return default_value
|
|
|
| def get_random_by_type(self, key_prefix, default_value):
|
| randomization_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
|
|
|
| if not randomization_enabled:
|
| self.log(f"[Randomization Type] {key_prefix}: Randomization type is OFF. Using base value {default_value}")
|
| return default_value
|
|
|
| randomization_type = self.get_randomization_type(key_prefix)
|
| randomization_percent = self.get_randomization_percent(key_prefix)
|
|
|
| if randomization_type == 'symmetric':
|
| rand_min = default_value * (1 - randomization_percent)
|
| rand_max = default_value * (1 + randomization_percent)
|
| self.log(f"[Symmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
|
|
|
| elif randomization_type == 'asymmetric':
|
| rand_min = default_value * (1 - randomization_percent)
|
| rand_max = default_value * (1 + (randomization_percent * 2))
|
| self.log(f"[Asymmetric Randomization] {key_prefix}: Range {rand_min} to {rand_max}")
|
|
|
| elif randomization_type == 'logarithmic':
|
| rand_min = math.log(default_value * (1 - randomization_percent))
|
| rand_max = math.log(default_value * (1 + randomization_percent))
|
| value = math.exp(random.uniform(rand_min, rand_max))
|
| self.log(f"[Logarithmic Randomization] {key_prefix}: Log-space randomization resulted in {value}")
|
| return value
|
|
|
| elif randomization_type == 'exponential':
|
| rand_min = default_value * (1 - randomization_percent)
|
| rand_max = default_value * (1 + randomization_percent)
|
| base_value = random.uniform(rand_min, rand_max)
|
| value = math.exp(base_value)
|
| self.log(f"[Exponential Randomization] {key_prefix}: Randomized exponential value {value}")
|
| return value
|
|
|
| else:
|
| self.log(f"[Randomization Type] {key_prefix}: Invalid randomization type {randomization_type}. Using base value.")
|
| return default_value
|
|
|
| value = random.uniform(rand_min, rand_max)
|
|
|
| self.log(f"[Randomization Type] {key_prefix}: Randomized value {value}")
|
| return value
|
|
|
| def get_random_or_default(self, key_prefix, default_value):
|
| """
|
| Selects randomization method based on active flags:
|
| - If both enabled → prioritize randomization type (or min/max if you prefer).
|
| - If only one enabled → apply that one.
|
| - If neither → return default value.
|
| """
|
| rand_type_enabled = self.settings.get(f'{key_prefix}_enable_randomization_type', False)
|
| min_max_enabled = self.settings.get(f'{key_prefix}_rand', False)
|
|
|
| if rand_type_enabled and min_max_enabled:
|
| self.log(f"[Randomization Policy] Both min/max and randomization type enabled for {key_prefix}. System will prioritize randomization type.")
|
| result_value = self.get_random_by_type(key_prefix, default_value)
|
|
|
| elif rand_type_enabled:
|
| result_value = self.get_random_by_type(key_prefix, default_value)
|
| self.log(f"[Randomization] {key_prefix}: Applied randomization type. Final value: {result_value}")
|
|
|
| elif min_max_enabled:
|
| result_value = self.get_random_between_min_max(key_prefix, default_value)
|
| self.log(f"[Randomization] {key_prefix}: Applied min/max randomization. Final value: {result_value}")
|
|
|
| else:
|
| result_value = default_value
|
| self.log(f"[Randomization] {key_prefix}: No randomization applied. Using default value: {result_value}")
|
|
|
| return result_value
|
|
|
|
|
| def resolve_blend_weights(self, blend_weights, blending_style):
|
| if blending_style == 'softmax':
|
|
|
| blend_weights = torch.tensor(blend_weights)
|
| normalized_weights = torch.softmax(blend_weights, dim=0)
|
| return normalized_weights.tolist()
|
|
|
| elif blending_style == 'explicit':
|
|
|
| return blend_weights
|
|
|
| else:
|
| raise ValueError(f"Unknown blending_style: {blending_style}")
|
|
|
| def extract_scalar(self, value):
|
| if isinstance(value, torch.Tensor):
|
| if value.numel() > 1:
|
| return value.mean().item()
|
| else:
|
| return value.item()
|
| return value
|
|
|
| def _call_legacy_mode(self, schedule_type):
|
|
|
| if schedule_type not in ['karras', 'exponential']:
|
| self.log(f"[Legacy Mode] Unsupported schedule_type: {schedule_type}")
|
| return
|
|
|
|
|
| target_attr = f"sigmas_{schedule_type}"
|
|
|
| scheduler_func = self.scheduler_registry.get(schedule_type)
|
|
|
| tails, decay, extras, sigmas = self.call_scheduler_function(
|
| scheduler_func,
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| device=self.device,
|
| decay_pattern=self.decay_pattern
|
| )
|
|
|
|
|
| setattr(self, target_attr, sigmas)
|
|
|
| self.log(f"[Legacy Mode] Loaded sigma sequence for {schedule_type}. Assigned to self.{target_attr}")
|
|
|
|
|
| def blend_sigma_sequence(self, sigmas_karras=None, sigmas_exponential=None, pre_pass=False, blend_methods=None, blend_weights=None):
|
|
|
| active_methods = [
|
| method for method, config in self.blend_method_dict.items() if config.get('weight', 1.0) > 0.0
|
| ]
|
|
|
| if not active_methods:
|
| self.log("[Blend Config] All weights are zero. Falling back to default blend (karras + exponential).")
|
| '''
|
| # Values set in init, placed here for reference
|
| self.blend_method_dict = {
|
| 'karras': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1},
|
| 'exponential': {'weight': 1.0, 'decay_pattern': 'zero', 'decay_mode': 'append', 'tail_steps': 1}
|
| }
|
| '''
|
| active_methods = list(self.blend_method_dict.keys())
|
|
|
| self.blend_methods = active_methods
|
|
|
|
|
| self.blend_weights = [self.blend_method_dict[m]['weight'] for m in self.blend_methods]
|
|
|
|
|
| if len(self.blend_methods) == 0:
|
| raise ValueError("[SimpleKEScheduler] No active schedulers selected. Please check your blend configuration.")
|
|
|
|
|
|
|
| if len(self.blend_methods) == 1:
|
| self.log(f"[Blend] Only one active scheduler: {self.blend_methods[0]}. Skipping blending, using it directly.")
|
| self.sigs = self.sigma_sequences[self.blend_methods[0]]['sigmas']
|
|
|
|
|
| if len(self.blend_methods) == 2:
|
| self.blending_mode = 'smooth_blend'
|
| elif len(self.blend_methods) > 2:
|
| self.blending_mode = 'weights'
|
|
|
|
|
| if not self.allow_step_expansion and self.auto_mode_enabled:
|
| self.auto_mode_enabled = False
|
| self.log("[Auto Mode] Step expansion disallowed. Auto mode forcibly disabled.")
|
|
|
| self.progress = torch.linspace(0, 1, len(self.sigs)).to(self.device)
|
| self.blended_sigmas = []
|
| self.change_log = []
|
| self.relative_converged = False
|
| self.max_converged = False
|
| self.delta_converged = False
|
| self.early_stop_triggered = False
|
|
|
| """
|
| Computes the blended sigma sequence using adaptive step sizes, dynamic blend factors,
|
| and noise scaling across the progress of the diffusion process.
|
|
|
| This method blends sigma values from the Karras and Exponential schedules using
|
| a smooth, progress-dependent interpolation. It applies adaptive scaling based on
|
| step size and noise scale factors to each sigma in the sequence.
|
|
|
| Parameters:
|
| -----------
|
| sigs : torch.Tensor
|
| A pre-allocated tensor where the computed sigma sequence will be stored.
|
| This tensor must match the shape of the sigma schedules.
|
|
|
| sigmas_karras : torch.Tensor
|
| The sigma sequence generated using the Karras schedule.
|
|
|
| sigmas_exponential : torch.Tensor
|
| The sigma sequence generated using the Exponential schedule.
|
|
|
| Returns:
|
| --------
|
| sigs : torch.Tensor
|
| The final blended and scaled sigma sequence.
|
|
|
| Notes:
|
| ------
|
| - This method is used in both the prepass and final pass of the scheduler.
|
| - The progress tensor is computed linearly from 0 to 1 over the length of the sequence.
|
| - The method uses class attributes for step size factors, blend factors, and noise scaling.
|
| - This method modifies `sigs` in place.
|
| """
|
| if self.sigmas_exponential is None:
|
| self._call_legacy_mode(schedule_type='exponential')
|
|
|
| if self.sigmas_karras is None:
|
| self._call_legacy_mode(schedule_type='karras')
|
|
|
| self.prepass_blended_sigmas = []
|
| self.blended_sigma = None
|
| self.blended_sigmas=[]
|
| for i in range(len(self.sigs)):
|
| if self.step_progress_mode == "linear":
|
| progress_value = self.progress[i]
|
| elif self.step_progress_mode == "exponential":
|
| progress_value = self.progress[i] ** self.exp_power
|
| elif self.step_progress_mode == "logarithmic":
|
| progress_value = torch.log1p(self.progress[i] * (torch.exp(torch.tensor(1.0)) - 1))
|
| elif self.step_progress_mode == "sigmoid":
|
| progress_value = 1 / (1 + torch.exp(-12 * (self.progress[i] - 0.5)))
|
| else:
|
| progress_value = self.progress[i]
|
|
|
| self.dynamic_blend_factor = self.start_blend * (1 - self.progress[i]) + self.end_blend * self.progress[i]
|
| self.smooth_blend = torch.sigmoid((self.dynamic_blend_factor - self.blend_midpoint) * self.smooth_blend_factor)
|
| self.noise_scale = self.initial_noise_scale * (1 - self.progress[i]) + self.final_noise_scale * self.progress[i] * self.noise_scale_factor
|
| self.step_size = self.initial_step_size * (1 - progress_value) + self.final_step_size * progress_value * self.step_size_factor
|
| if self.blending_mode == 'default':
|
|
|
| self.blended_sigma = self.sigmas_karras[i] * (1 - self.smooth_blend) + self.sigmas_exponential[i] * self.smooth_blend
|
|
|
| if self.blending_mode == 'smooth_blend' or (self.blending_mode == 'auto' and len(self.blend_methods) == 2):
|
|
|
| sigma_seq_a = self.sigma_sequences[self.blend_methods[0]]['sigmas']
|
| sigma_seq_b = self.sigma_sequences[self.blend_methods[1]]['sigmas']
|
|
|
|
|
| self.blended_sigma = sigma_seq_a[i] * (1 - self.smooth_blend) + sigma_seq_b[i] * self.smooth_blend
|
|
|
|
|
| elif self.blending_mode == 'weights' or (self.blending_mode == 'auto' and len(self.blend_methods) > 2):
|
|
|
|
|
| if self.blend_weights is None:
|
| self.blend_weights = [1.0] * len(self.all_sigmas)
|
| if self.blending_style is None:
|
| self.blending_style = 'soft_max'
|
|
|
|
|
| resolved_blend_weights = self.resolve_blend_weights(self.blend_weights, self.blending_style)
|
|
|
| weighted_sum = sum(w * self.extract_scalar(s[i]) for w, s in zip(resolved_blend_weights, self.all_sigmas))
|
|
|
|
|
| total_weight = sum(resolved_blend_weights)
|
| self.blended_sigma = weighted_sum / total_weight
|
|
|
| for s in self.all_sigmas:
|
| self.log(f"[DEBUG]sigma sequence shape: {s.shape}")
|
|
|
|
|
| self.sigs[i] = self.blended_sigma * self.step_size * self.noise_scale
|
| self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
|
|
| self.change_log.append(self.extract_scalar(self.change))
|
| relative_sigma_progress = (self.blended_sigma - self.sigs[-1].item()) / self.blended_sigma
|
| recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
|
| max_change = torch.max(recent_changes).item()
|
| mean_change = torch.mean(recent_changes).item()
|
|
|
| self.delta_change = abs(max_change - mean_change)
|
|
|
| self.blended_sigmas.append(self.extract_scalar(self.blended_sigma))
|
|
|
|
|
| self.relative_converged = relative_sigma_progress < 0.05
|
|
|
| self.max_converged = max_change < self.early_stopping_threshold
|
|
|
| self.delta_converged = self.delta_change < self.recent_change_convergence_delta
|
|
|
| if pre_pass:
|
| self.prepass_blended_sigmas=self.blended_sigmas.copy()
|
| self.prepass_blended_sigma = self.blended_sigma
|
| if i >= 2:
|
|
|
| sigma_rate = abs(self.prepass_blended_sigmas[i] - self.prepass_blended_sigmas[i - 1])
|
| previous_sigma_rate = abs(self.prepass_blended_sigmas[i - 1] - self.prepass_blended_sigmas[i - 2])
|
| if sigma_rate > previous_sigma_rate:
|
| self.prepass_log(f"Sigma decline is slowing down → possible plateau at step {i+1}.")
|
|
|
| if i == 0:
|
| self.prepass_log("\n--- Starting Pre-Pass Blending ---\n")
|
| step_label = "Prepass First Step"
|
| elif i == len(self.sigs) - 1:
|
| step_label = "Prepass Last Step"
|
| else:
|
| step_label = None
|
|
|
| if step_label:
|
| self.prepass_log(f"[{step_label} - Step {i}/{len(self.sigs)}] Prepass Blended Sigma: {self.prepass_blended_sigma:.6f}, Final Sigma: {self.sigs[i]:.6f}")
|
| self.prepass_log(f"{step_label} Delta Converged: {self.delta_converged} delta_change: {self.delta_change:.6f}, Target Default Settings:{self.recent_change_convergence_delta}")
|
|
|
|
|
| if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
|
|
|
| self.blended_tensor = torch.tensor(self.prepass_blended_sigmas)
|
| if self.device == 'cpu':
|
| self.sigma_variance = np.var(self.prepass_blended_sigmas)
|
| else:
|
| self.sigma_variance = torch.var(self.sigs).item()
|
|
|
| self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale
|
| self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
|
| self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
|
| self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
| self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
|
| self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
|
| self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
|
|
|
|
|
|
|
| if self.prepass_blended_sigma > self.min_sigma_threshold:
|
| self.prepass_log(f"Prepass Blended Sigma {self.prepass_blended_sigma:.6f} exceeds min sigma threshold {self.min_sigma_threshold:.6f} → Continuing.\n")
|
|
|
|
|
| if self.early_stopping_method == "mean":
|
| mean_change = sum(self.change_log) / len(self.change_log)
|
| if mean_change < self.early_stopping_threshold:
|
| skipped_steps = len(self.sigs) - (i)
|
| self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {mean_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
|
|
| elif self.early_stopping_method == "max":
|
|
|
| if max_change < self.early_stopping_threshold:
|
| skipped_steps = len(self.sigs) - (i)
|
| self.prepass_log(f"Early stopping triggered by mean at step {i}. Mean change: {max_change:.6f}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
|
|
| elif self.early_stopping_method == "sum":
|
| stable_steps = sum(
|
| 1 for j in range(1, len(self.change_log))
|
| if abs(self.change_log[j]) < self.early_stopping_threshold * abs(self.sigs[j])
|
| )
|
| if stable_steps >= 0.8 * len(self.change_log):
|
| skipped_steps = len(self.sigs) - (i)
|
| self.prepass_log(f"Early stopping triggered by sum at step {i}. Stable steps: {stable_steps}/{len(self.change_log)}. Steps used: {i}/{len(self.sigs)}, steps skipped: {skipped_steps}")
|
|
|
| if self.relative_converged and self.max_converged and self.delta_converged:
|
| self.early_stop_triggered = True
|
| self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i+1} ---")
|
| self.prepass_log(f"Relative Sigma Progress: {relative_sigma_progress:.6f}")
|
| self.prepass_log(f"Max Recent Sigma Change: {max_change:.6f}")
|
| self.prepass_log(f"Mean Recent Sigma Change: {mean_change:.6f}")
|
| self.prepass_log(f"Delta Change: {delta_change:.6f} (Target: {self.recent_change_convergence_delta})")
|
| self.prepass_log(f"Early stopping criteria met at step {i+1} based on all convergence checks.")
|
| self.predicted_stop_step = i
|
|
|
| self.save_image_plot(self.sigs, i)
|
| break
|
|
|
|
|
|
|
| if not pre_pass:
|
|
|
| if i == 0:
|
| step_label = "First Step"
|
| self.log("\n" + "=" * 10 + "\n[Start of Sigma Sequence Logging]\n" + "=" * 10)
|
| self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| f"\nStep Size: {self.step_size:.6f}"
|
| f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| f"\nNoise Scale: {self.noise_scale:.6f}"
|
| f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| elif i == len(self.sigs) // 2:
|
| step_label = "Middle Step"
|
| self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| f"\nStep Size: {self.step_size:.6f}"
|
| f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| f"\nNoise Scale: {self.noise_scale:.6f}"
|
| f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| elif i == len(self.sigs) - 1:
|
| step_label = "Last Step"
|
| self.log(f"[{step_label} - Step {i}/{len(self.sigs)}]"
|
| f"\nStep Size: {self.step_size:.6f}"
|
| f"\nDynamic Blend Factor: {self.dynamic_blend_factor:.6f}"
|
| f"\nNoise Scale: {self.noise_scale:.6f}"
|
| f"\nSmooth Blend: {self.smooth_blend:.6f}"
|
| f"\nBlended Sigma: {self.blended_sigma:.6f}"
|
| f"\nFinal Sigma: {self.sigs[i]:.6f}")
|
| self.log("\n" + "=" * 10 + "\n[End of Sigma Sequence Logging]\n" + "=" * 10)
|
| else:
|
| step_label = None
|
|
|
| if i > 0:
|
| self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
|
|
| self.change_log.append(self.extract_scalar(self.change))
|
|
|
|
|
| if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
|
| final_target_sigma = self.sigs[-1].item()
|
| if self.blended_sigma != 0:
|
| relative_sigma_progress = (self.blended_sigma - final_target_sigma) / self.blended_sigma
|
| else:
|
| relative_sigma_progress = 0
|
|
|
| self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
|
| self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
| if self.graph_save_enable:
|
| self.save_image_plot(self.sigs, i)
|
|
|
|
|
|
|
| if not self.auto_mode_enabled:
|
| if not pre_pass:
|
| if self.apply_tail_steps:
|
| for i, tail in enumerate(self.all_tails):
|
| if tail is not None:
|
| self.log(f"Appending tail from method: {self.blend_methods[i]}")
|
| self.sigs = torch.cat([self.sigs, tail])
|
|
|
| if self.apply_decay_tail:
|
| for i, decay in enumerate(self.all_decays):
|
| if decay is not None:
|
| self.log(f"Appending decay from method: {self.blend_methods[i]}")
|
| self.sigs = torch.cat([self.sigs, decay])
|
|
|
| if self.apply_progressive_decay:
|
| progressive_decay = None
|
| total_weight = 0
|
|
|
| for w, decay in zip(resolved_blend_weights, self.all_decays):
|
| if decay is not None:
|
| decay = decay[:len(self.sigs)]
|
| if progressive_decay is None:
|
| progressive_decay = w * decay
|
| else:
|
| progressive_decay += w * decay
|
| total_weight += w
|
|
|
| if progressive_decay is not None and total_weight > 0:
|
| progressive_decay /= total_weight
|
| self.log("Applying progressive decay to sigma sequence.")
|
| self.sigs = self.sigs * progressive_decay
|
|
|
| if self.apply_blended_tail:
|
| blended_tail = None
|
| total_weight = 0
|
|
|
| for w, tail in zip(resolved_blend_weights, self.all_tails):
|
| if tail is not None:
|
| if blended_tail is None:
|
| blended_tail = w * tail
|
| else:
|
| blended_tail += w * tail
|
| total_weight += w
|
|
|
| if blended_tail is not None and total_weight > 0:
|
| blended_tail /= total_weight
|
| self.log("Appending blended tail to sigma sequence.")
|
| self.sigs = torch.cat([self.sigs, blended_tail])
|
|
|
| else:
|
|
|
| if len(self.sigs) > self.steps:
|
| self.auto_stabilization_sequence = []
|
| self.log(f"[Auto Mode] Sigma sequence length {len(self.sigs)} exceeds requested steps {self.steps}. Disabling auto stabilization.")
|
| self.auto_mode_enabled = False
|
| self.sigs = self.sigs[:self.steps]
|
| return self.sigs
|
| self.run_auto_stabilization(self.sigs)
|
|
|
| if pre_pass and self.early_stop_triggered:
|
| return self.sigs[:self.predicted_stop_step]
|
| else:
|
| return self.sigs
|
| def run_auto_stabilization(self):
|
|
|
| if not self.allow_step_expansion:
|
| self.log("[Auto Mode] Step expansion is disabled by configuration. Skipping auto stabilization.")
|
| return self.sigs
|
| if self.allow_step_expansion:
|
| unstable = self.detect_sequence_instability()
|
|
|
| if not unstable:
|
| self.log("[Auto Mode] Sigma sequence is already stable.")
|
| return
|
|
|
| self.log("[Auto Mode] Detected instability in sigma sequence. Starting stabilization sequence.")
|
|
|
| for method in self.auto_stabilization_sequence:
|
| if not unstable:
|
| self.log(f"[Auto Mode] Sequence stabilized after {method}. Stopping further corrections.")
|
| break
|
|
|
| if method == 'smooth_interpolation':
|
| unstable = self.smooth_interpolation()
|
|
|
| elif method == 'append_tail':
|
| unstable = self.append_tail()
|
|
|
| elif method == 'blend_tail':
|
| unstable = self.blend_tail()
|
|
|
| elif method == 'apply_decay':
|
| unstable = self.apply_decay()
|
|
|
| elif method == 'progressive_decay':
|
| unstable = self.progressive_decay()
|
|
|
| else:
|
| self.log(f"[Auto Mode] Unknown stabilization method: {method}")
|
| def detect_sequence_instability(self):
|
| delta_sigmas = self.sigs[:-1] - self.sigs[1:]
|
| second_deltas = torch.diff(delta_sigmas)
|
|
|
| steep_drop_detected = torch.any(delta_sigmas > self.auto_tail_threshold)
|
| jaggedness_score = torch.var(second_deltas[-5:]) if len(second_deltas) >= 5 else 0
|
| jagged_transition_detected = jaggedness_score > self.jaggedness_threshold
|
|
|
| if steep_drop_detected:
|
| self.log(f"[Auto Mode] Steep drop detected. Max drop: {torch.max(delta_sigmas).item():.6f}")
|
| if jagged_transition_detected:
|
| self.log(f"[Auto Mode] Jagged transition detected. Jaggedness score: {jaggedness_score:.6f}")
|
|
|
| return steep_drop_detected or jagged_transition_detected
|
| def smooth_interpolation(self):
|
| self.log("[Auto Mode] Applying smooth interpolation to last 5 steps.")
|
| if len(self.sigs) >= 5:
|
| start = self.sigs[-6].item()
|
| end = self.sigs[-1].item()
|
| interpolated = torch.linspace(start, end, steps=6, device=self.device)[1:]
|
| self.sigs[-5:] = interpolated
|
|
|
| return self.detect_sequence_instability()
|
|
|
| def append_tail(self):
|
| self.log("[Auto Mode] Attempting to append available tail.")
|
| if hasattr(self, 'all_tails') and self.all_tails:
|
| for tail in self.all_tails:
|
| if tail is not None:
|
| tail = tail.to(self.device)
|
|
|
| if tail.shape[0] > self.sigs.shape[0]:
|
| tail = tail[:len(self.sigs)]
|
| self.sigs = torch.cat([self.sigs, tail])
|
| self.log("[Auto Mode] Appended tail to sigma sequence.")
|
| break
|
|
|
| return self.detect_sequence_instability()
|
|
|
|
|
| def blend_tail(self):
|
| if not hasattr(self, 'all_tails') or not self.all_tails:
|
| self.log("[Auto Mode] No available tails to blend.")
|
| return self.detect_sequence_instability()
|
|
|
| self.log("[Auto Mode] Attempting to blend multiple tails.")
|
| blended_tail = None
|
| total_weight = 0
|
|
|
| for w, tail in zip(self.blend_weights, self.all_tails):
|
| if tail is not None:
|
| tail = tail.to(self.device)
|
|
|
|
|
| if tail.shape[0] > self.sigs.shape[0]:
|
| tail = tail[:len(self.sigs)]
|
|
|
| if blended_tail is None:
|
| blended_tail = w * tail
|
| else:
|
| blended_tail += w * tail
|
| total_weight += w
|
|
|
| if blended_tail is not None and total_weight > 0:
|
| blended_tail /= total_weight
|
| self.sigs = torch.cat([self.sigs, blended_tail])
|
| self.log("[Auto Mode] Appended blended tail to sigma sequence.")
|
|
|
| return self.detect_sequence_instability()
|
|
|
|
|
| def apply_decay(self):
|
| self.log("[Auto Mode] Attempting to append decay tails.")
|
| if hasattr(self, 'all_decays') and self.all_decays:
|
| for decay in self.all_decays:
|
| if decay is not None:
|
| decay = decay.to(self.device)
|
|
|
|
|
| if decay.shape[0] > self.sigs.shape[0]:
|
| decay = decay[:len(self.sigs)]
|
|
|
| self.sigs = torch.cat([self.sigs, decay])
|
| self.log("[Auto Mode] Appended decay tail to sigma sequence.")
|
| break
|
|
|
| return self.detect_sequence_instability()
|
|
|
|
|
| def progressive_decay(self):
|
| self.log("[Auto Mode] Applying progressive decay to sigma sequence.")
|
| progressive_decay = None
|
| total_weight = 0
|
|
|
| for w, decay in zip(self.blend_weights, self.all_decays):
|
| if decay is not None:
|
| decay = decay.to(self.device)
|
|
|
|
|
| if decay.shape[0] != self.sigs.shape[0]:
|
| decay = decay.view(1, 1, -1)
|
| decay = F.interpolate(decay, size=self.sigs.shape[0], mode='linear', align_corners=False)
|
| decay = decay.view(-1)
|
|
|
| if progressive_decay is None:
|
| progressive_decay = w * decay
|
| else:
|
| progressive_decay += w * decay
|
|
|
| total_weight += w
|
|
|
| if progressive_decay is not None and total_weight > 0:
|
| progressive_decay /= total_weight
|
| self.sigs = self.sigs * progressive_decay
|
| self.log("[Auto Mode] Applied progressive decay to sigma sequence.")
|
|
|
| return self.detect_sequence_instability()
|
|
|
|
|
| def load_blend_method_sigmas(self, mode=None):
|
| """Loads all sigma sequences for the blend_methods list based on current settings and mode."""
|
| self.all_sigmas = []
|
|
|
|
|
| for method in self.blend_methods:
|
| self.method_config = self.blend_method_dict[method]
|
| self.method_config[method] = {
|
| 'decay_pattern': self.method_config.get('decay_pattern', 'zero'),
|
| 'decay_mode': self.method_config.get('decay_mode', 'blend'),
|
| 'tail_steps': self.method_config.get('tail_steps', 1)
|
| }
|
| self.current_config = self.method_config[method]
|
|
|
|
|
| sigma_func = self.scheduler_registry[method]
|
| tails, decay, extras, sigmas = self.call_scheduler_function(
|
| self.scheduler_registry.get(method),
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| device=self.device,
|
| decay_pattern=self.current_config['decay_pattern'],
|
| decay_mode=self.current_config['decay_mode'],
|
| tail_steps=self.current_config['tail_steps']
|
| )
|
| self.sigma_sequences[method] = {
|
| 'sigmas': sigmas,
|
| 'tails': tails,
|
| 'decay': decay,
|
| 'extras': extras
|
| }
|
| setattr(self, f"sigmas_{method}", sigmas)
|
|
|
| self.all_sigmas = [self.sigma_sequences[method]['sigmas'] for method in self.blend_methods]
|
| self.all_tails = [self.sigma_sequences[method]['tails'] for method in self.blend_methods]
|
| self.all_decays = [self.sigma_sequences[method]['decay'] for method in self.blend_methods]
|
| self.all_extras = [self.sigma_sequences[method].get('extras', []) for method in self.blend_methods]
|
| self._log_extras_to_file(self.all_extras)
|
| self.all_sigmas.append(sigmas)
|
|
|
|
|
| self.log(f"Loaded sigma schedules for blend methods: {self.blend_methods} using mode: {mode}")
|
| def validate_and_align_sigmas(self):
|
| """
|
| Ensures all sigma sequences in self.all_sigmas are valid and have the same length.
|
| Pads shorter sequences with their last sigma.
|
| """
|
| if not self.all_sigmas or len(self.all_sigmas) == 0:
|
| raise ValueError("No sigma sequences were loaded for blending.")
|
|
|
| target_length = max(len(s) for s in self.all_sigmas)
|
|
|
| for idx, sigmas in enumerate(self.all_sigmas):
|
| if sigmas is None or len(sigmas) == 0:
|
| raise ValueError(f"Sigma sequence at index {idx} is invalid or empty: {sigmas}")
|
|
|
| if len(sigmas) < target_length:
|
| padding = torch.full((target_length - len(sigmas),), sigmas[-1]).to(sigmas.device)
|
| self.all_sigmas[idx] = torch.cat([sigmas, padding])
|
|
|
| self.log(f"Validated and aligned all sigma sequences to length {target_length}.")
|
|
|
| def generate_sigmas_schedule(self, mode=None):
|
| """
|
| Generates the sigma schedules required for the hybrid blending process.
|
|
|
| The Karras and Exponential sigma sequences are created to provide two distinct
|
| noise scaling strategies:
|
| - The Karras sequence offers a more aggressive noise decay, commonly used in
|
| modern schedulers for improved image quality and denoising stability.
|
| - The Exponential sequence provides a traditional log-space noise schedule.
|
|
|
| These two sequences are dynamically blended in later steps using progress-dependent
|
| weights to produce a custom sigma path that combines the advantages of both approaches.
|
|
|
| This blending process is critical to the scheduler's ability to:
|
| - Adapt noise scaling across steps.
|
| - Control the sharpness and smoothness of transitions.
|
| - Support early stopping based on sigma convergence patterns.
|
|
|
| These sigma sequences must be regenerated in both the prepass (for early stopping detection)
|
| and the final pass (for polished sigma application), ensuring both passes are synchronized
|
| with the current step count and randomization settings.
|
| """
|
|
|
|
|
| if mode == 'prepass':
|
| if self.load_prepass_sigmas:
|
| self.cache_file = self.prepass_save_file
|
| self.mode = 'prepass'
|
|
|
| elif mode == 'final':
|
| if self.load_sigma_cache:
|
| self.cache_file = self.final_save_file
|
| self.mode = 'final'
|
|
|
| else:
|
| self.mode = None
|
| self.cache_file = None
|
|
|
| '''
|
| if self.cache_file:
|
| sigmas = self.get_sigma_with_cache(
|
| steps=self.steps,
|
| sigma_min=self.sigma_min,
|
| sigma_max=self.sigma_max,
|
| rho=self.rho,
|
| device=self.device,
|
| decay_pattern=self.decay_pattern,
|
| cache_file=self.cache_file,
|
| mode=self.mode
|
| #cache_key = self.cache_key
|
| )
|
| return sigmas
|
| '''
|
|
|
|
|
|
|
|
|
| self.load_blend_method_sigmas(mode=self.mode)
|
| self.blend_pairs = []
|
| self.active_methods = [method for method in self.blend_methods if self.blend_method_dict[method].get('weight', 1.0) > 0]
|
|
|
| if self.blending_mode == 'default':
|
| self._call_legacy_mode(schedule_type='exponential')
|
| self._call_legacy_mode(schedule_type='karras')
|
|
|
| self.blend_pairs = []
|
| self.blend_pairs.append({
|
| 'method_label': 'method_a',
|
| 'method': 'karras',
|
| 'sigmas': self.sigmas_karras
|
| })
|
| self.blend_pairs.append({
|
| 'method_label': 'method_b',
|
| 'method': 'exponential',
|
| 'sigmas': self.sigmas_exponential
|
| })
|
|
|
|
|
| max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
|
|
|
| for pair in self.blend_pairs:
|
| if len(pair['sigmas']) < max_length:
|
| padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
|
| pair['sigmas'] = torch.cat([pair['sigmas'], padding])
|
|
|
| self.log(f"All sigma sequences aligned to length: {max_length}")
|
|
|
|
|
| sigmas_a = self.blend_pairs[0]['sigmas']
|
| sigmas_b = self.blend_pairs[1]['sigmas']
|
| label_a = self.blend_pairs[0]['method']
|
| label_b = self.blend_pairs[1]['method']
|
| if sigmas_a is None:
|
| raise ValueError(f"Sigmas {label_a} failed to generate or assign correctly.")
|
| if sigmas_b is None:
|
| raise ValueError(f"Sigmas {label_b} failed to generate or assign correctly.")
|
| else:
|
| if len(self.active_methods) == 1:
|
|
|
| self.blend_pairs = []
|
| method = self.active_methods[0]
|
| self.blend_pairs.append({
|
| 'method_label': 'method_a',
|
| 'method': method,
|
| 'sigmas': self.sigma_sequences[method]['sigmas']
|
| })
|
|
|
| elif len(self.active_methods) >= 2:
|
|
|
| self.blend_pairs = []
|
| for idx, method in enumerate(self.active_methods):
|
| self.blend_pairs.append({
|
| 'method_label': f'method_{chr(97 + idx)}',
|
| 'method': method,
|
| 'sigmas': self.sigma_sequences[method]['sigmas']
|
| })
|
|
|
|
|
| for pair in self.blend_pairs:
|
| if pair['sigmas'] is None:
|
| raise ValueError(f"Sigmas {pair['method']} failed to generate or assign correctly.")
|
|
|
|
|
|
|
| if len(self.blend_pairs) > 1:
|
| target_length = min(len(pair['sigmas']) for pair in self.blend_pairs)
|
|
|
|
|
| for pair in self.blend_pairs:
|
| pair['sigmas'] = pair['sigmas'][:target_length]
|
|
|
|
|
| max_length = max(len(pair['sigmas']) for pair in self.blend_pairs)
|
|
|
| for pair in self.blend_pairs:
|
| if len(pair['sigmas']) < max_length:
|
| padding = torch.full((max_length - len(pair['sigmas']),), pair['sigmas'][-1]).to(pair['sigmas'].device)
|
| pair['sigmas'] = torch.cat([pair['sigmas'], padding])
|
|
|
| self.log(f"All sigma sequences aligned to length: {max_length}")
|
|
|
| self.sigs = torch.zeros(target_length, device=self.blend_pairs[0]['sigmas'].device)
|
| else:
|
|
|
| self.sigs = self.blend_pairs[0]['sigmas'].clone()
|
|
|
|
|
| '''
|
| # Now it's safe to compute sigs
|
| start = math.log(self.sigma_max)
|
| end = math.log(self.sigma_min)
|
| #self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
|
| if self.sigs is None or self.force_rebuild_sigs:
|
| self.sigs = torch.linspace(start, end, self.steps, device=self.device).exp()
|
|
|
|
|
|
|
| # Ensure sigs contain valid values before using them
|
| if torch.any(self.sigs > 0):
|
| self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
|
| else:
|
| # If sigs are all invalid, set a safe fallback
|
| self.sigma_min, self.sigma_max = self.min_threshold, self.min_threshold
|
| self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
|
|
| return {
|
| 'karras': self.sigmas_karras,
|
| 'exponential': self.sigmas_exponential,
|
| 'blend_methods': self.blend_methods,
|
| 'all_sigmas': self.all_sigmas,
|
| 'sigs': self.sigs
|
| }
|
|
|
| #sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
|
| #if len(set(sigma_lengths)) > 1: # There are mismatched lengths
|
| #self.validate_and_align_sigmas()
|
| #self.sigs = torch.zeros(self.steps, device=self.device)
|
| sigma_lengths = [len(self.sigma_sequences[method]['sigmas']) for method in self.blend_methods]
|
| if len(set(sigma_lengths)) > 1:
|
| self.log("[Sigma Alignment] Detected mismatched sigma sequence lengths. Aligning...")
|
| self.validate_and_align_sigmas()
|
|
|
| return {
|
| 'blend_methods': self.blend_methods,
|
| 'all_sigmas': self.all_sigmas,
|
| 'sigs': self.sigs
|
| }
|
| '''
|
|
|
|
|
|
|
|
|
|
|
|
|
| '''
|
| if torch.any(self.sigs > 0):
|
| self.sigma_min, self.sigma_max = self.sigs[self.sigs > 0].min(), self.sigs.max()
|
| else:
|
| # If sigs are all invalid, set a safe fallback
|
| self.sigma_min = self.min_threshold
|
| self.sigma_max = self.min_threshold
|
| self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
|
|
| return {
|
| 'blend_methods': self.blend_methods,
|
| 'all_sigmas': self.all_sigmas,
|
| 'sigs': self.sigs
|
| }
|
|
|
| return {
|
| 'blend_methods': self.blend_methods,
|
| 'all_sigmas': self.all_sigmas,
|
| 'sigs': self.sigs
|
| }
|
| '''
|
|
|
| if not torch.any(self.sigs > 0):
|
| self.sigma_min = self.min_threshold
|
| self.sigma_max = self.min_threshold
|
| self.log(f"Debugging Warning: No positive sigma values found! Setting fallback sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
| else:
|
| self.sigma_min = self.sigs[self.sigs > 0].min()
|
| self.sigma_max = self.sigs.max()
|
|
|
| return {
|
| 'blend_methods': self.blend_methods,
|
| 'all_sigmas': self.all_sigmas,
|
| 'sigs': self.sigs
|
| }
|
|
|
|
|
|
|
| def call_scheduler_function(self, scheduler_func, **kwargs):
|
| """
|
| Safely calls a scheduler function with dynamic argument filtering and flexible return handling.
|
|
|
| This method ensures that only the parameters accepted by the scheduler function are passed.
|
| It automatically handles scheduler functions that may return:
|
| - Only the sigma sequence
|
| - A tuple with (tails, sigmas)
|
| - A tuple with (tails, decay, sigmas)
|
| - A tuple with additional items (extras) before sigmas
|
|
|
| The method always assumes the last returned item is the sigma sequence, with optional
|
| items preceding it.
|
|
|
| Parameters:
|
| ----------
|
| scheduler_func : callable
|
| The scheduler function to be invoked. It may accept various arguments such as steps, sigma_min,
|
| sigma_max, rho, device, decay_pattern, etc.
|
| **kwargs : dict
|
| Arbitrary keyword arguments. Only those accepted by the scheduler function will be passed.
|
|
|
| Returns:
|
| -------
|
| tuple
|
| A 4-tuple containing:
|
| - tails : Any (optional, can be None)
|
| The tail component of the sigma schedule, if provided.
|
| - decay : Any (optional, can be None)
|
| The decay component of the sigma schedule, if provided.
|
| - extras : list
|
| Any additional return values provided by the scheduler function, beyond tails and decay.
|
| - sigmas : Any
|
| The sigma sequence, always assumed to be the last item returned by the scheduler function.
|
|
|
| Raises:
|
| ------
|
| ValueError
|
| If the scheduler function returns an empty tuple.
|
|
|
| Notes:
|
| -----
|
| This method allows future schedulers to return additional optional data without breaking the calling pattern.
|
| """
|
| valid_params = inspect.signature(scheduler_func).parameters
|
| filtered_args = {k: v for k, v in kwargs.items() if k in valid_params}
|
|
|
| result = scheduler_func(**filtered_args)
|
|
|
| if isinstance(result, dict):
|
| tails = result.get('tails', None)
|
| decay = result.get('decay', None)
|
| sigmas = result.get('sigmas')
|
| extras = result.get('extras', [])
|
|
|
| if sigmas is None:
|
| raise ValueError("Scheduler function must return a 'sigmas' key.")
|
|
|
| return tails, decay, extras, sigmas
|
|
|
|
|
| if not isinstance(result, tuple):
|
| return None, None, [], result
|
|
|
| if len(result) == 0:
|
| raise ValueError(f"Scheduler function returned an empty tuple. This is not allowed.")
|
|
|
| sigmas = result[-1]
|
| optional_returns = result[:-1]
|
|
|
| tails = optional_returns[0] if len(optional_returns) > 0 else None
|
| decay = optional_returns[1] if len(optional_returns) > 1 else None
|
| extras = optional_returns[2:] if len(optional_returns) > 2 else []
|
|
|
| return tails, decay, extras, sigmas
|
|
|
| def config_values(self):
|
|
|
| if self.sigma_min >= self.sigma_max:
|
| correction_factor = random.uniform(0.01, 0.99)
|
| old_sigma_min = self.sigma_min
|
| self.sigma_min = self.sigma_max * correction_factor
|
| self.log(f"[Correction] sigma_min ({old_sigma_min}) was >= sigma_max ({self.sigma_max}). Adjusted sigma_min to {self.sigma_min} using correction factor {correction_factor}.")
|
|
|
| self.log(f"Final sigmas: sigma_min={self.sigma_min}, sigma_max={self.sigma_max}")
|
|
|
|
|
| if self.sigma_auto_enabled:
|
| if self.sigma_auto_mode not in ["sigma_min", "sigma_max"]:
|
| raise ValueError(f"[Config Error] Invalid sigma_auto_mode: {self.sigma_auto_mode}. Must be 'sigma_min' or 'sigma_max'.")
|
|
|
| if self.sigma_auto_mode == "sigma_min":
|
| self.sigma_min = self.sigma_max / self.sigma_scale_factor
|
| self.log(f"[Auto Sigma Min] sigma_min set to {self.sigma_min} using scale factor {self.sigma_scale_factor}")
|
|
|
| elif self.sigma_auto_mode == "sigma_max":
|
| self.sigma_max = self.sigma_min * self.sigma_scale_factor
|
| self.log(f"[Auto Sigma Max] sigma_max set to {self.sigma_max} using scale factor {self.sigma_scale_factor} and using a multiplier of {sigma_max_multipier} to account for smoother transitions")
|
|
|
|
|
| self.min_threshold = random.uniform(1e-5, 5e-5)
|
|
|
| if self.sigma_min < self.min_threshold:
|
| self.log(f"[Threshold Enforcement] sigma_min was too low: {self.sigma_min} < min_threshold {self.min_threshold}")
|
| self.sigma_min = self.min_threshold
|
|
|
| if self.sigma_max < self.min_threshold:
|
| self.log(f"[Threshold Enforcement] sigma_max was too low: {self.sigma_max} < min_threshold {self.min_threshold}")
|
| self.sigma_max = self.min_threshold
|
|
|
|
|
| valid_methods = ['mean', 'max', 'sum']
|
| if self.early_stopping_method not in valid_methods:
|
| self.log(f"[Config Correction] Invalid early_stopping_method: {self.early_stopping_method}. Defaulting to 'mean'.")
|
| self.early_stopping_method = 'mean'
|
|
|
| def prepass_compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None, cache_key = None, skip_prepass = False)->torch.Tensor:
|
|
|
| '''
|
| if self.load_prepass_sigmas:
|
| if cache_key:
|
| self._safe_sigma_loader(cache_key)
|
| #self.load_or_regenerate_sigmas(cache_key)
|
| loaded_data = torch.load(self.cache_file.replace('.txt', '.pt'), map_location=self.device)
|
|
|
| sigmas = loaded_data['sigma_values'].to(self.device)
|
| self.loaded_sigmas = sigmas # No need to call torch.tensor again
|
|
|
| loaded_hash = loaded_data['sigma_hash']
|
|
|
| steps = loaded_data['steps']
|
| sigma_min = loaded_data['sigma_min']
|
| sigma_max = loaded_data['sigma_max']
|
| rho = loaded_data['rho']
|
| device = loaded_data['device']
|
| schedule_type = loaded_data['schedule_type']
|
| decay_pattern = loaded_data['decay_pattern']
|
|
|
| restored_config = loaded_data['full_config']
|
|
|
|
|
| # Optionally overwrite current settings with restored settings
|
| self.settings.update(restored_config)
|
|
|
| #self.load_sigmas_with_hash_validation(self, loaded_data, steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, cache_key, suffix=None)
|
| self.log(f"[Cache Loaded] Sigma schedule, hash, and config loaded from: {self.cache_file.replace('.pt', '.txt')}")
|
| '''
|
| acceptable_keys = [
|
| "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| "early_stopping_threshold", "initial_step_size",
|
| "final_step_size", "initial_noise_scale", "final_noise_scale",
|
| "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| ]
|
|
|
| for key in acceptable_keys:
|
| default_val = self.settings[key]
|
| value = self.get_random_or_default(key, default_val)
|
| setattr(self, key, value)
|
|
|
| if self.steps is None:
|
| raise ValueError("Number of steps must be provided.")
|
| if isinstance(self.device, str):
|
| self.device = torch.device(self.device)
|
| self.config_values()
|
| self.generate_sigmas_schedule(mode='prepass')
|
|
|
| self.predicted_stop_step = self.steps if None else self.original_steps
|
| if self.sharpen_last_n_steps > len(self.sigs):
|
| self.sharpen_last_n_steps = len(self.sigs)
|
| self.log(f"[Sharpening Notice] Requested last {self.sharpen_last_n_steps} steps exceeds sequence length. Using entire sequence instead.")
|
|
|
| self.visual_sigma = max(0.8, self.sigma_min * self.min_visual_sigma)
|
|
|
| self.blend_sigma_sequence(
|
| sigmas_karras=None,
|
| sigmas_exponential=None,
|
| pre_pass = True,
|
| blend_methods=self.blend_methods,
|
| blend_weights = self.blend_weights
|
| )
|
| if torch.isnan(self.sigs).any() or torch.isinf(self.sigs).any():
|
| raise ValueError("Invalid sigma values detected (NaN or Inf).")
|
| final_steps = self.sigs[:self.predicted_stop_step].to(self.device)
|
|
|
| self.final_steps = final_steps
|
| if self.blending_mode == 'default':
|
| self.final_sigmas_karras = self.sigmas_karras
|
| self.final_sigmas_exponential = self.sigmas_exponential
|
| self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
|
| self.log(f"final sigmas karras: {self.final_sigmas_karras}")
|
| else:
|
|
|
| self.final_sigmas_blended = torch.tensor(self.blended_sigmas, device=self.device)
|
|
|
| self.log(f" Final Steps = {self.final_steps}. Predicted_stop_step = {self.predicted_stop_step}. Original requested steps = {self.steps}")
|
| self.log(f"final blended sigmas: {self.final_sigmas_blended}")
|
|
|
|
|
| for idx, (method, sigmas) in enumerate(zip(self.blend_methods, self.all_sigmas)):
|
| self.log(f"Method: {method}, Sigma sequence: {sigmas}")
|
|
|
|
|
| '''
|
| # Build cache key
|
| sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|
|
|
| if self.save_prepass_sigmas:
|
| save_data = {
|
| 'sigma_values': sigmas.cpu(), # Keep as tensor
|
| 'sigma_hash': sigma_hash,
|
| 'steps': steps,
|
| 'sigma_min': sigma_min,
|
| 'sigma_max': sigma_max,
|
| 'rho': rho,
|
| 'device': device,
|
| 'schedule_type': schedule_type,
|
| 'decay_pattern': decay_pattern,
|
| 'full_config': self.settings # Save as raw dict
|
| }
|
|
|
| # Save directly with torch.save in .pt format
|
| torch.save(save_data, self.cache_file) # Assuming self.cache_file has .pt extension
|
| self.log(f"[Sigma Saver] Final sigmas saved to: {self.cache_file}")
|
| '''
|
| def load_or_regenerate_sigmas(self, cache_key):
|
| if self.load_sigma_cache and cache_key:
|
| try:
|
| loaded_data = torch.load(self.cache_file, map_location=self.device)
|
| sigmas = loaded_data['sigma_values'].to(self.device)
|
|
|
| except FileNotFoundError:
|
| self.log(f"[Cache Warning] Cache file not found: {self.cache_file}")
|
| self.log(f"[Cache Recovery] Automatically recomputing sigma schedule.")
|
| _, _, _, sigmas = self._generate_sigmas(
|
| self.steps,
|
| self.sigma_min,
|
| self.sigma_max,
|
| self.rho,
|
| self.device,
|
| self.schedule_type,
|
| self.decay_pattern
|
| )
|
|
|
|
|
| _, _, _, sigmas = self._generate_sigmas(
|
| self.steps,
|
| self.sigma_min,
|
| self.sigma_max,
|
| self.rho,
|
| self.device,
|
| self.schedule_type,
|
| self.decay_pattern
|
| )
|
| '''
|
| if self.save_prepass_sigmas:
|
| # Optional: Cache the recomputed sigma schedule
|
| torch.save(save_data, self.prepass_save_file)
|
| self.log(f"[Sigma Saver] Final sigmas saved to: {self.prepass_save_file}")
|
|
|
| #self.sigma_cache[cache_key] = sigmas
|
| '''
|
| return sigmas
|
|
|
| def compute_sigmas(self, steps, sigma_min, sigma_max, rho, device, schedule_type=None, decay_pattern=None, cache_key=None)->torch.Tensor:
|
| """
|
| Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
|
|
|
| Parameters:
|
| n (int): Number of steps.
|
| sigma_min (float): Minimum sigma value.
|
| sigma_max (float): Maximum sigma value.
|
| device (torch.device): The device on which to perform computations (e.g., 'cuda' or 'cpu').
|
| start_blend (float): Initial blend factor for dynamic blending.
|
| end_bend (float): Final blend factor for dynamic blending.
|
| sharpen_factor (float): Sharpening factor to be applied adaptively.
|
| early_stopping_threshold (float): Threshold to trigger early stopping.
|
| initial_step_size (float): Initial step size for adaptive step size calculation.
|
| final_step_size (float): Final step size for adaptive step size calculation.
|
| initial_noise_scale (float): Initial noise scale factor.
|
| final_noise_scale (float): Final noise scale factor.
|
| step_size_factor: Adjust to compensate for oversmoothing
|
| noise_scale_factor: Adjust to provide more variation
|
|
|
| Returns:
|
| torch.Tensor: A tensor of blended sigma values.
|
| """
|
| '''
|
| if self.load_sigma_cache and cache_key:
|
| sigmas = self._safe_sigma_loader(cache_key)
|
| if sigmas is None:
|
| self.log(f"[Cache Recovery] No valid cache found for key: {cache_key}. Recomputing sigma schedule.")
|
| _, _, _, sigmas = self._generate_sigmas(
|
| self.steps,
|
| self.sigma_min,
|
| self.sigma_max,
|
| self.rho,
|
| self.device,
|
| self.schedule_type,
|
| self.decay_pattern
|
| )
|
| else:
|
| self.log(f"[Cache Hit] Sigma schedule successfully loaded from cache.")
|
| self.sigs = sigmas
|
| '''
|
| acceptable_keys = [
|
| "sigma_min", "sigma_max", "start_blend", "end_blend", "sharpness",
|
| "early_stopping_threshold", "initial_step_size",
|
| "final_step_size", "initial_noise_scale", "final_noise_scale",
|
| "smooth_blend_factor", "step_size_factor", "noise_scale_factor", "rho"
|
| ]
|
|
|
| for key in acceptable_keys:
|
| default_val = self.settings[key]
|
| value = self.get_random_or_default(key, default_val)
|
| setattr(self, key, value)
|
|
|
| self.log(f"Using device: {self.device}")
|
| self.config_values()
|
| self.generate_sigmas_schedule(mode='final')
|
| if hasattr(self, 'final_sigmas_karras'):
|
| self.sigs = torch.zeros_like(self.final_sigmas_karras).to(self.device)
|
| else:
|
| self.sigs = torch.zeros_like(self.sigmas_karras).to(self.device)
|
|
|
| self.blend_sigma_sequence(
|
| sigmas_karras=self.final_sigmas_karras if hasattr(self, 'final_sigmas_karras') else self.sigmas_karras,
|
| sigmas_exponential=self.final_sigmas_exponential if hasattr(self, 'final_sigmas_exponential') else self.sigmas_exponential,
|
| pre_pass=False,
|
| blend_methods=self.blend_methods,
|
| blend_weights = self.blend_weights
|
|
|
| )
|
| self.sigma_variance = torch.var(self.sigs).item()
|
| if self.sharpen_mode in ['last_n', 'both']:
|
| if self.sigma_variance < self.sigma_variance_threshold:
|
|
|
| self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| self.sigs = self.sigs * self.sharpen_mask
|
| self.log(f"[Sharpen Mask] Full sharpening applied (low variance). Steps: {sharpen_indices}")
|
| else:
|
|
|
| recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
|
| sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
|
| self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
|
|
|
|
|
| for j in range(len(self.sigs) - self.sharpen_last_n_steps, len(self.sigs)):
|
| if self.sigs[j] < self.sigma_min * 1.5:
|
| old_value = self.sigs[j].item()
|
| self.sigs[j] = self.sigs[j] * self.sharpness
|
| self.log(f"[Sharpening] Step {j+1}: Applied sharpening. Sigma changed from {old_value:.6f} to {self.sigs[j].item():.6f}")
|
| else:
|
| self.log(f"[Sharpening] Step {j+1}: No sharpening applied. Sigma: {self.sigs[j].item():.6f}")
|
|
|
| if self.sharpen_mode in ['full', 'both']:
|
|
|
| self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| self.sigs = self.sigs * self.sharpen_mask
|
| self.log(f"[Sharpen Mask] Full sharpening applied at steps: {sharpen_indices}")
|
|
|
| '''
|
| sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|
| if self.settings.get('save_sigma_cache', False):
|
| save_data = {
|
| 'sigma_values': self.sigs.cpu().tolist(),
|
| 'sigma_hash': sigma_hash,
|
| 'steps': steps,
|
| 'sigma_min': sigma_min,
|
| 'sigma_max': sigma_max,
|
| 'rho': rho,
|
| 'device': device,
|
| 'schedule_type': schedule_type,
|
| 'decay_pattern': decay_pattern,
|
| 'full_config': json.dumps(self.settings)
|
| }
|
|
|
|
|
| if self.save_sigma_cache:
|
| torch.save(save_data, self.final_save_file)
|
| self.log(f"[Sigma Saver] Final sigmas saved to: {self.final_save_file}")
|
| '''
|
|
|
| return self.sigs.to(self.device)
|
|
|
| def get_sigma_from_cache(self, cache_key):
|
| """
|
| Safely retrieves a sigma sequence from cache.
|
| Always returns a detached copy to prevent in-place modification of cached data.
|
| """
|
| if cache_key in self.sigma_cache:
|
| cached_sigmas = self.sigma_cache[cache_key]
|
| self.log(f"[Cache Hit] Returning cached sigma sequence for key: {cache_key}")
|
|
|
|
|
| if isinstance(cached_sigmas, torch.Tensor):
|
| return cached_sigmas.clone().detach().to(self.device)
|
|
|
|
|
| elif isinstance(cached_sigmas, list):
|
| import copy
|
| return copy.deepcopy(cached_sigmas)
|
|
|
|
|
| else:
|
| return cached_sigmas
|
|
|
| else:
|
| self.log(f"[Cache Miss] Cache key not found: {cache_key}")
|
| return None
|
|
|