| """ |
| LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow (LTX). |
| Optimized for Lightning LoRA compatibility and ultra-fast inference. |
| """ |
|
|
| import torch |
| import math |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput |
|
|
|
|
| class LCMScheduler(SchedulerMixin): |
| """ |
| LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow. |
| - LCM: Enables 2-8 step inference with consistency models |
| - LTX: Uses RectifiedFlow for better flow matching dynamics |
| Optimized for Lightning LoRAs and ultra-fast, high-quality generation. |
| """ |
| |
| def __init__(self, num_train_timesteps: int = 1000, num_inference_steps: int = 4, shift: float = 1.0): |
| self.num_train_timesteps = num_train_timesteps |
| self.num_inference_steps = num_inference_steps |
| self.shift = shift |
| self._step_index = None |
| |
| def set_timesteps(self, num_inference_steps: int, device=None, shift: float = None, **kwargs): |
| """Set timesteps for LCM+LTX inference using RectifiedFlow approach""" |
| self.num_inference_steps = min(num_inference_steps, 8) |
| |
| if shift is None: |
| shift = self.shift |
| |
| |
| |
| t = torch.linspace(0, 1, self.num_inference_steps + 1, dtype=torch.float32) |
| |
| |
| |
| sigma_max = 1.0 |
| sigma_min = 0.003 / 1.002 |
| |
| |
| |
| sigmas = sigma_min + (sigma_max - sigma_min) * (1 - t) |
| |
| |
| sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
| |
| self.sigmas = sigmas |
| self.timesteps = self.sigmas[:-1] * self.num_train_timesteps |
| |
| if device is not None: |
| self.timesteps = self.timesteps.to(device) |
| self.sigmas = self.sigmas.to(device) |
| self._step_index = None |
| |
| def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, **kwargs) -> SchedulerOutput: |
| """ |
| Perform LCM + LTX step combining consistency model with rectified flow. |
| - LCM: Direct consistency model prediction for fast inference |
| - LTX: RectifiedFlow dynamics for optimal probability flow path |
| """ |
| if self._step_index is None: |
| self._init_step_index(timestep) |
| |
| |
| sigma = self.sigmas[self._step_index] |
| if self._step_index + 1 < len(self.sigmas): |
| sigma_next = self.sigmas[self._step_index + 1] |
| else: |
| sigma_next = torch.zeros_like(sigma) |
| |
| |
| |
| |
| |
| |
| |
| sigma_diff = (sigma_next - sigma) |
| while len(sigma_diff.shape) < len(sample.shape): |
| sigma_diff = sigma_diff.unsqueeze(-1) |
| |
| |
| |
| prev_sample = sample + model_output * sigma_diff |
| self._step_index += 1 |
| |
| return SchedulerOutput(prev_sample=prev_sample) |
| |
| def _init_step_index(self, timestep): |
| """Initialize step index based on current timestep""" |
| if isinstance(timestep, torch.Tensor): |
| timestep = timestep.to(self.timesteps.device) |
| indices = (self.timesteps == timestep).nonzero() |
| if len(indices) > 0: |
| self._step_index = indices[0].item() |
| else: |
| |
| diffs = torch.abs(self.timesteps - timestep) |
| self._step_index = torch.argmin(diffs).item() |
|
|