Fabrice-TIERCELIN commited on
Commit
b323cbe
·
verified ·
1 Parent(s): de02dc3

Delete packages/ltx-pipelines/ltx_pipelines

Browse files
packages/ltx-pipelines/ltx_pipelines/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- LTX-2 Pipelines: High-level video generation pipelines and utilities.
3
- This package provides ready-to-use pipelines for video generation:
4
- - TI2VidOneStagePipeline: Text/image-to-video in a single stage
5
- - TI2VidTwoStagesPipeline: Two-stage generation with upsampling
6
- - DistilledPipeline: Fast distilled two-stage generation
7
- - ICLoraPipeline: Image/video conditioning with distilled LoRA
8
- - KeyframeInterpolationPipeline: Keyframe-based video interpolation
9
- - ModelLedger: Central coordinator for loading and building models
10
- For more detailed components and utilities, import from specific submodules
11
- like `ltx_pipelines.utils.media_io` or `ltx_pipelines.utils.constants`.
12
- """
13
-
14
- from ltx_pipelines.distilled import DistilledPipeline
15
- from ltx_pipelines.ic_lora import ICLoraPipeline
16
- from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline
17
- from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline
18
- from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
19
-
20
- __all__ = [
21
- "DistilledPipeline",
22
- "ICLoraPipeline",
23
- "KeyframeInterpolationPipeline",
24
- "TI2VidOneStagePipeline",
25
- "TI2VidTwoStagesPipeline",
26
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/distilled.py DELETED
@@ -1,475 +0,0 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Amit Pintz.
3
-
4
-
5
- import torch
6
-
7
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
8
- from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
9
- from ltx_core.components.noisers import GaussianNoiser
10
- from ltx_core.components.protocols import DiffusionStepProtocol
11
- from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex, ConditioningError
12
- from ltx_core.loader import LoraPathStrengthAndSDOps
13
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
14
- from ltx_core.model.upsampler import upsample_video
15
- from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
16
- from ltx_core.model.video_vae import decode_video as vae_decode_video
17
- from ltx_core.text_encoders.gemma import encode_text
18
- from ltx_core.types import LatentState, VideoPixelShape
19
- from ltx_core.tools import LatentTools
20
- from ltx_pipelines import utils
21
- from ltx_pipelines.utils import ModelLedger
22
- from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
23
- from ltx_pipelines.utils.constants import (
24
- AUDIO_SAMPLE_RATE,
25
- DEFAULT_LORA_STRENGTH,
26
- DISTILLED_SIGMA_VALUES,
27
- STAGE_2_DISTILLED_SIGMA_VALUES,
28
- )
29
- from ltx_pipelines.utils.helpers import (
30
- assert_resolution,
31
- cleanup_memory,
32
- denoise_audio_video,
33
- euler_denoising_loop,
34
- generate_enhanced_prompt,
35
- get_device,
36
- image_conditionings_by_replacing_latent,
37
- simple_denoising_func,
38
- )
39
- from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
40
- from ltx_pipelines.utils.types import PipelineComponents
41
-
42
- import torchaudio
43
- from ltx_core.model.audio_vae import AudioProcessor
44
- from ltx_core.types import AudioLatentShape, VideoPixelShape
45
-
46
- class AudioConditionByLatent(ConditioningItem):
47
- """
48
- Conditions audio generation by injecting a full latent sequence.
49
- Replaces tokens in the latent state with the provided audio latents,
50
- and sets denoise strength according to the strength parameter.
51
- """
52
-
53
- def __init__(self, latent: torch.Tensor, strength: float):
54
- self.latent = latent
55
- self.strength = strength
56
-
57
- def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
58
- if not isinstance(latent_tools.target_shape, AudioLatentShape):
59
- raise ConditioningError("Audio conditioning requires an audio latent target shape.")
60
-
61
- cond_batch, cond_channels, cond_frames, cond_bins = self.latent.shape
62
- tgt_batch, tgt_channels, tgt_frames, tgt_bins = latent_tools.target_shape.to_torch_shape()
63
-
64
- if (cond_batch, cond_channels, cond_frames, cond_bins) != (tgt_batch, tgt_channels, tgt_frames, tgt_bins):
65
- raise ConditioningError(
66
- f"Can't apply audio conditioning item to latent with shape {latent_tools.target_shape}, expected "
67
- f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_bins})."
68
- )
69
-
70
- tokens = latent_tools.patchifier.patchify(self.latent)
71
- latent_state = latent_state.clone()
72
- latent_state.latent[:, : tokens.shape[1]] = tokens
73
- latent_state.clean_latent[:, : tokens.shape[1]] = tokens
74
- latent_state.denoise_mask[:, : tokens.shape[1]] = 1.0 - self.strength
75
-
76
- return latent_state
77
-
78
- device = get_device()
79
-
80
-
81
- class DistilledPipeline:
82
- def __init__(
83
- self,
84
- checkpoint_path: str,
85
- gemma_root: str,
86
- spatial_upsampler_path: str,
87
- loras: list[LoraPathStrengthAndSDOps],
88
- device: torch.device = device,
89
- fp8transformer: bool = False,
90
- local_files_only: bool = True,
91
- ):
92
- self.device = device
93
- self.dtype = torch.bfloat16
94
-
95
- self.model_ledger = ModelLedger(
96
- dtype=self.dtype,
97
- device=device,
98
- checkpoint_path=checkpoint_path,
99
- spatial_upsampler_path=spatial_upsampler_path,
100
- gemma_root_path=gemma_root,
101
- loras=loras,
102
- fp8transformer=fp8transformer,
103
- local_files_only=local_files_only
104
- )
105
-
106
- self.pipeline_components = PipelineComponents(
107
- dtype=self.dtype,
108
- device=device,
109
- )
110
-
111
- # Cached models to avoid reloading
112
- self._video_encoder = None
113
- self._transformer = None
114
-
115
- def _build_audio_conditionings_from_waveform(
116
- self,
117
- input_waveform: torch.Tensor,
118
- input_sample_rate: int,
119
- num_frames: int,
120
- fps: float,
121
- strength: float,
122
- ) -> list[AudioConditionByLatent] | None:
123
- strength = float(strength)
124
- if strength <= 0.0:
125
- return None
126
-
127
- # Expect waveform as:
128
- # - (T,) or (C,T) or (B,C,T). Convert to (B,C,T)
129
- waveform = input_waveform
130
- if waveform.ndim == 1:
131
- waveform = waveform.unsqueeze(0).unsqueeze(0)
132
- elif waveform.ndim == 2:
133
- waveform = waveform.unsqueeze(0)
134
- elif waveform.ndim != 3:
135
- raise ValueError(f"input_waveform must be 1D/2D/3D, got shape {tuple(waveform.shape)}")
136
-
137
- # Get audio encoder + its config
138
- audio_encoder = self.model_ledger.audio_encoder() # assumes ledger exposes it
139
- # If you want to cache it like video_encoder/transformer, you can.
140
- target_sr = int(getattr(audio_encoder, "sample_rate"))
141
- target_channels = int(getattr(audio_encoder, "in_channels", waveform.shape[1]))
142
- mel_bins = int(getattr(audio_encoder, "mel_bins"))
143
- mel_hop = int(getattr(audio_encoder, "mel_hop_length"))
144
- n_fft = int(getattr(audio_encoder, "n_fft"))
145
-
146
- # Match channels
147
- if waveform.shape[1] != target_channels:
148
- if waveform.shape[1] == 1 and target_channels > 1:
149
- waveform = waveform.repeat(1, target_channels, 1)
150
- elif target_channels == 1:
151
- waveform = waveform.mean(dim=1, keepdim=True)
152
- else:
153
- waveform = waveform[:, :target_channels, :]
154
- if waveform.shape[1] < target_channels:
155
- pad_ch = target_channels - waveform.shape[1]
156
- pad = torch.zeros((waveform.shape[0], pad_ch, waveform.shape[2]), dtype=waveform.dtype)
157
- waveform = torch.cat([waveform, pad], dim=1)
158
-
159
- # Resample if needed (CPU float32 is safest for torchaudio)
160
- waveform = waveform.to(device="cpu", dtype=torch.float32)
161
- if int(input_sample_rate) != target_sr:
162
- waveform = torchaudio.functional.resample(waveform, int(input_sample_rate), target_sr)
163
-
164
- # Waveform -> Mel
165
- audio_processor = AudioProcessor(
166
- sample_rate=target_sr,
167
- mel_bins=mel_bins,
168
- mel_hop_length=mel_hop,
169
- n_fft=n_fft,
170
- ).to(waveform.device)
171
-
172
- mel = audio_processor.waveform_to_mel(waveform, target_sr)
173
-
174
- # Mel -> latent (run encoder on its own device/dtype)
175
- audio_params = next(audio_encoder.parameters(), None)
176
- enc_device = audio_params.device if audio_params is not None else self.device
177
- enc_dtype = audio_params.dtype if audio_params is not None else self.dtype
178
-
179
- mel = mel.to(device=enc_device, dtype=enc_dtype)
180
- with torch.inference_mode():
181
- audio_latent = audio_encoder(mel)
182
-
183
- # Pad/trim latent to match the target video duration
184
- audio_downsample = getattr(getattr(audio_encoder, "patchifier", None), "audio_latent_downsample_factor", 4)
185
- target_shape = AudioLatentShape.from_video_pixel_shape(
186
- VideoPixelShape(batch=audio_latent.shape[0], frames=int(num_frames), width=1, height=1, fps=float(fps)),
187
- channels=audio_latent.shape[1],
188
- mel_bins=audio_latent.shape[3],
189
- sample_rate=target_sr,
190
- hop_length=mel_hop,
191
- audio_latent_downsample_factor=audio_downsample,
192
- )
193
- target_frames = int(target_shape.frames)
194
-
195
- if audio_latent.shape[2] < target_frames:
196
- pad_frames = target_frames - audio_latent.shape[2]
197
- pad = torch.zeros(
198
- (audio_latent.shape[0], audio_latent.shape[1], pad_frames, audio_latent.shape[3]),
199
- device=audio_latent.device,
200
- dtype=audio_latent.dtype,
201
- )
202
- audio_latent = torch.cat([audio_latent, pad], dim=2)
203
- elif audio_latent.shape[2] > target_frames:
204
- audio_latent = audio_latent[:, :, :target_frames, :]
205
-
206
- # Move latent to pipeline device/dtype for conditioning object
207
- audio_latent = audio_latent.to(device=self.device, dtype=self.dtype)
208
-
209
- return [AudioConditionByLatent(audio_latent, strength)]
210
-
211
- def _prepare_output_waveform(
212
- self,
213
- input_waveform: torch.Tensor,
214
- input_sample_rate: int,
215
- target_sample_rate: int,
216
- num_frames: int,
217
- fps: float,
218
- ) -> torch.Tensor:
219
- """
220
- Returns waveform on CPU, float32, resampled to target_sample_rate and
221
- trimmed/padded to match video duration.
222
- Output shape: (T,) for mono or (C, T) for multi-channel.
223
- """
224
- wav = input_waveform
225
-
226
- # Accept (T,), (C,T), (B,C,T)
227
- if wav.ndim == 3:
228
- wav = wav[0]
229
- elif wav.ndim == 2:
230
- pass
231
- elif wav.ndim == 1:
232
- wav = wav.unsqueeze(0)
233
- else:
234
- raise ValueError(f"input_waveform must be 1D/2D/3D, got {tuple(wav.shape)}")
235
-
236
- # Now wav is (C, T)
237
- wav = wav.detach().to("cpu", dtype=torch.float32)
238
-
239
- # Resample if needed
240
- if int(input_sample_rate) != int(target_sample_rate):
241
- wav = torchaudio.functional.resample(wav, int(input_sample_rate), int(target_sample_rate))
242
-
243
- # Match video duration
244
- duration_sec = float(num_frames) / float(fps)
245
- target_len = int(round(duration_sec * float(target_sample_rate)))
246
-
247
- cur_len = int(wav.shape[-1])
248
- if cur_len > target_len:
249
- wav = wav[..., :target_len]
250
- elif cur_len < target_len:
251
- pad = target_len - cur_len
252
- wav = torch.nn.functional.pad(wav, (0, pad))
253
-
254
- # If mono, return (T,) for convenience
255
- if wav.shape[0] == 1:
256
- return wav[0]
257
- return wav
258
-
259
-
260
- @torch.inference_mode()
261
- def __call__(
262
- self,
263
- prompt: str,
264
- output_path: str,
265
- seed: int,
266
- height: int,
267
- width: int,
268
- num_frames: int,
269
- frame_rate: float,
270
- images: list[tuple[str, int, float]],
271
- video_conditioning: list[tuple[str, float]] | None = None,
272
- video_conditioning_frame_idx: int = 0,
273
- apply_video_conditioning_to_stage2: bool = False,
274
- tiling_config: TilingConfig | None = None,
275
- video_context: torch.Tensor | None = None,
276
- audio_context: torch.Tensor | None = None,
277
- input_waveform: torch.Tensor | None = None,
278
- input_waveform_sample_rate: int | None = None,
279
- audio_strength: float = 1.0, # or audio_scale, your naming
280
- ) -> None:
281
- generator = torch.Generator(device=self.device).manual_seed(seed)
282
- noiser = GaussianNoiser(generator=generator)
283
- stepper = EulerDiffusionStep()
284
- dtype = torch.bfloat16
285
-
286
- audio_conditionings = None
287
- if input_waveform is not None:
288
- if input_waveform_sample_rate is None:
289
- raise ValueError("input_waveform_sample_rate must be provided when input_waveform is set.")
290
- audio_conditionings = self._build_audio_conditionings_from_waveform(
291
- input_waveform=input_waveform,
292
- input_sample_rate=int(input_waveform_sample_rate),
293
- num_frames=num_frames,
294
- fps=frame_rate,
295
- strength=audio_strength,
296
- )
297
-
298
- # Use pre-computed embeddings if provided, otherwise encode text
299
- if video_context is None or audio_context is None:
300
- text_encoder = self.model_ledger.text_encoder()
301
- context_p = encode_text(text_encoder, prompts=[prompt])[0]
302
- video_context, audio_context = context_p
303
-
304
- torch.cuda.synchronize()
305
- del text_encoder
306
- utils.cleanup_memory()
307
- else:
308
- # Move pre-computed embeddings to device if needed
309
- video_context = video_context.to(self.device)
310
- audio_context = audio_context.to(self.device)
311
-
312
- # Stage 1: Initial low resolution video generation.
313
- # Load models only if not already cached
314
- if self._video_encoder is None:
315
- self._video_encoder = self.model_ledger.video_encoder()
316
- video_encoder = self._video_encoder
317
-
318
- if self._transformer is None:
319
- self._transformer = self.model_ledger.transformer()
320
- transformer = self._transformer
321
- stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
322
-
323
- def denoising_loop(
324
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
325
- ) -> tuple[LatentState, LatentState]:
326
- return euler_denoising_loop(
327
- sigmas=sigmas,
328
- video_state=video_state,
329
- audio_state=audio_state,
330
- stepper=stepper,
331
- denoise_fn=simple_denoising_func(
332
- video_context=video_context,
333
- audio_context=audio_context,
334
- transformer=transformer, # noqa: F821
335
- ),
336
- )
337
-
338
- stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
339
- stage_1_conditionings = self._create_conditionings(
340
- images=images,
341
- video_conditioning=video_conditioning,
342
- height=stage_1_output_shape.height,
343
- width=stage_1_output_shape.width,
344
- num_frames=num_frames,
345
- video_encoder=video_encoder,
346
- video_conditioning_frame_idx=video_conditioning_frame_idx,
347
- dtype=dtype,
348
- )
349
-
350
- video_state, audio_state = denoise_audio_video(
351
- output_shape=stage_1_output_shape,
352
- conditionings=stage_1_conditionings,
353
- audio_conditionings=audio_conditionings,
354
- noiser=noiser,
355
- sigmas=stage_1_sigmas,
356
- stepper=stepper,
357
- denoising_loop_fn=denoising_loop,
358
- components=self.pipeline_components,
359
- dtype=dtype,
360
- device=self.device,
361
- )
362
-
363
- # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
364
- upscaled_video_latent = upsample_video(
365
- latent=video_state.latent[:1], video_encoder=video_encoder, upsampler=self.model_ledger.spatial_upsampler()
366
- )
367
-
368
- torch.cuda.synchronize()
369
- cleanup_memory()
370
-
371
- stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
372
- stage_2_output_shape = VideoPixelShape(
373
- batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate
374
- )
375
- if apply_video_conditioning_to_stage2:
376
- stage_2_conditionings = self._create_conditionings(
377
- images=images,
378
- video_conditioning=video_conditioning,
379
- height=stage_2_output_shape.height,
380
- width=stage_2_output_shape.width,
381
- num_frames=num_frames,
382
- video_encoder=video_encoder,
383
- video_conditioning_frame_idx=video_conditioning_frame_idx,
384
- )
385
- else:
386
- stage_2_conditionings = image_conditionings_by_replacing_latent(
387
- images=images,
388
- height=stage_2_output_shape.height,
389
- width=stage_2_output_shape.width,
390
- video_encoder=video_encoder,
391
- dtype=dtype,
392
- device=self.device,
393
- )
394
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
395
- video_state, audio_state = denoise_audio_video(
396
- output_shape=stage_2_output_shape,
397
- conditionings=stage_2_conditionings,
398
- audio_conditionings=audio_conditionings,
399
- noiser=noiser,
400
- sigmas=stage_2_sigmas,
401
- stepper=stepper,
402
- denoising_loop_fn=denoising_loop,
403
- components=self.pipeline_components,
404
- dtype=dtype,
405
- device=self.device,
406
- noise_scale=stage_2_sigmas[0],
407
- initial_video_latent=upscaled_video_latent,
408
- initial_audio_latent=audio_state.latent,
409
- )
410
-
411
- torch.cuda.synchronize()
412
- # del transformer
413
- # del video_encoder
414
- # utils.cleanup_memory()
415
-
416
- decoded_video = vae_decode_video(video_state.latent, self.model_ledger.video_decoder(), tiling_config)
417
- decoded_audio = vae_decode_audio(audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder())
418
-
419
- encode_video(
420
- video=decoded_video,
421
- fps=frame_rate,
422
- audio=decoded_audio,
423
- audio_sample_rate=AUDIO_SAMPLE_RATE,
424
- output_path=output_path,
425
- video_chunks_number=video_chunks_number,
426
- )
427
-
428
-
429
- def _create_conditionings(
430
- self,
431
- images: list[tuple[str, int, float]],
432
- video_conditioning: list[tuple[str, float]] | None,
433
- height: int,
434
- width: int,
435
- num_frames: int,
436
- video_encoder,
437
- video_conditioning_frame_idx: int,
438
- dtype: torch.dtype,
439
- ):
440
- # 1) Keep ORIGINAL behavior: image conditioning by replacing latent
441
- conditionings = image_conditionings_by_replacing_latent(
442
- images=images,
443
- height=height,
444
- width=width,
445
- video_encoder=video_encoder,
446
- dtype=dtype,
447
- device=self.device,
448
- )
449
-
450
- # 2) Optional: add video conditioning (IC-LoRA style)
451
- if not video_conditioning:
452
- return conditionings
453
-
454
- for video_path, strength in video_conditioning:
455
- video = load_video_conditioning(
456
- video_path=video_path,
457
- height=height,
458
- width=width,
459
- frame_cap=num_frames, # ✅ correct kwarg name
460
- dtype=dtype,
461
- device=self.device,
462
- )
463
-
464
- encoded_video = video_encoder(video)
465
-
466
- # ✅ match IC-LoRA: append the conditioning object directly
467
- conditionings.append(
468
- VideoConditionByKeyframeIndex(
469
- keyframes=encoded_video,
470
- frame_idx=video_conditioning_frame_idx,
471
- strength=strength,
472
- )
473
- )
474
-
475
- return conditionings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/ic_lora.py DELETED
@@ -1,309 +0,0 @@
1
- import logging
2
- from collections.abc import Iterator
3
-
4
- import torch
5
-
6
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
- from ltx_core.components.noisers import GaussianNoiser
8
- from ltx_core.components.protocols import DiffusionStepProtocol
9
- from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex
10
- from ltx_core.loader import LoraPathStrengthAndSDOps
11
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
12
- from ltx_core.model.upsampler import upsample_video
13
- from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
14
- from ltx_core.model.video_vae import decode_video as vae_decode_video
15
- from ltx_core.text_encoders.gemma import encode_text
16
- from ltx_core.types import LatentState, VideoPixelShape
17
- from ltx_pipelines.utils import ModelLedger
18
- from ltx_pipelines.utils.args import VideoConditioningAction, default_2_stage_distilled_arg_parser
19
- from ltx_pipelines.utils.constants import (
20
- AUDIO_SAMPLE_RATE,
21
- DISTILLED_SIGMA_VALUES,
22
- STAGE_2_DISTILLED_SIGMA_VALUES,
23
- )
24
- from ltx_pipelines.utils.helpers import (
25
- assert_resolution,
26
- cleanup_memory,
27
- denoise_audio_video,
28
- euler_denoising_loop,
29
- generate_enhanced_prompt,
30
- get_device,
31
- image_conditionings_by_replacing_latent,
32
- simple_denoising_func,
33
- )
34
- from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
35
- from ltx_pipelines.utils.types import PipelineComponents
36
-
37
- device = get_device()
38
-
39
-
40
- class ICLoraPipeline:
41
- """
42
- Two-stage video generation pipeline with In-Context (IC) LoRA support.
43
- Allows conditioning the generated video on control signals such as depth maps,
44
- human pose, or image edges via the video_conditioning parameter.
45
- The specific IC-LoRA model should be provided via the loras parameter.
46
- Stage 1 generates video at the target resolution, then Stage 2 upsamples
47
- by 2x and refines with additional denoising steps for higher quality output.
48
- """
49
-
50
- def __init__(
51
- self,
52
- checkpoint_path: str,
53
- spatial_upsampler_path: str,
54
- gemma_root: str,
55
- loras: list[LoraPathStrengthAndSDOps],
56
- device: torch.device = device,
57
- fp8transformer: bool = False,
58
- ):
59
- self.dtype = torch.bfloat16
60
- self.stage_1_model_ledger = ModelLedger(
61
- dtype=self.dtype,
62
- device=device,
63
- checkpoint_path=checkpoint_path,
64
- spatial_upsampler_path=spatial_upsampler_path,
65
- gemma_root_path=gemma_root,
66
- loras=loras,
67
- fp8transformer=fp8transformer,
68
- )
69
- self.stage_2_model_ledger = ModelLedger(
70
- dtype=self.dtype,
71
- device=device,
72
- checkpoint_path=checkpoint_path,
73
- spatial_upsampler_path=spatial_upsampler_path,
74
- gemma_root_path=gemma_root,
75
- loras=[],
76
- fp8transformer=fp8transformer,
77
- )
78
- self.pipeline_components = PipelineComponents(
79
- dtype=self.dtype,
80
- device=device,
81
- )
82
- self.device = device
83
-
84
- @torch.inference_mode()
85
- def __call__(
86
- self,
87
- prompt: str,
88
- seed: int,
89
- height: int,
90
- width: int,
91
- num_frames: int,
92
- frame_rate: float,
93
- images: list[tuple[str, int, float]],
94
- video_conditioning: list[tuple[str, float]],
95
- enhance_prompt: bool = False,
96
- tiling_config: TilingConfig | None = None,
97
- ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
98
- assert_resolution(height=height, width=width, is_two_stage=True)
99
-
100
- generator = torch.Generator(device=self.device).manual_seed(seed)
101
- noiser = GaussianNoiser(generator=generator)
102
- stepper = EulerDiffusionStep()
103
- dtype = torch.bfloat16
104
-
105
- text_encoder = self.stage_1_model_ledger.text_encoder()
106
-
107
- if enhance_prompt:
108
- prompt = generate_enhanced_prompt(
109
- text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
110
- )
111
- video_context, audio_context = encode_text(text_encoder, prompts=[prompt])[0]
112
-
113
- torch.cuda.synchronize()
114
- del text_encoder
115
- cleanup_memory()
116
-
117
- # Stage 1: Initial low resolution video generation.
118
- video_encoder = self.stage_1_model_ledger.video_encoder()
119
- transformer = self.stage_1_model_ledger.transformer()
120
- stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
121
-
122
- def first_stage_denoising_loop(
123
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
124
- ) -> tuple[LatentState, LatentState]:
125
- return euler_denoising_loop(
126
- sigmas=sigmas,
127
- video_state=video_state,
128
- audio_state=audio_state,
129
- stepper=stepper,
130
- denoise_fn=simple_denoising_func(
131
- video_context=video_context,
132
- audio_context=audio_context,
133
- transformer=transformer, # noqa: F821
134
- ),
135
- )
136
-
137
- stage_1_output_shape = VideoPixelShape(
138
- batch=1,
139
- frames=num_frames,
140
- width=width // 2,
141
- height=height // 2,
142
- fps=frame_rate,
143
- )
144
- stage_1_conditionings = self._create_conditionings(
145
- images=images,
146
- video_conditioning=video_conditioning,
147
- height=stage_1_output_shape.height,
148
- width=stage_1_output_shape.width,
149
- video_encoder=video_encoder,
150
- num_frames=num_frames,
151
- )
152
- video_state, audio_state = denoise_audio_video(
153
- output_shape=stage_1_output_shape,
154
- conditionings=stage_1_conditionings,
155
- noiser=noiser,
156
- sigmas=stage_1_sigmas,
157
- stepper=stepper,
158
- denoising_loop_fn=first_stage_denoising_loop,
159
- components=self.pipeline_components,
160
- dtype=dtype,
161
- device=self.device,
162
- )
163
-
164
- torch.cuda.synchronize()
165
- del transformer
166
- cleanup_memory()
167
-
168
- # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
169
- upscaled_video_latent = upsample_video(
170
- latent=video_state.latent[:1],
171
- video_encoder=video_encoder,
172
- upsampler=self.stage_2_model_ledger.spatial_upsampler(),
173
- )
174
-
175
- torch.cuda.synchronize()
176
- cleanup_memory()
177
-
178
- transformer = self.stage_2_model_ledger.transformer()
179
- distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
180
-
181
- def second_stage_denoising_loop(
182
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
183
- ) -> tuple[LatentState, LatentState]:
184
- return euler_denoising_loop(
185
- sigmas=sigmas,
186
- video_state=video_state,
187
- audio_state=audio_state,
188
- stepper=stepper,
189
- denoise_fn=simple_denoising_func(
190
- video_context=video_context,
191
- audio_context=audio_context,
192
- transformer=transformer, # noqa: F821
193
- ),
194
- )
195
-
196
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
197
- stage_2_conditionings = image_conditionings_by_replacing_latent(
198
- images=images,
199
- height=stage_2_output_shape.height,
200
- width=stage_2_output_shape.width,
201
- video_encoder=video_encoder,
202
- dtype=self.dtype,
203
- device=self.device,
204
- )
205
-
206
- video_state, audio_state = denoise_audio_video(
207
- output_shape=stage_2_output_shape,
208
- conditionings=stage_2_conditionings,
209
- noiser=noiser,
210
- sigmas=distilled_sigmas,
211
- stepper=stepper,
212
- denoising_loop_fn=second_stage_denoising_loop,
213
- components=self.pipeline_components,
214
- dtype=dtype,
215
- device=self.device,
216
- noise_scale=distilled_sigmas[0],
217
- initial_video_latent=upscaled_video_latent,
218
- initial_audio_latent=audio_state.latent,
219
- )
220
-
221
- torch.cuda.synchronize()
222
- del transformer
223
- del video_encoder
224
- cleanup_memory()
225
-
226
- decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
227
- decoded_audio = vae_decode_audio(
228
- audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
229
- )
230
- return decoded_video, decoded_audio
231
-
232
- def _create_conditionings(
233
- self,
234
- images: list[tuple[str, int, float]],
235
- video_conditioning: list[tuple[str, float]],
236
- height: int,
237
- width: int,
238
- num_frames: int,
239
- video_encoder: VideoEncoder,
240
- ) -> list[ConditioningItem]:
241
- conditionings = image_conditionings_by_replacing_latent(
242
- images=images,
243
- height=height,
244
- width=width,
245
- video_encoder=video_encoder,
246
- dtype=self.dtype,
247
- device=self.device,
248
- )
249
-
250
- for video_path, strength in video_conditioning:
251
- video = load_video_conditioning(
252
- video_path=video_path,
253
- height=height,
254
- width=width,
255
- frame_cap=num_frames,
256
- dtype=self.dtype,
257
- device=self.device,
258
- )
259
- encoded_video = video_encoder(video)
260
- conditionings.append(VideoConditionByKeyframeIndex(keyframes=encoded_video, frame_idx=0, strength=strength))
261
-
262
- return conditionings
263
-
264
-
265
- @torch.inference_mode()
266
- def main() -> None:
267
- logging.getLogger().setLevel(logging.INFO)
268
- parser = default_2_stage_distilled_arg_parser()
269
- parser.add_argument(
270
- "--video-conditioning",
271
- action=VideoConditioningAction,
272
- nargs=2,
273
- metavar=("PATH", "STRENGTH"),
274
- required=True,
275
- )
276
- args = parser.parse_args()
277
- pipeline = ICLoraPipeline(
278
- checkpoint_path=args.checkpoint_path,
279
- spatial_upsampler_path=args.spatial_upsampler_path,
280
- gemma_root=args.gemma_root,
281
- loras=args.lora,
282
- fp8transformer=args.enable_fp8,
283
- )
284
- tiling_config = TilingConfig.default()
285
- video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
286
- video, audio = pipeline(
287
- prompt=args.prompt,
288
- seed=args.seed,
289
- height=args.height,
290
- width=args.width,
291
- num_frames=args.num_frames,
292
- frame_rate=args.frame_rate,
293
- images=args.images,
294
- video_conditioning=args.video_conditioning,
295
- tiling_config=tiling_config,
296
- )
297
-
298
- encode_video(
299
- video=video,
300
- fps=args.frame_rate,
301
- audio=audio,
302
- audio_sample_rate=AUDIO_SAMPLE_RATE,
303
- output_path=args.output_path,
304
- video_chunks_number=video_chunks_number,
305
- )
306
-
307
-
308
- if __name__ == "__main__":
309
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/keyframe_interpolation.py DELETED
@@ -1,273 +0,0 @@
1
- import logging
2
- from collections.abc import Iterator
3
-
4
- import torch
5
-
6
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
- from ltx_core.components.guiders import CFGGuider
8
- from ltx_core.components.noisers import GaussianNoiser
9
- from ltx_core.components.protocols import DiffusionStepProtocol
10
- from ltx_core.components.schedulers import LTX2Scheduler
11
- from ltx_core.loader import LoraPathStrengthAndSDOps
12
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
- from ltx_core.model.upsampler import upsample_video
14
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
15
- from ltx_core.model.video_vae import decode_video as vae_decode_video
16
- from ltx_core.text_encoders.gemma import encode_text
17
- from ltx_core.types import LatentState, VideoPixelShape
18
- from ltx_pipelines.utils import ModelLedger
19
- from ltx_pipelines.utils.args import default_2_stage_arg_parser
20
- from ltx_pipelines.utils.constants import (
21
- AUDIO_SAMPLE_RATE,
22
- STAGE_2_DISTILLED_SIGMA_VALUES,
23
- )
24
- from ltx_pipelines.utils.helpers import (
25
- assert_resolution,
26
- cleanup_memory,
27
- denoise_audio_video,
28
- euler_denoising_loop,
29
- generate_enhanced_prompt,
30
- get_device,
31
- guider_denoising_func,
32
- image_conditionings_by_adding_guiding_latent,
33
- simple_denoising_func,
34
- )
35
- from ltx_pipelines.utils.media_io import encode_video
36
- from ltx_pipelines.utils.types import PipelineComponents
37
-
38
- device = get_device()
39
-
40
-
41
- class KeyframeInterpolationPipeline:
42
- """
43
- Keyframe-based Two-stage video interpolation pipeline.
44
- Interpolates between keyframes to generate a video with smoother transitions.
45
- Stage 1 generates video at the target resolution, then Stage 2 upsamples
46
- by 2x and refines with additional denoising steps for higher quality output.
47
- """
48
-
49
- def __init__(
50
- self,
51
- checkpoint_path: str,
52
- distilled_lora: list[LoraPathStrengthAndSDOps],
53
- spatial_upsampler_path: str,
54
- gemma_root: str,
55
- loras: list[LoraPathStrengthAndSDOps],
56
- device: torch.device = device,
57
- fp8transformer: bool = False,
58
- ):
59
- self.device = device
60
- self.dtype = torch.bfloat16
61
- self.stage_1_model_ledger = ModelLedger(
62
- dtype=self.dtype,
63
- device=device,
64
- checkpoint_path=checkpoint_path,
65
- spatial_upsampler_path=spatial_upsampler_path,
66
- gemma_root_path=gemma_root,
67
- loras=loras,
68
- fp8transformer=fp8transformer,
69
- )
70
- self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
71
- loras=distilled_lora,
72
- )
73
- self.pipeline_components = PipelineComponents(
74
- dtype=self.dtype,
75
- device=device,
76
- )
77
-
78
- @torch.inference_mode()
79
- def __call__( # noqa: PLR0913
80
- self,
81
- prompt: str,
82
- negative_prompt: str,
83
- seed: int,
84
- height: int,
85
- width: int,
86
- num_frames: int,
87
- frame_rate: float,
88
- num_inference_steps: int,
89
- cfg_guidance_scale: float,
90
- images: list[tuple[str, int, float]],
91
- tiling_config: TilingConfig | None = None,
92
- enhance_prompt: bool = False,
93
- ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
94
- assert_resolution(height=height, width=width, is_two_stage=True)
95
-
96
- generator = torch.Generator(device=self.device).manual_seed(seed)
97
- noiser = GaussianNoiser(generator=generator)
98
- stepper = EulerDiffusionStep()
99
- cfg_guider = CFGGuider(cfg_guidance_scale)
100
- dtype = torch.bfloat16
101
-
102
- text_encoder = self.stage_1_model_ledger.text_encoder()
103
- if enhance_prompt:
104
- prompt = generate_enhanced_prompt(
105
- text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
106
- )
107
- context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
108
- v_context_p, a_context_p = context_p
109
- v_context_n, a_context_n = context_n
110
-
111
- torch.cuda.synchronize()
112
- del text_encoder
113
- cleanup_memory()
114
-
115
- # Stage 1: Initial low resolution video generation.
116
- video_encoder = self.stage_1_model_ledger.video_encoder()
117
- transformer = self.stage_1_model_ledger.transformer()
118
- sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
119
-
120
- def first_stage_denoising_loop(
121
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
122
- ) -> tuple[LatentState, LatentState]:
123
- return euler_denoising_loop(
124
- sigmas=sigmas,
125
- video_state=video_state,
126
- audio_state=audio_state,
127
- stepper=stepper,
128
- denoise_fn=guider_denoising_func(
129
- cfg_guider,
130
- v_context_p,
131
- v_context_n,
132
- a_context_p,
133
- a_context_n,
134
- transformer=transformer, # noqa: F821
135
- ),
136
- )
137
-
138
- stage_1_output_shape = VideoPixelShape(
139
- batch=1,
140
- frames=num_frames,
141
- width=width // 2,
142
- height=height // 2,
143
- fps=frame_rate,
144
- )
145
- stage_1_conditionings = image_conditionings_by_adding_guiding_latent(
146
- images=images,
147
- height=stage_1_output_shape.height,
148
- width=stage_1_output_shape.width,
149
- video_encoder=video_encoder,
150
- dtype=dtype,
151
- device=self.device,
152
- )
153
- video_state, audio_state = denoise_audio_video(
154
- output_shape=stage_1_output_shape,
155
- conditionings=stage_1_conditionings,
156
- noiser=noiser,
157
- sigmas=sigmas,
158
- stepper=stepper,
159
- denoising_loop_fn=first_stage_denoising_loop,
160
- components=self.pipeline_components,
161
- dtype=dtype,
162
- device=self.device,
163
- )
164
-
165
- torch.cuda.synchronize()
166
- del transformer
167
- cleanup_memory()
168
-
169
- # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
170
- upscaled_video_latent = upsample_video(
171
- latent=video_state.latent[:1],
172
- video_encoder=video_encoder,
173
- upsampler=self.stage_2_model_ledger.spatial_upsampler(),
174
- )
175
-
176
- torch.cuda.synchronize()
177
- cleanup_memory()
178
-
179
- transformer = self.stage_2_model_ledger.transformer()
180
- distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
181
-
182
- def second_stage_denoising_loop(
183
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
184
- ) -> tuple[LatentState, LatentState]:
185
- return euler_denoising_loop(
186
- sigmas=sigmas,
187
- video_state=video_state,
188
- audio_state=audio_state,
189
- stepper=stepper,
190
- denoise_fn=simple_denoising_func(
191
- video_context=v_context_p,
192
- audio_context=a_context_p,
193
- transformer=transformer, # noqa: F821
194
- ),
195
- )
196
-
197
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
198
- stage_2_conditionings = image_conditionings_by_adding_guiding_latent(
199
- images=images,
200
- height=stage_2_output_shape.height,
201
- width=stage_2_output_shape.width,
202
- video_encoder=video_encoder,
203
- dtype=dtype,
204
- device=self.device,
205
- )
206
- video_state, audio_state = denoise_audio_video(
207
- output_shape=stage_2_output_shape,
208
- conditionings=stage_2_conditionings,
209
- noiser=noiser,
210
- sigmas=distilled_sigmas,
211
- stepper=stepper,
212
- denoising_loop_fn=second_stage_denoising_loop,
213
- components=self.pipeline_components,
214
- dtype=dtype,
215
- device=self.device,
216
- noise_scale=distilled_sigmas[0],
217
- initial_video_latent=upscaled_video_latent,
218
- initial_audio_latent=audio_state.latent,
219
- )
220
-
221
- torch.cuda.synchronize()
222
- del transformer
223
- del video_encoder
224
- cleanup_memory()
225
-
226
- decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
227
- decoded_audio = vae_decode_audio(
228
- audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
229
- )
230
- return decoded_video, decoded_audio
231
-
232
-
233
- @torch.inference_mode()
234
- def main() -> None:
235
- logging.getLogger().setLevel(logging.INFO)
236
- parser = default_2_stage_arg_parser()
237
- args = parser.parse_args()
238
- pipeline = KeyframeInterpolationPipeline(
239
- checkpoint_path=args.checkpoint_path,
240
- distilled_lora=args.distilled_lora,
241
- spatial_upsampler_path=args.spatial_upsampler_path,
242
- gemma_root=args.gemma_root,
243
- loras=args.lora,
244
- fp8transformer=args.enable_fp8,
245
- )
246
- tiling_config = TilingConfig.default()
247
- video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
248
- video, audio = pipeline(
249
- prompt=args.prompt,
250
- negative_prompt=args.negative_prompt,
251
- seed=args.seed,
252
- height=args.height,
253
- width=args.width,
254
- num_frames=args.num_frames,
255
- frame_rate=args.frame_rate,
256
- num_inference_steps=args.num_inference_steps,
257
- cfg_guidance_scale=args.cfg_guidance_scale,
258
- images=args.images,
259
- tiling_config=tiling_config,
260
- )
261
-
262
- encode_video(
263
- video=video,
264
- fps=args.frame_rate,
265
- audio=audio,
266
- audio_sample_rate=AUDIO_SAMPLE_RATE,
267
- output_path=args.output_path,
268
- video_chunks_number=video_chunks_number,
269
- )
270
-
271
-
272
- if __name__ == "__main__":
273
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/ti2vid_one_stage.py DELETED
@@ -1,193 +0,0 @@
1
- import logging
2
- from collections.abc import Iterator
3
-
4
- import torch
5
-
6
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
- from ltx_core.components.guiders import CFGGuider
8
- from ltx_core.components.noisers import GaussianNoiser
9
- from ltx_core.components.protocols import DiffusionStepProtocol
10
- from ltx_core.components.schedulers import LTX2Scheduler
11
- from ltx_core.loader import LoraPathStrengthAndSDOps
12
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
- from ltx_core.model.video_vae import decode_video as vae_decode_video
14
- from ltx_core.text_encoders.gemma import encode_text
15
- from ltx_core.types import LatentState, VideoPixelShape
16
- from ltx_pipelines.utils import ModelLedger
17
- from ltx_pipelines.utils.args import default_1_stage_arg_parser
18
- from ltx_pipelines.utils.constants import AUDIO_SAMPLE_RATE
19
- from ltx_pipelines.utils.helpers import (
20
- assert_resolution,
21
- cleanup_memory,
22
- denoise_audio_video,
23
- euler_denoising_loop,
24
- generate_enhanced_prompt,
25
- get_device,
26
- guider_denoising_func,
27
- image_conditionings_by_replacing_latent,
28
- )
29
- from ltx_pipelines.utils.media_io import encode_video
30
- from ltx_pipelines.utils.types import PipelineComponents
31
-
32
- device = get_device()
33
-
34
-
35
- class TI2VidOneStagePipeline:
36
- """
37
- Single-stage text/image-to-video generation pipeline.
38
- Generates video at the target resolution in a single diffusion pass with
39
- classifier-free guidance (CFG). Supports optional image conditioning via
40
- the images parameter.
41
- """
42
-
43
- def __init__(
44
- self,
45
- checkpoint_path: str,
46
- gemma_root: str,
47
- loras: list[LoraPathStrengthAndSDOps],
48
- device: torch.device = device,
49
- fp8transformer: bool = False,
50
- ):
51
- self.dtype = torch.bfloat16
52
- self.device = device
53
- self.model_ledger = ModelLedger(
54
- dtype=self.dtype,
55
- device=device,
56
- checkpoint_path=checkpoint_path,
57
- gemma_root_path=gemma_root,
58
- loras=loras,
59
- fp8transformer=fp8transformer,
60
- )
61
- self.pipeline_components = PipelineComponents(
62
- dtype=self.dtype,
63
- device=device,
64
- )
65
-
66
- def __call__( # noqa: PLR0913
67
- self,
68
- prompt: str,
69
- negative_prompt: str,
70
- seed: int,
71
- height: int,
72
- width: int,
73
- num_frames: int,
74
- frame_rate: float,
75
- num_inference_steps: int,
76
- cfg_guidance_scale: float,
77
- images: list[tuple[str, int, float]],
78
- enhance_prompt: bool = False,
79
- ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
80
- assert_resolution(height=height, width=width, is_two_stage=False)
81
-
82
- generator = torch.Generator(device=self.device).manual_seed(seed)
83
- noiser = GaussianNoiser(generator=generator)
84
- stepper = EulerDiffusionStep()
85
- cfg_guider = CFGGuider(cfg_guidance_scale)
86
- dtype = torch.bfloat16
87
-
88
- text_encoder = self.model_ledger.text_encoder()
89
- if enhance_prompt:
90
- prompt = generate_enhanced_prompt(
91
- text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
92
- )
93
- context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
94
- v_context_p, a_context_p = context_p
95
- v_context_n, a_context_n = context_n
96
-
97
- torch.cuda.synchronize()
98
- del text_encoder
99
- cleanup_memory()
100
-
101
- # Stage 1: Initial low resolution video generation.
102
- video_encoder = self.model_ledger.video_encoder()
103
- transformer = self.model_ledger.transformer()
104
- sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
105
-
106
- def first_stage_denoising_loop(
107
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
108
- ) -> tuple[LatentState, LatentState]:
109
- return euler_denoising_loop(
110
- sigmas=sigmas,
111
- video_state=video_state,
112
- audio_state=audio_state,
113
- stepper=stepper,
114
- denoise_fn=guider_denoising_func(
115
- cfg_guider,
116
- v_context_p,
117
- v_context_n,
118
- a_context_p,
119
- a_context_n,
120
- transformer=transformer, # noqa: F821
121
- ),
122
- )
123
-
124
- stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
125
- stage_1_conditionings = image_conditionings_by_replacing_latent(
126
- images=images,
127
- height=stage_1_output_shape.height,
128
- width=stage_1_output_shape.width,
129
- video_encoder=video_encoder,
130
- dtype=dtype,
131
- device=self.device,
132
- )
133
-
134
- video_state, audio_state = denoise_audio_video(
135
- output_shape=stage_1_output_shape,
136
- conditionings=stage_1_conditionings,
137
- noiser=noiser,
138
- sigmas=sigmas,
139
- stepper=stepper,
140
- denoising_loop_fn=first_stage_denoising_loop,
141
- components=self.pipeline_components,
142
- dtype=dtype,
143
- device=self.device,
144
- )
145
-
146
- torch.cuda.synchronize()
147
- del transformer
148
- cleanup_memory()
149
-
150
- decoded_video = vae_decode_video(video_state.latent, self.model_ledger.video_decoder())
151
- decoded_audio = vae_decode_audio(
152
- audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
153
- )
154
-
155
- return decoded_video, decoded_audio
156
-
157
-
158
- @torch.inference_mode()
159
- def main() -> None:
160
- logging.getLogger().setLevel(logging.INFO)
161
- parser = default_1_stage_arg_parser()
162
- args = parser.parse_args()
163
- pipeline = TI2VidOneStagePipeline(
164
- checkpoint_path=args.checkpoint_path,
165
- gemma_root=args.gemma_root,
166
- loras=args.lora,
167
- fp8transformer=args.enable_fp8,
168
- )
169
- video, audio = pipeline(
170
- prompt=args.prompt,
171
- negative_prompt=args.negative_prompt,
172
- seed=args.seed,
173
- height=args.height,
174
- width=args.width,
175
- num_frames=args.num_frames,
176
- frame_rate=args.frame_rate,
177
- num_inference_steps=args.num_inference_steps,
178
- cfg_guidance_scale=args.cfg_guidance_scale,
179
- images=args.images,
180
- )
181
-
182
- encode_video(
183
- video=video,
184
- fps=args.frame_rate,
185
- audio=audio,
186
- audio_sample_rate=AUDIO_SAMPLE_RATE,
187
- output_path=args.output_path,
188
- video_chunks_number=1,
189
- )
190
-
191
-
192
- if __name__ == "__main__":
193
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/ti2vid_two_stages.py DELETED
@@ -1,276 +0,0 @@
1
- import logging
2
- from collections.abc import Iterator
3
-
4
- import torch
5
-
6
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
- from ltx_core.components.guiders import CFGGuider
8
- from ltx_core.components.noisers import GaussianNoiser
9
- from ltx_core.components.protocols import DiffusionStepProtocol
10
- from ltx_core.components.schedulers import LTX2Scheduler
11
- from ltx_core.loader import LoraPathStrengthAndSDOps
12
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
- from ltx_core.model.upsampler import upsample_video
14
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
15
- from ltx_core.model.video_vae import decode_video as vae_decode_video
16
- from ltx_core.text_encoders.gemma import encode_text
17
- from ltx_core.types import LatentState, VideoPixelShape
18
- from ltx_pipelines.utils import ModelLedger
19
- from ltx_pipelines.utils.args import default_2_stage_arg_parser
20
- from ltx_pipelines.utils.constants import (
21
- AUDIO_SAMPLE_RATE,
22
- STAGE_2_DISTILLED_SIGMA_VALUES,
23
- )
24
- from ltx_pipelines.utils.helpers import (
25
- assert_resolution,
26
- cleanup_memory,
27
- denoise_audio_video,
28
- euler_denoising_loop,
29
- generate_enhanced_prompt,
30
- get_device,
31
- guider_denoising_func,
32
- image_conditionings_by_replacing_latent,
33
- simple_denoising_func,
34
- )
35
- from ltx_pipelines.utils.media_io import encode_video
36
- from ltx_pipelines.utils.types import PipelineComponents
37
-
38
- device = get_device()
39
-
40
-
41
- class TI2VidTwoStagesPipeline:
42
- """
43
- Two-stage text/image-to-video generation pipeline.
44
- Stage 1 generates video at the target resolution with CFG guidance, then
45
- Stage 2 upsamples by 2x and refines using a distilled LoRA for higher
46
- quality output. Supports optional image conditioning via the images parameter.
47
- """
48
-
49
- def __init__(
50
- self,
51
- checkpoint_path: str,
52
- distilled_lora: list[LoraPathStrengthAndSDOps],
53
- spatial_upsampler_path: str,
54
- gemma_root: str,
55
- loras: list[LoraPathStrengthAndSDOps],
56
- device: str = device,
57
- fp8transformer: bool = False,
58
- ):
59
- self.device = device
60
- self.dtype = torch.bfloat16
61
- self.stage_1_model_ledger = ModelLedger(
62
- dtype=self.dtype,
63
- device=device,
64
- checkpoint_path=checkpoint_path,
65
- gemma_root_path=gemma_root,
66
- spatial_upsampler_path=spatial_upsampler_path,
67
- loras=loras,
68
- fp8transformer=fp8transformer,
69
- )
70
-
71
- self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
72
- loras=distilled_lora,
73
- )
74
-
75
- self.pipeline_components = PipelineComponents(
76
- dtype=self.dtype,
77
- device=device,
78
- )
79
-
80
- @torch.inference_mode()
81
- def __call__( # noqa: PLR0913
82
- self,
83
- prompt: str,
84
- negative_prompt: str,
85
- seed: int,
86
- height: int,
87
- width: int,
88
- num_frames: int,
89
- frame_rate: float,
90
- num_inference_steps: int,
91
- cfg_guidance_scale: float,
92
- images: list[tuple[str, int, float]],
93
- tiling_config: TilingConfig | None = None,
94
- enhance_prompt: bool = False,
95
- ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
96
- assert_resolution(height=height, width=width, is_two_stage=True)
97
-
98
- generator = torch.Generator(device=self.device).manual_seed(seed)
99
- noiser = GaussianNoiser(generator=generator)
100
- stepper = EulerDiffusionStep()
101
- cfg_guider = CFGGuider(cfg_guidance_scale)
102
- dtype = torch.bfloat16
103
-
104
- text_encoder = self.stage_1_model_ledger.text_encoder()
105
- if enhance_prompt:
106
- prompt = generate_enhanced_prompt(
107
- text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
108
- )
109
- context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
110
- v_context_p, a_context_p = context_p
111
- v_context_n, a_context_n = context_n
112
-
113
- torch.cuda.synchronize()
114
- del text_encoder
115
- cleanup_memory()
116
-
117
- # Stage 1: Initial low resolution video generation.
118
- video_encoder = self.stage_1_model_ledger.video_encoder()
119
- transformer = self.stage_1_model_ledger.transformer()
120
- sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
121
-
122
- def first_stage_denoising_loop(
123
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
124
- ) -> tuple[LatentState, LatentState]:
125
- return euler_denoising_loop(
126
- sigmas=sigmas,
127
- video_state=video_state,
128
- audio_state=audio_state,
129
- stepper=stepper,
130
- denoise_fn=guider_denoising_func(
131
- cfg_guider,
132
- v_context_p,
133
- v_context_n,
134
- a_context_p,
135
- a_context_n,
136
- transformer=transformer, # noqa: F821
137
- ),
138
- )
139
-
140
- stage_1_output_shape = VideoPixelShape(
141
- batch=1,
142
- frames=num_frames,
143
- width=width // 2,
144
- height=height // 2,
145
- fps=frame_rate,
146
- )
147
- stage_1_conditionings = image_conditionings_by_replacing_latent(
148
- images=images,
149
- height=stage_1_output_shape.height,
150
- width=stage_1_output_shape.width,
151
- video_encoder=video_encoder,
152
- dtype=dtype,
153
- device=self.device,
154
- )
155
- video_state, audio_state = denoise_audio_video(
156
- output_shape=stage_1_output_shape,
157
- conditionings=stage_1_conditionings,
158
- noiser=noiser,
159
- sigmas=sigmas,
160
- stepper=stepper,
161
- denoising_loop_fn=first_stage_denoising_loop,
162
- components=self.pipeline_components,
163
- dtype=dtype,
164
- device=self.device,
165
- )
166
-
167
- torch.cuda.synchronize()
168
- del transformer
169
- cleanup_memory()
170
-
171
- # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
172
- upscaled_video_latent = upsample_video(
173
- latent=video_state.latent[:1],
174
- video_encoder=video_encoder,
175
- upsampler=self.stage_2_model_ledger.spatial_upsampler(),
176
- )
177
-
178
- torch.cuda.synchronize()
179
- cleanup_memory()
180
-
181
- transformer = self.stage_2_model_ledger.transformer()
182
- distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
183
-
184
- def second_stage_denoising_loop(
185
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
186
- ) -> tuple[LatentState, LatentState]:
187
- return euler_denoising_loop(
188
- sigmas=sigmas,
189
- video_state=video_state,
190
- audio_state=audio_state,
191
- stepper=stepper,
192
- denoise_fn=simple_denoising_func(
193
- video_context=v_context_p,
194
- audio_context=a_context_p,
195
- transformer=transformer, # noqa: F821
196
- ),
197
- )
198
-
199
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
200
- stage_2_conditionings = image_conditionings_by_replacing_latent(
201
- images=images,
202
- height=stage_2_output_shape.height,
203
- width=stage_2_output_shape.width,
204
- video_encoder=video_encoder,
205
- dtype=dtype,
206
- device=self.device,
207
- )
208
- video_state, audio_state = denoise_audio_video(
209
- output_shape=stage_2_output_shape,
210
- conditionings=stage_2_conditionings,
211
- noiser=noiser,
212
- sigmas=distilled_sigmas,
213
- stepper=stepper,
214
- denoising_loop_fn=second_stage_denoising_loop,
215
- components=self.pipeline_components,
216
- dtype=dtype,
217
- device=self.device,
218
- noise_scale=distilled_sigmas[0],
219
- initial_video_latent=upscaled_video_latent,
220
- initial_audio_latent=audio_state.latent,
221
- )
222
-
223
- torch.cuda.synchronize()
224
- del transformer
225
- del video_encoder
226
- cleanup_memory()
227
-
228
- decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
229
- decoded_audio = vae_decode_audio(
230
- audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
231
- )
232
-
233
- return decoded_video, decoded_audio
234
-
235
-
236
- @torch.inference_mode()
237
- def main() -> None:
238
- logging.getLogger().setLevel(logging.INFO)
239
- parser = default_2_stage_arg_parser()
240
- args = parser.parse_args()
241
- pipeline = TI2VidTwoStagesPipeline(
242
- checkpoint_path=args.checkpoint_path,
243
- distilled_lora=args.distilled_lora,
244
- spatial_upsampler_path=args.spatial_upsampler_path,
245
- gemma_root=args.gemma_root,
246
- loras=args.lora,
247
- fp8transformer=args.enable_fp8,
248
- )
249
- tiling_config = TilingConfig.default()
250
- video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
251
- video, audio = pipeline(
252
- prompt=args.prompt,
253
- negative_prompt=args.negative_prompt,
254
- seed=args.seed,
255
- height=args.height,
256
- width=args.width,
257
- num_frames=args.num_frames,
258
- frame_rate=args.frame_rate,
259
- num_inference_steps=args.num_inference_steps,
260
- cfg_guidance_scale=args.cfg_guidance_scale,
261
- images=args.images,
262
- tiling_config=tiling_config,
263
- )
264
-
265
- encode_video(
266
- video=video,
267
- fps=args.frame_rate,
268
- audio=audio,
269
- audio_sample_rate=AUDIO_SAMPLE_RATE,
270
- output_path=args.output_path,
271
- video_chunks_number=video_chunks_number,
272
- )
273
-
274
-
275
- if __name__ == "__main__":
276
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/args.py DELETED
@@ -1,277 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
-
4
- from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
5
- from ltx_pipelines.utils.constants import (
6
- DEFAULT_1_STAGE_HEIGHT,
7
- DEFAULT_1_STAGE_WIDTH,
8
- DEFAULT_2_STAGE_HEIGHT,
9
- DEFAULT_2_STAGE_WIDTH,
10
- DEFAULT_CFG_GUIDANCE_SCALE,
11
- DEFAULT_FRAME_RATE,
12
- DEFAULT_LORA_STRENGTH,
13
- DEFAULT_NEGATIVE_PROMPT,
14
- DEFAULT_NUM_FRAMES,
15
- DEFAULT_NUM_INFERENCE_STEPS,
16
- DEFAULT_SEED,
17
- )
18
-
19
-
20
- class VideoConditioningAction(argparse.Action):
21
- def __call__(
22
- self,
23
- parser: argparse.ArgumentParser, # noqa: ARG002
24
- namespace: argparse.Namespace,
25
- values: list[str],
26
- option_string: str | None = None, # noqa: ARG002
27
- ) -> None:
28
- path, strength_str = values
29
- resolved_path = resolve_path(path)
30
- strength = float(strength_str)
31
- current = getattr(namespace, self.dest) or []
32
- current.append((resolved_path, strength))
33
- setattr(namespace, self.dest, current)
34
-
35
-
36
- class ImageAction(argparse.Action):
37
- def __call__(
38
- self,
39
- parser: argparse.ArgumentParser, # noqa: ARG002
40
- namespace: argparse.Namespace,
41
- values: list[str],
42
- option_string: str | None = None, # noqa: ARG002
43
- ) -> None:
44
- path, frame_idx, strength_str = values
45
- resolved_path = resolve_path(path)
46
- frame_idx = int(frame_idx)
47
- strength = float(strength_str)
48
- current = getattr(namespace, self.dest) or []
49
- current.append((resolved_path, frame_idx, strength))
50
- setattr(namespace, self.dest, current)
51
-
52
-
53
- class LoraAction(argparse.Action):
54
- def __call__(
55
- self,
56
- parser: argparse.ArgumentParser, # noqa: ARG002
57
- namespace: argparse.Namespace,
58
- values: list[str],
59
- option_string: str | None = None,
60
- ) -> None:
61
- if len(values) > 2:
62
- msg = f"{option_string} accepts at most 2 arguments (PATH and optional STRENGTH), got {len(values)} values"
63
- raise argparse.ArgumentError(self, msg)
64
-
65
- path = values[0]
66
- strength_str = values[1] if len(values) > 1 else str(DEFAULT_LORA_STRENGTH)
67
-
68
- resolved_path = resolve_path(path)
69
- strength = float(strength_str)
70
-
71
- current = getattr(namespace, self.dest) or []
72
- current.append(LoraPathStrengthAndSDOps(resolved_path, strength, LTXV_LORA_COMFY_RENAMING_MAP))
73
- setattr(namespace, self.dest, current)
74
-
75
-
76
- def resolve_path(path: str) -> str:
77
- return str(Path(path).expanduser().resolve().as_posix())
78
-
79
-
80
- def basic_arg_parser() -> argparse.ArgumentParser:
81
- parser = argparse.ArgumentParser()
82
- parser.add_argument(
83
- "--checkpoint-path",
84
- type=resolve_path,
85
- required=True,
86
- help="Path to LTX-2 model checkpoint (.safetensors file).",
87
- )
88
- parser.add_argument(
89
- "--gemma-root",
90
- type=resolve_path,
91
- required=True,
92
- help="Path to the root directory containing the Gemma text encoder model files.",
93
- )
94
- parser.add_argument(
95
- "--prompt",
96
- type=str,
97
- required=True,
98
- help="Text prompt describing the desired video content to be generated by the model.",
99
- )
100
- parser.add_argument(
101
- "--output-path",
102
- type=resolve_path,
103
- required=True,
104
- help="Path to the output video file (MP4 format).",
105
- )
106
- parser.add_argument(
107
- "--seed",
108
- type=int,
109
- default=DEFAULT_SEED,
110
- help=(
111
- f"Random seed value used to initialize the noise tensor for "
112
- f"reproducible generation (default: {DEFAULT_SEED})."
113
- ),
114
- )
115
- parser.add_argument(
116
- "--height",
117
- type=int,
118
- default=DEFAULT_1_STAGE_HEIGHT,
119
- help=f"Height of the generated video in pixels, should be divisible by 32 (default: {DEFAULT_1_STAGE_HEIGHT}).",
120
- )
121
- parser.add_argument(
122
- "--width",
123
- type=int,
124
- default=DEFAULT_1_STAGE_WIDTH,
125
- help=f"Width of the generated video in pixels, should be divisible by 32 (default: {DEFAULT_1_STAGE_WIDTH}).",
126
- )
127
- parser.add_argument(
128
- "--num-frames",
129
- type=int,
130
- default=DEFAULT_NUM_FRAMES,
131
- help=f"Number of frames to generate in the output video sequence, num-frames = (8 x K) + 1, "
132
- f"where k is a non-negative integer (default: {DEFAULT_NUM_FRAMES}).",
133
- )
134
- parser.add_argument(
135
- "--frame-rate",
136
- type=float,
137
- default=DEFAULT_FRAME_RATE,
138
- help=f"Frame rate of the generated video (fps) (default: {DEFAULT_FRAME_RATE}).",
139
- )
140
- parser.add_argument(
141
- "--num-inference-steps",
142
- type=int,
143
- default=DEFAULT_NUM_INFERENCE_STEPS,
144
- help=(
145
- f"Number of denoising steps in the diffusion sampling process. "
146
- f"Higher values improve quality but increase generation time (default: {DEFAULT_NUM_INFERENCE_STEPS})."
147
- ),
148
- )
149
- parser.add_argument(
150
- "--image",
151
- dest="images",
152
- action=ImageAction,
153
- nargs=3,
154
- metavar=("PATH", "FRAME_IDX", "STRENGTH"),
155
- default=[],
156
- help=(
157
- "Image conditioning input: path to image file, target frame index, "
158
- "and conditioning strength (all three required). Default: empty list [] (no image conditioning). "
159
- "Can be specified multiple times. Example: --image path/to/image1.jpg 0 0.8 "
160
- "--image path/to/image2.jpg 160 0.9"
161
- ),
162
- )
163
- parser.add_argument(
164
- "--lora",
165
- dest="lora",
166
- action=LoraAction,
167
- nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
168
- metavar=("PATH", "STRENGTH"),
169
- default=[],
170
- help=(
171
- "LoRA (Low-Rank Adaptation) model: path to model file and optional strength "
172
- f"(default strength: {DEFAULT_LORA_STRENGTH}). Can be specified multiple times. "
173
- "Example: --lora path/to/lora1.safetensors 0.8 --lora path/to/lora2.safetensors"
174
- ),
175
- )
176
- parser.add_argument(
177
- "--enable-fp8",
178
- action="store_true",
179
- help="Enable FP8 mode to reduce memory footprint by keeping model in lower precision. "
180
- "Note that calculations are still performed in bfloat16 precision.",
181
- )
182
- parser.add_argument("--enhance-prompt", action="store_true")
183
- return parser
184
-
185
-
186
- def default_1_stage_arg_parser() -> argparse.ArgumentParser:
187
- parser = basic_arg_parser()
188
- parser.add_argument(
189
- "--cfg-guidance-scale",
190
- type=float,
191
- default=DEFAULT_CFG_GUIDANCE_SCALE,
192
- help=(
193
- f"Classifier-free guidance (CFG) scale controlling how strongly "
194
- f"the model adheres to the prompt. Higher values increase prompt "
195
- f"adherence but may reduce diversity (default: {DEFAULT_CFG_GUIDANCE_SCALE})."
196
- ),
197
- )
198
- parser.add_argument(
199
- "--negative-prompt",
200
- type=str,
201
- default=DEFAULT_NEGATIVE_PROMPT,
202
- help=(
203
- "Negative prompt describing what should not appear in the generated video, "
204
- "used to guide the diffusion process away from unwanted content. "
205
- "Default: a comprehensive negative prompt covering common artifacts and quality issues."
206
- ),
207
- )
208
-
209
- return parser
210
-
211
-
212
- def default_2_stage_arg_parser() -> argparse.ArgumentParser:
213
- parser = default_1_stage_arg_parser()
214
- parser.set_defaults(height=DEFAULT_2_STAGE_HEIGHT, width=DEFAULT_2_STAGE_WIDTH)
215
- # Update help text to reflect 2-stage defaults
216
- for action in parser._actions:
217
- if "--height" in action.option_strings:
218
- action.help = (
219
- f"Height of the generated video in pixels, should be divisible by 64 "
220
- f"(default: {DEFAULT_2_STAGE_HEIGHT})."
221
- )
222
- if "--width" in action.option_strings:
223
- action.help = (
224
- f"Width of the generated video in pixels, should be divisible by 64 (default: {DEFAULT_2_STAGE_WIDTH})."
225
- )
226
- parser.add_argument(
227
- "--distilled-lora",
228
- dest="distilled_lora",
229
- action=LoraAction,
230
- nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
231
- metavar=("PATH", "STRENGTH"),
232
- required=True,
233
- help=(
234
- "Distilled LoRA (Low-Rank Adaptation) model used in the second stage (upscaling and refinement): "
235
- f"path to model file and optional strength (default strength: {DEFAULT_LORA_STRENGTH}). "
236
- "The second stage upsamples the video by 2x resolution and refines it using a distilled "
237
- "denoising schedule (fewer steps, no CFG). The distilled LoRA is specifically trained "
238
- "for this refinement process to improve quality at higher resolutions. "
239
- "Example: --distilled-lora path/to/distilled_lora.safetensors 0.8"
240
- ),
241
- )
242
- parser.add_argument(
243
- "--spatial-upsampler-path",
244
- type=resolve_path,
245
- required=True,
246
- help=(
247
- "Path to the spatial upsampler model used to increase the resolution "
248
- "of the generated video in the latent space."
249
- ),
250
- )
251
- return parser
252
-
253
-
254
- def default_2_stage_distilled_arg_parser() -> argparse.ArgumentParser:
255
- parser = basic_arg_parser()
256
- parser.set_defaults(height=DEFAULT_2_STAGE_HEIGHT, width=DEFAULT_2_STAGE_WIDTH)
257
- # Update help text to reflect 2-stage defaults
258
- for action in parser._actions:
259
- if "--height" in action.option_strings:
260
- action.help = (
261
- f"Height of the generated video in pixels, should be divisible by 64 "
262
- f"(default: {DEFAULT_2_STAGE_HEIGHT})."
263
- )
264
- if "--width" in action.option_strings:
265
- action.help = (
266
- f"Width of the generated video in pixels, should be divisible by 64 (default: {DEFAULT_2_STAGE_WIDTH})."
267
- )
268
- parser.add_argument(
269
- "--spatial-upsampler-path",
270
- type=resolve_path,
271
- required=True,
272
- help=(
273
- "Path to the spatial upsampler model used to increase the resolution "
274
- "of the generated video in the latent space."
275
- ),
276
- )
277
- return parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/constants.py DELETED
@@ -1,77 +0,0 @@
1
- # =============================================================================
2
- # Diffusion Schedule
3
- # =============================================================================
4
-
5
- # Noise schedule for the distilled pipeline. These sigma values control noise
6
- # levels at each denoising step and were tuned to match the distillation process.
7
- from ltx_core.types import SpatioTemporalScaleFactors
8
-
9
- DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
10
-
11
- # Reduced schedule for super-resolution stage 2 (subset of distilled values)
12
- STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
13
-
14
-
15
- # =============================================================================
16
- # Video Generation Defaults
17
- # =============================================================================
18
-
19
- DEFAULT_SEED = 10
20
- DEFAULT_1_STAGE_HEIGHT = 512
21
- DEFAULT_1_STAGE_WIDTH = 768
22
- DEFAULT_2_STAGE_HEIGHT = DEFAULT_1_STAGE_HEIGHT * 2
23
- DEFAULT_2_STAGE_WIDTH = DEFAULT_1_STAGE_WIDTH * 2
24
- DEFAULT_NUM_FRAMES = 121
25
- DEFAULT_FRAME_RATE = 24.0
26
- DEFAULT_NUM_INFERENCE_STEPS = 40
27
- DEFAULT_CFG_GUIDANCE_SCALE = 4.0
28
-
29
-
30
- # =============================================================================
31
- # Audio
32
- # =============================================================================
33
-
34
- AUDIO_SAMPLE_RATE = 24000
35
-
36
-
37
- # =============================================================================
38
- # LoRA
39
- # =============================================================================
40
-
41
- DEFAULT_LORA_STRENGTH = 1.0
42
-
43
-
44
- # =============================================================================
45
- # Video VAE Architecture
46
- # =============================================================================
47
-
48
- VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
49
- VIDEO_LATENT_CHANNELS = 128
50
-
51
-
52
- # =============================================================================
53
- # Image Preprocessing
54
- # =============================================================================
55
-
56
- # CRF (Constant Rate Factor) for H.264 encoding used in image conditioning.
57
- # Lower = higher quality, 0 = lossless. This mimics compression artifacts.
58
- DEFAULT_IMAGE_CRF = 33
59
-
60
-
61
- # =============================================================================
62
- # Prompts
63
- # =============================================================================
64
-
65
- DEFAULT_NEGATIVE_PROMPT = (
66
- "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
67
- "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
68
- "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
69
- "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
70
- "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
71
- "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
72
- "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
73
- "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
74
- "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
75
- "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
76
- "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
77
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/helpers.py DELETED
@@ -1,507 +0,0 @@
1
- import gc
2
- import logging
3
- from dataclasses import replace
4
-
5
- import torch
6
- from tqdm import tqdm
7
-
8
- from ltx_core.components.noisers import Noiser
9
- from ltx_core.components.protocols import DiffusionStepProtocol, GuiderProtocol
10
- from ltx_core.conditioning import (
11
- ConditioningItem,
12
- VideoConditionByKeyframeIndex,
13
- VideoConditionByLatentIndex,
14
- )
15
- from ltx_core.model.transformer import Modality, X0Model
16
- from ltx_core.model.video_vae import VideoEncoder
17
- from ltx_core.text_encoders.gemma import GemmaTextEncoderModelBase
18
- from ltx_core.tools import AudioLatentTools, LatentTools, VideoLatentTools
19
- from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape
20
- from ltx_core.utils import to_denoised, to_velocity
21
- from ltx_pipelines.utils.media_io import decode_image, load_image_conditioning, resize_aspect_ratio_preserving
22
- from ltx_pipelines.utils.types import (
23
- DenoisingFunc,
24
- DenoisingLoopFunc,
25
- PipelineComponents,
26
- )
27
-
28
-
29
- def get_device() -> torch.device:
30
- if torch.cuda.is_available():
31
- return torch.device("cuda")
32
- return torch.device("cpu")
33
-
34
-
35
- def cleanup_memory() -> None:
36
- gc.collect()
37
- torch.cuda.empty_cache()
38
- torch.cuda.synchronize()
39
-
40
-
41
- def image_conditionings_by_replacing_latent(
42
- images: list[tuple[str, int, float]],
43
- height: int,
44
- width: int,
45
- video_encoder: VideoEncoder,
46
- dtype: torch.dtype,
47
- device: torch.device,
48
- ) -> list[ConditioningItem]:
49
- conditionings = []
50
- for image_path, frame_idx, strength in images:
51
- image = load_image_conditioning(
52
- image_path=image_path,
53
- height=height,
54
- width=width,
55
- dtype=dtype,
56
- device=device,
57
- )
58
- encoded_image = video_encoder(image)
59
- conditionings.append(
60
- VideoConditionByLatentIndex(
61
- latent=encoded_image,
62
- strength=strength,
63
- latent_idx=frame_idx,
64
- )
65
- )
66
-
67
- return conditionings
68
-
69
-
70
- def image_conditionings_by_adding_guiding_latent(
71
- images: list[tuple[str, int, float]],
72
- height: int,
73
- width: int,
74
- video_encoder: VideoEncoder,
75
- dtype: torch.dtype,
76
- device: torch.device,
77
- ) -> list[ConditioningItem]:
78
- conditionings = []
79
- for image_path, frame_idx, strength in images:
80
- image = load_image_conditioning(
81
- image_path=image_path,
82
- height=height,
83
- width=width,
84
- dtype=dtype,
85
- device=device,
86
- )
87
- encoded_image = video_encoder(image)
88
- conditionings.append(
89
- VideoConditionByKeyframeIndex(keyframes=encoded_image, frame_idx=frame_idx, strength=strength)
90
- )
91
- return conditionings
92
-
93
-
94
- def euler_denoising_loop(
95
- sigmas: torch.Tensor,
96
- video_state: LatentState,
97
- audio_state: LatentState,
98
- stepper: DiffusionStepProtocol,
99
- denoise_fn: DenoisingFunc,
100
- ) -> tuple[LatentState, LatentState]:
101
- """
102
- Perform the joint audio-video denoising loop over a diffusion schedule.
103
- This function iterates over all but the final value in ``sigmas`` and, at
104
- each diffusion step, calls ``denoise_fn`` to obtain denoised video and
105
- audio latents. The denoised latents are post-processed with their
106
- respective denoise masks and clean latents, then passed to ``stepper`` to
107
- advance the noisy latents one step along the diffusion schedule.
108
- ### Parameters
109
- sigmas:
110
- A 1D tensor of noise levels (diffusion sigmas) defining the sampling
111
- schedule. All steps except the last element are iterated over.
112
- video_state:
113
- The current video :class:`LatentState`, containing the noisy latent,
114
- its clean reference latent, and the denoising mask.
115
- audio_state:
116
- The current audio :class:`LatentState`, analogous to ``video_state``
117
- but for the audio modality.
118
- stepper:
119
- An implementation of :class:`DiffusionStepProtocol` that updates a
120
- latent given the current latent, its denoised estimate, the full
121
- ``sigmas`` schedule, and the current step index.
122
- denoise_fn:
123
- A callable implementing :class:`DenoisingFunc`. It is invoked as
124
- ``denoise_fn(video_state, audio_state, sigmas, step_index)`` and must
125
- return a tuple ``(denoised_video, denoised_audio)``, where each element
126
- is a tensor with the same shape as the corresponding latent.
127
- ### Returns
128
- tuple[LatentState, LatentState]
129
- A pair ``(video_state, audio_state)`` containing the final video and
130
- audio latent states after completing the denoising loop.
131
- """
132
- for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
133
- denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx)
134
-
135
- denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent)
136
- denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent)
137
-
138
- video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx))
139
- audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx))
140
-
141
- return (video_state, audio_state)
142
-
143
-
144
- def gradient_estimating_euler_denoising_loop(
145
- sigmas: torch.Tensor,
146
- video_state: LatentState,
147
- audio_state: LatentState,
148
- stepper: DiffusionStepProtocol,
149
- denoise_fn: DenoisingFunc,
150
- ge_gamma: float = 2.0,
151
- ) -> tuple[LatentState, LatentState]:
152
- """
153
- Perform the joint audio-video denoising loop using gradient-estimation sampling.
154
- This function is similar to :func:`euler_denoising_loop`, but applies
155
- gradient estimation to improve the denoised estimates by tracking velocity
156
- changes across steps. See the referenced function for detailed parameter
157
- documentation.
158
- ### Parameters
159
- ge_gamma:
160
- Gradient estimation coefficient controlling the velocity correction term.
161
- Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK
162
- sigmas, video_state, audio_state, stepper, denoise_fn:
163
- See :func:`euler_denoising_loop` for parameter descriptions.
164
- ### Returns
165
- tuple[LatentState, LatentState]
166
- See :func:`euler_denoising_loop` for return value description.
167
- """
168
-
169
- previous_audio_velocity = None
170
- previous_video_velocity = None
171
-
172
- def update_velocity_and_sample(
173
- noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None
174
- ) -> tuple[torch.Tensor, torch.Tensor]:
175
- current_velocity = to_velocity(noisy_sample, sigma, denoised_sample)
176
- if previous_velocity is not None:
177
- delta_v = current_velocity - previous_velocity
178
- total_velocity = ge_gamma * delta_v + previous_velocity
179
- denoised_sample = to_denoised(noisy_sample, total_velocity, sigma)
180
- return current_velocity, denoised_sample
181
-
182
- for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
183
- denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx)
184
-
185
- denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent)
186
- denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent)
187
-
188
- if sigmas[step_idx + 1] == 0:
189
- return replace(video_state, latent=denoised_video), replace(audio_state, latent=denoised_audio)
190
-
191
- previous_video_velocity, denoised_video = update_velocity_and_sample(
192
- video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity
193
- )
194
- previous_audio_velocity, denoised_audio = update_velocity_and_sample(
195
- audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity
196
- )
197
-
198
- video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx))
199
- audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx))
200
-
201
- return (video_state, audio_state)
202
-
203
-
204
- def noise_video_state(
205
- output_shape: VideoPixelShape,
206
- noiser: Noiser,
207
- conditionings: list[ConditioningItem],
208
- components: PipelineComponents,
209
- dtype: torch.dtype,
210
- device: torch.device,
211
- noise_scale: float = 1.0,
212
- initial_latent: torch.Tensor | None = None,
213
- ) -> tuple[LatentState, VideoLatentTools]:
214
- """Initialize and noise a video latent state for the diffusion pipeline.
215
- Creates a video latent state from the output shape, applies conditionings,
216
- and adds noise using the provided noiser. Returns the noised state and
217
- video latent tools for further processing. If initial_latent is provided, it will be used to create the initial
218
- state, otherwise an empty initial state will be created.
219
- """
220
- video_latent_shape = VideoLatentShape.from_pixel_shape(
221
- shape=output_shape,
222
- latent_channels=components.video_latent_channels,
223
- scale_factors=components.video_scale_factors,
224
- )
225
- video_tools = VideoLatentTools(components.video_patchifier, video_latent_shape, output_shape.fps)
226
- video_state = create_noised_state(
227
- tools=video_tools,
228
- conditionings=conditionings,
229
- noiser=noiser,
230
- dtype=dtype,
231
- device=device,
232
- noise_scale=noise_scale,
233
- initial_latent=initial_latent,
234
- )
235
-
236
- return video_state, video_tools
237
-
238
-
239
- def noise_audio_state(
240
- output_shape: VideoPixelShape,
241
- noiser: Noiser,
242
- conditionings: list[ConditioningItem],
243
- components: PipelineComponents,
244
- dtype: torch.dtype,
245
- device: torch.device,
246
- noise_scale: float = 1.0,
247
- initial_latent: torch.Tensor | None = None,
248
- denoise_mask: torch.Tensor | None = None
249
- ) -> tuple[LatentState, AudioLatentTools]:
250
- """Initialize and noise an audio latent state for the diffusion pipeline.
251
- Creates an audio latent state from the output shape, applies conditionings,
252
- and adds noise using the provided noiser. Returns the noised state and
253
- audio latent tools for further processing. If initial_latent is provided, it will be used to create the initial
254
- state, otherwise an empty initial state will be created.
255
- """
256
- audio_latent_shape = AudioLatentShape.from_video_pixel_shape(output_shape)
257
- audio_tools = AudioLatentTools(components.audio_patchifier, audio_latent_shape)
258
- audio_state = create_noised_state(
259
- tools=audio_tools,
260
- conditionings=conditionings,
261
- noiser=noiser,
262
- dtype=dtype,
263
- device=device,
264
- noise_scale=noise_scale,
265
- initial_latent=initial_latent,
266
- denoise_mask=denoise_mask,
267
- )
268
-
269
- return audio_state, audio_tools
270
-
271
-
272
- def create_noised_state(
273
- tools: LatentTools,
274
- conditionings: list[ConditioningItem],
275
- noiser: Noiser,
276
- dtype: torch.dtype,
277
- device: torch.device,
278
- noise_scale: float = 1.0,
279
- initial_latent: torch.Tensor | None = None,
280
- denoise_mask: torch.Tensor | None = None, # <-- add
281
- ) -> LatentState:
282
- state = tools.create_initial_state(device, dtype, initial_latent)
283
- state = state_with_conditionings(state, conditionings, tools)
284
-
285
- if denoise_mask is not None:
286
- # Convert any tensor mask into a single scalar (solid mask behavior)
287
- if isinstance(denoise_mask, torch.Tensor):
288
- mask_value = float(denoise_mask.mean().item())
289
- else:
290
- mask_value = float(denoise_mask)
291
-
292
- state = replace(
293
- state,
294
- clean_latent=state.latent.clone(),
295
- denoise_mask=torch.full_like(state.denoise_mask, mask_value), # <- matches internal shape
296
- )
297
-
298
- state = noiser(state, noise_scale)
299
-
300
- if denoise_mask is not None:
301
- m = state.denoise_mask.to(dtype=state.latent.dtype, device=state.latent.device)
302
- clean = state.clean_latent.to(dtype=state.latent.dtype, device=state.latent.device)
303
- state = replace(state, latent=state.latent * m + clean * (1 - m))
304
-
305
- return state
306
-
307
-
308
-
309
- def state_with_conditionings(
310
- latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
311
- ) -> LatentState:
312
- """Apply a list of conditionings to a latent state.
313
- Iterates through the conditioning items and applies each one to the latent
314
- state in sequence. Returns the modified state with all conditionings applied.
315
- """
316
- for conditioning in conditioning_items:
317
- latent_state = conditioning.apply_to(latent_state=latent_state, latent_tools=latent_tools)
318
-
319
- return latent_state
320
-
321
-
322
- def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
323
- """Blend denoised output with clean state based on mask."""
324
- clean = clean.to(dtype=denoised.dtype)
325
- denoise_mask = denoise_mask.to(dtype=denoised.dtype)
326
- return denoised * denoise_mask + clean * (1 - denoise_mask)
327
-
328
-
329
- def modality_from_latent_state(
330
- state: LatentState, context: torch.Tensor, sigma: float | torch.Tensor, enabled: bool = True
331
- ) -> Modality:
332
- """Create a Modality from a latent state.
333
- Constructs a Modality object with the latent state's data, timesteps derived
334
- from the denoise mask and sigma, positions, and the provided context.
335
- """
336
- return Modality(
337
- enabled=enabled,
338
- latent=state.latent,
339
- timesteps=timesteps_from_mask(state.denoise_mask, sigma),
340
- positions=state.positions,
341
- context=context,
342
- context_mask=None,
343
- )
344
-
345
-
346
- def timesteps_from_mask(denoise_mask: torch.Tensor, sigma: float | torch.Tensor) -> torch.Tensor:
347
- """Compute timesteps from a denoise mask and sigma value.
348
- Multiplies the denoise mask by sigma to produce timesteps for each position
349
- in the latent state. Areas where the mask is 0 will have zero timesteps.
350
- """
351
- return denoise_mask * sigma
352
-
353
-
354
- def simple_denoising_func(
355
- video_context: torch.Tensor, audio_context: torch.Tensor, transformer: X0Model
356
- ) -> DenoisingFunc:
357
- def simple_denoising_step(
358
- video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
359
- ) -> tuple[torch.Tensor, torch.Tensor]:
360
- sigma = sigmas[step_index]
361
- pos_video = modality_from_latent_state(video_state, video_context, sigma)
362
- pos_audio = modality_from_latent_state(audio_state, audio_context, sigma)
363
-
364
- denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None)
365
- return denoised_video, denoised_audio
366
-
367
- return simple_denoising_step
368
-
369
-
370
- def guider_denoising_func(
371
- guider: GuiderProtocol,
372
- v_context_p: torch.Tensor,
373
- v_context_n: torch.Tensor,
374
- a_context_p: torch.Tensor,
375
- a_context_n: torch.Tensor,
376
- transformer: X0Model,
377
- ) -> DenoisingFunc:
378
- def guider_denoising_step(
379
- video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
380
- ) -> tuple[torch.Tensor, torch.Tensor]:
381
- sigma = sigmas[step_index]
382
- pos_video = modality_from_latent_state(video_state, v_context_p, sigma)
383
- pos_audio = modality_from_latent_state(audio_state, a_context_p, sigma)
384
-
385
- denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None)
386
- if guider.enabled():
387
- neg_video = modality_from_latent_state(video_state, v_context_n, sigma)
388
- neg_audio = modality_from_latent_state(audio_state, a_context_n, sigma)
389
-
390
- neg_denoised_video, neg_denoised_audio = transformer(video=neg_video, audio=neg_audio, perturbations=None)
391
-
392
- denoised_video = denoised_video + guider.delta(denoised_video, neg_denoised_video)
393
- denoised_audio = denoised_audio + guider.delta(denoised_audio, neg_denoised_audio)
394
-
395
- return denoised_video, denoised_audio
396
-
397
- return guider_denoising_step
398
-
399
-
400
- def denoise_audio_video( # noqa: PLR0913
401
- output_shape: VideoPixelShape,
402
- conditionings: list[ConditioningItem],
403
- noiser: Noiser,
404
- sigmas: torch.Tensor,
405
- stepper: DiffusionStepProtocol,
406
- denoising_loop_fn: DenoisingLoopFunc,
407
- components: PipelineComponents,
408
- dtype: torch.dtype,
409
- device: torch.device,
410
- audio_conditionings: list[ConditioningItem] | None = None,
411
- noise_scale: float = 1.0,
412
- initial_video_latent: torch.Tensor | None = None,
413
- initial_audio_latent: torch.Tensor | None = None,
414
- # mask_context: MaskInjection | None = None,
415
- ) -> tuple[LatentState | None, LatentState | None]:
416
- video_state, video_tools = noise_video_state(
417
- output_shape=output_shape,
418
- noiser=noiser,
419
- conditionings=conditionings,
420
- components=components,
421
- dtype=dtype,
422
- device=device,
423
- noise_scale=noise_scale,
424
- initial_latent=initial_video_latent,
425
- )
426
- audio_state, audio_tools = noise_audio_state(
427
- output_shape=output_shape,
428
- noiser=noiser,
429
- conditionings=audio_conditionings or [],
430
- components=components,
431
- dtype=dtype,
432
- device=device,
433
- noise_scale=noise_scale,
434
- initial_latent=initial_audio_latent,
435
- )
436
-
437
- loop_kwargs = {}
438
- # if "preview_tools" in inspect.signature(denoising_loop_fn).parameters:
439
- # loop_kwargs["preview_tools"] = video_tools
440
- # if "mask_context" in inspect.signature(denoising_loop_fn).parameters:
441
- # loop_kwargs["mask_context"] = mask_context
442
- video_state, audio_state = denoising_loop_fn(
443
- sigmas,
444
- video_state,
445
- audio_state,
446
- stepper,
447
- **loop_kwargs,
448
- )
449
-
450
- if video_state is None or audio_state is None:
451
- return None, None
452
-
453
- video_state = video_tools.clear_conditioning(video_state)
454
- video_state = video_tools.unpatchify(video_state)
455
- audio_state = audio_tools.clear_conditioning(audio_state)
456
- audio_state = audio_tools.unpatchify(audio_state)
457
-
458
- return video_state, audio_state
459
-
460
-
461
-
462
- _UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
463
-
464
-
465
- def clean_response(text: str) -> str:
466
- """Clean a response from curly quotes and leading non-letter characters which Gemma tends to insert."""
467
- text = text.translate(_UNICODE_REPLACEMENTS)
468
-
469
- # Remove leading non-letter characters
470
- for i, char in enumerate(text):
471
- if char.isalpha():
472
- return text[i:]
473
- return text
474
-
475
-
476
- def generate_enhanced_prompt(
477
- text_encoder: GemmaTextEncoderModelBase,
478
- prompt: str,
479
- image_path: str | None = None,
480
- image_long_side: int = 896,
481
- seed: int = 42,
482
- ) -> str:
483
- """Generate an enhanced prompt from a text encoder and a prompt."""
484
- image = None
485
- if image_path:
486
- image = decode_image(image_path=image_path)
487
- image = torch.tensor(image)
488
- image = resize_aspect_ratio_preserving(image, image_long_side).to(torch.uint8)
489
- prompt = text_encoder.enhance_i2v(prompt, image, seed=seed)
490
- else:
491
- prompt = text_encoder.enhance_t2v(prompt, seed=seed)
492
- logging.info(f"Enhanced prompt: {prompt}")
493
- return clean_response(prompt)
494
-
495
-
496
- def assert_resolution(height: int, width: int, is_two_stage: bool) -> None:
497
- """Assert that the resolution is divisible by the required divisor.
498
- For two-stage pipelines, the resolution must be divisible by 64.
499
- For one-stage pipelines, the resolution must be divisible by 32.
500
- """
501
- divisor = 64 if is_two_stage else 32
502
- if height % divisor != 0 or width % divisor != 0:
503
- raise ValueError(
504
- f"Resolution ({height}x{width}) is not divisible by {divisor}. "
505
- f"For {'two-stage' if is_two_stage else 'one-stage'} pipelines, "
506
- f"height and width must be multiples of {divisor}."
507
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/media_io.py DELETED
@@ -1,299 +0,0 @@
1
- import math
2
- from collections.abc import Generator, Iterator
3
- from fractions import Fraction
4
- from io import BytesIO
5
-
6
- import av
7
- import numpy as np
8
- import torch
9
- from einops import rearrange
10
- from PIL import Image
11
- from torch._prims_common import DeviceLikeType
12
- from tqdm import tqdm
13
-
14
- from ltx_pipelines.utils.constants import DEFAULT_IMAGE_CRF
15
-
16
-
17
- def resize_aspect_ratio_preserving(image: torch.Tensor, long_side: int) -> torch.Tensor:
18
- """
19
- Resize image preserving aspect ratio (filling target long side).
20
- Preserves the input dimensions order.
21
- Args:
22
- image: Input image tensor with shape (F (optional), H, W, C)
23
- long_side: Target long side size.
24
- Returns:
25
- Tensor with shape (F (optional), H, W, C) F = 1 if input is 3D, otherwise input shape[0]
26
- """
27
- height, width = image.shape[-3:2]
28
- max_side = max(height, width)
29
- scale = long_side / float(max_side)
30
- target_height = int(height * scale)
31
- target_width = int(width * scale)
32
- resized = resize_and_center_crop(image, target_height, target_width)
33
- # rearrange and remove batch dimension
34
- result = rearrange(resized, "b c f h w -> b f h w c")[0]
35
- # preserve input dimensions
36
- return result[0] if result.shape[0] == 1 else result
37
-
38
-
39
- def resize_and_center_crop(tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
40
- """
41
- Resize tensor preserving aspect ratio (filling target), then center crop to exact dimensions.
42
- Args:
43
- latent: Input tensor with shape (H, W, C) or (F, H, W, C)
44
- height: Target height
45
- width: Target width
46
- Returns:
47
- Tensor with shape (1, C, 1, height, width) for 3D input or (1, C, F, height, width) for 4D input
48
- """
49
- if tensor.ndim == 3:
50
- tensor = rearrange(tensor, "h w c -> 1 c h w")
51
- elif tensor.ndim == 4:
52
- tensor = rearrange(tensor, "f h w c -> f c h w")
53
- else:
54
- raise ValueError(f"Expected input with 3 or 4 dimensions; got shape {tensor.shape}.")
55
-
56
- _, _, src_h, src_w = tensor.shape
57
-
58
- scale = max(height / src_h, width / src_w)
59
- # Use ceil to avoid floating-point rounding causing new_h/new_w to be
60
- # slightly smaller than target, which would result in negative crop offsets.
61
- new_h = math.ceil(src_h * scale)
62
- new_w = math.ceil(src_w * scale)
63
-
64
- tensor = torch.nn.functional.interpolate(tensor, size=(new_h, new_w), mode="bilinear", align_corners=False)
65
-
66
- crop_top = (new_h - height) // 2
67
- crop_left = (new_w - width) // 2
68
- tensor = tensor[:, :, crop_top : crop_top + height, crop_left : crop_left + width]
69
-
70
- tensor = rearrange(tensor, "f c h w -> 1 c f h w")
71
- return tensor
72
-
73
-
74
- def normalize_latent(latent: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
75
- return (latent / 127.5 - 1.0).to(device=device, dtype=dtype)
76
-
77
-
78
- def load_image_conditioning(
79
- image_path: str, height: int, width: int, dtype: torch.dtype, device: torch.device
80
- ) -> torch.Tensor:
81
- """
82
- Loads an image from a path and preprocesses it for conditioning.
83
- Note: The image is resized to the nearest multiple of 2 for compatibility with video codecs.
84
- """
85
- image = decode_image(image_path=image_path)
86
- image = preprocess(image=image)
87
- image = torch.tensor(image, dtype=torch.float32, device=device)
88
- image = resize_and_center_crop(image, height, width)
89
- image = normalize_latent(image, device, dtype)
90
- return image
91
-
92
-
93
- def load_video_conditioning(
94
- video_path: str, height: int, width: int, frame_cap: int, dtype: torch.dtype, device: torch.device
95
- ) -> torch.Tensor:
96
- """
97
- Loads a video from a path and preprocesses it for conditioning.
98
- Note: The video is resized to the nearest multiple of 2 for compatibility with video codecs.
99
- """
100
- frames = decode_video_from_file(path=video_path, frame_cap=frame_cap, device=device)
101
- result = None
102
- for f in frames:
103
- frame = resize_and_center_crop(f.to(torch.float32), height, width)
104
- frame = normalize_latent(frame, device, dtype)
105
- result = frame if result is None else torch.cat([result, frame], dim=2)
106
- return result
107
-
108
-
109
- def decode_image(image_path: str) -> np.ndarray:
110
- image = Image.open(image_path)
111
- np_array = np.array(image)[..., :3]
112
- return np_array
113
-
114
-
115
- def _write_audio(
116
- container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
117
- ) -> None:
118
- if samples.ndim == 1:
119
- samples = samples[:, None]
120
-
121
- if samples.shape[1] != 2 and samples.shape[0] == 2:
122
- samples = samples.T
123
-
124
- if samples.shape[1] != 2:
125
- raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
126
-
127
- # Convert to int16 packed for ingestion; resampler converts to encoder fmt.
128
- if samples.dtype != torch.int16:
129
- samples = torch.clip(samples, -1.0, 1.0)
130
- samples = (samples * 32767.0).to(torch.int16)
131
-
132
- frame_in = av.AudioFrame.from_ndarray(
133
- samples.contiguous().reshape(1, -1).cpu().numpy(),
134
- format="s16",
135
- layout="stereo",
136
- )
137
- frame_in.sample_rate = audio_sample_rate
138
-
139
- _resample_audio(container, audio_stream, frame_in)
140
-
141
-
142
- def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
143
- """
144
- Prepare the audio stream for writing.
145
- """
146
- audio_stream = container.add_stream("aac", rate=audio_sample_rate)
147
- audio_stream.codec_context.sample_rate = audio_sample_rate
148
- audio_stream.codec_context.layout = "stereo"
149
- audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
150
- return audio_stream
151
-
152
-
153
- def _resample_audio(
154
- container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
155
- ) -> None:
156
- cc = audio_stream.codec_context
157
-
158
- # Use the encoder's format/layout/rate as the *target*
159
- target_format = cc.format or "fltp" # AAC → usually fltp
160
- target_layout = cc.layout or "stereo"
161
- target_rate = cc.sample_rate or frame_in.sample_rate
162
-
163
- audio_resampler = av.audio.resampler.AudioResampler(
164
- format=target_format,
165
- layout=target_layout,
166
- rate=target_rate,
167
- )
168
-
169
- audio_next_pts = 0
170
- for rframe in audio_resampler.resample(frame_in):
171
- if rframe.pts is None:
172
- rframe.pts = audio_next_pts
173
- audio_next_pts += rframe.samples
174
- rframe.sample_rate = frame_in.sample_rate
175
- container.mux(audio_stream.encode(rframe))
176
-
177
- # flush audio encoder
178
- for packet in audio_stream.encode():
179
- container.mux(packet)
180
-
181
-
182
- def encode_video(
183
- video: torch.Tensor | Iterator[torch.Tensor],
184
- fps: int,
185
- audio: torch.Tensor | None,
186
- audio_sample_rate: int | None,
187
- output_path: str,
188
- video_chunks_number: int,
189
- ) -> None:
190
- if isinstance(video, torch.Tensor):
191
- video = iter([video])
192
-
193
- first_chunk = next(video)
194
-
195
- _, height, width, _ = first_chunk.shape
196
-
197
- container = av.open(output_path, mode="w")
198
- stream = container.add_stream("libx264", rate=int(fps))
199
- stream.width = width
200
- stream.height = height
201
- stream.pix_fmt = "yuv420p"
202
-
203
- if audio is not None:
204
- if audio_sample_rate is None:
205
- raise ValueError("audio_sample_rate is required when audio is provided")
206
-
207
- audio_stream = _prepare_audio_stream(container, audio_sample_rate)
208
-
209
- def all_tiles(
210
- first_chunk: torch.Tensor, tiles_generator: Generator[tuple[torch.Tensor, int], None, None]
211
- ) -> Generator[tuple[torch.Tensor, int], None, None]:
212
- yield first_chunk
213
- yield from tiles_generator
214
-
215
- for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number):
216
- video_chunk_cpu = video_chunk.to("cpu").numpy()
217
- for frame_array in video_chunk_cpu:
218
- frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
219
- for packet in stream.encode(frame):
220
- container.mux(packet)
221
-
222
- # Flush encoder
223
- for packet in stream.encode():
224
- container.mux(packet)
225
-
226
- if audio is not None:
227
- _write_audio(container, audio_stream, audio, audio_sample_rate)
228
-
229
- container.close()
230
-
231
-
232
- def decode_audio_from_file(path: str, device: torch.device) -> torch.Tensor | None:
233
- container = av.open(path)
234
- try:
235
- audio = []
236
- audio_stream = next(s for s in container.streams if s.type == "audio")
237
- for frame in container.decode(audio_stream):
238
- audio.append(torch.tensor(frame.to_ndarray(), dtype=torch.float32, device=device).unsqueeze(0))
239
- container.close()
240
- audio = torch.cat(audio)
241
- except StopIteration:
242
- audio = None
243
- finally:
244
- container.close()
245
-
246
- return audio
247
-
248
-
249
- def decode_video_from_file(path: str, frame_cap: int, device: DeviceLikeType) -> Generator[torch.Tensor]:
250
- container = av.open(path)
251
- try:
252
- video_stream = next(s for s in container.streams if s.type == "video")
253
- for frame in container.decode(video_stream):
254
- tensor = torch.tensor(frame.to_rgb().to_ndarray(), dtype=torch.uint8, device=device).unsqueeze(0)
255
- yield tensor
256
- frame_cap = frame_cap - 1
257
- if frame_cap == 0:
258
- break
259
- finally:
260
- container.close()
261
-
262
-
263
- def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
264
- container = av.open(output_file, "w", format="mp4")
265
- try:
266
- stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
267
- # Round to nearest multiple of 2 for compatibility with video codecs
268
- height = image_array.shape[0] // 2 * 2
269
- width = image_array.shape[1] // 2 * 2
270
- image_array = image_array[:height, :width]
271
- stream.height = height
272
- stream.width = width
273
- av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
274
- container.mux(stream.encode(av_frame))
275
- container.mux(stream.encode())
276
- finally:
277
- container.close()
278
-
279
-
280
- def decode_single_frame(video_file: str) -> np.array:
281
- container = av.open(video_file)
282
- try:
283
- stream = next(s for s in container.streams if s.type == "video")
284
- frame = next(container.decode(stream))
285
- finally:
286
- container.close()
287
- return frame.to_ndarray(format="rgb24")
288
-
289
-
290
- def preprocess(image: np.array, crf: float = DEFAULT_IMAGE_CRF) -> np.array:
291
- if crf == 0:
292
- return image
293
-
294
- with BytesIO() as output_file:
295
- encode_single_frame(output_file, image, crf)
296
- video_bytes = output_file.getvalue()
297
- with BytesIO(video_bytes) as video_file:
298
- image_array = decode_single_frame(video_file)
299
- return image_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/model_ledger.py DELETED
@@ -1,275 +0,0 @@
1
- from dataclasses import replace
2
-
3
- import torch
4
-
5
- from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
6
- from ltx_core.loader.registry import DummyRegistry, Registry
7
- from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
8
- from ltx_core.model.audio_vae import (
9
- AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
10
- VOCODER_COMFY_KEYS_FILTER,
11
- AudioDecoder,
12
- AudioDecoderConfigurator,
13
- Vocoder,
14
- VocoderConfigurator,
15
- )
16
- from ltx_core.model.transformer import (
17
- LTXV_MODEL_COMFY_RENAMING_MAP,
18
- LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
19
- UPCAST_DURING_INFERENCE,
20
- LTXModelConfigurator,
21
- X0Model,
22
- )
23
- from ltx_core.model.upsampler import LatentUpsampler, LatentUpsamplerConfigurator
24
- from ltx_core.model.video_vae import (
25
- VAE_DECODER_COMFY_KEYS_FILTER,
26
- VAE_ENCODER_COMFY_KEYS_FILTER,
27
- VideoDecoder,
28
- VideoDecoderConfigurator,
29
- VideoEncoder,
30
- VideoEncoderConfigurator,
31
- )
32
- from ltx_core.text_encoders.gemma import (
33
- AV_GEMMA_TEXT_ENCODER_KEY_OPS,
34
- AVGemmaTextEncoderModel,
35
- AVGemmaTextEncoderModelConfigurator,
36
- module_ops_from_gemma_root,
37
- )
38
-
39
- from ltx_core.model.audio_vae import (
40
- AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
41
- VOCODER_COMFY_KEYS_FILTER,
42
- AudioDecoder,
43
- AudioDecoderConfigurator,
44
- Vocoder,
45
- VocoderConfigurator,
46
- AudioEncoder,
47
- )
48
- from ltx_core.model.audio_vae.model_configurator import (
49
- AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
50
- AudioEncoderConfigurator,
51
- )
52
-
53
-
54
- class ModelLedger:
55
- """
56
- Central coordinator for loading and building models used in an LTX pipeline.
57
- The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
58
- audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
59
- factory methods for constructing model instances.
60
- ### Model Building
61
- Each model method (e.g. :meth:`transformer`, :meth:`video_decoder`, :meth:`text_encoder`)
62
- constructs a new model instance on each call. The builder uses the
63
- :class:`~ltx_core.loader.registry.Registry` to load weights from the checkpoint,
64
- instantiates the model with the configured ``dtype``, and moves it to ``self.device``.
65
- .. note::
66
- Models are **not cached**. Each call to a model method creates a new instance.
67
- Callers are responsible for storing references to models they wish to reuse
68
- and for freeing GPU memory (e.g. by deleting references and calling
69
- ``torch.cuda.empty_cache()``).
70
- ### Constructor parameters
71
- dtype:
72
- Torch dtype used when constructing all models (e.g. ``torch.bfloat16``).
73
- device:
74
- Target device to which models are moved after construction (e.g. ``torch.device("cuda")``).
75
- checkpoint_path:
76
- Path to a checkpoint directory or file containing the core model weights
77
- (transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the
78
- corresponding builders are not created and calling those methods will raise
79
- a :class:`ValueError`.
80
- gemma_root_path:
81
- Base path to Gemma-compatible CLIP/text encoder weights. Required to
82
- initialize the text encoder builder; if omitted, :meth:`text_encoder` cannot be used.
83
- spatial_upsampler_path:
84
- Optional path to a latent upsampler checkpoint. If provided, the
85
- :meth:`spatial_upsampler` method becomes available; otherwise calling it raises
86
- a :class:`ValueError`.
87
- loras:
88
- Optional collection of LoRA configurations (paths, strengths, and key operations)
89
- that are applied on top of the base transformer weights when building the model.
90
- registry:
91
- Optional :class:`Registry` instance for weight caching across builders.
92
- Defaults to :class:`DummyRegistry` which performs no cross-builder caching.
93
- fp8transformer:
94
- If ``True``, builds the transformer with FP8 quantization and upcasting during inference.
95
- ### Creating Variants
96
- Use :meth:`with_loras` to create a new ``ModelLedger`` instance that includes
97
- additional LoRA configurations while sharing the same registry for weight caching.
98
- """
99
-
100
- def __init__(
101
- self,
102
- dtype: torch.dtype,
103
- device: torch.device,
104
- checkpoint_path: str | None = None,
105
- gemma_root_path: str | None = None,
106
- spatial_upsampler_path: str | None = None,
107
- loras: LoraPathStrengthAndSDOps | None = None,
108
- registry: Registry | None = None,
109
- fp8transformer: bool = False,
110
- local_files_only: bool = True
111
- ):
112
- self.dtype = dtype
113
- self.device = device
114
- self.checkpoint_path = checkpoint_path
115
- self.gemma_root_path = gemma_root_path
116
- self.spatial_upsampler_path = spatial_upsampler_path
117
- self.loras = loras or ()
118
- self.registry = registry or DummyRegistry()
119
- self.fp8transformer = fp8transformer
120
- self.local_files_only = local_files_only
121
- self.build_model_builders()
122
-
123
- def build_model_builders(self) -> None:
124
- if self.checkpoint_path is not None:
125
- self.transformer_builder = Builder(
126
- model_path=self.checkpoint_path,
127
- model_class_configurator=LTXModelConfigurator,
128
- model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
129
- loras=tuple(self.loras),
130
- registry=self.registry,
131
- )
132
-
133
- self.vae_decoder_builder = Builder(
134
- model_path=self.checkpoint_path,
135
- model_class_configurator=VideoDecoderConfigurator,
136
- model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
137
- registry=self.registry,
138
- )
139
-
140
- self.vae_encoder_builder = Builder(
141
- model_path=self.checkpoint_path,
142
- model_class_configurator=VideoEncoderConfigurator,
143
- model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
144
- registry=self.registry,
145
- )
146
-
147
- self.audio_decoder_builder = Builder(
148
- model_path=self.checkpoint_path,
149
- model_class_configurator=AudioDecoderConfigurator,
150
- model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
151
- registry=self.registry,
152
- )
153
-
154
- self.vocoder_builder = Builder(
155
- model_path=self.checkpoint_path,
156
- model_class_configurator=VocoderConfigurator,
157
- model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
158
- registry=self.registry,
159
- )
160
-
161
- self.audio_encoder_builder = Builder(
162
- model_path=self.checkpoint_path,
163
- model_class_configurator=AudioEncoderConfigurator,
164
- model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
165
- registry=self.registry,
166
- )
167
-
168
-
169
- if self.gemma_root_path is not None:
170
- self.text_encoder_builder = Builder(
171
- model_path=self.checkpoint_path,
172
- model_class_configurator=AVGemmaTextEncoderModelConfigurator,
173
- model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
174
- registry=self.registry,
175
- module_ops=module_ops_from_gemma_root(self.gemma_root_path,self.local_files_only),
176
- )
177
-
178
- if self.spatial_upsampler_path is not None:
179
- self.upsampler_builder = Builder(
180
- model_path=self.spatial_upsampler_path,
181
- model_class_configurator=LatentUpsamplerConfigurator,
182
- registry=self.registry,
183
- )
184
-
185
- def _target_device(self) -> torch.device:
186
- if isinstance(self.registry, DummyRegistry) or self.registry is None:
187
- return self.device
188
- else:
189
- return torch.device("cpu")
190
-
191
- def with_loras(self, loras: LoraPathStrengthAndSDOps) -> "ModelLedger":
192
- return ModelLedger(
193
- dtype=self.dtype,
194
- device=self.device,
195
- checkpoint_path=self.checkpoint_path,
196
- gemma_root_path=self.gemma_root_path,
197
- spatial_upsampler_path=self.spatial_upsampler_path,
198
- loras=(*self.loras, *loras),
199
- registry=self.registry,
200
- fp8transformer=self.fp8transformer,
201
- )
202
-
203
- def transformer(self) -> X0Model:
204
- if not hasattr(self, "transformer_builder"):
205
- raise ValueError(
206
- "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor."
207
- )
208
- if self.fp8transformer:
209
- fp8_builder = replace(
210
- self.transformer_builder,
211
- module_ops=(UPCAST_DURING_INFERENCE,),
212
- model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
213
- )
214
- return X0Model(fp8_builder.build(device=self._target_device())).to(self.device).eval()
215
- else:
216
- return (
217
- X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype))
218
- .to(self.device)
219
- .eval()
220
- )
221
-
222
- def audio_encoder(self) -> AudioEncoder:
223
- if not hasattr(self, "audio_encoder_builder"):
224
- raise ValueError(
225
- "Audio encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
226
- )
227
- return self.audio_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
228
-
229
-
230
- def video_decoder(self) -> VideoDecoder:
231
- if not hasattr(self, "vae_decoder_builder"):
232
- raise ValueError(
233
- "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
234
- )
235
-
236
- return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
237
-
238
- def video_encoder(self) -> VideoEncoder:
239
- if not hasattr(self, "vae_encoder_builder"):
240
- raise ValueError(
241
- "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
242
- )
243
-
244
- return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
245
-
246
- def text_encoder(self) -> AVGemmaTextEncoderModel:
247
- if not hasattr(self, "text_encoder_builder"):
248
- raise ValueError(
249
- "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the "
250
- "ModelLedger constructor."
251
- )
252
-
253
- return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
254
-
255
- def audio_decoder(self) -> AudioDecoder:
256
- if not hasattr(self, "audio_decoder_builder"):
257
- raise ValueError(
258
- "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
259
- )
260
-
261
- return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
262
-
263
- def vocoder(self) -> Vocoder:
264
- if not hasattr(self, "vocoder_builder"):
265
- raise ValueError(
266
- "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
267
- )
268
-
269
- return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
270
-
271
- def spatial_upsampler(self) -> LatentUpsampler:
272
- if not hasattr(self, "upsampler_builder"):
273
- raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.")
274
-
275
- return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/ltx-pipelines/ltx_pipelines/utils/types.py DELETED
@@ -1,73 +0,0 @@
1
- from typing import Protocol
2
-
3
- import torch
4
-
5
- from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
6
- from ltx_core.components.protocols import DiffusionStepProtocol
7
- from ltx_core.types import LatentState
8
- from ltx_pipelines.utils.constants import VIDEO_LATENT_CHANNELS, VIDEO_SCALE_FACTORS
9
-
10
-
11
- class PipelineComponents:
12
- """
13
- Container class for pipeline components used throughout the LTX pipelines.
14
- Attributes:
15
- dtype (torch.dtype): Default torch dtype for tensors in the pipeline.
16
- device (torch.device): Target device to place tensors and modules on.
17
- video_scale_factors (SpatioTemporalScaleFactors): Scale factors (T, H, W) for VAE latent space.
18
- video_latent_channels (int): Number of channels in the video latent representation.
19
- video_patchifier (VideoLatentPatchifier): Patchifier instance for video latents.
20
- audio_patchifier (AudioPatchifier): Patchifier instance for audio latents.
21
- """
22
-
23
- def __init__(
24
- self,
25
- dtype: torch.dtype,
26
- device: torch.device,
27
- ):
28
- self.dtype = dtype
29
- self.device = device
30
-
31
- self.video_scale_factors = VIDEO_SCALE_FACTORS
32
- self.video_latent_channels = VIDEO_LATENT_CHANNELS
33
-
34
- self.video_patchifier = VideoLatentPatchifier(patch_size=1)
35
- self.audio_patchifier = AudioPatchifier(patch_size=1)
36
-
37
-
38
- class DenoisingFunc(Protocol):
39
- """
40
- Protocol for a denoising function used in the LTX pipeline.
41
- Args:
42
- video_state (LatentState): The current latent state for video.
43
- audio_state (LatentState): The current latent state for audio.
44
- sigmas (torch.Tensor): A 1D tensor of sigma values for each diffusion step.
45
- step_index (int): Index of the current denoising step.
46
- Returns:
47
- tuple[torch.Tensor, torch.Tensor]: The denoised video and audio tensors.
48
- """
49
-
50
- def __call__(
51
- self, video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int
52
- ) -> tuple[torch.Tensor, torch.Tensor]: ...
53
-
54
-
55
- class DenoisingLoopFunc(Protocol):
56
- """
57
- Protocol for a denoising loop function used in the LTX pipeline.
58
- Args:
59
- sigmas (torch.Tensor): A 1D tensor of sigma values for each diffusion step.
60
- video_state (LatentState): The current latent state for video.
61
- audio_state (LatentState): The current latent state for audio.
62
- stepper (DiffusionStepProtocol): The diffusion step protocol to use.
63
- Returns:
64
- tuple[LatentState, LatentState]: The denoised video and audio latent states.
65
- """
66
-
67
- def __call__(
68
- self,
69
- sigmas: torch.Tensor,
70
- video_state: LatentState,
71
- audio_state: LatentState,
72
- stepper: DiffusionStepProtocol,
73
- ) -> tuple[torch.Tensor, torch.Tensor]: ...