Fabrice-TIERCELIN commited on
Commit
67a6ee7
·
verified ·
1 Parent(s): 45c8dfa

Upload 6 files

Browse files
packages/ltx-core/src/ltx_core/components/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diffusion pipeline components.
3
+ Submodules:
4
+ diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
5
+ guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
6
+ noisers - Noise samplers (GaussianNoiser)
7
+ patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
8
+ protocols - Protocol definitions (Patchifier, etc.)
9
+ schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
10
+ """
packages/ltx-core/src/ltx_core/components/guiders.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.protocols import GuiderProtocol
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class CFGGuider(GuiderProtocol):
10
+ """
11
+ Classifier-free guidance (CFG) guider.
12
+ Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
13
+ denoising process toward the conditioned prediction.
14
+ Attributes:
15
+ scale: Guidance strength. 1.0 means no guidance, higher values increase
16
+ adherence to the conditioning.
17
+ """
18
+
19
+ scale: float
20
+
21
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
22
+ return (self.scale - 1) * (cond - uncond)
23
+
24
+ def enabled(self) -> bool:
25
+ return self.scale != 1.0
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class CFGStarRescalingGuider(GuiderProtocol):
30
+ """
31
+ Calculates the CFG delta between conditioned and unconditioned samples.
32
+ To minimize offset in the denoising direction and move mostly along the
33
+ conditioning axis within the distribution, the unconditioned sample is
34
+ rescaled in accordance with the norm of the conditioned sample.
35
+ Attributes:
36
+ scale (float):
37
+ Global guidance strength. A value of 1.0 corresponds to no extra
38
+ guidance beyond the base model prediction. Values > 1.0 increase
39
+ the influence of the conditioned sample relative to the
40
+ unconditioned one.
41
+ """
42
+
43
+ scale: float
44
+
45
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
46
+ rescaled_neg = projection_coef(cond, uncond) * uncond
47
+ return (self.scale - 1) * (cond - rescaled_neg)
48
+
49
+ def enabled(self) -> bool:
50
+ return self.scale != 1.0
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class STGGuider(GuiderProtocol):
55
+ """
56
+ Calculates the STG delta between conditioned and perturbed denoised samples.
57
+ Perturbed samples are the result of the denoising process with perturbations,
58
+ e.g. attentions acting as passthrough for certain layers and modalities.
59
+ Attributes:
60
+ scale (float):
61
+ Global strength of the STG guidance. A value of 0.0 disables the
62
+ guidance. Larger values increase the correction applied in the
63
+ direction of (pos_denoised - perturbed_denoised).
64
+ """
65
+
66
+ scale: float
67
+
68
+ def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
69
+ return self.scale * (pos_denoised - perturbed_denoised)
70
+
71
+ def enabled(self) -> bool:
72
+ return self.scale != 0.0
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class LtxAPGGuider(GuiderProtocol):
77
+ """
78
+ Calculates the APG (adaptive projected guidance) delta between conditioned
79
+ and unconditioned samples.
80
+ To minimize offset in the denoising direction and move mostly along the
81
+ conditioning axis within the distribution, the (cond - uncond) delta is
82
+ decomposed into components parallel and orthogonal to the conditioned
83
+ sample. The `eta` parameter weights the parallel component, while `scale`
84
+ is applied to the orthogonal component. Optionally, a norm threshold can
85
+ be used to suppress guidance when the magnitude of the correction is small.
86
+ Attributes:
87
+ scale (float):
88
+ Strength applied to the component of the guidance that is orthogonal
89
+ to the conditioned sample. Controls how aggressively we move in
90
+ directions that change semantics but stay consistent with the
91
+ conditioning manifold.
92
+ eta (float):
93
+ Weight of the component of the guidance that is parallel to the
94
+ conditioned sample. A value of 1.0 keeps the full parallel
95
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
96
+ motion along the conditioning direction.
97
+ norm_threshold (float):
98
+ Minimum L2 norm of the guidance delta below which the guidance
99
+ can be reduced or ignored (depending on implementation).
100
+ This is useful for avoiding noisy or unstable updates when the
101
+ guidance signal is very small.
102
+ """
103
+
104
+ scale: float
105
+ eta: float = 1.0
106
+ norm_threshold: float = 0.0
107
+
108
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
109
+ guidance = cond - uncond
110
+ if self.norm_threshold > 0:
111
+ ones = torch.ones_like(guidance)
112
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
113
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
114
+ guidance = guidance * scale_factor
115
+ proj_coeff = projection_coef(guidance, cond)
116
+ g_parallel = proj_coeff * cond
117
+ g_orth = guidance - g_parallel
118
+ g_apg = g_parallel * self.eta + g_orth
119
+
120
+ return g_apg * (self.scale - 1)
121
+
122
+ def enabled(self) -> bool:
123
+ return self.scale != 1.0
124
+
125
+
126
+ @dataclass(frozen=False)
127
+ class LegacyStatefulAPGGuider(GuiderProtocol):
128
+ """
129
+ Calculates the APG (adaptive projected guidance) delta between conditioned
130
+ and unconditioned samples.
131
+ To minimize offset in the denoising direction and move mostly along the
132
+ conditioning axis within the distribution, the (cond - uncond) delta is
133
+ decomposed into components parallel and orthogonal to the conditioned
134
+ sample. The `eta` parameter weights the parallel component, while `scale`
135
+ is applied to the orthogonal component. Optionally, a norm threshold can
136
+ be used to suppress guidance when the magnitude of the correction is small.
137
+ Attributes:
138
+ scale (float):
139
+ Strength applied to the component of the guidance that is orthogonal
140
+ to the conditioned sample. Controls how aggressively we move in
141
+ directions that change semantics but stay consistent with the
142
+ conditioning manifold.
143
+ eta (float):
144
+ Weight of the component of the guidance that is parallel to the
145
+ conditioned sample. A value of 1.0 keeps the full parallel
146
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
147
+ motion along the conditioning direction.
148
+ norm_threshold (float):
149
+ Minimum L2 norm of the guidance delta below which the guidance
150
+ can be reduced or ignored (depending on implementation).
151
+ This is useful for avoiding noisy or unstable updates when the
152
+ guidance signal is very small.
153
+ momentum (float):
154
+ Exponential moving-average coefficient for accumulating guidance
155
+ over time. running_avg = momentum * running_avg + guidance
156
+ """
157
+
158
+ scale: float
159
+ eta: float
160
+ norm_threshold: float = 5.0
161
+ momentum: float = 0.0
162
+ # it is user's responsibility not to use same APGGuider for several denoisings or different modalities
163
+ # in order not to share accumulated average across different denoisings or modalities
164
+ running_avg: torch.Tensor | None = None
165
+
166
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
167
+ guidance = cond - uncond
168
+ if self.momentum != 0:
169
+ if self.running_avg is None:
170
+ self.running_avg = guidance.clone()
171
+ else:
172
+ self.running_avg = self.momentum * self.running_avg + guidance
173
+ guidance = self.running_avg
174
+
175
+ if self.norm_threshold > 0:
176
+ ones = torch.ones_like(guidance)
177
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
178
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
179
+ guidance = guidance * scale_factor
180
+
181
+ proj_coeff = projection_coef(guidance, cond)
182
+ g_parallel = proj_coeff * cond
183
+ g_orth = guidance - g_parallel
184
+ g_apg = g_parallel * self.eta + g_orth
185
+
186
+ return g_apg * self.scale
187
+
188
+ def enabled(self) -> bool:
189
+ return self.scale != 0.0
190
+
191
+
192
+ def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
193
+ batch_size = to_project.shape[0]
194
+ positive_flat = to_project.reshape(batch_size, -1)
195
+ negative_flat = project_onto.reshape(batch_size, -1)
196
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
197
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
198
+ return dot_product / squared_norm
packages/ltx-core/src/ltx_core/components/noisers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class Noiser(Protocol):
10
+ """Protocol for adding noise to a latent state during diffusion."""
11
+
12
+ def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
13
+
14
+
15
+ class GaussianNoiser(Noiser):
16
+ """Adds Gaussian noise to a latent state, scaled by the denoise mask."""
17
+
18
+ def __init__(self, generator: torch.Generator):
19
+ super().__init__()
20
+
21
+ self.generator = generator
22
+
23
+ def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
24
+ noise = torch.randn(
25
+ *latent_state.latent.shape,
26
+ device=latent_state.latent.device,
27
+ dtype=latent_state.latent.dtype,
28
+ generator=self.generator,
29
+ )
30
+ scaled_mask = latent_state.denoise_mask * noise_scale
31
+ latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
32
+ return replace(
33
+ latent_state,
34
+ latent=latent.to(latent_state.latent.dtype),
35
+ )
packages/ltx-core/src/ltx_core/components/patchifiers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import einops
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import Patchifier
8
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
9
+
10
+
11
+ class VideoLatentPatchifier(Patchifier):
12
+ def __init__(self, patch_size: int):
13
+ # Patch sizes for video latents.
14
+ self._patch_size = (
15
+ 1, # temporal dimension
16
+ patch_size, # height dimension
17
+ patch_size, # width dimension
18
+ )
19
+
20
+ @property
21
+ def patch_size(self) -> Tuple[int, int, int]:
22
+ return self._patch_size
23
+
24
+ def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
25
+ return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
26
+
27
+ def patchify(
28
+ self,
29
+ latents: torch.Tensor,
30
+ ) -> torch.Tensor:
31
+ latents = einops.rearrange(
32
+ latents,
33
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
34
+ p1=self._patch_size[0],
35
+ p2=self._patch_size[1],
36
+ p3=self._patch_size[2],
37
+ )
38
+
39
+ return latents
40
+
41
+ def unpatchify(
42
+ self,
43
+ latents: torch.Tensor,
44
+ output_shape: VideoLatentShape,
45
+ ) -> torch.Tensor:
46
+ assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
47
+
48
+ patch_grid_frames = output_shape.frames // self._patch_size[0]
49
+ patch_grid_height = output_shape.height // self._patch_size[1]
50
+ patch_grid_width = output_shape.width // self._patch_size[2]
51
+
52
+ latents = einops.rearrange(
53
+ latents,
54
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
55
+ f=patch_grid_frames,
56
+ h=patch_grid_height,
57
+ w=patch_grid_width,
58
+ p=self._patch_size[1],
59
+ q=self._patch_size[2],
60
+ )
61
+
62
+ return latents
63
+
64
+ def get_patch_grid_bounds(
65
+ self,
66
+ output_shape: AudioLatentShape | VideoLatentShape,
67
+ device: Optional[torch.device] = None,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Return the per-dimension bounds [inclusive start, exclusive end) for every
71
+ patch produced by `patchify`. The bounds are expressed in the original
72
+ video grid coordinates: frame/time, height, and width.
73
+ The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
74
+ - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
75
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
76
+ Args:
77
+ output_shape: Video grid description containing frames, height, and width.
78
+ device: Device of the latent tensor.
79
+ """
80
+ if not isinstance(output_shape, VideoLatentShape):
81
+ raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
82
+
83
+ frames = output_shape.frames
84
+ height = output_shape.height
85
+ width = output_shape.width
86
+ batch_size = output_shape.batch
87
+
88
+ # Validate inputs to ensure positive dimensions
89
+ assert frames > 0, f"frames must be positive, got {frames}"
90
+ assert height > 0, f"height must be positive, got {height}"
91
+ assert width > 0, f"width must be positive, got {width}"
92
+ assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
93
+
94
+ # Generate grid coordinates for each dimension (frame, height, width)
95
+ # We use torch.arange to create the starting coordinates for each patch.
96
+ # indexing='ij' ensures the dimensions are in the order (frame, height, width).
97
+ grid_coords = torch.meshgrid(
98
+ torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
99
+ torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
100
+ torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Stack the grid coordinates to create the start coordinates tensor.
105
+ # Shape becomes (3, grid_f, grid_h, grid_w)
106
+ patch_starts = torch.stack(grid_coords, dim=0)
107
+
108
+ # Create a tensor containing the size of a single patch:
109
+ # (frame_patch_size, height_patch_size, width_patch_size).
110
+ # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
111
+ patch_size_delta = torch.tensor(
112
+ self._patch_size,
113
+ device=patch_starts.device,
114
+ dtype=patch_starts.dtype,
115
+ ).view(3, 1, 1, 1)
116
+
117
+ # Calculate end coordinates: start + patch_size
118
+ # Shape becomes (3, grid_f, grid_h, grid_w)
119
+ patch_ends = patch_starts + patch_size_delta
120
+
121
+ # Stack start and end coordinates together along the last dimension
122
+ # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
123
+ latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
124
+
125
+ # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
126
+ # Final Shape: (batch_size, 3, num_patches, 2)
127
+ latent_coords = einops.repeat(
128
+ latent_coords,
129
+ "c f h w bounds -> b c (f h w) bounds",
130
+ b=batch_size,
131
+ bounds=2,
132
+ )
133
+
134
+ return latent_coords
135
+
136
+
137
+ def get_pixel_coords(
138
+ latent_coords: torch.Tensor,
139
+ scale_factors: SpatioTemporalScaleFactors,
140
+ causal_fix: bool = False,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
144
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
145
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
146
+ Args:
147
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
148
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
149
+ per axis.
150
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
151
+ that treat frame zero differently still yield non-negative timestamps.
152
+ """
153
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
154
+ broadcast_shape = [1] * latent_coords.ndim
155
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
156
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
157
+
158
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
159
+ pixel_coords = latent_coords * scale_tensor
160
+
161
+ if causal_fix:
162
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
163
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
164
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
165
+
166
+ return pixel_coords
167
+
168
+
169
+ class AudioPatchifier(Patchifier):
170
+ def __init__(
171
+ self,
172
+ patch_size: int,
173
+ sample_rate: int = 16000,
174
+ hop_length: int = 160,
175
+ audio_latent_downsample_factor: int = 4,
176
+ is_causal: bool = True,
177
+ shift: int = 0,
178
+ ):
179
+ """
180
+ Patchifier tailored for spectrogram/audio latents.
181
+ Args:
182
+ patch_size: Number of mel bins combined into a single patch. This
183
+ controls the resolution along the frequency axis.
184
+ sample_rate: Original waveform sampling rate. Used to map latent
185
+ indices back to seconds so downstream consumers can align audio
186
+ and video cues.
187
+ hop_length: Window hop length used for the spectrogram. Determines
188
+ how many real-time samples separate two consecutive latent frames.
189
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
190
+ latent frames; compensates for additional downsampling inside the
191
+ VAE encoder.
192
+ is_causal: When True, timing is shifted to account for causal
193
+ receptive fields so timestamps do not peek into the future.
194
+ shift: Integer offset applied to the latent indices. Enables
195
+ constructing overlapping windows from the same latent sequence.
196
+ """
197
+ self.hop_length = hop_length
198
+ self.sample_rate = sample_rate
199
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
200
+ self.is_causal = is_causal
201
+ self.shift = shift
202
+ self._patch_size = (1, patch_size, patch_size)
203
+
204
+ @property
205
+ def patch_size(self) -> Tuple[int, int, int]:
206
+ return self._patch_size
207
+
208
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
209
+ return tgt_shape.frames
210
+
211
+ def _get_audio_latent_time_in_sec(
212
+ self,
213
+ start_latent: int,
214
+ end_latent: int,
215
+ dtype: torch.dtype,
216
+ device: Optional[torch.device] = None,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Converts latent indices into real-time seconds while honoring causal
220
+ offsets and the configured hop length.
221
+ Args:
222
+ start_latent: Inclusive start index inside the latent sequence. This
223
+ sets the first timestamp returned.
224
+ end_latent: Exclusive end index. Determines how many timestamps get
225
+ generated.
226
+ dtype: Floating-point dtype used for the returned tensor, allowing
227
+ callers to control precision.
228
+ device: Target device for the timestamp tensor. When omitted the
229
+ computation occurs on CPU to avoid surprising GPU allocations.
230
+ """
231
+ if device is None:
232
+ device = torch.device("cpu")
233
+
234
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
235
+
236
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
237
+
238
+ if self.is_causal:
239
+ # Frame offset for causal alignment.
240
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
241
+ causal_offset = 1
242
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
243
+
244
+ return audio_mel_frame * self.hop_length / self.sample_rate
245
+
246
+ def _compute_audio_timings(
247
+ self,
248
+ batch_size: int,
249
+ num_steps: int,
250
+ device: Optional[torch.device] = None,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
254
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
255
+ Args:
256
+ batch_size: Number of sequences to broadcast the timings over.
257
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
258
+ device: Device on which the resulting tensor should reside.
259
+ """
260
+ resolved_device = device
261
+ if resolved_device is None:
262
+ resolved_device = torch.device("cpu")
263
+
264
+ start_timings = self._get_audio_latent_time_in_sec(
265
+ self.shift,
266
+ num_steps + self.shift,
267
+ torch.float32,
268
+ resolved_device,
269
+ )
270
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
271
+
272
+ end_timings = self._get_audio_latent_time_in_sec(
273
+ self.shift + 1,
274
+ num_steps + self.shift + 1,
275
+ torch.float32,
276
+ resolved_device,
277
+ )
278
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
279
+
280
+ return torch.stack([start_timings, end_timings], dim=-1)
281
+
282
+ def patchify(
283
+ self,
284
+ audio_latents: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ """
287
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
288
+ to derive timestamps for each latent frame based on the configured hop
289
+ length and downsampling.
290
+ Args:
291
+ audio_latents: Latent tensor to patchify.
292
+ Returns:
293
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
294
+ corresponding timing metadata when needed.
295
+ """
296
+ audio_latents = einops.rearrange(
297
+ audio_latents,
298
+ "b c t f -> b t (c f)",
299
+ )
300
+
301
+ return audio_latents
302
+
303
+ def unpatchify(
304
+ self,
305
+ audio_latents: torch.Tensor,
306
+ output_shape: AudioLatentShape,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
310
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
311
+ frame's position in real time.
312
+ Args:
313
+ audio_latents: Latent tensor to unpatchify.
314
+ output_shape: Shape of the unpatched output tensor.
315
+ Returns:
316
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
317
+ metadata associated with the restored latents.
318
+ """
319
+ # audio_latents shape: (batch, time, freq * channels)
320
+ audio_latents = einops.rearrange(
321
+ audio_latents,
322
+ "b t (c f) -> b c t f",
323
+ c=output_shape.channels,
324
+ f=output_shape.mel_bins,
325
+ )
326
+
327
+ return audio_latents
328
+
329
+ def get_patch_grid_bounds(
330
+ self,
331
+ output_shape: AudioLatentShape | VideoLatentShape,
332
+ device: Optional[torch.device] = None,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
336
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
337
+ seconds aligned with the original spectrogram grid.
338
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
339
+ - axis 1 (size 1) represents the temporal dimension
340
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
341
+ Args:
342
+ output_shape: Audio grid specification describing the number of time steps.
343
+ device: Target device for the returned tensor.
344
+ """
345
+ if not isinstance(output_shape, AudioLatentShape):
346
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
347
+
348
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
packages/ltx-core/src/ltx_core/components/protocols.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.types import AudioLatentShape, VideoLatentShape
6
+
7
+
8
+ class Patchifier(Protocol):
9
+ """
10
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
11
+ """
12
+
13
+ def patchify(
14
+ self,
15
+ latents: torch.Tensor,
16
+ ) -> torch.Tensor:
17
+ ...
18
+ """
19
+ Convert latent tensors into flattened patch tokens.
20
+ Args:
21
+ latents: Latent tensor to patchify.
22
+ Returns:
23
+ Flattened patch tokens tensor.
24
+ """
25
+
26
+ def unpatchify(
27
+ self,
28
+ latents: torch.Tensor,
29
+ output_shape: AudioLatentShape | VideoLatentShape,
30
+ ) -> torch.Tensor:
31
+ """
32
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
33
+ Args:
34
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
35
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
36
+ VideoLatentShape.
37
+ Returns:
38
+ Dense latent tensor restored from the flattened representation.
39
+ """
40
+
41
+ @property
42
+ def patch_size(self) -> Tuple[int, int, int]:
43
+ ...
44
+ """
45
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
46
+ """
47
+
48
+ def get_patch_grid_bounds(
49
+ self,
50
+ output_shape: AudioLatentShape | VideoLatentShape,
51
+ device: torch.device | None = None,
52
+ ) -> torch.Tensor:
53
+ ...
54
+ """
55
+ Compute metadata describing where each latent patch resides within the
56
+ grid specified by `output_shape`.
57
+ Args:
58
+ output_shape: Target grid layout for the patches.
59
+ device: Target device for the returned tensor.
60
+ Returns:
61
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
62
+ """
63
+
64
+
65
+ class SchedulerProtocol(Protocol):
66
+ """
67
+ Protocol for schedulers that provide a sigmas schedule tensor for a
68
+ given number of steps. Device is cpu.
69
+ """
70
+
71
+ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
72
+
73
+
74
+ class GuiderProtocol(Protocol):
75
+ """
76
+ Protocol for guiders that compute a delta tensor given conditioning inputs.
77
+ The returned delta should be added to the conditional output (cond), enabling
78
+ multiple guiders to be chained together by accumulating their deltas.
79
+ """
80
+
81
+ scale: float
82
+
83
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
84
+
85
+ def enabled(self) -> bool:
86
+ """
87
+ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
88
+ is 1.0.
89
+ """
90
+ ...
91
+
92
+
93
+ class DiffusionStepProtocol(Protocol):
94
+ """
95
+ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
96
+ current denoised sample tensor, and sigmas tensor.
97
+ """
98
+
99
+ def step(
100
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
101
+ ) -> torch.Tensor: ...
packages/ltx-core/src/ltx_core/components/schedulers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+
4
+ import numpy
5
+ import scipy
6
+ import torch
7
+
8
+ from ltx_core.components.protocols import SchedulerProtocol
9
+
10
+ BASE_SHIFT_ANCHOR = 1024
11
+ MAX_SHIFT_ANCHOR = 4096
12
+
13
+
14
+ class LTX2Scheduler(SchedulerProtocol):
15
+ """
16
+ Default scheduler for LTX-2 diffusion sampling.
17
+ Generates a sigma schedule with token-count-dependent shifting and optional
18
+ stretching to a terminal value.
19
+ """
20
+
21
+ def execute(
22
+ self,
23
+ steps: int,
24
+ latent: torch.Tensor | None = None,
25
+ max_shift: float = 2.05,
26
+ base_shift: float = 0.95,
27
+ stretch: bool = True,
28
+ terminal: float = 0.1,
29
+ **_kwargs,
30
+ ) -> torch.FloatTensor:
31
+ tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR
32
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
33
+
34
+ x1 = BASE_SHIFT_ANCHOR
35
+ x2 = MAX_SHIFT_ANCHOR
36
+ mm = (max_shift - base_shift) / (x2 - x1)
37
+ b = base_shift - mm * x1
38
+ sigma_shift = (tokens) * mm + b
39
+
40
+ power = 1
41
+ sigmas = torch.where(
42
+ sigmas != 0,
43
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
44
+ 0,
45
+ )
46
+
47
+ # Stretch sigmas so that its final value matches the given terminal value.
48
+ if stretch:
49
+ non_zero_mask = sigmas != 0
50
+ non_zero_sigmas = sigmas[non_zero_mask]
51
+ one_minus_z = 1.0 - non_zero_sigmas
52
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
53
+ stretched = 1.0 - (one_minus_z / scale_factor)
54
+ sigmas[non_zero_mask] = stretched
55
+
56
+ return sigmas.to(torch.float32)
57
+
58
+
59
+ class LinearQuadraticScheduler(SchedulerProtocol):
60
+ """
61
+ Scheduler with linear steps followed by quadratic steps.
62
+ Produces a sigma schedule that transitions linearly up to a threshold,
63
+ then follows a quadratic curve for the remaining steps.
64
+ """
65
+
66
+ def execute(
67
+ self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
68
+ ) -> torch.FloatTensor:
69
+ if steps == 1:
70
+ return torch.FloatTensor([1.0, 0.0])
71
+
72
+ if linear_steps is None:
73
+ linear_steps = steps // 2
74
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
75
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
76
+ quadratic_steps = steps - linear_steps
77
+ quadratic_sigma_schedule = []
78
+ if quadratic_steps > 0:
79
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
80
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
81
+ const = quadratic_coef * (linear_steps**2)
82
+ quadratic_sigma_schedule = [
83
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
84
+ ]
85
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
86
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
87
+ return torch.FloatTensor(sigma_schedule)
88
+
89
+
90
+ class BetaScheduler(SchedulerProtocol):
91
+ """
92
+ Scheduler using a beta distribution to sample timesteps.
93
+ Based on: https://arxiv.org/abs/2407.12173
94
+ """
95
+
96
+ shift = 2.37
97
+ timesteps_length = 10000
98
+
99
+ def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
100
+ """
101
+ Execute the beta scheduler.
102
+ Args:
103
+ steps: The number of steps to execute the scheduler for.
104
+ alpha: The alpha parameter for the beta distribution.
105
+ beta: The beta parameter for the beta distribution.
106
+ Warnings:
107
+ The number of steps within `sigmas` theoretically might be less than `steps+1`,
108
+ because of the deduplication of the identical timesteps
109
+ Returns:
110
+ A tensor of sigmas.
111
+ """
112
+ model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
113
+ total_timesteps = len(model_sampling_sigmas) - 1
114
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
115
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
116
+ ts = list(dict.fromkeys(ts))
117
+
118
+ sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
119
+ return torch.FloatTensor(sigmas)
120
+
121
+
122
+ @lru_cache(maxsize=5)
123
+ def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
124
+ timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
125
+ return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
126
+
127
+
128
+ def flux_time_shift(mu: float, sigma: float, t: float) -> float:
129
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)