Spaces:
Running on Zero
Running on Zero
File size: 9,702 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """
Align Your Steps (AYS) Scheduler
Optimized noise schedules for faster convergence with same/better quality.
Training-free, lossless (often better quality) optimization.
Reference: "Align Your Steps: Optimizing Sampling Schedules in Diffusion Models" (2024)
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/
Key insight: Not all timesteps contribute equally to image quality.
AYS finds optimal timestep schedules that allow fewer steps with same quality.
"""
import torch
import numpy as np
import logging
# Pre-computed optimal sigma schedules from AYS paper
# These were found through optimization to minimize reconstruction error
AYS_OPTIMAL_SCHEDULES = {
# SD1.5 schedules
"SD15": {
4: [14.6146, 6.4745, 2.4826, 0.5497, 0.0],
6: [14.6146, 8.0426, 4.4170, 2.2172, 0.9316, 0.2596, 0.0],
8: [14.6146, 9.4663, 5.9384, 3.4759, 1.9696, 1.0417, 0.4598, 0.1328, 0.0],
10: [14.6146, 10.4708, 7.3688, 4.9651, 3.2924, 2.1391, 1.3633, 0.8437, 0.4898, 0.2279, 0.0],
12: [14.6146, 11.2797, 8.5033, 6.2928, 4.5662, 3.2721, 2.3124, 1.6103, 1.1029, 0.7414, 0.4804, 0.2552, 0.0],
15: [14.6146, 12.1652, 9.8581, 7.8194, 6.1302, 4.7577, 3.6611, 2.7994, 2.1271, 1.6053, 1.2033, 0.8917, 0.6525, 0.4672, 0.3196, 0.0],
20: [14.6146, 13.1721, 11.0451, 9.2027, 7.6085, 6.2281, 5.0356, 4.0135, 3.1468, 2.4226, 1.8293, 1.3567, 0.9967, 0.7219, 0.5129, 0.3589, 0.2463, 0.1644, 0.1025, 0.0522, 0.0],
25: [14.6146, 13.8539, 11.9991, 10.3603, 8.9042, 7.6058, 6.4449, 5.4048, 4.4710, 3.6313, 2.8757, 2.1959, 1.5851, 1.0375, 0.5483, 0.0948, 0.0],
},
# SDXL schedules
"SDXL": {
4: [14.6146, 6.8873, 2.7084, 0.6577, 0.0],
6: [14.6146, 8.3767, 4.6699, 2.4175, 1.0643, 0.3262, 0.0],
8: [14.6146, 9.6929, 6.1589, 3.6454, 2.1116, 1.1507, 0.5474, 0.1770, 0.0],
10: [14.6146, 10.7043, 7.5043, 5.1442, 3.4302, 2.2379, 1.4288, 0.8874, 0.5174, 0.2427, 0.0],
12: [14.6146, 11.5222, 8.6124, 6.4254, 4.7084, 3.4034, 2.4252, 1.7028, 1.1721, 0.7930, 0.5182, 0.2782, 0.0],
15: [14.6146, 12.4748, 10.0985, 8.0432, 6.3548, 4.9664, 3.8444, 2.9488, 2.2453, 1.6962, 1.2714, 0.9442, 0.6920, 0.4962, 0.3410, 0.0],
20: [14.6146, 13.4772, 11.6548, 9.9908, 8.4577, 7.0347, 5.7062, 4.4602, 3.2880, 2.1832, 1.1412, 0.1594, 0.0],
# Note: The 20-step schedule above is optimized to 12 actual steps for efficiency.
# If you want true 20 steps, the scheduler will interpolate from this schedule.
},
# Flux schedules (experimental - adapted from SDXL)
"FLUX": {
4: [14.6146, 7.2458, 3.0169, 0.7842, 0.0],
8: [14.6146, 10.1472, 6.5812, 4.0103, 2.4138, 1.3842, 0.6951, 0.2456, 0.0],
10: [14.6146, 11.2058, 8.0185, 5.5842, 3.8102, 2.5741, 1.6745, 1.0596, 0.6288, 0.3056, 0.0],
15: [14.6146, 12.9258, 10.7341, 8.7962, 7.0782, 5.5502, 4.2822, 3.2442, 2.4062, 1.7682, 1.2902, 0.9422, 0.6742, 0.4662, 0.3082, 0.0],
20: [14.6146, 14.0146, 12.4146, 10.8146, 9.2146, 7.6146, 6.0146, 4.4146, 2.8146, 1.2146, 0.0],
},
}
def ays_scheduler(
model_sampling: torch.nn.Module,
steps: int,
model_type: str = "SD15",
denoise: float = 1.0
) -> torch.FloatTensor:
"""Create an Align Your Steps optimized scheduler.
This scheduler uses pre-computed optimal sigma distributions that allow
fewer sampling steps with equivalent or better image quality compared to
uniform schedulers.
Args:
model_sampling (torch.nn.Module): The model sampling module.
steps (int): The number of denoising steps.
model_type (str): Model type - "SD15", "SDXL", or "FLUX". Defaults to "SD15".
denoise (float): Denoise strength (1.0 = full denoise). Defaults to 1.0.
Returns:
torch.FloatTensor: Optimized sigma schedule.
"""
# Get the schedule for this model type
if model_type not in AYS_OPTIMAL_SCHEDULES:
logging.warning(f"Unknown model type '{model_type}' for AYS scheduler, falling back to SD15")
model_type = "SD15"
schedules = AYS_OPTIMAL_SCHEDULES[model_type]
# Use exact schedule if available
if steps in schedules:
sigmas = torch.FloatTensor(schedules[steps])
logging.debug(f"Using AYS optimal schedule for {model_type} @ {steps} steps")
else:
# Interpolate between available schedules
available_steps = sorted(schedules.keys())
if steps < available_steps[0]:
# Use smallest available schedule
use_steps = available_steps[0]
logging.debug(f"Using AYS {use_steps}-step schedule (requested {steps} steps)")
sigmas = torch.FloatTensor(schedules[use_steps])
elif steps > available_steps[-1]:
# Extrapolate from largest schedule
use_steps = available_steps[-1]
logging.debug(f"Using AYS {use_steps}-step schedule (requested {steps} steps)")
base_sigmas = torch.FloatTensor(schedules[use_steps])
# Vectorized interpolation to desired number of steps
sigmas = resample_sigmas(base_sigmas, steps + 1)
else:
# Interpolate between two neighboring schedules
lower_steps = max([s for s in available_steps if s <= steps])
upper_steps = min([s for s in available_steps if s >= steps])
if lower_steps == upper_steps:
sigmas = torch.FloatTensor(schedules[lower_steps])
else:
# Interpolate between schedules
lower_sigmas = torch.FloatTensor(schedules[lower_steps])
upper_sigmas = torch.FloatTensor(schedules[upper_steps])
# Resample both to target step count
lower_resampled = resample_sigmas(lower_sigmas, steps + 1)
upper_resampled = resample_sigmas(upper_sigmas, steps + 1)
# Blend based on distance
weight = (steps - lower_steps) / (upper_steps - lower_steps)
sigmas = lower_resampled * (1 - weight) + upper_resampled * weight
logging.debug(f"Interpolated AYS schedule for {model_type} @ {steps} steps")
# Final guard: Ensure we return exactly the requested number of steps
if len(sigmas) != steps + 1:
sigmas = resample_sigmas(sigmas, steps + 1)
# Apply denoise factor if needed
if denoise < 1.0:
sigmas = apply_denoise_factor(sigmas, denoise)
# Ensure last sigma is exactly 0
sigmas[-1] = 0.0
return sigmas
def resample_sigmas(sigmas: torch.Tensor, target_steps: int) -> torch.Tensor:
"""Resample sigma schedule to different number of steps using linear interpolation.
Args:
sigmas (torch.Tensor): Original sigma schedule.
target_steps (int): Desired number of steps.
Returns:
torch.Tensor: Resampled sigma schedule.
"""
if len(sigmas) == target_steps:
return sigmas
# Vectorized interpolation using PyTorch's native interpolate
# This avoids manual loops and host-device synchronizations on GPU
sigmas_reshaped = sigmas.unsqueeze(0).unsqueeze(0)
resampled = torch.nn.functional.interpolate(
sigmas_reshaped, size=(target_steps,), mode='linear', align_corners=True
)
return resampled.squeeze()
def apply_denoise_factor(sigmas: torch.Tensor, denoise: float) -> torch.Tensor:
"""Apply denoise factor to sigma schedule (for img2img, inpainting, etc.).
Args:
sigmas (torch.Tensor): Original sigma schedule.
denoise (float): Denoise strength (0.0-1.0).
Returns:
torch.Tensor: Modified sigma schedule.
"""
if denoise >= 0.9999:
return sigmas
# Start from a higher sigma based on denoise factor
total_steps = len(sigmas) - 1
start_step = int((1.0 - denoise) * total_steps)
if start_step >= total_steps:
return torch.FloatTensor([0.0])
return sigmas[start_step:]
def get_available_ays_configs(model_type: str = "SD15") -> list:
"""Get list of available step counts for a model type.
Args:
model_type (str): Model type ("SD15", "SDXL", or "FLUX").
Returns:
list: Available step counts with optimal schedules.
"""
if model_type not in AYS_OPTIMAL_SCHEDULES:
return []
return sorted(AYS_OPTIMAL_SCHEDULES[model_type].keys())
def print_ays_info():
"""Print information about available AYS schedules."""
print("\n" + "="*70)
print("Align Your Steps (AYS) Scheduler - Available Configurations")
print("="*70)
for model_type in sorted(AYS_OPTIMAL_SCHEDULES.keys()):
steps = get_available_ays_configs(model_type)
print(f"\n{model_type}:")
print(f" Optimal schedules: {steps}")
print(f" Interpolated: any step count (quality may vary)")
print("\n" + "="*70)
print("Benefits:")
print(" • Same quality with fewer steps (e.g., 10 steps vs 20)")
print(" • Better timestep distribution for image formation")
print(" • Training-free, works with any model")
print(" • Particularly effective for SD1.5 and SDXL")
print("="*70 + "\n")
# Export optimal step counts as constants
RECOMMENDED_STEPS = {
"SD15": 10, # 10 steps with AYS = 20 steps with uniform
"SDXL": 10, # Same for SDXL
"FLUX": 8, # Flux can go lower
}
if __name__ == "__main__":
# Demo usage
print_ays_info()
# Example schedule
print("\nExample SD1.5 10-step schedule:")
sigmas = ays_scheduler(None, 10, "SD15")
print(sigmas)
|