Upload 6 files
Browse files- packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
- packages/ltx-core/src/ltx_core/components/guiders.py +198 -0
- packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
- packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
- packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
- packages/ltx-core/src/ltx_core/components/schedulers.py +129 -0
packages/ltx-core/src/ltx_core/components/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Diffusion pipeline components.
|
| 3 |
+
Submodules:
|
| 4 |
+
diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
|
| 5 |
+
guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
|
| 6 |
+
noisers - Noise samplers (GaussianNoiser)
|
| 7 |
+
patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
|
| 8 |
+
protocols - Protocol definitions (Patchifier, etc.)
|
| 9 |
+
schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
|
| 10 |
+
"""
|
packages/ltx-core/src/ltx_core/components/guiders.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.protocols import GuiderProtocol
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class CFGGuider(GuiderProtocol):
|
| 10 |
+
"""
|
| 11 |
+
Classifier-free guidance (CFG) guider.
|
| 12 |
+
Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
|
| 13 |
+
denoising process toward the conditioned prediction.
|
| 14 |
+
Attributes:
|
| 15 |
+
scale: Guidance strength. 1.0 means no guidance, higher values increase
|
| 16 |
+
adherence to the conditioning.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
scale: float
|
| 20 |
+
|
| 21 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
return (self.scale - 1) * (cond - uncond)
|
| 23 |
+
|
| 24 |
+
def enabled(self) -> bool:
|
| 25 |
+
return self.scale != 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class CFGStarRescalingGuider(GuiderProtocol):
|
| 30 |
+
"""
|
| 31 |
+
Calculates the CFG delta between conditioned and unconditioned samples.
|
| 32 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 33 |
+
conditioning axis within the distribution, the unconditioned sample is
|
| 34 |
+
rescaled in accordance with the norm of the conditioned sample.
|
| 35 |
+
Attributes:
|
| 36 |
+
scale (float):
|
| 37 |
+
Global guidance strength. A value of 1.0 corresponds to no extra
|
| 38 |
+
guidance beyond the base model prediction. Values > 1.0 increase
|
| 39 |
+
the influence of the conditioned sample relative to the
|
| 40 |
+
unconditioned one.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
scale: float
|
| 44 |
+
|
| 45 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
rescaled_neg = projection_coef(cond, uncond) * uncond
|
| 47 |
+
return (self.scale - 1) * (cond - rescaled_neg)
|
| 48 |
+
|
| 49 |
+
def enabled(self) -> bool:
|
| 50 |
+
return self.scale != 1.0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(frozen=True)
|
| 54 |
+
class STGGuider(GuiderProtocol):
|
| 55 |
+
"""
|
| 56 |
+
Calculates the STG delta between conditioned and perturbed denoised samples.
|
| 57 |
+
Perturbed samples are the result of the denoising process with perturbations,
|
| 58 |
+
e.g. attentions acting as passthrough for certain layers and modalities.
|
| 59 |
+
Attributes:
|
| 60 |
+
scale (float):
|
| 61 |
+
Global strength of the STG guidance. A value of 0.0 disables the
|
| 62 |
+
guidance. Larger values increase the correction applied in the
|
| 63 |
+
direction of (pos_denoised - perturbed_denoised).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
scale: float
|
| 67 |
+
|
| 68 |
+
def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
return self.scale * (pos_denoised - perturbed_denoised)
|
| 70 |
+
|
| 71 |
+
def enabled(self) -> bool:
|
| 72 |
+
return self.scale != 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass(frozen=True)
|
| 76 |
+
class LtxAPGGuider(GuiderProtocol):
|
| 77 |
+
"""
|
| 78 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 79 |
+
and unconditioned samples.
|
| 80 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 81 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 82 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 83 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 84 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 85 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 86 |
+
Attributes:
|
| 87 |
+
scale (float):
|
| 88 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 89 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 90 |
+
directions that change semantics but stay consistent with the
|
| 91 |
+
conditioning manifold.
|
| 92 |
+
eta (float):
|
| 93 |
+
Weight of the component of the guidance that is parallel to the
|
| 94 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 95 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 96 |
+
motion along the conditioning direction.
|
| 97 |
+
norm_threshold (float):
|
| 98 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 99 |
+
can be reduced or ignored (depending on implementation).
|
| 100 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 101 |
+
guidance signal is very small.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
scale: float
|
| 105 |
+
eta: float = 1.0
|
| 106 |
+
norm_threshold: float = 0.0
|
| 107 |
+
|
| 108 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
guidance = cond - uncond
|
| 110 |
+
if self.norm_threshold > 0:
|
| 111 |
+
ones = torch.ones_like(guidance)
|
| 112 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 113 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 114 |
+
guidance = guidance * scale_factor
|
| 115 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 116 |
+
g_parallel = proj_coeff * cond
|
| 117 |
+
g_orth = guidance - g_parallel
|
| 118 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 119 |
+
|
| 120 |
+
return g_apg * (self.scale - 1)
|
| 121 |
+
|
| 122 |
+
def enabled(self) -> bool:
|
| 123 |
+
return self.scale != 1.0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass(frozen=False)
|
| 127 |
+
class LegacyStatefulAPGGuider(GuiderProtocol):
|
| 128 |
+
"""
|
| 129 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 130 |
+
and unconditioned samples.
|
| 131 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 132 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 133 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 134 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 135 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 136 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 137 |
+
Attributes:
|
| 138 |
+
scale (float):
|
| 139 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 140 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 141 |
+
directions that change semantics but stay consistent with the
|
| 142 |
+
conditioning manifold.
|
| 143 |
+
eta (float):
|
| 144 |
+
Weight of the component of the guidance that is parallel to the
|
| 145 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 146 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 147 |
+
motion along the conditioning direction.
|
| 148 |
+
norm_threshold (float):
|
| 149 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 150 |
+
can be reduced or ignored (depending on implementation).
|
| 151 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 152 |
+
guidance signal is very small.
|
| 153 |
+
momentum (float):
|
| 154 |
+
Exponential moving-average coefficient for accumulating guidance
|
| 155 |
+
over time. running_avg = momentum * running_avg + guidance
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
scale: float
|
| 159 |
+
eta: float
|
| 160 |
+
norm_threshold: float = 5.0
|
| 161 |
+
momentum: float = 0.0
|
| 162 |
+
# it is user's responsibility not to use same APGGuider for several denoisings or different modalities
|
| 163 |
+
# in order not to share accumulated average across different denoisings or modalities
|
| 164 |
+
running_avg: torch.Tensor | None = None
|
| 165 |
+
|
| 166 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
guidance = cond - uncond
|
| 168 |
+
if self.momentum != 0:
|
| 169 |
+
if self.running_avg is None:
|
| 170 |
+
self.running_avg = guidance.clone()
|
| 171 |
+
else:
|
| 172 |
+
self.running_avg = self.momentum * self.running_avg + guidance
|
| 173 |
+
guidance = self.running_avg
|
| 174 |
+
|
| 175 |
+
if self.norm_threshold > 0:
|
| 176 |
+
ones = torch.ones_like(guidance)
|
| 177 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 178 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 179 |
+
guidance = guidance * scale_factor
|
| 180 |
+
|
| 181 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 182 |
+
g_parallel = proj_coeff * cond
|
| 183 |
+
g_orth = guidance - g_parallel
|
| 184 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 185 |
+
|
| 186 |
+
return g_apg * self.scale
|
| 187 |
+
|
| 188 |
+
def enabled(self) -> bool:
|
| 189 |
+
return self.scale != 0.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
batch_size = to_project.shape[0]
|
| 194 |
+
positive_flat = to_project.reshape(batch_size, -1)
|
| 195 |
+
negative_flat = project_onto.reshape(batch_size, -1)
|
| 196 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 197 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
| 198 |
+
return dot_product / squared_norm
|
packages/ltx-core/src/ltx_core/components/noisers.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Noiser(Protocol):
|
| 10 |
+
"""Protocol for adding noise to a latent state during diffusion."""
|
| 11 |
+
|
| 12 |
+
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianNoiser(Noiser):
|
| 16 |
+
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator: torch.Generator):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.generator = generator
|
| 22 |
+
|
| 23 |
+
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
|
| 24 |
+
noise = torch.randn(
|
| 25 |
+
*latent_state.latent.shape,
|
| 26 |
+
device=latent_state.latent.device,
|
| 27 |
+
dtype=latent_state.latent.dtype,
|
| 28 |
+
generator=self.generator,
|
| 29 |
+
)
|
| 30 |
+
scaled_mask = latent_state.denoise_mask * noise_scale
|
| 31 |
+
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
|
| 32 |
+
return replace(
|
| 33 |
+
latent_state,
|
| 34 |
+
latent=latent.to(latent_state.latent.dtype),
|
| 35 |
+
)
|
packages/ltx-core/src/ltx_core/components/patchifiers.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import Patchifier
|
| 8 |
+
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VideoLatentPatchifier(Patchifier):
|
| 12 |
+
def __init__(self, patch_size: int):
|
| 13 |
+
# Patch sizes for video latents.
|
| 14 |
+
self._patch_size = (
|
| 15 |
+
1, # temporal dimension
|
| 16 |
+
patch_size, # height dimension
|
| 17 |
+
patch_size, # width dimension
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 22 |
+
return self._patch_size
|
| 23 |
+
|
| 24 |
+
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
| 25 |
+
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
| 26 |
+
|
| 27 |
+
def patchify(
|
| 28 |
+
self,
|
| 29 |
+
latents: torch.Tensor,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
latents = einops.rearrange(
|
| 32 |
+
latents,
|
| 33 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 34 |
+
p1=self._patch_size[0],
|
| 35 |
+
p2=self._patch_size[1],
|
| 36 |
+
p3=self._patch_size[2],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return latents
|
| 40 |
+
|
| 41 |
+
def unpatchify(
|
| 42 |
+
self,
|
| 43 |
+
latents: torch.Tensor,
|
| 44 |
+
output_shape: VideoLatentShape,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
| 47 |
+
|
| 48 |
+
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
| 49 |
+
patch_grid_height = output_shape.height // self._patch_size[1]
|
| 50 |
+
patch_grid_width = output_shape.width // self._patch_size[2]
|
| 51 |
+
|
| 52 |
+
latents = einops.rearrange(
|
| 53 |
+
latents,
|
| 54 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 55 |
+
f=patch_grid_frames,
|
| 56 |
+
h=patch_grid_height,
|
| 57 |
+
w=patch_grid_width,
|
| 58 |
+
p=self._patch_size[1],
|
| 59 |
+
q=self._patch_size[2],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return latents
|
| 63 |
+
|
| 64 |
+
def get_patch_grid_bounds(
|
| 65 |
+
self,
|
| 66 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 67 |
+
device: Optional[torch.device] = None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
| 71 |
+
patch produced by `patchify`. The bounds are expressed in the original
|
| 72 |
+
video grid coordinates: frame/time, height, and width.
|
| 73 |
+
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
| 74 |
+
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
| 75 |
+
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
| 76 |
+
Args:
|
| 77 |
+
output_shape: Video grid description containing frames, height, and width.
|
| 78 |
+
device: Device of the latent tensor.
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(output_shape, VideoLatentShape):
|
| 81 |
+
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
| 82 |
+
|
| 83 |
+
frames = output_shape.frames
|
| 84 |
+
height = output_shape.height
|
| 85 |
+
width = output_shape.width
|
| 86 |
+
batch_size = output_shape.batch
|
| 87 |
+
|
| 88 |
+
# Validate inputs to ensure positive dimensions
|
| 89 |
+
assert frames > 0, f"frames must be positive, got {frames}"
|
| 90 |
+
assert height > 0, f"height must be positive, got {height}"
|
| 91 |
+
assert width > 0, f"width must be positive, got {width}"
|
| 92 |
+
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
| 93 |
+
|
| 94 |
+
# Generate grid coordinates for each dimension (frame, height, width)
|
| 95 |
+
# We use torch.arange to create the starting coordinates for each patch.
|
| 96 |
+
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
| 97 |
+
grid_coords = torch.meshgrid(
|
| 98 |
+
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
| 99 |
+
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
| 100 |
+
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Stack the grid coordinates to create the start coordinates tensor.
|
| 105 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 106 |
+
patch_starts = torch.stack(grid_coords, dim=0)
|
| 107 |
+
|
| 108 |
+
# Create a tensor containing the size of a single patch:
|
| 109 |
+
# (frame_patch_size, height_patch_size, width_patch_size).
|
| 110 |
+
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
| 111 |
+
patch_size_delta = torch.tensor(
|
| 112 |
+
self._patch_size,
|
| 113 |
+
device=patch_starts.device,
|
| 114 |
+
dtype=patch_starts.dtype,
|
| 115 |
+
).view(3, 1, 1, 1)
|
| 116 |
+
|
| 117 |
+
# Calculate end coordinates: start + patch_size
|
| 118 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 119 |
+
patch_ends = patch_starts + patch_size_delta
|
| 120 |
+
|
| 121 |
+
# Stack start and end coordinates together along the last dimension
|
| 122 |
+
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
| 123 |
+
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
| 124 |
+
|
| 125 |
+
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
| 126 |
+
# Final Shape: (batch_size, 3, num_patches, 2)
|
| 127 |
+
latent_coords = einops.repeat(
|
| 128 |
+
latent_coords,
|
| 129 |
+
"c f h w bounds -> b c (f h w) bounds",
|
| 130 |
+
b=batch_size,
|
| 131 |
+
bounds=2,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return latent_coords
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_pixel_coords(
|
| 138 |
+
latent_coords: torch.Tensor,
|
| 139 |
+
scale_factors: SpatioTemporalScaleFactors,
|
| 140 |
+
causal_fix: bool = False,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
| 144 |
+
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
| 145 |
+
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
| 146 |
+
Args:
|
| 147 |
+
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
| 148 |
+
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
| 149 |
+
per axis.
|
| 150 |
+
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
| 151 |
+
that treat frame zero differently still yield non-negative timestamps.
|
| 152 |
+
"""
|
| 153 |
+
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
| 154 |
+
broadcast_shape = [1] * latent_coords.ndim
|
| 155 |
+
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
| 156 |
+
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
| 157 |
+
|
| 158 |
+
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
| 159 |
+
pixel_coords = latent_coords * scale_tensor
|
| 160 |
+
|
| 161 |
+
if causal_fix:
|
| 162 |
+
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
| 163 |
+
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
| 164 |
+
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
| 165 |
+
|
| 166 |
+
return pixel_coords
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AudioPatchifier(Patchifier):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
patch_size: int,
|
| 173 |
+
sample_rate: int = 16000,
|
| 174 |
+
hop_length: int = 160,
|
| 175 |
+
audio_latent_downsample_factor: int = 4,
|
| 176 |
+
is_causal: bool = True,
|
| 177 |
+
shift: int = 0,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Patchifier tailored for spectrogram/audio latents.
|
| 181 |
+
Args:
|
| 182 |
+
patch_size: Number of mel bins combined into a single patch. This
|
| 183 |
+
controls the resolution along the frequency axis.
|
| 184 |
+
sample_rate: Original waveform sampling rate. Used to map latent
|
| 185 |
+
indices back to seconds so downstream consumers can align audio
|
| 186 |
+
and video cues.
|
| 187 |
+
hop_length: Window hop length used for the spectrogram. Determines
|
| 188 |
+
how many real-time samples separate two consecutive latent frames.
|
| 189 |
+
audio_latent_downsample_factor: Ratio between spectrogram frames and
|
| 190 |
+
latent frames; compensates for additional downsampling inside the
|
| 191 |
+
VAE encoder.
|
| 192 |
+
is_causal: When True, timing is shifted to account for causal
|
| 193 |
+
receptive fields so timestamps do not peek into the future.
|
| 194 |
+
shift: Integer offset applied to the latent indices. Enables
|
| 195 |
+
constructing overlapping windows from the same latent sequence.
|
| 196 |
+
"""
|
| 197 |
+
self.hop_length = hop_length
|
| 198 |
+
self.sample_rate = sample_rate
|
| 199 |
+
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
| 200 |
+
self.is_causal = is_causal
|
| 201 |
+
self.shift = shift
|
| 202 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 206 |
+
return self._patch_size
|
| 207 |
+
|
| 208 |
+
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
|
| 209 |
+
return tgt_shape.frames
|
| 210 |
+
|
| 211 |
+
def _get_audio_latent_time_in_sec(
|
| 212 |
+
self,
|
| 213 |
+
start_latent: int,
|
| 214 |
+
end_latent: int,
|
| 215 |
+
dtype: torch.dtype,
|
| 216 |
+
device: Optional[torch.device] = None,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Converts latent indices into real-time seconds while honoring causal
|
| 220 |
+
offsets and the configured hop length.
|
| 221 |
+
Args:
|
| 222 |
+
start_latent: Inclusive start index inside the latent sequence. This
|
| 223 |
+
sets the first timestamp returned.
|
| 224 |
+
end_latent: Exclusive end index. Determines how many timestamps get
|
| 225 |
+
generated.
|
| 226 |
+
dtype: Floating-point dtype used for the returned tensor, allowing
|
| 227 |
+
callers to control precision.
|
| 228 |
+
device: Target device for the timestamp tensor. When omitted the
|
| 229 |
+
computation occurs on CPU to avoid surprising GPU allocations.
|
| 230 |
+
"""
|
| 231 |
+
if device is None:
|
| 232 |
+
device = torch.device("cpu")
|
| 233 |
+
|
| 234 |
+
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
| 235 |
+
|
| 236 |
+
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
| 237 |
+
|
| 238 |
+
if self.is_causal:
|
| 239 |
+
# Frame offset for causal alignment.
|
| 240 |
+
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
|
| 241 |
+
causal_offset = 1
|
| 242 |
+
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
|
| 243 |
+
|
| 244 |
+
return audio_mel_frame * self.hop_length / self.sample_rate
|
| 245 |
+
|
| 246 |
+
def _compute_audio_timings(
|
| 247 |
+
self,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
num_steps: int,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
"""
|
| 253 |
+
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
|
| 254 |
+
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
|
| 255 |
+
Args:
|
| 256 |
+
batch_size: Number of sequences to broadcast the timings over.
|
| 257 |
+
num_steps: Number of latent frames (time steps) to convert into timestamps.
|
| 258 |
+
device: Device on which the resulting tensor should reside.
|
| 259 |
+
"""
|
| 260 |
+
resolved_device = device
|
| 261 |
+
if resolved_device is None:
|
| 262 |
+
resolved_device = torch.device("cpu")
|
| 263 |
+
|
| 264 |
+
start_timings = self._get_audio_latent_time_in_sec(
|
| 265 |
+
self.shift,
|
| 266 |
+
num_steps + self.shift,
|
| 267 |
+
torch.float32,
|
| 268 |
+
resolved_device,
|
| 269 |
+
)
|
| 270 |
+
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
end_timings = self._get_audio_latent_time_in_sec(
|
| 273 |
+
self.shift + 1,
|
| 274 |
+
num_steps + self.shift + 1,
|
| 275 |
+
torch.float32,
|
| 276 |
+
resolved_device,
|
| 277 |
+
)
|
| 278 |
+
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
return torch.stack([start_timings, end_timings], dim=-1)
|
| 281 |
+
|
| 282 |
+
def patchify(
|
| 283 |
+
self,
|
| 284 |
+
audio_latents: torch.Tensor,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""
|
| 287 |
+
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
|
| 288 |
+
to derive timestamps for each latent frame based on the configured hop
|
| 289 |
+
length and downsampling.
|
| 290 |
+
Args:
|
| 291 |
+
audio_latents: Latent tensor to patchify.
|
| 292 |
+
Returns:
|
| 293 |
+
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
|
| 294 |
+
corresponding timing metadata when needed.
|
| 295 |
+
"""
|
| 296 |
+
audio_latents = einops.rearrange(
|
| 297 |
+
audio_latents,
|
| 298 |
+
"b c t f -> b t (c f)",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return audio_latents
|
| 302 |
+
|
| 303 |
+
def unpatchify(
|
| 304 |
+
self,
|
| 305 |
+
audio_latents: torch.Tensor,
|
| 306 |
+
output_shape: AudioLatentShape,
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
"""
|
| 309 |
+
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
|
| 310 |
+
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
|
| 311 |
+
frame's position in real time.
|
| 312 |
+
Args:
|
| 313 |
+
audio_latents: Latent tensor to unpatchify.
|
| 314 |
+
output_shape: Shape of the unpatched output tensor.
|
| 315 |
+
Returns:
|
| 316 |
+
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
|
| 317 |
+
metadata associated with the restored latents.
|
| 318 |
+
"""
|
| 319 |
+
# audio_latents shape: (batch, time, freq * channels)
|
| 320 |
+
audio_latents = einops.rearrange(
|
| 321 |
+
audio_latents,
|
| 322 |
+
"b t (c f) -> b c t f",
|
| 323 |
+
c=output_shape.channels,
|
| 324 |
+
f=output_shape.mel_bins,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return audio_latents
|
| 328 |
+
|
| 329 |
+
def get_patch_grid_bounds(
|
| 330 |
+
self,
|
| 331 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 332 |
+
device: Optional[torch.device] = None,
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
"""
|
| 335 |
+
Return the temporal bounds `[inclusive start, exclusive end)` for every
|
| 336 |
+
patch emitted by `patchify`. For audio this corresponds to timestamps in
|
| 337 |
+
seconds aligned with the original spectrogram grid.
|
| 338 |
+
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
|
| 339 |
+
- axis 1 (size 1) represents the temporal dimension
|
| 340 |
+
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
|
| 341 |
+
Args:
|
| 342 |
+
output_shape: Audio grid specification describing the number of time steps.
|
| 343 |
+
device: Target device for the returned tensor.
|
| 344 |
+
"""
|
| 345 |
+
if not isinstance(output_shape, AudioLatentShape):
|
| 346 |
+
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
|
| 347 |
+
|
| 348 |
+
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
|
packages/ltx-core/src/ltx_core/components/protocols.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import AudioLatentShape, VideoLatentShape
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Patchifier(Protocol):
|
| 9 |
+
"""
|
| 10 |
+
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def patchify(
|
| 14 |
+
self,
|
| 15 |
+
latents: torch.Tensor,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
...
|
| 18 |
+
"""
|
| 19 |
+
Convert latent tensors into flattened patch tokens.
|
| 20 |
+
Args:
|
| 21 |
+
latents: Latent tensor to patchify.
|
| 22 |
+
Returns:
|
| 23 |
+
Flattened patch tokens tensor.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def unpatchify(
|
| 27 |
+
self,
|
| 28 |
+
latents: torch.Tensor,
|
| 29 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
| 33 |
+
Args:
|
| 34 |
+
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
| 35 |
+
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
| 36 |
+
VideoLatentShape.
|
| 37 |
+
Returns:
|
| 38 |
+
Dense latent tensor restored from the flattened representation.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 43 |
+
...
|
| 44 |
+
"""
|
| 45 |
+
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def get_patch_grid_bounds(
|
| 49 |
+
self,
|
| 50 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 51 |
+
device: torch.device | None = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
...
|
| 54 |
+
"""
|
| 55 |
+
Compute metadata describing where each latent patch resides within the
|
| 56 |
+
grid specified by `output_shape`.
|
| 57 |
+
Args:
|
| 58 |
+
output_shape: Target grid layout for the patches.
|
| 59 |
+
device: Target device for the returned tensor.
|
| 60 |
+
Returns:
|
| 61 |
+
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SchedulerProtocol(Protocol):
|
| 66 |
+
"""
|
| 67 |
+
Protocol for schedulers that provide a sigmas schedule tensor for a
|
| 68 |
+
given number of steps. Device is cpu.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GuiderProtocol(Protocol):
|
| 75 |
+
"""
|
| 76 |
+
Protocol for guiders that compute a delta tensor given conditioning inputs.
|
| 77 |
+
The returned delta should be added to the conditional output (cond), enabling
|
| 78 |
+
multiple guiders to be chained together by accumulating their deltas.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
scale: float
|
| 82 |
+
|
| 83 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
|
| 84 |
+
|
| 85 |
+
def enabled(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
|
| 88 |
+
is 1.0.
|
| 89 |
+
"""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DiffusionStepProtocol(Protocol):
|
| 94 |
+
"""
|
| 95 |
+
Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
|
| 96 |
+
current denoised sample tensor, and sigmas tensor.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def step(
|
| 100 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
|
| 101 |
+
) -> torch.Tensor: ...
|
packages/ltx-core/src/ltx_core/components/schedulers.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import scipy
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.components.protocols import SchedulerProtocol
|
| 9 |
+
|
| 10 |
+
BASE_SHIFT_ANCHOR = 1024
|
| 11 |
+
MAX_SHIFT_ANCHOR = 4096
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LTX2Scheduler(SchedulerProtocol):
|
| 15 |
+
"""
|
| 16 |
+
Default scheduler for LTX-2 diffusion sampling.
|
| 17 |
+
Generates a sigma schedule with token-count-dependent shifting and optional
|
| 18 |
+
stretching to a terminal value.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def execute(
|
| 22 |
+
self,
|
| 23 |
+
steps: int,
|
| 24 |
+
latent: torch.Tensor | None = None,
|
| 25 |
+
max_shift: float = 2.05,
|
| 26 |
+
base_shift: float = 0.95,
|
| 27 |
+
stretch: bool = True,
|
| 28 |
+
terminal: float = 0.1,
|
| 29 |
+
**_kwargs,
|
| 30 |
+
) -> torch.FloatTensor:
|
| 31 |
+
tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR
|
| 32 |
+
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
| 33 |
+
|
| 34 |
+
x1 = BASE_SHIFT_ANCHOR
|
| 35 |
+
x2 = MAX_SHIFT_ANCHOR
|
| 36 |
+
mm = (max_shift - base_shift) / (x2 - x1)
|
| 37 |
+
b = base_shift - mm * x1
|
| 38 |
+
sigma_shift = (tokens) * mm + b
|
| 39 |
+
|
| 40 |
+
power = 1
|
| 41 |
+
sigmas = torch.where(
|
| 42 |
+
sigmas != 0,
|
| 43 |
+
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
| 44 |
+
0,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Stretch sigmas so that its final value matches the given terminal value.
|
| 48 |
+
if stretch:
|
| 49 |
+
non_zero_mask = sigmas != 0
|
| 50 |
+
non_zero_sigmas = sigmas[non_zero_mask]
|
| 51 |
+
one_minus_z = 1.0 - non_zero_sigmas
|
| 52 |
+
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
| 53 |
+
stretched = 1.0 - (one_minus_z / scale_factor)
|
| 54 |
+
sigmas[non_zero_mask] = stretched
|
| 55 |
+
|
| 56 |
+
return sigmas.to(torch.float32)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LinearQuadraticScheduler(SchedulerProtocol):
|
| 60 |
+
"""
|
| 61 |
+
Scheduler with linear steps followed by quadratic steps.
|
| 62 |
+
Produces a sigma schedule that transitions linearly up to a threshold,
|
| 63 |
+
then follows a quadratic curve for the remaining steps.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def execute(
|
| 67 |
+
self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
|
| 68 |
+
) -> torch.FloatTensor:
|
| 69 |
+
if steps == 1:
|
| 70 |
+
return torch.FloatTensor([1.0, 0.0])
|
| 71 |
+
|
| 72 |
+
if linear_steps is None:
|
| 73 |
+
linear_steps = steps // 2
|
| 74 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
| 75 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
| 76 |
+
quadratic_steps = steps - linear_steps
|
| 77 |
+
quadratic_sigma_schedule = []
|
| 78 |
+
if quadratic_steps > 0:
|
| 79 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
| 80 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
| 81 |
+
const = quadratic_coef * (linear_steps**2)
|
| 82 |
+
quadratic_sigma_schedule = [
|
| 83 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
|
| 84 |
+
]
|
| 85 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
| 86 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
| 87 |
+
return torch.FloatTensor(sigma_schedule)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class BetaScheduler(SchedulerProtocol):
|
| 91 |
+
"""
|
| 92 |
+
Scheduler using a beta distribution to sample timesteps.
|
| 93 |
+
Based on: https://arxiv.org/abs/2407.12173
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
shift = 2.37
|
| 97 |
+
timesteps_length = 10000
|
| 98 |
+
|
| 99 |
+
def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
|
| 100 |
+
"""
|
| 101 |
+
Execute the beta scheduler.
|
| 102 |
+
Args:
|
| 103 |
+
steps: The number of steps to execute the scheduler for.
|
| 104 |
+
alpha: The alpha parameter for the beta distribution.
|
| 105 |
+
beta: The beta parameter for the beta distribution.
|
| 106 |
+
Warnings:
|
| 107 |
+
The number of steps within `sigmas` theoretically might be less than `steps+1`,
|
| 108 |
+
because of the deduplication of the identical timesteps
|
| 109 |
+
Returns:
|
| 110 |
+
A tensor of sigmas.
|
| 111 |
+
"""
|
| 112 |
+
model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
|
| 113 |
+
total_timesteps = len(model_sampling_sigmas) - 1
|
| 114 |
+
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
| 115 |
+
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
|
| 116 |
+
ts = list(dict.fromkeys(ts))
|
| 117 |
+
|
| 118 |
+
sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
|
| 119 |
+
return torch.FloatTensor(sigmas)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@lru_cache(maxsize=5)
|
| 123 |
+
def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
|
| 124 |
+
timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
|
| 125 |
+
return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def flux_time_shift(mu: float, sigma: float, t: float) -> float:
|
| 129 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|