Spaces:
Running on Zero
Running on Zero
| """ | |
| Align Your Steps (AYS) Scheduler | |
| Optimized noise schedules for faster convergence with same/better quality. | |
| Training-free, lossless (often better quality) optimization. | |
| Reference: "Align Your Steps: Optimizing Sampling Schedules in Diffusion Models" (2024) | |
| https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/ | |
| Key insight: Not all timesteps contribute equally to image quality. | |
| AYS finds optimal timestep schedules that allow fewer steps with same quality. | |
| """ | |
| import torch | |
| import numpy as np | |
| import logging | |
| # Pre-computed optimal sigma schedules from AYS paper | |
| # These were found through optimization to minimize reconstruction error | |
| AYS_OPTIMAL_SCHEDULES = { | |
| # SD1.5 schedules | |
| "SD15": { | |
| 4: [14.6146, 6.4745, 2.4826, 0.5497, 0.0], | |
| 6: [14.6146, 8.0426, 4.4170, 2.2172, 0.9316, 0.2596, 0.0], | |
| 8: [14.6146, 9.4663, 5.9384, 3.4759, 1.9696, 1.0417, 0.4598, 0.1328, 0.0], | |
| 10: [14.6146, 10.4708, 7.3688, 4.9651, 3.2924, 2.1391, 1.3633, 0.8437, 0.4898, 0.2279, 0.0], | |
| 12: [14.6146, 11.2797, 8.5033, 6.2928, 4.5662, 3.2721, 2.3124, 1.6103, 1.1029, 0.7414, 0.4804, 0.2552, 0.0], | |
| 15: [14.6146, 12.1652, 9.8581, 7.8194, 6.1302, 4.7577, 3.6611, 2.7994, 2.1271, 1.6053, 1.2033, 0.8917, 0.6525, 0.4672, 0.3196, 0.0], | |
| 20: [14.6146, 13.1721, 11.0451, 9.2027, 7.6085, 6.2281, 5.0356, 4.0135, 3.1468, 2.4226, 1.8293, 1.3567, 0.9967, 0.7219, 0.5129, 0.3589, 0.2463, 0.1644, 0.1025, 0.0522, 0.0], | |
| 25: [14.6146, 13.8539, 11.9991, 10.3603, 8.9042, 7.6058, 6.4449, 5.4048, 4.4710, 3.6313, 2.8757, 2.1959, 1.5851, 1.0375, 0.5483, 0.0948, 0.0], | |
| }, | |
| # SDXL schedules | |
| "SDXL": { | |
| 4: [14.6146, 6.8873, 2.7084, 0.6577, 0.0], | |
| 6: [14.6146, 8.3767, 4.6699, 2.4175, 1.0643, 0.3262, 0.0], | |
| 8: [14.6146, 9.6929, 6.1589, 3.6454, 2.1116, 1.1507, 0.5474, 0.1770, 0.0], | |
| 10: [14.6146, 10.7043, 7.5043, 5.1442, 3.4302, 2.2379, 1.4288, 0.8874, 0.5174, 0.2427, 0.0], | |
| 12: [14.6146, 11.5222, 8.6124, 6.4254, 4.7084, 3.4034, 2.4252, 1.7028, 1.1721, 0.7930, 0.5182, 0.2782, 0.0], | |
| 15: [14.6146, 12.4748, 10.0985, 8.0432, 6.3548, 4.9664, 3.8444, 2.9488, 2.2453, 1.6962, 1.2714, 0.9442, 0.6920, 0.4962, 0.3410, 0.0], | |
| 20: [14.6146, 13.4772, 11.6548, 9.9908, 8.4577, 7.0347, 5.7062, 4.4602, 3.2880, 2.1832, 1.1412, 0.1594, 0.0], | |
| # Note: The 20-step schedule above is optimized to 12 actual steps for efficiency. | |
| # If you want true 20 steps, the scheduler will interpolate from this schedule. | |
| }, | |
| # Flux schedules (experimental - adapted from SDXL) | |
| "FLUX": { | |
| 4: [14.6146, 7.2458, 3.0169, 0.7842, 0.0], | |
| 8: [14.6146, 10.1472, 6.5812, 4.0103, 2.4138, 1.3842, 0.6951, 0.2456, 0.0], | |
| 10: [14.6146, 11.2058, 8.0185, 5.5842, 3.8102, 2.5741, 1.6745, 1.0596, 0.6288, 0.3056, 0.0], | |
| 15: [14.6146, 12.9258, 10.7341, 8.7962, 7.0782, 5.5502, 4.2822, 3.2442, 2.4062, 1.7682, 1.2902, 0.9422, 0.6742, 0.4662, 0.3082, 0.0], | |
| 20: [14.6146, 14.0146, 12.4146, 10.8146, 9.2146, 7.6146, 6.0146, 4.4146, 2.8146, 1.2146, 0.0], | |
| }, | |
| } | |
| def ays_scheduler( | |
| model_sampling: torch.nn.Module, | |
| steps: int, | |
| model_type: str = "SD15", | |
| denoise: float = 1.0 | |
| ) -> torch.FloatTensor: | |
| """Create an Align Your Steps optimized scheduler. | |
| This scheduler uses pre-computed optimal sigma distributions that allow | |
| fewer sampling steps with equivalent or better image quality compared to | |
| uniform schedulers. | |
| Args: | |
| model_sampling (torch.nn.Module): The model sampling module. | |
| steps (int): The number of denoising steps. | |
| model_type (str): Model type - "SD15", "SDXL", or "FLUX". Defaults to "SD15". | |
| denoise (float): Denoise strength (1.0 = full denoise). Defaults to 1.0. | |
| Returns: | |
| torch.FloatTensor: Optimized sigma schedule. | |
| """ | |
| # Get the schedule for this model type | |
| if model_type not in AYS_OPTIMAL_SCHEDULES: | |
| logging.warning(f"Unknown model type '{model_type}' for AYS scheduler, falling back to SD15") | |
| model_type = "SD15" | |
| schedules = AYS_OPTIMAL_SCHEDULES[model_type] | |
| # Use exact schedule if available | |
| if steps in schedules: | |
| sigmas = torch.FloatTensor(schedules[steps]) | |
| logging.debug(f"Using AYS optimal schedule for {model_type} @ {steps} steps") | |
| else: | |
| # Interpolate between available schedules | |
| available_steps = sorted(schedules.keys()) | |
| if steps < available_steps[0]: | |
| # Use smallest available schedule | |
| use_steps = available_steps[0] | |
| logging.debug(f"Using AYS {use_steps}-step schedule (requested {steps} steps)") | |
| sigmas = torch.FloatTensor(schedules[use_steps]) | |
| elif steps > available_steps[-1]: | |
| # Extrapolate from largest schedule | |
| use_steps = available_steps[-1] | |
| logging.debug(f"Using AYS {use_steps}-step schedule (requested {steps} steps)") | |
| base_sigmas = torch.FloatTensor(schedules[use_steps]) | |
| # Vectorized interpolation to desired number of steps | |
| sigmas = resample_sigmas(base_sigmas, steps + 1) | |
| else: | |
| # Interpolate between two neighboring schedules | |
| lower_steps = max([s for s in available_steps if s <= steps]) | |
| upper_steps = min([s for s in available_steps if s >= steps]) | |
| if lower_steps == upper_steps: | |
| sigmas = torch.FloatTensor(schedules[lower_steps]) | |
| else: | |
| # Interpolate between schedules | |
| lower_sigmas = torch.FloatTensor(schedules[lower_steps]) | |
| upper_sigmas = torch.FloatTensor(schedules[upper_steps]) | |
| # Resample both to target step count | |
| lower_resampled = resample_sigmas(lower_sigmas, steps + 1) | |
| upper_resampled = resample_sigmas(upper_sigmas, steps + 1) | |
| # Blend based on distance | |
| weight = (steps - lower_steps) / (upper_steps - lower_steps) | |
| sigmas = lower_resampled * (1 - weight) + upper_resampled * weight | |
| logging.debug(f"Interpolated AYS schedule for {model_type} @ {steps} steps") | |
| # Final guard: Ensure we return exactly the requested number of steps | |
| if len(sigmas) != steps + 1: | |
| sigmas = resample_sigmas(sigmas, steps + 1) | |
| # Apply denoise factor if needed | |
| if denoise < 1.0: | |
| sigmas = apply_denoise_factor(sigmas, denoise) | |
| # Ensure last sigma is exactly 0 | |
| sigmas[-1] = 0.0 | |
| return sigmas | |
| def resample_sigmas(sigmas: torch.Tensor, target_steps: int) -> torch.Tensor: | |
| """Resample sigma schedule to different number of steps using linear interpolation. | |
| Args: | |
| sigmas (torch.Tensor): Original sigma schedule. | |
| target_steps (int): Desired number of steps. | |
| Returns: | |
| torch.Tensor: Resampled sigma schedule. | |
| """ | |
| if len(sigmas) == target_steps: | |
| return sigmas | |
| # Vectorized interpolation using PyTorch's native interpolate | |
| # This avoids manual loops and host-device synchronizations on GPU | |
| sigmas_reshaped = sigmas.unsqueeze(0).unsqueeze(0) | |
| resampled = torch.nn.functional.interpolate( | |
| sigmas_reshaped, size=(target_steps,), mode='linear', align_corners=True | |
| ) | |
| return resampled.squeeze() | |
| def apply_denoise_factor(sigmas: torch.Tensor, denoise: float) -> torch.Tensor: | |
| """Apply denoise factor to sigma schedule (for img2img, inpainting, etc.). | |
| Args: | |
| sigmas (torch.Tensor): Original sigma schedule. | |
| denoise (float): Denoise strength (0.0-1.0). | |
| Returns: | |
| torch.Tensor: Modified sigma schedule. | |
| """ | |
| if denoise >= 0.9999: | |
| return sigmas | |
| # Start from a higher sigma based on denoise factor | |
| total_steps = len(sigmas) - 1 | |
| start_step = int((1.0 - denoise) * total_steps) | |
| if start_step >= total_steps: | |
| return torch.FloatTensor([0.0]) | |
| return sigmas[start_step:] | |
| def get_available_ays_configs(model_type: str = "SD15") -> list: | |
| """Get list of available step counts for a model type. | |
| Args: | |
| model_type (str): Model type ("SD15", "SDXL", or "FLUX"). | |
| Returns: | |
| list: Available step counts with optimal schedules. | |
| """ | |
| if model_type not in AYS_OPTIMAL_SCHEDULES: | |
| return [] | |
| return sorted(AYS_OPTIMAL_SCHEDULES[model_type].keys()) | |
| def print_ays_info(): | |
| """Print information about available AYS schedules.""" | |
| print("\n" + "="*70) | |
| print("Align Your Steps (AYS) Scheduler - Available Configurations") | |
| print("="*70) | |
| for model_type in sorted(AYS_OPTIMAL_SCHEDULES.keys()): | |
| steps = get_available_ays_configs(model_type) | |
| print(f"\n{model_type}:") | |
| print(f" Optimal schedules: {steps}") | |
| print(f" Interpolated: any step count (quality may vary)") | |
| print("\n" + "="*70) | |
| print("Benefits:") | |
| print(" • Same quality with fewer steps (e.g., 10 steps vs 20)") | |
| print(" • Better timestep distribution for image formation") | |
| print(" • Training-free, works with any model") | |
| print(" • Particularly effective for SD1.5 and SDXL") | |
| print("="*70 + "\n") | |
| # Export optimal step counts as constants | |
| RECOMMENDED_STEPS = { | |
| "SD15": 10, # 10 steps with AYS = 20 steps with uniform | |
| "SDXL": 10, # Same for SDXL | |
| "FLUX": 8, # Flux can go lower | |
| } | |
| if __name__ == "__main__": | |
| # Demo usage | |
| print_ays_info() | |
| # Example schedule | |
| print("\nExample SD1.5 10-step schedule:") | |
| sigmas = ays_scheduler(None, 10, "SD15") | |
| print(sigmas) | |