Fabrice-TIERCELIN commited on
Commit
4f5fe16
·
verified ·
1 Parent(s): 2d24b22

Upload 6 files

Browse files
packages/ltx-pipelines/src/ltx_pipelines/utils/args.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
5
+ from ltx_pipelines.utils.constants import (
6
+ DEFAULT_1_STAGE_HEIGHT,
7
+ DEFAULT_1_STAGE_WIDTH,
8
+ DEFAULT_2_STAGE_HEIGHT,
9
+ DEFAULT_2_STAGE_WIDTH,
10
+ DEFAULT_CFG_GUIDANCE_SCALE,
11
+ DEFAULT_FRAME_RATE,
12
+ DEFAULT_LORA_STRENGTH,
13
+ DEFAULT_NEGATIVE_PROMPT,
14
+ DEFAULT_NUM_FRAMES,
15
+ DEFAULT_NUM_INFERENCE_STEPS,
16
+ DEFAULT_SEED,
17
+ )
18
+
19
+
20
+ class VideoConditioningAction(argparse.Action):
21
+ def __call__(
22
+ self,
23
+ parser: argparse.ArgumentParser, # noqa: ARG002
24
+ namespace: argparse.Namespace,
25
+ values: list[str],
26
+ option_string: str | None = None, # noqa: ARG002
27
+ ) -> None:
28
+ path, strength_str = values
29
+ resolved_path = resolve_path(path)
30
+ strength = float(strength_str)
31
+ current = getattr(namespace, self.dest) or []
32
+ current.append((resolved_path, strength))
33
+ setattr(namespace, self.dest, current)
34
+
35
+
36
+ class ImageAction(argparse.Action):
37
+ def __call__(
38
+ self,
39
+ parser: argparse.ArgumentParser, # noqa: ARG002
40
+ namespace: argparse.Namespace,
41
+ values: list[str],
42
+ option_string: str | None = None, # noqa: ARG002
43
+ ) -> None:
44
+ path, frame_idx, strength_str = values
45
+ resolved_path = resolve_path(path)
46
+ frame_idx = int(frame_idx)
47
+ strength = float(strength_str)
48
+ current = getattr(namespace, self.dest) or []
49
+ current.append((resolved_path, frame_idx, strength))
50
+ setattr(namespace, self.dest, current)
51
+
52
+
53
+ class LoraAction(argparse.Action):
54
+ def __call__(
55
+ self,
56
+ parser: argparse.ArgumentParser, # noqa: ARG002
57
+ namespace: argparse.Namespace,
58
+ values: list[str],
59
+ option_string: str | None = None,
60
+ ) -> None:
61
+ if len(values) > 2:
62
+ msg = f"{option_string} accepts at most 2 arguments (PATH and optional STRENGTH), got {len(values)} values"
63
+ raise argparse.ArgumentError(self, msg)
64
+
65
+ path = values[0]
66
+ strength_str = values[1] if len(values) > 1 else str(DEFAULT_LORA_STRENGTH)
67
+
68
+ resolved_path = resolve_path(path)
69
+ strength = float(strength_str)
70
+
71
+ current = getattr(namespace, self.dest) or []
72
+ current.append(LoraPathStrengthAndSDOps(resolved_path, strength, LTXV_LORA_COMFY_RENAMING_MAP))
73
+ setattr(namespace, self.dest, current)
74
+
75
+
76
+ def resolve_path(path: str) -> str:
77
+ return str(Path(path).expanduser().resolve().as_posix())
78
+
79
+
80
+ def basic_arg_parser() -> argparse.ArgumentParser:
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument(
83
+ "--checkpoint-path",
84
+ type=resolve_path,
85
+ required=True,
86
+ help="Path to LTX-2 model checkpoint (.safetensors file).",
87
+ )
88
+ parser.add_argument(
89
+ "--gemma-root",
90
+ type=resolve_path,
91
+ required=True,
92
+ help="Path to the root directory containing the Gemma text encoder model files.",
93
+ )
94
+ parser.add_argument(
95
+ "--prompt",
96
+ type=str,
97
+ required=True,
98
+ help="Text prompt describing the desired video content to be generated by the model.",
99
+ )
100
+ parser.add_argument(
101
+ "--output-path",
102
+ type=resolve_path,
103
+ required=True,
104
+ help="Path to the output video file (MP4 format).",
105
+ )
106
+ parser.add_argument(
107
+ "--seed",
108
+ type=int,
109
+ default=DEFAULT_SEED,
110
+ help=(
111
+ f"Random seed value used to initialize the noise tensor for "
112
+ f"reproducible generation (default: {DEFAULT_SEED})."
113
+ ),
114
+ )
115
+ parser.add_argument(
116
+ "--height",
117
+ type=int,
118
+ default=DEFAULT_1_STAGE_HEIGHT,
119
+ help=f"Height of the generated video in pixels, should be divisible by 32 (default: {DEFAULT_1_STAGE_HEIGHT}).",
120
+ )
121
+ parser.add_argument(
122
+ "--width",
123
+ type=int,
124
+ default=DEFAULT_1_STAGE_WIDTH,
125
+ help=f"Width of the generated video in pixels, should be divisible by 32 (default: {DEFAULT_1_STAGE_WIDTH}).",
126
+ )
127
+ parser.add_argument(
128
+ "--num-frames",
129
+ type=int,
130
+ default=DEFAULT_NUM_FRAMES,
131
+ help=f"Number of frames to generate in the output video sequence, num-frames = (8 x K) + 1, "
132
+ f"where k is a non-negative integer (default: {DEFAULT_NUM_FRAMES}).",
133
+ )
134
+ parser.add_argument(
135
+ "--frame-rate",
136
+ type=float,
137
+ default=DEFAULT_FRAME_RATE,
138
+ help=f"Frame rate of the generated video (fps) (default: {DEFAULT_FRAME_RATE}).",
139
+ )
140
+ parser.add_argument(
141
+ "--num-inference-steps",
142
+ type=int,
143
+ default=DEFAULT_NUM_INFERENCE_STEPS,
144
+ help=(
145
+ f"Number of denoising steps in the diffusion sampling process. "
146
+ f"Higher values improve quality but increase generation time (default: {DEFAULT_NUM_INFERENCE_STEPS})."
147
+ ),
148
+ )
149
+ parser.add_argument(
150
+ "--image",
151
+ dest="images",
152
+ action=ImageAction,
153
+ nargs=3,
154
+ metavar=("PATH", "FRAME_IDX", "STRENGTH"),
155
+ default=[],
156
+ help=(
157
+ "Image conditioning input: path to image file, target frame index, "
158
+ "and conditioning strength (all three required). Default: empty list [] (no image conditioning). "
159
+ "Can be specified multiple times. Example: --image path/to/image1.jpg 0 0.8 "
160
+ "--image path/to/image2.jpg 160 0.9"
161
+ ),
162
+ )
163
+ parser.add_argument(
164
+ "--lora",
165
+ dest="lora",
166
+ action=LoraAction,
167
+ nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
168
+ metavar=("PATH", "STRENGTH"),
169
+ default=[],
170
+ help=(
171
+ "LoRA (Low-Rank Adaptation) model: path to model file and optional strength "
172
+ f"(default strength: {DEFAULT_LORA_STRENGTH}). Can be specified multiple times. "
173
+ "Example: --lora path/to/lora1.safetensors 0.8 --lora path/to/lora2.safetensors"
174
+ ),
175
+ )
176
+ parser.add_argument(
177
+ "--enable-fp8",
178
+ action="store_true",
179
+ help="Enable FP8 mode to reduce memory footprint by keeping model in lower precision. "
180
+ "Note that calculations are still performed in bfloat16 precision.",
181
+ )
182
+ parser.add_argument("--enhance-prompt", action="store_true")
183
+ return parser
184
+
185
+
186
+ def default_1_stage_arg_parser() -> argparse.ArgumentParser:
187
+ parser = basic_arg_parser()
188
+ parser.add_argument(
189
+ "--cfg-guidance-scale",
190
+ type=float,
191
+ default=DEFAULT_CFG_GUIDANCE_SCALE,
192
+ help=(
193
+ f"Classifier-free guidance (CFG) scale controlling how strongly "
194
+ f"the model adheres to the prompt. Higher values increase prompt "
195
+ f"adherence but may reduce diversity (default: {DEFAULT_CFG_GUIDANCE_SCALE})."
196
+ ),
197
+ )
198
+ parser.add_argument(
199
+ "--negative-prompt",
200
+ type=str,
201
+ default=DEFAULT_NEGATIVE_PROMPT,
202
+ help=(
203
+ "Negative prompt describing what should not appear in the generated video, "
204
+ "used to guide the diffusion process away from unwanted content. "
205
+ "Default: a comprehensive negative prompt covering common artifacts and quality issues."
206
+ ),
207
+ )
208
+
209
+ return parser
210
+
211
+
212
+ def default_2_stage_arg_parser() -> argparse.ArgumentParser:
213
+ parser = default_1_stage_arg_parser()
214
+ parser.set_defaults(height=DEFAULT_2_STAGE_HEIGHT, width=DEFAULT_2_STAGE_WIDTH)
215
+ # Update help text to reflect 2-stage defaults
216
+ for action in parser._actions:
217
+ if "--height" in action.option_strings:
218
+ action.help = (
219
+ f"Height of the generated video in pixels, should be divisible by 64 "
220
+ f"(default: {DEFAULT_2_STAGE_HEIGHT})."
221
+ )
222
+ if "--width" in action.option_strings:
223
+ action.help = (
224
+ f"Width of the generated video in pixels, should be divisible by 64 (default: {DEFAULT_2_STAGE_WIDTH})."
225
+ )
226
+ parser.add_argument(
227
+ "--distilled-lora",
228
+ dest="distilled_lora",
229
+ action=LoraAction,
230
+ nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
231
+ metavar=("PATH", "STRENGTH"),
232
+ required=True,
233
+ help=(
234
+ "Distilled LoRA (Low-Rank Adaptation) model used in the second stage (upscaling and refinement): "
235
+ f"path to model file and optional strength (default strength: {DEFAULT_LORA_STRENGTH}). "
236
+ "The second stage upsamples the video by 2x resolution and refines it using a distilled "
237
+ "denoising schedule (fewer steps, no CFG). The distilled LoRA is specifically trained "
238
+ "for this refinement process to improve quality at higher resolutions. "
239
+ "Example: --distilled-lora path/to/distilled_lora.safetensors 0.8"
240
+ ),
241
+ )
242
+ parser.add_argument(
243
+ "--spatial-upsampler-path",
244
+ type=resolve_path,
245
+ required=True,
246
+ help=(
247
+ "Path to the spatial upsampler model used to increase the resolution "
248
+ "of the generated video in the latent space."
249
+ ),
250
+ )
251
+ return parser
252
+
253
+
254
+ def default_2_stage_distilled_arg_parser() -> argparse.ArgumentParser:
255
+ parser = basic_arg_parser()
256
+ parser.set_defaults(height=DEFAULT_2_STAGE_HEIGHT, width=DEFAULT_2_STAGE_WIDTH)
257
+ # Update help text to reflect 2-stage defaults
258
+ for action in parser._actions:
259
+ if "--height" in action.option_strings:
260
+ action.help = (
261
+ f"Height of the generated video in pixels, should be divisible by 64 "
262
+ f"(default: {DEFAULT_2_STAGE_HEIGHT})."
263
+ )
264
+ if "--width" in action.option_strings:
265
+ action.help = (
266
+ f"Width of the generated video in pixels, should be divisible by 64 (default: {DEFAULT_2_STAGE_WIDTH})."
267
+ )
268
+ parser.add_argument(
269
+ "--spatial-upsampler-path",
270
+ type=resolve_path,
271
+ required=True,
272
+ help=(
273
+ "Path to the spatial upsampler model used to increase the resolution "
274
+ "of the generated video in the latent space."
275
+ ),
276
+ )
277
+ return parser
packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Diffusion Schedule
3
+ # =============================================================================
4
+
5
+ # Noise schedule for the distilled pipeline. These sigma values control noise
6
+ # levels at each denoising step and were tuned to match the distillation process.
7
+ from ltx_core.types import SpatioTemporalScaleFactors
8
+
9
+ DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
10
+
11
+ # Reduced schedule for super-resolution stage 2 (subset of distilled values)
12
+ STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
13
+
14
+
15
+ # =============================================================================
16
+ # Video Generation Defaults
17
+ # =============================================================================
18
+
19
+ DEFAULT_SEED = 10
20
+ DEFAULT_1_STAGE_HEIGHT = 512
21
+ DEFAULT_1_STAGE_WIDTH = 768
22
+ DEFAULT_2_STAGE_HEIGHT = DEFAULT_1_STAGE_HEIGHT * 2
23
+ DEFAULT_2_STAGE_WIDTH = DEFAULT_1_STAGE_WIDTH * 2
24
+ DEFAULT_NUM_FRAMES = 121
25
+ DEFAULT_FRAME_RATE = 24.0
26
+ DEFAULT_NUM_INFERENCE_STEPS = 40
27
+ DEFAULT_CFG_GUIDANCE_SCALE = 4.0
28
+
29
+
30
+ # =============================================================================
31
+ # Audio
32
+ # =============================================================================
33
+
34
+ AUDIO_SAMPLE_RATE = 24000
35
+
36
+
37
+ # =============================================================================
38
+ # LoRA
39
+ # =============================================================================
40
+
41
+ DEFAULT_LORA_STRENGTH = 1.0
42
+
43
+
44
+ # =============================================================================
45
+ # Video VAE Architecture
46
+ # =============================================================================
47
+
48
+ VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
49
+ VIDEO_LATENT_CHANNELS = 128
50
+
51
+
52
+ # =============================================================================
53
+ # Image Preprocessing
54
+ # =============================================================================
55
+
56
+ # CRF (Constant Rate Factor) for H.264 encoding used in image conditioning.
57
+ # Lower = higher quality, 0 = lossless. This mimics compression artifacts.
58
+ DEFAULT_IMAGE_CRF = 33
59
+
60
+
61
+ # =============================================================================
62
+ # Prompts
63
+ # =============================================================================
64
+
65
+ DEFAULT_NEGATIVE_PROMPT = (
66
+ "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
67
+ "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
68
+ "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
69
+ "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
70
+ "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
71
+ "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
72
+ "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
73
+ "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
74
+ "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
75
+ "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
76
+ "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
77
+ )
packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from dataclasses import replace
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from ltx_core.components.noisers import Noiser
9
+ from ltx_core.components.protocols import DiffusionStepProtocol, GuiderProtocol
10
+ from ltx_core.conditioning import (
11
+ ConditioningItem,
12
+ VideoConditionByKeyframeIndex,
13
+ VideoConditionByLatentIndex,
14
+ )
15
+ from ltx_core.model.transformer import Modality, X0Model
16
+ from ltx_core.model.video_vae import VideoEncoder
17
+ from ltx_core.text_encoders.gemma import GemmaTextEncoderModelBase
18
+ from ltx_core.tools import AudioLatentTools, LatentTools, VideoLatentTools
19
+ from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape
20
+ from ltx_core.utils import to_denoised, to_velocity
21
+ from ltx_pipelines.utils.media_io import decode_image, load_image_conditioning, resize_aspect_ratio_preserving
22
+ from ltx_pipelines.utils.types import (
23
+ DenoisingFunc,
24
+ DenoisingLoopFunc,
25
+ PipelineComponents,
26
+ )
27
+
28
+
29
+ def get_device() -> torch.device:
30
+ if torch.cuda.is_available():
31
+ return torch.device("cuda")
32
+ return torch.device("cpu")
33
+
34
+
35
+ def cleanup_memory() -> None:
36
+ gc.collect()
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.synchronize()
39
+
40
+
41
+ def image_conditionings_by_replacing_latent(
42
+ images: list[tuple[str, int, float]],
43
+ height: int,
44
+ width: int,
45
+ video_encoder: VideoEncoder,
46
+ dtype: torch.dtype,
47
+ device: torch.device,
48
+ ) -> list[ConditioningItem]:
49
+ conditionings = []
50
+ for image_path, frame_idx, strength in images:
51
+ image = load_image_conditioning(
52
+ image_path=image_path,
53
+ height=height,
54
+ width=width,
55
+ dtype=dtype,
56
+ device=device,
57
+ )
58
+ encoded_image = video_encoder(image)
59
+ conditionings.append(
60
+ VideoConditionByLatentIndex(
61
+ latent=encoded_image,
62
+ strength=strength,
63
+ latent_idx=frame_idx,
64
+ )
65
+ )
66
+
67
+ return conditionings
68
+
69
+
70
+ def image_conditionings_by_adding_guiding_latent(
71
+ images: list[tuple[str, int, float]],
72
+ height: int,
73
+ width: int,
74
+ video_encoder: VideoEncoder,
75
+ dtype: torch.dtype,
76
+ device: torch.device,
77
+ ) -> list[ConditioningItem]:
78
+ conditionings = []
79
+ for image_path, frame_idx, strength in images:
80
+ image = load_image_conditioning(
81
+ image_path=image_path,
82
+ height=height,
83
+ width=width,
84
+ dtype=dtype,
85
+ device=device,
86
+ )
87
+ encoded_image = video_encoder(image)
88
+ conditionings.append(
89
+ VideoConditionByKeyframeIndex(keyframes=encoded_image, frame_idx=frame_idx, strength=strength)
90
+ )
91
+ return conditionings
92
+
93
+
94
+ def euler_denoising_loop(
95
+ sigmas: torch.Tensor,
96
+ video_state: LatentState,
97
+ audio_state: LatentState,
98
+ stepper: DiffusionStepProtocol,
99
+ denoise_fn: DenoisingFunc,
100
+ ) -> tuple[LatentState, LatentState]:
101
+ """
102
+ Perform the joint audio-video denoising loop over a diffusion schedule.
103
+ This function iterates over all but the final value in ``sigmas`` and, at
104
+ each diffusion step, calls ``denoise_fn`` to obtain denoised video and
105
+ audio latents. The denoised latents are post-processed with their
106
+ respective denoise masks and clean latents, then passed to ``stepper`` to
107
+ advance the noisy latents one step along the diffusion schedule.
108
+ ### Parameters
109
+ sigmas:
110
+ A 1D tensor of noise levels (diffusion sigmas) defining the sampling
111
+ schedule. All steps except the last element are iterated over.
112
+ video_state:
113
+ The current video :class:`LatentState`, containing the noisy latent,
114
+ its clean reference latent, and the denoising mask.
115
+ audio_state:
116
+ The current audio :class:`LatentState`, analogous to ``video_state``
117
+ but for the audio modality.
118
+ stepper:
119
+ An implementation of :class:`DiffusionStepProtocol` that updates a
120
+ latent given the current latent, its denoised estimate, the full
121
+ ``sigmas`` schedule, and the current step index.
122
+ denoise_fn:
123
+ A callable implementing :class:`DenoisingFunc`. It is invoked as
124
+ ``denoise_fn(video_state, audio_state, sigmas, step_index)`` and must
125
+ return a tuple ``(denoised_video, denoised_audio)``, where each element
126
+ is a tensor with the same shape as the corresponding latent.
127
+ ### Returns
128
+ tuple[LatentState, LatentState]
129
+ A pair ``(video_state, audio_state)`` containing the final video and
130
+ audio latent states after completing the denoising loop.
131
+ """
132
+ for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
133
+ denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx)
134
+
135
+ denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent)
136
+ denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent)
137
+
138
+ video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx))
139
+ audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx))
140
+
141
+ return (video_state, audio_state)
142
+
143
+
144
+ def gradient_estimating_euler_denoising_loop(
145
+ sigmas: torch.Tensor,
146
+ video_state: LatentState,
147
+ audio_state: LatentState,
148
+ stepper: DiffusionStepProtocol,
149
+ denoise_fn: DenoisingFunc,
150
+ ge_gamma: float = 2.0,
151
+ ) -> tuple[LatentState, LatentState]:
152
+ """
153
+ Perform the joint audio-video denoising loop using gradient-estimation sampling.
154
+ This function is similar to :func:`euler_denoising_loop`, but applies
155
+ gradient estimation to improve the denoised estimates by tracking velocity
156
+ changes across steps. See the referenced function for detailed parameter
157
+ documentation.
158
+ ### Parameters
159
+ ge_gamma:
160
+ Gradient estimation coefficient controlling the velocity correction term.
161
+ Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK
162
+ sigmas, video_state, audio_state, stepper, denoise_fn:
163
+ See :func:`euler_denoising_loop` for parameter descriptions.
164
+ ### Returns
165
+ tuple[LatentState, LatentState]
166
+ See :func:`euler_denoising_loop` for return value description.
167
+ """
168
+
169
+ previous_audio_velocity = None
170
+ previous_video_velocity = None
171
+
172
+ def update_velocity_and_sample(
173
+ noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None
174
+ ) -> tuple[torch.Tensor, torch.Tensor]:
175
+ current_velocity = to_velocity(noisy_sample, sigma, denoised_sample)
176
+ if previous_velocity is not None:
177
+ delta_v = current_velocity - previous_velocity
178
+ total_velocity = ge_gamma * delta_v + previous_velocity
179
+ denoised_sample = to_denoised(noisy_sample, total_velocity, sigma)
180
+ return current_velocity, denoised_sample
181
+
182
+ for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
183
+ denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx)
184
+
185
+ denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent)
186
+ denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent)
187
+
188
+ if sigmas[step_idx + 1] == 0:
189
+ return replace(video_state, latent=denoised_video), replace(audio_state, latent=denoised_audio)
190
+
191
+ previous_video_velocity, denoised_video = update_velocity_and_sample(
192
+ video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity
193
+ )
194
+ previous_audio_velocity, denoised_audio = update_velocity_and_sample(
195
+ audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity
196
+ )
197
+
198
+ video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx))
199
+ audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx))
200
+
201
+ return (video_state, audio_state)
202
+
203
+
204
+ def noise_video_state(
205
+ output_shape: VideoPixelShape,
206
+ noiser: Noiser,
207
+ conditionings: list[ConditioningItem],
208
+ components: PipelineComponents,
209
+ dtype: torch.dtype,
210
+ device: torch.device,
211
+ noise_scale: float = 1.0,
212
+ initial_latent: torch.Tensor | None = None,
213
+ ) -> tuple[LatentState, VideoLatentTools]:
214
+ """Initialize and noise a video latent state for the diffusion pipeline.
215
+ Creates a video latent state from the output shape, applies conditionings,
216
+ and adds noise using the provided noiser. Returns the noised state and
217
+ video latent tools for further processing. If initial_latent is provided, it will be used to create the initial
218
+ state, otherwise an empty initial state will be created.
219
+ """
220
+ video_latent_shape = VideoLatentShape.from_pixel_shape(
221
+ shape=output_shape,
222
+ latent_channels=components.video_latent_channels,
223
+ scale_factors=components.video_scale_factors,
224
+ )
225
+ video_tools = VideoLatentTools(components.video_patchifier, video_latent_shape, output_shape.fps)
226
+ video_state = create_noised_state(
227
+ tools=video_tools,
228
+ conditionings=conditionings,
229
+ noiser=noiser,
230
+ dtype=dtype,
231
+ device=device,
232
+ noise_scale=noise_scale,
233
+ initial_latent=initial_latent,
234
+ )
235
+
236
+ return video_state, video_tools
237
+
238
+
239
+ def noise_audio_state(
240
+ output_shape: VideoPixelShape,
241
+ noiser: Noiser,
242
+ conditionings: list[ConditioningItem],
243
+ components: PipelineComponents,
244
+ dtype: torch.dtype,
245
+ device: torch.device,
246
+ noise_scale: float = 1.0,
247
+ initial_latent: torch.Tensor | None = None,
248
+ denoise_mask: torch.Tensor | None = None
249
+ ) -> tuple[LatentState, AudioLatentTools]:
250
+ """Initialize and noise an audio latent state for the diffusion pipeline.
251
+ Creates an audio latent state from the output shape, applies conditionings,
252
+ and adds noise using the provided noiser. Returns the noised state and
253
+ audio latent tools for further processing. If initial_latent is provided, it will be used to create the initial
254
+ state, otherwise an empty initial state will be created.
255
+ """
256
+ audio_latent_shape = AudioLatentShape.from_video_pixel_shape(output_shape)
257
+ audio_tools = AudioLatentTools(components.audio_patchifier, audio_latent_shape)
258
+ audio_state = create_noised_state(
259
+ tools=audio_tools,
260
+ conditionings=conditionings,
261
+ noiser=noiser,
262
+ dtype=dtype,
263
+ device=device,
264
+ noise_scale=noise_scale,
265
+ initial_latent=initial_latent,
266
+ denoise_mask=denoise_mask,
267
+ )
268
+
269
+ return audio_state, audio_tools
270
+
271
+
272
+ def create_noised_state(
273
+ tools: LatentTools,
274
+ conditionings: list[ConditioningItem],
275
+ noiser: Noiser,
276
+ dtype: torch.dtype,
277
+ device: torch.device,
278
+ noise_scale: float = 1.0,
279
+ initial_latent: torch.Tensor | None = None,
280
+ denoise_mask: torch.Tensor | None = None, # <-- add
281
+ ) -> LatentState:
282
+ state = tools.create_initial_state(device, dtype, initial_latent)
283
+ state = state_with_conditionings(state, conditionings, tools)
284
+
285
+ if denoise_mask is not None:
286
+ # Convert any tensor mask into a single scalar (solid mask behavior)
287
+ if isinstance(denoise_mask, torch.Tensor):
288
+ mask_value = float(denoise_mask.mean().item())
289
+ else:
290
+ mask_value = float(denoise_mask)
291
+
292
+ state = replace(
293
+ state,
294
+ clean_latent=state.latent.clone(),
295
+ denoise_mask=torch.full_like(state.denoise_mask, mask_value), # <- matches internal shape
296
+ )
297
+
298
+ state = noiser(state, noise_scale)
299
+
300
+ if denoise_mask is not None:
301
+ m = state.denoise_mask.to(dtype=state.latent.dtype, device=state.latent.device)
302
+ clean = state.clean_latent.to(dtype=state.latent.dtype, device=state.latent.device)
303
+ state = replace(state, latent=state.latent * m + clean * (1 - m))
304
+
305
+ return state
306
+
307
+
308
+
309
+ def state_with_conditionings(
310
+ latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
311
+ ) -> LatentState:
312
+ """Apply a list of conditionings to a latent state.
313
+ Iterates through the conditioning items and applies each one to the latent
314
+ state in sequence. Returns the modified state with all conditionings applied.
315
+ """
316
+ for conditioning in conditioning_items:
317
+ latent_state = conditioning.apply_to(latent_state=latent_state, latent_tools=latent_tools)
318
+
319
+ return latent_state
320
+
321
+
322
+ def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
323
+ """Blend denoised output with clean state based on mask."""
324
+ clean = clean.to(dtype=denoised.dtype)
325
+ denoise_mask = denoise_mask.to(dtype=denoised.dtype)
326
+ return denoised * denoise_mask + clean * (1 - denoise_mask)
327
+
328
+
329
+ def modality_from_latent_state(
330
+ state: LatentState, context: torch.Tensor, sigma: float | torch.Tensor, enabled: bool = True
331
+ ) -> Modality:
332
+ """Create a Modality from a latent state.
333
+ Constructs a Modality object with the latent state's data, timesteps derived
334
+ from the denoise mask and sigma, positions, and the provided context.
335
+ """
336
+ return Modality(
337
+ enabled=enabled,
338
+ latent=state.latent,
339
+ timesteps=timesteps_from_mask(state.denoise_mask, sigma),
340
+ positions=state.positions,
341
+ context=context,
342
+ context_mask=None,
343
+ )
344
+
345
+
346
+ def timesteps_from_mask(denoise_mask: torch.Tensor, sigma: float | torch.Tensor) -> torch.Tensor:
347
+ """Compute timesteps from a denoise mask and sigma value.
348
+ Multiplies the denoise mask by sigma to produce timesteps for each position
349
+ in the latent state. Areas where the mask is 0 will have zero timesteps.
350
+ """
351
+ return denoise_mask * sigma
352
+
353
+
354
+ def simple_denoising_func(
355
+ video_context: torch.Tensor, audio_context: torch.Tensor, transformer: X0Model
356
+ ) -> DenoisingFunc:
357
+ def simple_denoising_step(
358
+ video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
359
+ ) -> tuple[torch.Tensor, torch.Tensor]:
360
+ sigma = sigmas[step_index]
361
+ pos_video = modality_from_latent_state(video_state, video_context, sigma)
362
+ pos_audio = modality_from_latent_state(audio_state, audio_context, sigma)
363
+
364
+ denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None)
365
+ return denoised_video, denoised_audio
366
+
367
+ return simple_denoising_step
368
+
369
+
370
+ def guider_denoising_func(
371
+ guider: GuiderProtocol,
372
+ v_context_p: torch.Tensor,
373
+ v_context_n: torch.Tensor,
374
+ a_context_p: torch.Tensor,
375
+ a_context_n: torch.Tensor,
376
+ transformer: X0Model,
377
+ ) -> DenoisingFunc:
378
+ def guider_denoising_step(
379
+ video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
380
+ ) -> tuple[torch.Tensor, torch.Tensor]:
381
+ sigma = sigmas[step_index]
382
+ pos_video = modality_from_latent_state(video_state, v_context_p, sigma)
383
+ pos_audio = modality_from_latent_state(audio_state, a_context_p, sigma)
384
+
385
+ denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None)
386
+ if guider.enabled():
387
+ neg_video = modality_from_latent_state(video_state, v_context_n, sigma)
388
+ neg_audio = modality_from_latent_state(audio_state, a_context_n, sigma)
389
+
390
+ neg_denoised_video, neg_denoised_audio = transformer(video=neg_video, audio=neg_audio, perturbations=None)
391
+
392
+ denoised_video = denoised_video + guider.delta(denoised_video, neg_denoised_video)
393
+ denoised_audio = denoised_audio + guider.delta(denoised_audio, neg_denoised_audio)
394
+
395
+ return denoised_video, denoised_audio
396
+
397
+ return guider_denoising_step
398
+
399
+
400
+ def denoise_audio_video( # noqa: PLR0913
401
+ output_shape: VideoPixelShape,
402
+ conditionings: list[ConditioningItem],
403
+ noiser: Noiser,
404
+ sigmas: torch.Tensor,
405
+ stepper: DiffusionStepProtocol,
406
+ denoising_loop_fn: DenoisingLoopFunc,
407
+ components: PipelineComponents,
408
+ dtype: torch.dtype,
409
+ device: torch.device,
410
+ audio_conditionings: list[ConditioningItem] | None = None,
411
+ noise_scale: float = 1.0,
412
+ initial_video_latent: torch.Tensor | None = None,
413
+ initial_audio_latent: torch.Tensor | None = None,
414
+ # mask_context: MaskInjection | None = None,
415
+ ) -> tuple[LatentState | None, LatentState | None]:
416
+ video_state, video_tools = noise_video_state(
417
+ output_shape=output_shape,
418
+ noiser=noiser,
419
+ conditionings=conditionings,
420
+ components=components,
421
+ dtype=dtype,
422
+ device=device,
423
+ noise_scale=noise_scale,
424
+ initial_latent=initial_video_latent,
425
+ )
426
+ audio_state, audio_tools = noise_audio_state(
427
+ output_shape=output_shape,
428
+ noiser=noiser,
429
+ conditionings=audio_conditionings or [],
430
+ components=components,
431
+ dtype=dtype,
432
+ device=device,
433
+ noise_scale=noise_scale,
434
+ initial_latent=initial_audio_latent,
435
+ )
436
+
437
+ loop_kwargs = {}
438
+ # if "preview_tools" in inspect.signature(denoising_loop_fn).parameters:
439
+ # loop_kwargs["preview_tools"] = video_tools
440
+ # if "mask_context" in inspect.signature(denoising_loop_fn).parameters:
441
+ # loop_kwargs["mask_context"] = mask_context
442
+ video_state, audio_state = denoising_loop_fn(
443
+ sigmas,
444
+ video_state,
445
+ audio_state,
446
+ stepper,
447
+ **loop_kwargs,
448
+ )
449
+
450
+ if video_state is None or audio_state is None:
451
+ return None, None
452
+
453
+ video_state = video_tools.clear_conditioning(video_state)
454
+ video_state = video_tools.unpatchify(video_state)
455
+ audio_state = audio_tools.clear_conditioning(audio_state)
456
+ audio_state = audio_tools.unpatchify(audio_state)
457
+
458
+ return video_state, audio_state
459
+
460
+
461
+
462
+ _UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
463
+
464
+
465
+ def clean_response(text: str) -> str:
466
+ """Clean a response from curly quotes and leading non-letter characters which Gemma tends to insert."""
467
+ text = text.translate(_UNICODE_REPLACEMENTS)
468
+
469
+ # Remove leading non-letter characters
470
+ for i, char in enumerate(text):
471
+ if char.isalpha():
472
+ return text[i:]
473
+ return text
474
+
475
+
476
+ def generate_enhanced_prompt(
477
+ text_encoder: GemmaTextEncoderModelBase,
478
+ prompt: str,
479
+ image_path: str | None = None,
480
+ image_long_side: int = 896,
481
+ seed: int = 42,
482
+ ) -> str:
483
+ """Generate an enhanced prompt from a text encoder and a prompt."""
484
+ image = None
485
+ if image_path:
486
+ image = decode_image(image_path=image_path)
487
+ image = torch.tensor(image)
488
+ image = resize_aspect_ratio_preserving(image, image_long_side).to(torch.uint8)
489
+ prompt = text_encoder.enhance_i2v(prompt, image, seed=seed)
490
+ else:
491
+ prompt = text_encoder.enhance_t2v(prompt, seed=seed)
492
+ logging.info(f"Enhanced prompt: {prompt}")
493
+ return clean_response(prompt)
494
+
495
+
496
+ def assert_resolution(height: int, width: int, is_two_stage: bool) -> None:
497
+ """Assert that the resolution is divisible by the required divisor.
498
+ For two-stage pipelines, the resolution must be divisible by 64.
499
+ For one-stage pipelines, the resolution must be divisible by 32.
500
+ """
501
+ divisor = 64 if is_two_stage else 32
502
+ if height % divisor != 0 or width % divisor != 0:
503
+ raise ValueError(
504
+ f"Resolution ({height}x{width}) is not divisible by {divisor}. "
505
+ f"For {'two-stage' if is_two_stage else 'one-stage'} pipelines, "
506
+ f"height and width must be multiples of {divisor}."
507
+ )
packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Generator, Iterator
3
+ from fractions import Fraction
4
+ from io import BytesIO
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from einops import rearrange
10
+ from PIL import Image
11
+ from torch._prims_common import DeviceLikeType
12
+ from tqdm import tqdm
13
+
14
+ from ltx_pipelines.utils.constants import DEFAULT_IMAGE_CRF
15
+
16
+
17
+ def resize_aspect_ratio_preserving(image: torch.Tensor, long_side: int) -> torch.Tensor:
18
+ """
19
+ Resize image preserving aspect ratio (filling target long side).
20
+ Preserves the input dimensions order.
21
+ Args:
22
+ image: Input image tensor with shape (F (optional), H, W, C)
23
+ long_side: Target long side size.
24
+ Returns:
25
+ Tensor with shape (F (optional), H, W, C) F = 1 if input is 3D, otherwise input shape[0]
26
+ """
27
+ height, width = image.shape[-3:2]
28
+ max_side = max(height, width)
29
+ scale = long_side / float(max_side)
30
+ target_height = int(height * scale)
31
+ target_width = int(width * scale)
32
+ resized = resize_and_center_crop(image, target_height, target_width)
33
+ # rearrange and remove batch dimension
34
+ result = rearrange(resized, "b c f h w -> b f h w c")[0]
35
+ # preserve input dimensions
36
+ return result[0] if result.shape[0] == 1 else result
37
+
38
+
39
+ def resize_and_center_crop(tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
40
+ """
41
+ Resize tensor preserving aspect ratio (filling target), then center crop to exact dimensions.
42
+ Args:
43
+ latent: Input tensor with shape (H, W, C) or (F, H, W, C)
44
+ height: Target height
45
+ width: Target width
46
+ Returns:
47
+ Tensor with shape (1, C, 1, height, width) for 3D input or (1, C, F, height, width) for 4D input
48
+ """
49
+ if tensor.ndim == 3:
50
+ tensor = rearrange(tensor, "h w c -> 1 c h w")
51
+ elif tensor.ndim == 4:
52
+ tensor = rearrange(tensor, "f h w c -> f c h w")
53
+ else:
54
+ raise ValueError(f"Expected input with 3 or 4 dimensions; got shape {tensor.shape}.")
55
+
56
+ _, _, src_h, src_w = tensor.shape
57
+
58
+ scale = max(height / src_h, width / src_w)
59
+ # Use ceil to avoid floating-point rounding causing new_h/new_w to be
60
+ # slightly smaller than target, which would result in negative crop offsets.
61
+ new_h = math.ceil(src_h * scale)
62
+ new_w = math.ceil(src_w * scale)
63
+
64
+ tensor = torch.nn.functional.interpolate(tensor, size=(new_h, new_w), mode="bilinear", align_corners=False)
65
+
66
+ crop_top = (new_h - height) // 2
67
+ crop_left = (new_w - width) // 2
68
+ tensor = tensor[:, :, crop_top : crop_top + height, crop_left : crop_left + width]
69
+
70
+ tensor = rearrange(tensor, "f c h w -> 1 c f h w")
71
+ return tensor
72
+
73
+
74
+ def normalize_latent(latent: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
75
+ return (latent / 127.5 - 1.0).to(device=device, dtype=dtype)
76
+
77
+
78
+ def load_image_conditioning(
79
+ image_path: str, height: int, width: int, dtype: torch.dtype, device: torch.device
80
+ ) -> torch.Tensor:
81
+ """
82
+ Loads an image from a path and preprocesses it for conditioning.
83
+ Note: The image is resized to the nearest multiple of 2 for compatibility with video codecs.
84
+ """
85
+ image = decode_image(image_path=image_path)
86
+ image = preprocess(image=image)
87
+ image = torch.tensor(image, dtype=torch.float32, device=device)
88
+ image = resize_and_center_crop(image, height, width)
89
+ image = normalize_latent(image, device, dtype)
90
+ return image
91
+
92
+
93
+ def load_video_conditioning(
94
+ video_path: str, height: int, width: int, frame_cap: int, dtype: torch.dtype, device: torch.device
95
+ ) -> torch.Tensor:
96
+ """
97
+ Loads a video from a path and preprocesses it for conditioning.
98
+ Note: The video is resized to the nearest multiple of 2 for compatibility with video codecs.
99
+ """
100
+ frames = decode_video_from_file(path=video_path, frame_cap=frame_cap, device=device)
101
+ result = None
102
+ for f in frames:
103
+ frame = resize_and_center_crop(f.to(torch.float32), height, width)
104
+ frame = normalize_latent(frame, device, dtype)
105
+ result = frame if result is None else torch.cat([result, frame], dim=2)
106
+ return result
107
+
108
+
109
+ def decode_image(image_path: str) -> np.ndarray:
110
+ image = Image.open(image_path)
111
+ np_array = np.array(image)[..., :3]
112
+ return np_array
113
+
114
+
115
+ def _write_audio(
116
+ container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
117
+ ) -> None:
118
+ if samples.ndim == 1:
119
+ samples = samples[:, None]
120
+
121
+ if samples.shape[1] != 2 and samples.shape[0] == 2:
122
+ samples = samples.T
123
+
124
+ if samples.shape[1] != 2:
125
+ raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
126
+
127
+ # Convert to int16 packed for ingestion; resampler converts to encoder fmt.
128
+ if samples.dtype != torch.int16:
129
+ samples = torch.clip(samples, -1.0, 1.0)
130
+ samples = (samples * 32767.0).to(torch.int16)
131
+
132
+ frame_in = av.AudioFrame.from_ndarray(
133
+ samples.contiguous().reshape(1, -1).cpu().numpy(),
134
+ format="s16",
135
+ layout="stereo",
136
+ )
137
+ frame_in.sample_rate = audio_sample_rate
138
+
139
+ _resample_audio(container, audio_stream, frame_in)
140
+
141
+
142
+ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
143
+ """
144
+ Prepare the audio stream for writing.
145
+ """
146
+ audio_stream = container.add_stream("aac", rate=audio_sample_rate)
147
+ audio_stream.codec_context.sample_rate = audio_sample_rate
148
+ audio_stream.codec_context.layout = "stereo"
149
+ audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
150
+ return audio_stream
151
+
152
+
153
+ def _resample_audio(
154
+ container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
155
+ ) -> None:
156
+ cc = audio_stream.codec_context
157
+
158
+ # Use the encoder's format/layout/rate as the *target*
159
+ target_format = cc.format or "fltp" # AAC → usually fltp
160
+ target_layout = cc.layout or "stereo"
161
+ target_rate = cc.sample_rate or frame_in.sample_rate
162
+
163
+ audio_resampler = av.audio.resampler.AudioResampler(
164
+ format=target_format,
165
+ layout=target_layout,
166
+ rate=target_rate,
167
+ )
168
+
169
+ audio_next_pts = 0
170
+ for rframe in audio_resampler.resample(frame_in):
171
+ if rframe.pts is None:
172
+ rframe.pts = audio_next_pts
173
+ audio_next_pts += rframe.samples
174
+ rframe.sample_rate = frame_in.sample_rate
175
+ container.mux(audio_stream.encode(rframe))
176
+
177
+ # flush audio encoder
178
+ for packet in audio_stream.encode():
179
+ container.mux(packet)
180
+
181
+
182
+ def encode_video(
183
+ video: torch.Tensor | Iterator[torch.Tensor],
184
+ fps: int,
185
+ audio: torch.Tensor | None,
186
+ audio_sample_rate: int | None,
187
+ output_path: str,
188
+ video_chunks_number: int,
189
+ ) -> None:
190
+ if isinstance(video, torch.Tensor):
191
+ video = iter([video])
192
+
193
+ first_chunk = next(video)
194
+
195
+ _, height, width, _ = first_chunk.shape
196
+
197
+ container = av.open(output_path, mode="w")
198
+ stream = container.add_stream("libx264", rate=int(fps))
199
+ stream.width = width
200
+ stream.height = height
201
+ stream.pix_fmt = "yuv420p"
202
+
203
+ if audio is not None:
204
+ if audio_sample_rate is None:
205
+ raise ValueError("audio_sample_rate is required when audio is provided")
206
+
207
+ audio_stream = _prepare_audio_stream(container, audio_sample_rate)
208
+
209
+ def all_tiles(
210
+ first_chunk: torch.Tensor, tiles_generator: Generator[tuple[torch.Tensor, int], None, None]
211
+ ) -> Generator[tuple[torch.Tensor, int], None, None]:
212
+ yield first_chunk
213
+ yield from tiles_generator
214
+
215
+ for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number):
216
+ video_chunk_cpu = video_chunk.to("cpu").numpy()
217
+ for frame_array in video_chunk_cpu:
218
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
219
+ for packet in stream.encode(frame):
220
+ container.mux(packet)
221
+
222
+ # Flush encoder
223
+ for packet in stream.encode():
224
+ container.mux(packet)
225
+
226
+ if audio is not None:
227
+ _write_audio(container, audio_stream, audio, audio_sample_rate)
228
+
229
+ container.close()
230
+
231
+
232
+ def decode_audio_from_file(path: str, device: torch.device) -> torch.Tensor | None:
233
+ container = av.open(path)
234
+ try:
235
+ audio = []
236
+ audio_stream = next(s for s in container.streams if s.type == "audio")
237
+ for frame in container.decode(audio_stream):
238
+ audio.append(torch.tensor(frame.to_ndarray(), dtype=torch.float32, device=device).unsqueeze(0))
239
+ container.close()
240
+ audio = torch.cat(audio)
241
+ except StopIteration:
242
+ audio = None
243
+ finally:
244
+ container.close()
245
+
246
+ return audio
247
+
248
+
249
+ def decode_video_from_file(path: str, frame_cap: int, device: DeviceLikeType) -> Generator[torch.Tensor]:
250
+ container = av.open(path)
251
+ try:
252
+ video_stream = next(s for s in container.streams if s.type == "video")
253
+ for frame in container.decode(video_stream):
254
+ tensor = torch.tensor(frame.to_rgb().to_ndarray(), dtype=torch.uint8, device=device).unsqueeze(0)
255
+ yield tensor
256
+ frame_cap = frame_cap - 1
257
+ if frame_cap == 0:
258
+ break
259
+ finally:
260
+ container.close()
261
+
262
+
263
+ def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
264
+ container = av.open(output_file, "w", format="mp4")
265
+ try:
266
+ stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
267
+ # Round to nearest multiple of 2 for compatibility with video codecs
268
+ height = image_array.shape[0] // 2 * 2
269
+ width = image_array.shape[1] // 2 * 2
270
+ image_array = image_array[:height, :width]
271
+ stream.height = height
272
+ stream.width = width
273
+ av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
274
+ container.mux(stream.encode(av_frame))
275
+ container.mux(stream.encode())
276
+ finally:
277
+ container.close()
278
+
279
+
280
+ def decode_single_frame(video_file: str) -> np.array:
281
+ container = av.open(video_file)
282
+ try:
283
+ stream = next(s for s in container.streams if s.type == "video")
284
+ frame = next(container.decode(stream))
285
+ finally:
286
+ container.close()
287
+ return frame.to_ndarray(format="rgb24")
288
+
289
+
290
+ def preprocess(image: np.array, crf: float = DEFAULT_IMAGE_CRF) -> np.array:
291
+ if crf == 0:
292
+ return image
293
+
294
+ with BytesIO() as output_file:
295
+ encode_single_frame(output_file, image, crf)
296
+ video_bytes = output_file.getvalue()
297
+ with BytesIO(video_bytes) as video_file:
298
+ image_array = decode_single_frame(video_file)
299
+ return image_array
packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+
3
+ import torch
4
+
5
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
6
+ from ltx_core.loader.registry import DummyRegistry, Registry
7
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
8
+ from ltx_core.model.audio_vae import (
9
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
10
+ VOCODER_COMFY_KEYS_FILTER,
11
+ AudioDecoder,
12
+ AudioDecoderConfigurator,
13
+ Vocoder,
14
+ VocoderConfigurator,
15
+ )
16
+ from ltx_core.model.transformer import (
17
+ LTXV_MODEL_COMFY_RENAMING_MAP,
18
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
19
+ UPCAST_DURING_INFERENCE,
20
+ LTXModelConfigurator,
21
+ X0Model,
22
+ )
23
+ from ltx_core.model.upsampler import LatentUpsampler, LatentUpsamplerConfigurator
24
+ from ltx_core.model.video_vae import (
25
+ VAE_DECODER_COMFY_KEYS_FILTER,
26
+ VAE_ENCODER_COMFY_KEYS_FILTER,
27
+ VideoDecoder,
28
+ VideoDecoderConfigurator,
29
+ VideoEncoder,
30
+ VideoEncoderConfigurator,
31
+ )
32
+ from ltx_core.text_encoders.gemma import (
33
+ AV_GEMMA_TEXT_ENCODER_KEY_OPS,
34
+ AVGemmaTextEncoderModel,
35
+ AVGemmaTextEncoderModelConfigurator,
36
+ module_ops_from_gemma_root,
37
+ )
38
+
39
+ from ltx_core.model.audio_vae import (
40
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
41
+ VOCODER_COMFY_KEYS_FILTER,
42
+ AudioDecoder,
43
+ AudioDecoderConfigurator,
44
+ Vocoder,
45
+ VocoderConfigurator,
46
+ AudioEncoder,
47
+ )
48
+ from ltx_core.model.audio_vae.model_configurator import (
49
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
50
+ AudioEncoderConfigurator,
51
+ )
52
+
53
+
54
+ class ModelLedger:
55
+ """
56
+ Central coordinator for loading and building models used in an LTX pipeline.
57
+ The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
58
+ audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
59
+ factory methods for constructing model instances.
60
+ ### Model Building
61
+ Each model method (e.g. :meth:`transformer`, :meth:`video_decoder`, :meth:`text_encoder`)
62
+ constructs a new model instance on each call. The builder uses the
63
+ :class:`~ltx_core.loader.registry.Registry` to load weights from the checkpoint,
64
+ instantiates the model with the configured ``dtype``, and moves it to ``self.device``.
65
+ .. note::
66
+ Models are **not cached**. Each call to a model method creates a new instance.
67
+ Callers are responsible for storing references to models they wish to reuse
68
+ and for freeing GPU memory (e.g. by deleting references and calling
69
+ ``torch.cuda.empty_cache()``).
70
+ ### Constructor parameters
71
+ dtype:
72
+ Torch dtype used when constructing all models (e.g. ``torch.bfloat16``).
73
+ device:
74
+ Target device to which models are moved after construction (e.g. ``torch.device("cuda")``).
75
+ checkpoint_path:
76
+ Path to a checkpoint directory or file containing the core model weights
77
+ (transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the
78
+ corresponding builders are not created and calling those methods will raise
79
+ a :class:`ValueError`.
80
+ gemma_root_path:
81
+ Base path to Gemma-compatible CLIP/text encoder weights. Required to
82
+ initialize the text encoder builder; if omitted, :meth:`text_encoder` cannot be used.
83
+ spatial_upsampler_path:
84
+ Optional path to a latent upsampler checkpoint. If provided, the
85
+ :meth:`spatial_upsampler` method becomes available; otherwise calling it raises
86
+ a :class:`ValueError`.
87
+ loras:
88
+ Optional collection of LoRA configurations (paths, strengths, and key operations)
89
+ that are applied on top of the base transformer weights when building the model.
90
+ registry:
91
+ Optional :class:`Registry` instance for weight caching across builders.
92
+ Defaults to :class:`DummyRegistry` which performs no cross-builder caching.
93
+ fp8transformer:
94
+ If ``True``, builds the transformer with FP8 quantization and upcasting during inference.
95
+ ### Creating Variants
96
+ Use :meth:`with_loras` to create a new ``ModelLedger`` instance that includes
97
+ additional LoRA configurations while sharing the same registry for weight caching.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ dtype: torch.dtype,
103
+ device: torch.device,
104
+ checkpoint_path: str | None = None,
105
+ gemma_root_path: str | None = None,
106
+ spatial_upsampler_path: str | None = None,
107
+ loras: LoraPathStrengthAndSDOps | None = None,
108
+ registry: Registry | None = None,
109
+ fp8transformer: bool = False,
110
+ local_files_only: bool = True
111
+ ):
112
+ self.dtype = dtype
113
+ self.device = device
114
+ self.checkpoint_path = checkpoint_path
115
+ self.gemma_root_path = gemma_root_path
116
+ self.spatial_upsampler_path = spatial_upsampler_path
117
+ self.loras = loras or ()
118
+ self.registry = registry or DummyRegistry()
119
+ self.fp8transformer = fp8transformer
120
+ self.local_files_only = local_files_only
121
+ self.build_model_builders()
122
+
123
+ def build_model_builders(self) -> None:
124
+ if self.checkpoint_path is not None:
125
+ self.transformer_builder = Builder(
126
+ model_path=self.checkpoint_path,
127
+ model_class_configurator=LTXModelConfigurator,
128
+ model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
129
+ loras=tuple(self.loras),
130
+ registry=self.registry,
131
+ )
132
+
133
+ self.vae_decoder_builder = Builder(
134
+ model_path=self.checkpoint_path,
135
+ model_class_configurator=VideoDecoderConfigurator,
136
+ model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
137
+ registry=self.registry,
138
+ )
139
+
140
+ self.vae_encoder_builder = Builder(
141
+ model_path=self.checkpoint_path,
142
+ model_class_configurator=VideoEncoderConfigurator,
143
+ model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
144
+ registry=self.registry,
145
+ )
146
+
147
+ self.audio_decoder_builder = Builder(
148
+ model_path=self.checkpoint_path,
149
+ model_class_configurator=AudioDecoderConfigurator,
150
+ model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
151
+ registry=self.registry,
152
+ )
153
+
154
+ self.vocoder_builder = Builder(
155
+ model_path=self.checkpoint_path,
156
+ model_class_configurator=VocoderConfigurator,
157
+ model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
158
+ registry=self.registry,
159
+ )
160
+
161
+ self.audio_encoder_builder = Builder(
162
+ model_path=self.checkpoint_path,
163
+ model_class_configurator=AudioEncoderConfigurator,
164
+ model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
165
+ registry=self.registry,
166
+ )
167
+
168
+
169
+ if self.gemma_root_path is not None:
170
+ self.text_encoder_builder = Builder(
171
+ model_path=self.checkpoint_path,
172
+ model_class_configurator=AVGemmaTextEncoderModelConfigurator,
173
+ model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
174
+ registry=self.registry,
175
+ module_ops=module_ops_from_gemma_root(self.gemma_root_path,self.local_files_only),
176
+ )
177
+
178
+ if self.spatial_upsampler_path is not None:
179
+ self.upsampler_builder = Builder(
180
+ model_path=self.spatial_upsampler_path,
181
+ model_class_configurator=LatentUpsamplerConfigurator,
182
+ registry=self.registry,
183
+ )
184
+
185
+ def _target_device(self) -> torch.device:
186
+ if isinstance(self.registry, DummyRegistry) or self.registry is None:
187
+ return self.device
188
+ else:
189
+ return torch.device("cpu")
190
+
191
+ def with_loras(self, loras: LoraPathStrengthAndSDOps) -> "ModelLedger":
192
+ return ModelLedger(
193
+ dtype=self.dtype,
194
+ device=self.device,
195
+ checkpoint_path=self.checkpoint_path,
196
+ gemma_root_path=self.gemma_root_path,
197
+ spatial_upsampler_path=self.spatial_upsampler_path,
198
+ loras=(*self.loras, *loras),
199
+ registry=self.registry,
200
+ fp8transformer=self.fp8transformer,
201
+ )
202
+
203
+ def transformer(self) -> X0Model:
204
+ if not hasattr(self, "transformer_builder"):
205
+ raise ValueError(
206
+ "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor."
207
+ )
208
+ if self.fp8transformer:
209
+ fp8_builder = replace(
210
+ self.transformer_builder,
211
+ module_ops=(UPCAST_DURING_INFERENCE,),
212
+ model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
213
+ )
214
+ return X0Model(fp8_builder.build(device=self._target_device())).to(self.device).eval()
215
+ else:
216
+ return (
217
+ X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype))
218
+ .to(self.device)
219
+ .eval()
220
+ )
221
+
222
+ def audio_encoder(self) -> AudioEncoder:
223
+ if not hasattr(self, "audio_encoder_builder"):
224
+ raise ValueError(
225
+ "Audio encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
226
+ )
227
+ return self.audio_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
228
+
229
+
230
+ def video_decoder(self) -> VideoDecoder:
231
+ if not hasattr(self, "vae_decoder_builder"):
232
+ raise ValueError(
233
+ "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
234
+ )
235
+
236
+ return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
237
+
238
+ def video_encoder(self) -> VideoEncoder:
239
+ if not hasattr(self, "vae_encoder_builder"):
240
+ raise ValueError(
241
+ "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
242
+ )
243
+
244
+ return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
245
+
246
+ def text_encoder(self) -> AVGemmaTextEncoderModel:
247
+ if not hasattr(self, "text_encoder_builder"):
248
+ raise ValueError(
249
+ "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the "
250
+ "ModelLedger constructor."
251
+ )
252
+
253
+ return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
254
+
255
+ def audio_decoder(self) -> AudioDecoder:
256
+ if not hasattr(self, "audio_decoder_builder"):
257
+ raise ValueError(
258
+ "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
259
+ )
260
+
261
+ return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
262
+
263
+ def vocoder(self) -> Vocoder:
264
+ if not hasattr(self, "vocoder_builder"):
265
+ raise ValueError(
266
+ "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
267
+ )
268
+
269
+ return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
270
+
271
+ def spatial_upsampler(self) -> LatentUpsampler:
272
+ if not hasattr(self, "upsampler_builder"):
273
+ raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.")
274
+
275
+ return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
packages/ltx-pipelines/src/ltx_pipelines/utils/types.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
6
+ from ltx_core.components.protocols import DiffusionStepProtocol
7
+ from ltx_core.types import LatentState
8
+ from ltx_pipelines.utils.constants import VIDEO_LATENT_CHANNELS, VIDEO_SCALE_FACTORS
9
+
10
+
11
+ class PipelineComponents:
12
+ """
13
+ Container class for pipeline components used throughout the LTX pipelines.
14
+ Attributes:
15
+ dtype (torch.dtype): Default torch dtype for tensors in the pipeline.
16
+ device (torch.device): Target device to place tensors and modules on.
17
+ video_scale_factors (SpatioTemporalScaleFactors): Scale factors (T, H, W) for VAE latent space.
18
+ video_latent_channels (int): Number of channels in the video latent representation.
19
+ video_patchifier (VideoLatentPatchifier): Patchifier instance for video latents.
20
+ audio_patchifier (AudioPatchifier): Patchifier instance for audio latents.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ dtype: torch.dtype,
26
+ device: torch.device,
27
+ ):
28
+ self.dtype = dtype
29
+ self.device = device
30
+
31
+ self.video_scale_factors = VIDEO_SCALE_FACTORS
32
+ self.video_latent_channels = VIDEO_LATENT_CHANNELS
33
+
34
+ self.video_patchifier = VideoLatentPatchifier(patch_size=1)
35
+ self.audio_patchifier = AudioPatchifier(patch_size=1)
36
+
37
+
38
+ class DenoisingFunc(Protocol):
39
+ """
40
+ Protocol for a denoising function used in the LTX pipeline.
41
+ Args:
42
+ video_state (LatentState): The current latent state for video.
43
+ audio_state (LatentState): The current latent state for audio.
44
+ sigmas (torch.Tensor): A 1D tensor of sigma values for each diffusion step.
45
+ step_index (int): Index of the current denoising step.
46
+ Returns:
47
+ tuple[torch.Tensor, torch.Tensor]: The denoised video and audio tensors.
48
+ """
49
+
50
+ def __call__(
51
+ self, video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
52
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
53
+
54
+
55
+ class DenoisingLoopFunc(Protocol):
56
+ """
57
+ Protocol for a denoising loop function used in the LTX pipeline.
58
+ Args:
59
+ sigmas (torch.Tensor): A 1D tensor of sigma values for each diffusion step.
60
+ video_state (LatentState): The current latent state for video.
61
+ audio_state (LatentState): The current latent state for audio.
62
+ stepper (DiffusionStepProtocol): The diffusion step protocol to use.
63
+ Returns:
64
+ tuple[LatentState, LatentState]: The denoised video and audio latent states.
65
+ """
66
+
67
+ def __call__(
68
+ self,
69
+ sigmas: torch.Tensor,
70
+ video_state: LatentState,
71
+ audio_state: LatentState,
72
+ stepper: DiffusionStepProtocol,
73
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...