| | """
|
| | LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow (LTX).
|
| | Optimized for Lightning LoRA compatibility and ultra-fast inference.
|
| | """
|
| |
|
| | import torch
|
| | import math
|
| | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
| |
|
| |
|
| | class LCMScheduler(SchedulerMixin):
|
| | """
|
| | LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow.
|
| | - LCM: Enables 2-8 step inference with consistency models
|
| | - LTX: Uses RectifiedFlow for better flow matching dynamics
|
| | Optimized for Lightning LoRAs and ultra-fast, high-quality generation.
|
| | """
|
| |
|
| | def __init__(self, num_train_timesteps: int = 1000, num_inference_steps: int = 4, shift: float = 1.0):
|
| | self.num_train_timesteps = num_train_timesteps
|
| | self.num_inference_steps = num_inference_steps
|
| | self.shift = shift
|
| | self._step_index = None
|
| |
|
| | def set_timesteps(self, num_inference_steps: int, device=None, shift: float = None, **kwargs):
|
| | """Set timesteps for LCM+LTX inference using RectifiedFlow approach"""
|
| | self.num_inference_steps = min(num_inference_steps, 8)
|
| |
|
| | if shift is None:
|
| | shift = self.shift
|
| |
|
| |
|
| |
|
| | t = torch.linspace(0, 1, self.num_inference_steps + 1, dtype=torch.float32)
|
| |
|
| |
|
| |
|
| | sigma_max = 1.0
|
| | sigma_min = 0.003 / 1.002
|
| |
|
| |
|
| |
|
| | sigmas = sigma_min + (sigma_max - sigma_min) * (1 - t)
|
| |
|
| |
|
| | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| |
|
| | self.sigmas = sigmas
|
| | self.timesteps = self.sigmas[:-1] * self.num_train_timesteps
|
| |
|
| | if device is not None:
|
| | self.timesteps = self.timesteps.to(device)
|
| | self.sigmas = self.sigmas.to(device)
|
| | self._step_index = None
|
| |
|
| | def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, **kwargs) -> SchedulerOutput:
|
| | """
|
| | Perform LCM + LTX step combining consistency model with rectified flow.
|
| | - LCM: Direct consistency model prediction for fast inference
|
| | - LTX: RectifiedFlow dynamics for optimal probability flow path
|
| | """
|
| | if self._step_index is None:
|
| | self._init_step_index(timestep)
|
| |
|
| |
|
| | sigma = self.sigmas[self._step_index]
|
| | if self._step_index + 1 < len(self.sigmas):
|
| | sigma_next = self.sigmas[self._step_index + 1]
|
| | else:
|
| | sigma_next = torch.zeros_like(sigma)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | sigma_diff = (sigma_next - sigma)
|
| | while len(sigma_diff.shape) < len(sample.shape):
|
| | sigma_diff = sigma_diff.unsqueeze(-1)
|
| |
|
| |
|
| |
|
| | prev_sample = sample + model_output * sigma_diff
|
| | self._step_index += 1
|
| |
|
| | return SchedulerOutput(prev_sample=prev_sample)
|
| |
|
| | def _init_step_index(self, timestep):
|
| | """Initialize step index based on current timestep"""
|
| | if isinstance(timestep, torch.Tensor):
|
| | timestep = timestep.to(self.timesteps.device)
|
| | indices = (self.timesteps == timestep).nonzero()
|
| | if len(indices) > 0:
|
| | self._step_index = indices[0].item()
|
| | else:
|
| |
|
| | diffs = torch.abs(self.timesteps - timestep)
|
| | self._step_index = torch.argmin(diffs).item()
|
| |
|