Fabrice-TIERCELIN commited on
Commit
de02dc3
·
verified ·
1 Parent(s): 4f5fe16

Upload 6 files

Browse files
packages/ltx-pipelines/src/ltx_pipelines/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LTX-2 Pipelines: High-level video generation pipelines and utilities.
3
+ This package provides ready-to-use pipelines for video generation:
4
+ - TI2VidOneStagePipeline: Text/image-to-video in a single stage
5
+ - TI2VidTwoStagesPipeline: Two-stage generation with upsampling
6
+ - DistilledPipeline: Fast distilled two-stage generation
7
+ - ICLoraPipeline: Image/video conditioning with distilled LoRA
8
+ - KeyframeInterpolationPipeline: Keyframe-based video interpolation
9
+ - ModelLedger: Central coordinator for loading and building models
10
+ For more detailed components and utilities, import from specific submodules
11
+ like `ltx_pipelines.utils.media_io` or `ltx_pipelines.utils.constants`.
12
+ """
13
+
14
+ from ltx_pipelines.distilled import DistilledPipeline
15
+ from ltx_pipelines.ic_lora import ICLoraPipeline
16
+ from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline
17
+ from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline
18
+ from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
19
+
20
+ __all__ = [
21
+ "DistilledPipeline",
22
+ "ICLoraPipeline",
23
+ "KeyframeInterpolationPipeline",
24
+ "TI2VidOneStagePipeline",
25
+ "TI2VidTwoStagesPipeline",
26
+ ]
packages/ltx-pipelines/src/ltx_pipelines/distilled.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Amit Pintz.
3
+
4
+
5
+ import torch
6
+
7
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
8
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
9
+ from ltx_core.components.noisers import GaussianNoiser
10
+ from ltx_core.components.protocols import DiffusionStepProtocol
11
+ from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex, ConditioningError
12
+ from ltx_core.loader import LoraPathStrengthAndSDOps
13
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
14
+ from ltx_core.model.upsampler import upsample_video
15
+ from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
16
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
17
+ from ltx_core.text_encoders.gemma import encode_text
18
+ from ltx_core.types import LatentState, VideoPixelShape
19
+ from ltx_core.tools import LatentTools
20
+ from ltx_pipelines import utils
21
+ from ltx_pipelines.utils import ModelLedger
22
+ from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
23
+ from ltx_pipelines.utils.constants import (
24
+ AUDIO_SAMPLE_RATE,
25
+ DEFAULT_LORA_STRENGTH,
26
+ DISTILLED_SIGMA_VALUES,
27
+ STAGE_2_DISTILLED_SIGMA_VALUES,
28
+ )
29
+ from ltx_pipelines.utils.helpers import (
30
+ assert_resolution,
31
+ cleanup_memory,
32
+ denoise_audio_video,
33
+ euler_denoising_loop,
34
+ generate_enhanced_prompt,
35
+ get_device,
36
+ image_conditionings_by_replacing_latent,
37
+ simple_denoising_func,
38
+ )
39
+ from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
40
+ from ltx_pipelines.utils.types import PipelineComponents
41
+
42
+ import torchaudio
43
+ from ltx_core.model.audio_vae import AudioProcessor
44
+ from ltx_core.types import AudioLatentShape, VideoPixelShape
45
+
46
+ class AudioConditionByLatent(ConditioningItem):
47
+ """
48
+ Conditions audio generation by injecting a full latent sequence.
49
+ Replaces tokens in the latent state with the provided audio latents,
50
+ and sets denoise strength according to the strength parameter.
51
+ """
52
+
53
+ def __init__(self, latent: torch.Tensor, strength: float):
54
+ self.latent = latent
55
+ self.strength = strength
56
+
57
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
58
+ if not isinstance(latent_tools.target_shape, AudioLatentShape):
59
+ raise ConditioningError("Audio conditioning requires an audio latent target shape.")
60
+
61
+ cond_batch, cond_channels, cond_frames, cond_bins = self.latent.shape
62
+ tgt_batch, tgt_channels, tgt_frames, tgt_bins = latent_tools.target_shape.to_torch_shape()
63
+
64
+ if (cond_batch, cond_channels, cond_frames, cond_bins) != (tgt_batch, tgt_channels, tgt_frames, tgt_bins):
65
+ raise ConditioningError(
66
+ f"Can't apply audio conditioning item to latent with shape {latent_tools.target_shape}, expected "
67
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_bins})."
68
+ )
69
+
70
+ tokens = latent_tools.patchifier.patchify(self.latent)
71
+ latent_state = latent_state.clone()
72
+ latent_state.latent[:, : tokens.shape[1]] = tokens
73
+ latent_state.clean_latent[:, : tokens.shape[1]] = tokens
74
+ latent_state.denoise_mask[:, : tokens.shape[1]] = 1.0 - self.strength
75
+
76
+ return latent_state
77
+
78
+ device = get_device()
79
+
80
+
81
+ class DistilledPipeline:
82
+ def __init__(
83
+ self,
84
+ checkpoint_path: str,
85
+ gemma_root: str,
86
+ spatial_upsampler_path: str,
87
+ loras: list[LoraPathStrengthAndSDOps],
88
+ device: torch.device = device,
89
+ fp8transformer: bool = False,
90
+ local_files_only: bool = True,
91
+ ):
92
+ self.device = device
93
+ self.dtype = torch.bfloat16
94
+
95
+ self.model_ledger = ModelLedger(
96
+ dtype=self.dtype,
97
+ device=device,
98
+ checkpoint_path=checkpoint_path,
99
+ spatial_upsampler_path=spatial_upsampler_path,
100
+ gemma_root_path=gemma_root,
101
+ loras=loras,
102
+ fp8transformer=fp8transformer,
103
+ local_files_only=local_files_only
104
+ )
105
+
106
+ self.pipeline_components = PipelineComponents(
107
+ dtype=self.dtype,
108
+ device=device,
109
+ )
110
+
111
+ # Cached models to avoid reloading
112
+ self._video_encoder = None
113
+ self._transformer = None
114
+
115
+ def _build_audio_conditionings_from_waveform(
116
+ self,
117
+ input_waveform: torch.Tensor,
118
+ input_sample_rate: int,
119
+ num_frames: int,
120
+ fps: float,
121
+ strength: float,
122
+ ) -> list[AudioConditionByLatent] | None:
123
+ strength = float(strength)
124
+ if strength <= 0.0:
125
+ return None
126
+
127
+ # Expect waveform as:
128
+ # - (T,) or (C,T) or (B,C,T). Convert to (B,C,T)
129
+ waveform = input_waveform
130
+ if waveform.ndim == 1:
131
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
132
+ elif waveform.ndim == 2:
133
+ waveform = waveform.unsqueeze(0)
134
+ elif waveform.ndim != 3:
135
+ raise ValueError(f"input_waveform must be 1D/2D/3D, got shape {tuple(waveform.shape)}")
136
+
137
+ # Get audio encoder + its config
138
+ audio_encoder = self.model_ledger.audio_encoder() # assumes ledger exposes it
139
+ # If you want to cache it like video_encoder/transformer, you can.
140
+ target_sr = int(getattr(audio_encoder, "sample_rate"))
141
+ target_channels = int(getattr(audio_encoder, "in_channels", waveform.shape[1]))
142
+ mel_bins = int(getattr(audio_encoder, "mel_bins"))
143
+ mel_hop = int(getattr(audio_encoder, "mel_hop_length"))
144
+ n_fft = int(getattr(audio_encoder, "n_fft"))
145
+
146
+ # Match channels
147
+ if waveform.shape[1] != target_channels:
148
+ if waveform.shape[1] == 1 and target_channels > 1:
149
+ waveform = waveform.repeat(1, target_channels, 1)
150
+ elif target_channels == 1:
151
+ waveform = waveform.mean(dim=1, keepdim=True)
152
+ else:
153
+ waveform = waveform[:, :target_channels, :]
154
+ if waveform.shape[1] < target_channels:
155
+ pad_ch = target_channels - waveform.shape[1]
156
+ pad = torch.zeros((waveform.shape[0], pad_ch, waveform.shape[2]), dtype=waveform.dtype)
157
+ waveform = torch.cat([waveform, pad], dim=1)
158
+
159
+ # Resample if needed (CPU float32 is safest for torchaudio)
160
+ waveform = waveform.to(device="cpu", dtype=torch.float32)
161
+ if int(input_sample_rate) != target_sr:
162
+ waveform = torchaudio.functional.resample(waveform, int(input_sample_rate), target_sr)
163
+
164
+ # Waveform -> Mel
165
+ audio_processor = AudioProcessor(
166
+ sample_rate=target_sr,
167
+ mel_bins=mel_bins,
168
+ mel_hop_length=mel_hop,
169
+ n_fft=n_fft,
170
+ ).to(waveform.device)
171
+
172
+ mel = audio_processor.waveform_to_mel(waveform, target_sr)
173
+
174
+ # Mel -> latent (run encoder on its own device/dtype)
175
+ audio_params = next(audio_encoder.parameters(), None)
176
+ enc_device = audio_params.device if audio_params is not None else self.device
177
+ enc_dtype = audio_params.dtype if audio_params is not None else self.dtype
178
+
179
+ mel = mel.to(device=enc_device, dtype=enc_dtype)
180
+ with torch.inference_mode():
181
+ audio_latent = audio_encoder(mel)
182
+
183
+ # Pad/trim latent to match the target video duration
184
+ audio_downsample = getattr(getattr(audio_encoder, "patchifier", None), "audio_latent_downsample_factor", 4)
185
+ target_shape = AudioLatentShape.from_video_pixel_shape(
186
+ VideoPixelShape(batch=audio_latent.shape[0], frames=int(num_frames), width=1, height=1, fps=float(fps)),
187
+ channels=audio_latent.shape[1],
188
+ mel_bins=audio_latent.shape[3],
189
+ sample_rate=target_sr,
190
+ hop_length=mel_hop,
191
+ audio_latent_downsample_factor=audio_downsample,
192
+ )
193
+ target_frames = int(target_shape.frames)
194
+
195
+ if audio_latent.shape[2] < target_frames:
196
+ pad_frames = target_frames - audio_latent.shape[2]
197
+ pad = torch.zeros(
198
+ (audio_latent.shape[0], audio_latent.shape[1], pad_frames, audio_latent.shape[3]),
199
+ device=audio_latent.device,
200
+ dtype=audio_latent.dtype,
201
+ )
202
+ audio_latent = torch.cat([audio_latent, pad], dim=2)
203
+ elif audio_latent.shape[2] > target_frames:
204
+ audio_latent = audio_latent[:, :, :target_frames, :]
205
+
206
+ # Move latent to pipeline device/dtype for conditioning object
207
+ audio_latent = audio_latent.to(device=self.device, dtype=self.dtype)
208
+
209
+ return [AudioConditionByLatent(audio_latent, strength)]
210
+
211
+ def _prepare_output_waveform(
212
+ self,
213
+ input_waveform: torch.Tensor,
214
+ input_sample_rate: int,
215
+ target_sample_rate: int,
216
+ num_frames: int,
217
+ fps: float,
218
+ ) -> torch.Tensor:
219
+ """
220
+ Returns waveform on CPU, float32, resampled to target_sample_rate and
221
+ trimmed/padded to match video duration.
222
+ Output shape: (T,) for mono or (C, T) for multi-channel.
223
+ """
224
+ wav = input_waveform
225
+
226
+ # Accept (T,), (C,T), (B,C,T)
227
+ if wav.ndim == 3:
228
+ wav = wav[0]
229
+ elif wav.ndim == 2:
230
+ pass
231
+ elif wav.ndim == 1:
232
+ wav = wav.unsqueeze(0)
233
+ else:
234
+ raise ValueError(f"input_waveform must be 1D/2D/3D, got {tuple(wav.shape)}")
235
+
236
+ # Now wav is (C, T)
237
+ wav = wav.detach().to("cpu", dtype=torch.float32)
238
+
239
+ # Resample if needed
240
+ if int(input_sample_rate) != int(target_sample_rate):
241
+ wav = torchaudio.functional.resample(wav, int(input_sample_rate), int(target_sample_rate))
242
+
243
+ # Match video duration
244
+ duration_sec = float(num_frames) / float(fps)
245
+ target_len = int(round(duration_sec * float(target_sample_rate)))
246
+
247
+ cur_len = int(wav.shape[-1])
248
+ if cur_len > target_len:
249
+ wav = wav[..., :target_len]
250
+ elif cur_len < target_len:
251
+ pad = target_len - cur_len
252
+ wav = torch.nn.functional.pad(wav, (0, pad))
253
+
254
+ # If mono, return (T,) for convenience
255
+ if wav.shape[0] == 1:
256
+ return wav[0]
257
+ return wav
258
+
259
+
260
+ @torch.inference_mode()
261
+ def __call__(
262
+ self,
263
+ prompt: str,
264
+ output_path: str,
265
+ seed: int,
266
+ height: int,
267
+ width: int,
268
+ num_frames: int,
269
+ frame_rate: float,
270
+ images: list[tuple[str, int, float]],
271
+ video_conditioning: list[tuple[str, float]] | None = None,
272
+ video_conditioning_frame_idx: int = 0,
273
+ apply_video_conditioning_to_stage2: bool = False,
274
+ tiling_config: TilingConfig | None = None,
275
+ video_context: torch.Tensor | None = None,
276
+ audio_context: torch.Tensor | None = None,
277
+ input_waveform: torch.Tensor | None = None,
278
+ input_waveform_sample_rate: int | None = None,
279
+ audio_strength: float = 1.0, # or audio_scale, your naming
280
+ ) -> None:
281
+ generator = torch.Generator(device=self.device).manual_seed(seed)
282
+ noiser = GaussianNoiser(generator=generator)
283
+ stepper = EulerDiffusionStep()
284
+ dtype = torch.bfloat16
285
+
286
+ audio_conditionings = None
287
+ if input_waveform is not None:
288
+ if input_waveform_sample_rate is None:
289
+ raise ValueError("input_waveform_sample_rate must be provided when input_waveform is set.")
290
+ audio_conditionings = self._build_audio_conditionings_from_waveform(
291
+ input_waveform=input_waveform,
292
+ input_sample_rate=int(input_waveform_sample_rate),
293
+ num_frames=num_frames,
294
+ fps=frame_rate,
295
+ strength=audio_strength,
296
+ )
297
+
298
+ # Use pre-computed embeddings if provided, otherwise encode text
299
+ if video_context is None or audio_context is None:
300
+ text_encoder = self.model_ledger.text_encoder()
301
+ context_p = encode_text(text_encoder, prompts=[prompt])[0]
302
+ video_context, audio_context = context_p
303
+
304
+ torch.cuda.synchronize()
305
+ del text_encoder
306
+ utils.cleanup_memory()
307
+ else:
308
+ # Move pre-computed embeddings to device if needed
309
+ video_context = video_context.to(self.device)
310
+ audio_context = audio_context.to(self.device)
311
+
312
+ # Stage 1: Initial low resolution video generation.
313
+ # Load models only if not already cached
314
+ if self._video_encoder is None:
315
+ self._video_encoder = self.model_ledger.video_encoder()
316
+ video_encoder = self._video_encoder
317
+
318
+ if self._transformer is None:
319
+ self._transformer = self.model_ledger.transformer()
320
+ transformer = self._transformer
321
+ stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
322
+
323
+ def denoising_loop(
324
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
325
+ ) -> tuple[LatentState, LatentState]:
326
+ return euler_denoising_loop(
327
+ sigmas=sigmas,
328
+ video_state=video_state,
329
+ audio_state=audio_state,
330
+ stepper=stepper,
331
+ denoise_fn=simple_denoising_func(
332
+ video_context=video_context,
333
+ audio_context=audio_context,
334
+ transformer=transformer, # noqa: F821
335
+ ),
336
+ )
337
+
338
+ stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
339
+ stage_1_conditionings = self._create_conditionings(
340
+ images=images,
341
+ video_conditioning=video_conditioning,
342
+ height=stage_1_output_shape.height,
343
+ width=stage_1_output_shape.width,
344
+ num_frames=num_frames,
345
+ video_encoder=video_encoder,
346
+ video_conditioning_frame_idx=video_conditioning_frame_idx,
347
+ dtype=dtype,
348
+ )
349
+
350
+ video_state, audio_state = denoise_audio_video(
351
+ output_shape=stage_1_output_shape,
352
+ conditionings=stage_1_conditionings,
353
+ audio_conditionings=audio_conditionings,
354
+ noiser=noiser,
355
+ sigmas=stage_1_sigmas,
356
+ stepper=stepper,
357
+ denoising_loop_fn=denoising_loop,
358
+ components=self.pipeline_components,
359
+ dtype=dtype,
360
+ device=self.device,
361
+ )
362
+
363
+ # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
364
+ upscaled_video_latent = upsample_video(
365
+ latent=video_state.latent[:1], video_encoder=video_encoder, upsampler=self.model_ledger.spatial_upsampler()
366
+ )
367
+
368
+ torch.cuda.synchronize()
369
+ cleanup_memory()
370
+
371
+ stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
372
+ stage_2_output_shape = VideoPixelShape(
373
+ batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate
374
+ )
375
+ if apply_video_conditioning_to_stage2:
376
+ stage_2_conditionings = self._create_conditionings(
377
+ images=images,
378
+ video_conditioning=video_conditioning,
379
+ height=stage_2_output_shape.height,
380
+ width=stage_2_output_shape.width,
381
+ num_frames=num_frames,
382
+ video_encoder=video_encoder,
383
+ video_conditioning_frame_idx=video_conditioning_frame_idx,
384
+ )
385
+ else:
386
+ stage_2_conditionings = image_conditionings_by_replacing_latent(
387
+ images=images,
388
+ height=stage_2_output_shape.height,
389
+ width=stage_2_output_shape.width,
390
+ video_encoder=video_encoder,
391
+ dtype=dtype,
392
+ device=self.device,
393
+ )
394
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
395
+ video_state, audio_state = denoise_audio_video(
396
+ output_shape=stage_2_output_shape,
397
+ conditionings=stage_2_conditionings,
398
+ audio_conditionings=audio_conditionings,
399
+ noiser=noiser,
400
+ sigmas=stage_2_sigmas,
401
+ stepper=stepper,
402
+ denoising_loop_fn=denoising_loop,
403
+ components=self.pipeline_components,
404
+ dtype=dtype,
405
+ device=self.device,
406
+ noise_scale=stage_2_sigmas[0],
407
+ initial_video_latent=upscaled_video_latent,
408
+ initial_audio_latent=audio_state.latent,
409
+ )
410
+
411
+ torch.cuda.synchronize()
412
+ # del transformer
413
+ # del video_encoder
414
+ # utils.cleanup_memory()
415
+
416
+ decoded_video = vae_decode_video(video_state.latent, self.model_ledger.video_decoder(), tiling_config)
417
+ decoded_audio = vae_decode_audio(audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder())
418
+
419
+ encode_video(
420
+ video=decoded_video,
421
+ fps=frame_rate,
422
+ audio=decoded_audio,
423
+ audio_sample_rate=AUDIO_SAMPLE_RATE,
424
+ output_path=output_path,
425
+ video_chunks_number=video_chunks_number,
426
+ )
427
+
428
+
429
+ def _create_conditionings(
430
+ self,
431
+ images: list[tuple[str, int, float]],
432
+ video_conditioning: list[tuple[str, float]] | None,
433
+ height: int,
434
+ width: int,
435
+ num_frames: int,
436
+ video_encoder,
437
+ video_conditioning_frame_idx: int,
438
+ dtype: torch.dtype,
439
+ ):
440
+ # 1) Keep ORIGINAL behavior: image conditioning by replacing latent
441
+ conditionings = image_conditionings_by_replacing_latent(
442
+ images=images,
443
+ height=height,
444
+ width=width,
445
+ video_encoder=video_encoder,
446
+ dtype=dtype,
447
+ device=self.device,
448
+ )
449
+
450
+ # 2) Optional: add video conditioning (IC-LoRA style)
451
+ if not video_conditioning:
452
+ return conditionings
453
+
454
+ for video_path, strength in video_conditioning:
455
+ video = load_video_conditioning(
456
+ video_path=video_path,
457
+ height=height,
458
+ width=width,
459
+ frame_cap=num_frames, # ✅ correct kwarg name
460
+ dtype=dtype,
461
+ device=self.device,
462
+ )
463
+
464
+ encoded_video = video_encoder(video)
465
+
466
+ # ✅ match IC-LoRA: append the conditioning object directly
467
+ conditionings.append(
468
+ VideoConditionByKeyframeIndex(
469
+ keyframes=encoded_video,
470
+ frame_idx=video_conditioning_frame_idx,
471
+ strength=strength,
472
+ )
473
+ )
474
+
475
+ return conditionings
packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Iterator
3
+
4
+ import torch
5
+
6
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
+ from ltx_core.components.noisers import GaussianNoiser
8
+ from ltx_core.components.protocols import DiffusionStepProtocol
9
+ from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex
10
+ from ltx_core.loader import LoraPathStrengthAndSDOps
11
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
12
+ from ltx_core.model.upsampler import upsample_video
13
+ from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
14
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
15
+ from ltx_core.text_encoders.gemma import encode_text
16
+ from ltx_core.types import LatentState, VideoPixelShape
17
+ from ltx_pipelines.utils import ModelLedger
18
+ from ltx_pipelines.utils.args import VideoConditioningAction, default_2_stage_distilled_arg_parser
19
+ from ltx_pipelines.utils.constants import (
20
+ AUDIO_SAMPLE_RATE,
21
+ DISTILLED_SIGMA_VALUES,
22
+ STAGE_2_DISTILLED_SIGMA_VALUES,
23
+ )
24
+ from ltx_pipelines.utils.helpers import (
25
+ assert_resolution,
26
+ cleanup_memory,
27
+ denoise_audio_video,
28
+ euler_denoising_loop,
29
+ generate_enhanced_prompt,
30
+ get_device,
31
+ image_conditionings_by_replacing_latent,
32
+ simple_denoising_func,
33
+ )
34
+ from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
35
+ from ltx_pipelines.utils.types import PipelineComponents
36
+
37
+ device = get_device()
38
+
39
+
40
+ class ICLoraPipeline:
41
+ """
42
+ Two-stage video generation pipeline with In-Context (IC) LoRA support.
43
+ Allows conditioning the generated video on control signals such as depth maps,
44
+ human pose, or image edges via the video_conditioning parameter.
45
+ The specific IC-LoRA model should be provided via the loras parameter.
46
+ Stage 1 generates video at the target resolution, then Stage 2 upsamples
47
+ by 2x and refines with additional denoising steps for higher quality output.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ checkpoint_path: str,
53
+ spatial_upsampler_path: str,
54
+ gemma_root: str,
55
+ loras: list[LoraPathStrengthAndSDOps],
56
+ device: torch.device = device,
57
+ fp8transformer: bool = False,
58
+ ):
59
+ self.dtype = torch.bfloat16
60
+ self.stage_1_model_ledger = ModelLedger(
61
+ dtype=self.dtype,
62
+ device=device,
63
+ checkpoint_path=checkpoint_path,
64
+ spatial_upsampler_path=spatial_upsampler_path,
65
+ gemma_root_path=gemma_root,
66
+ loras=loras,
67
+ fp8transformer=fp8transformer,
68
+ )
69
+ self.stage_2_model_ledger = ModelLedger(
70
+ dtype=self.dtype,
71
+ device=device,
72
+ checkpoint_path=checkpoint_path,
73
+ spatial_upsampler_path=spatial_upsampler_path,
74
+ gemma_root_path=gemma_root,
75
+ loras=[],
76
+ fp8transformer=fp8transformer,
77
+ )
78
+ self.pipeline_components = PipelineComponents(
79
+ dtype=self.dtype,
80
+ device=device,
81
+ )
82
+ self.device = device
83
+
84
+ @torch.inference_mode()
85
+ def __call__(
86
+ self,
87
+ prompt: str,
88
+ seed: int,
89
+ height: int,
90
+ width: int,
91
+ num_frames: int,
92
+ frame_rate: float,
93
+ images: list[tuple[str, int, float]],
94
+ video_conditioning: list[tuple[str, float]],
95
+ enhance_prompt: bool = False,
96
+ tiling_config: TilingConfig | None = None,
97
+ ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
98
+ assert_resolution(height=height, width=width, is_two_stage=True)
99
+
100
+ generator = torch.Generator(device=self.device).manual_seed(seed)
101
+ noiser = GaussianNoiser(generator=generator)
102
+ stepper = EulerDiffusionStep()
103
+ dtype = torch.bfloat16
104
+
105
+ text_encoder = self.stage_1_model_ledger.text_encoder()
106
+
107
+ if enhance_prompt:
108
+ prompt = generate_enhanced_prompt(
109
+ text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
110
+ )
111
+ video_context, audio_context = encode_text(text_encoder, prompts=[prompt])[0]
112
+
113
+ torch.cuda.synchronize()
114
+ del text_encoder
115
+ cleanup_memory()
116
+
117
+ # Stage 1: Initial low resolution video generation.
118
+ video_encoder = self.stage_1_model_ledger.video_encoder()
119
+ transformer = self.stage_1_model_ledger.transformer()
120
+ stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
121
+
122
+ def first_stage_denoising_loop(
123
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
124
+ ) -> tuple[LatentState, LatentState]:
125
+ return euler_denoising_loop(
126
+ sigmas=sigmas,
127
+ video_state=video_state,
128
+ audio_state=audio_state,
129
+ stepper=stepper,
130
+ denoise_fn=simple_denoising_func(
131
+ video_context=video_context,
132
+ audio_context=audio_context,
133
+ transformer=transformer, # noqa: F821
134
+ ),
135
+ )
136
+
137
+ stage_1_output_shape = VideoPixelShape(
138
+ batch=1,
139
+ frames=num_frames,
140
+ width=width // 2,
141
+ height=height // 2,
142
+ fps=frame_rate,
143
+ )
144
+ stage_1_conditionings = self._create_conditionings(
145
+ images=images,
146
+ video_conditioning=video_conditioning,
147
+ height=stage_1_output_shape.height,
148
+ width=stage_1_output_shape.width,
149
+ video_encoder=video_encoder,
150
+ num_frames=num_frames,
151
+ )
152
+ video_state, audio_state = denoise_audio_video(
153
+ output_shape=stage_1_output_shape,
154
+ conditionings=stage_1_conditionings,
155
+ noiser=noiser,
156
+ sigmas=stage_1_sigmas,
157
+ stepper=stepper,
158
+ denoising_loop_fn=first_stage_denoising_loop,
159
+ components=self.pipeline_components,
160
+ dtype=dtype,
161
+ device=self.device,
162
+ )
163
+
164
+ torch.cuda.synchronize()
165
+ del transformer
166
+ cleanup_memory()
167
+
168
+ # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
169
+ upscaled_video_latent = upsample_video(
170
+ latent=video_state.latent[:1],
171
+ video_encoder=video_encoder,
172
+ upsampler=self.stage_2_model_ledger.spatial_upsampler(),
173
+ )
174
+
175
+ torch.cuda.synchronize()
176
+ cleanup_memory()
177
+
178
+ transformer = self.stage_2_model_ledger.transformer()
179
+ distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
180
+
181
+ def second_stage_denoising_loop(
182
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
183
+ ) -> tuple[LatentState, LatentState]:
184
+ return euler_denoising_loop(
185
+ sigmas=sigmas,
186
+ video_state=video_state,
187
+ audio_state=audio_state,
188
+ stepper=stepper,
189
+ denoise_fn=simple_denoising_func(
190
+ video_context=video_context,
191
+ audio_context=audio_context,
192
+ transformer=transformer, # noqa: F821
193
+ ),
194
+ )
195
+
196
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
197
+ stage_2_conditionings = image_conditionings_by_replacing_latent(
198
+ images=images,
199
+ height=stage_2_output_shape.height,
200
+ width=stage_2_output_shape.width,
201
+ video_encoder=video_encoder,
202
+ dtype=self.dtype,
203
+ device=self.device,
204
+ )
205
+
206
+ video_state, audio_state = denoise_audio_video(
207
+ output_shape=stage_2_output_shape,
208
+ conditionings=stage_2_conditionings,
209
+ noiser=noiser,
210
+ sigmas=distilled_sigmas,
211
+ stepper=stepper,
212
+ denoising_loop_fn=second_stage_denoising_loop,
213
+ components=self.pipeline_components,
214
+ dtype=dtype,
215
+ device=self.device,
216
+ noise_scale=distilled_sigmas[0],
217
+ initial_video_latent=upscaled_video_latent,
218
+ initial_audio_latent=audio_state.latent,
219
+ )
220
+
221
+ torch.cuda.synchronize()
222
+ del transformer
223
+ del video_encoder
224
+ cleanup_memory()
225
+
226
+ decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
227
+ decoded_audio = vae_decode_audio(
228
+ audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
229
+ )
230
+ return decoded_video, decoded_audio
231
+
232
+ def _create_conditionings(
233
+ self,
234
+ images: list[tuple[str, int, float]],
235
+ video_conditioning: list[tuple[str, float]],
236
+ height: int,
237
+ width: int,
238
+ num_frames: int,
239
+ video_encoder: VideoEncoder,
240
+ ) -> list[ConditioningItem]:
241
+ conditionings = image_conditionings_by_replacing_latent(
242
+ images=images,
243
+ height=height,
244
+ width=width,
245
+ video_encoder=video_encoder,
246
+ dtype=self.dtype,
247
+ device=self.device,
248
+ )
249
+
250
+ for video_path, strength in video_conditioning:
251
+ video = load_video_conditioning(
252
+ video_path=video_path,
253
+ height=height,
254
+ width=width,
255
+ frame_cap=num_frames,
256
+ dtype=self.dtype,
257
+ device=self.device,
258
+ )
259
+ encoded_video = video_encoder(video)
260
+ conditionings.append(VideoConditionByKeyframeIndex(keyframes=encoded_video, frame_idx=0, strength=strength))
261
+
262
+ return conditionings
263
+
264
+
265
+ @torch.inference_mode()
266
+ def main() -> None:
267
+ logging.getLogger().setLevel(logging.INFO)
268
+ parser = default_2_stage_distilled_arg_parser()
269
+ parser.add_argument(
270
+ "--video-conditioning",
271
+ action=VideoConditioningAction,
272
+ nargs=2,
273
+ metavar=("PATH", "STRENGTH"),
274
+ required=True,
275
+ )
276
+ args = parser.parse_args()
277
+ pipeline = ICLoraPipeline(
278
+ checkpoint_path=args.checkpoint_path,
279
+ spatial_upsampler_path=args.spatial_upsampler_path,
280
+ gemma_root=args.gemma_root,
281
+ loras=args.lora,
282
+ fp8transformer=args.enable_fp8,
283
+ )
284
+ tiling_config = TilingConfig.default()
285
+ video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
286
+ video, audio = pipeline(
287
+ prompt=args.prompt,
288
+ seed=args.seed,
289
+ height=args.height,
290
+ width=args.width,
291
+ num_frames=args.num_frames,
292
+ frame_rate=args.frame_rate,
293
+ images=args.images,
294
+ video_conditioning=args.video_conditioning,
295
+ tiling_config=tiling_config,
296
+ )
297
+
298
+ encode_video(
299
+ video=video,
300
+ fps=args.frame_rate,
301
+ audio=audio,
302
+ audio_sample_rate=AUDIO_SAMPLE_RATE,
303
+ output_path=args.output_path,
304
+ video_chunks_number=video_chunks_number,
305
+ )
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()
packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Iterator
3
+
4
+ import torch
5
+
6
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
+ from ltx_core.components.guiders import CFGGuider
8
+ from ltx_core.components.noisers import GaussianNoiser
9
+ from ltx_core.components.protocols import DiffusionStepProtocol
10
+ from ltx_core.components.schedulers import LTX2Scheduler
11
+ from ltx_core.loader import LoraPathStrengthAndSDOps
12
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
+ from ltx_core.model.upsampler import upsample_video
14
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
15
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
16
+ from ltx_core.text_encoders.gemma import encode_text
17
+ from ltx_core.types import LatentState, VideoPixelShape
18
+ from ltx_pipelines.utils import ModelLedger
19
+ from ltx_pipelines.utils.args import default_2_stage_arg_parser
20
+ from ltx_pipelines.utils.constants import (
21
+ AUDIO_SAMPLE_RATE,
22
+ STAGE_2_DISTILLED_SIGMA_VALUES,
23
+ )
24
+ from ltx_pipelines.utils.helpers import (
25
+ assert_resolution,
26
+ cleanup_memory,
27
+ denoise_audio_video,
28
+ euler_denoising_loop,
29
+ generate_enhanced_prompt,
30
+ get_device,
31
+ guider_denoising_func,
32
+ image_conditionings_by_adding_guiding_latent,
33
+ simple_denoising_func,
34
+ )
35
+ from ltx_pipelines.utils.media_io import encode_video
36
+ from ltx_pipelines.utils.types import PipelineComponents
37
+
38
+ device = get_device()
39
+
40
+
41
+ class KeyframeInterpolationPipeline:
42
+ """
43
+ Keyframe-based Two-stage video interpolation pipeline.
44
+ Interpolates between keyframes to generate a video with smoother transitions.
45
+ Stage 1 generates video at the target resolution, then Stage 2 upsamples
46
+ by 2x and refines with additional denoising steps for higher quality output.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ checkpoint_path: str,
52
+ distilled_lora: list[LoraPathStrengthAndSDOps],
53
+ spatial_upsampler_path: str,
54
+ gemma_root: str,
55
+ loras: list[LoraPathStrengthAndSDOps],
56
+ device: torch.device = device,
57
+ fp8transformer: bool = False,
58
+ ):
59
+ self.device = device
60
+ self.dtype = torch.bfloat16
61
+ self.stage_1_model_ledger = ModelLedger(
62
+ dtype=self.dtype,
63
+ device=device,
64
+ checkpoint_path=checkpoint_path,
65
+ spatial_upsampler_path=spatial_upsampler_path,
66
+ gemma_root_path=gemma_root,
67
+ loras=loras,
68
+ fp8transformer=fp8transformer,
69
+ )
70
+ self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
71
+ loras=distilled_lora,
72
+ )
73
+ self.pipeline_components = PipelineComponents(
74
+ dtype=self.dtype,
75
+ device=device,
76
+ )
77
+
78
+ @torch.inference_mode()
79
+ def __call__( # noqa: PLR0913
80
+ self,
81
+ prompt: str,
82
+ negative_prompt: str,
83
+ seed: int,
84
+ height: int,
85
+ width: int,
86
+ num_frames: int,
87
+ frame_rate: float,
88
+ num_inference_steps: int,
89
+ cfg_guidance_scale: float,
90
+ images: list[tuple[str, int, float]],
91
+ tiling_config: TilingConfig | None = None,
92
+ enhance_prompt: bool = False,
93
+ ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
94
+ assert_resolution(height=height, width=width, is_two_stage=True)
95
+
96
+ generator = torch.Generator(device=self.device).manual_seed(seed)
97
+ noiser = GaussianNoiser(generator=generator)
98
+ stepper = EulerDiffusionStep()
99
+ cfg_guider = CFGGuider(cfg_guidance_scale)
100
+ dtype = torch.bfloat16
101
+
102
+ text_encoder = self.stage_1_model_ledger.text_encoder()
103
+ if enhance_prompt:
104
+ prompt = generate_enhanced_prompt(
105
+ text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
106
+ )
107
+ context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
108
+ v_context_p, a_context_p = context_p
109
+ v_context_n, a_context_n = context_n
110
+
111
+ torch.cuda.synchronize()
112
+ del text_encoder
113
+ cleanup_memory()
114
+
115
+ # Stage 1: Initial low resolution video generation.
116
+ video_encoder = self.stage_1_model_ledger.video_encoder()
117
+ transformer = self.stage_1_model_ledger.transformer()
118
+ sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
119
+
120
+ def first_stage_denoising_loop(
121
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
122
+ ) -> tuple[LatentState, LatentState]:
123
+ return euler_denoising_loop(
124
+ sigmas=sigmas,
125
+ video_state=video_state,
126
+ audio_state=audio_state,
127
+ stepper=stepper,
128
+ denoise_fn=guider_denoising_func(
129
+ cfg_guider,
130
+ v_context_p,
131
+ v_context_n,
132
+ a_context_p,
133
+ a_context_n,
134
+ transformer=transformer, # noqa: F821
135
+ ),
136
+ )
137
+
138
+ stage_1_output_shape = VideoPixelShape(
139
+ batch=1,
140
+ frames=num_frames,
141
+ width=width // 2,
142
+ height=height // 2,
143
+ fps=frame_rate,
144
+ )
145
+ stage_1_conditionings = image_conditionings_by_adding_guiding_latent(
146
+ images=images,
147
+ height=stage_1_output_shape.height,
148
+ width=stage_1_output_shape.width,
149
+ video_encoder=video_encoder,
150
+ dtype=dtype,
151
+ device=self.device,
152
+ )
153
+ video_state, audio_state = denoise_audio_video(
154
+ output_shape=stage_1_output_shape,
155
+ conditionings=stage_1_conditionings,
156
+ noiser=noiser,
157
+ sigmas=sigmas,
158
+ stepper=stepper,
159
+ denoising_loop_fn=first_stage_denoising_loop,
160
+ components=self.pipeline_components,
161
+ dtype=dtype,
162
+ device=self.device,
163
+ )
164
+
165
+ torch.cuda.synchronize()
166
+ del transformer
167
+ cleanup_memory()
168
+
169
+ # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
170
+ upscaled_video_latent = upsample_video(
171
+ latent=video_state.latent[:1],
172
+ video_encoder=video_encoder,
173
+ upsampler=self.stage_2_model_ledger.spatial_upsampler(),
174
+ )
175
+
176
+ torch.cuda.synchronize()
177
+ cleanup_memory()
178
+
179
+ transformer = self.stage_2_model_ledger.transformer()
180
+ distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
181
+
182
+ def second_stage_denoising_loop(
183
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
184
+ ) -> tuple[LatentState, LatentState]:
185
+ return euler_denoising_loop(
186
+ sigmas=sigmas,
187
+ video_state=video_state,
188
+ audio_state=audio_state,
189
+ stepper=stepper,
190
+ denoise_fn=simple_denoising_func(
191
+ video_context=v_context_p,
192
+ audio_context=a_context_p,
193
+ transformer=transformer, # noqa: F821
194
+ ),
195
+ )
196
+
197
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
198
+ stage_2_conditionings = image_conditionings_by_adding_guiding_latent(
199
+ images=images,
200
+ height=stage_2_output_shape.height,
201
+ width=stage_2_output_shape.width,
202
+ video_encoder=video_encoder,
203
+ dtype=dtype,
204
+ device=self.device,
205
+ )
206
+ video_state, audio_state = denoise_audio_video(
207
+ output_shape=stage_2_output_shape,
208
+ conditionings=stage_2_conditionings,
209
+ noiser=noiser,
210
+ sigmas=distilled_sigmas,
211
+ stepper=stepper,
212
+ denoising_loop_fn=second_stage_denoising_loop,
213
+ components=self.pipeline_components,
214
+ dtype=dtype,
215
+ device=self.device,
216
+ noise_scale=distilled_sigmas[0],
217
+ initial_video_latent=upscaled_video_latent,
218
+ initial_audio_latent=audio_state.latent,
219
+ )
220
+
221
+ torch.cuda.synchronize()
222
+ del transformer
223
+ del video_encoder
224
+ cleanup_memory()
225
+
226
+ decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
227
+ decoded_audio = vae_decode_audio(
228
+ audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
229
+ )
230
+ return decoded_video, decoded_audio
231
+
232
+
233
+ @torch.inference_mode()
234
+ def main() -> None:
235
+ logging.getLogger().setLevel(logging.INFO)
236
+ parser = default_2_stage_arg_parser()
237
+ args = parser.parse_args()
238
+ pipeline = KeyframeInterpolationPipeline(
239
+ checkpoint_path=args.checkpoint_path,
240
+ distilled_lora=args.distilled_lora,
241
+ spatial_upsampler_path=args.spatial_upsampler_path,
242
+ gemma_root=args.gemma_root,
243
+ loras=args.lora,
244
+ fp8transformer=args.enable_fp8,
245
+ )
246
+ tiling_config = TilingConfig.default()
247
+ video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
248
+ video, audio = pipeline(
249
+ prompt=args.prompt,
250
+ negative_prompt=args.negative_prompt,
251
+ seed=args.seed,
252
+ height=args.height,
253
+ width=args.width,
254
+ num_frames=args.num_frames,
255
+ frame_rate=args.frame_rate,
256
+ num_inference_steps=args.num_inference_steps,
257
+ cfg_guidance_scale=args.cfg_guidance_scale,
258
+ images=args.images,
259
+ tiling_config=tiling_config,
260
+ )
261
+
262
+ encode_video(
263
+ video=video,
264
+ fps=args.frame_rate,
265
+ audio=audio,
266
+ audio_sample_rate=AUDIO_SAMPLE_RATE,
267
+ output_path=args.output_path,
268
+ video_chunks_number=video_chunks_number,
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Iterator
3
+
4
+ import torch
5
+
6
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
+ from ltx_core.components.guiders import CFGGuider
8
+ from ltx_core.components.noisers import GaussianNoiser
9
+ from ltx_core.components.protocols import DiffusionStepProtocol
10
+ from ltx_core.components.schedulers import LTX2Scheduler
11
+ from ltx_core.loader import LoraPathStrengthAndSDOps
12
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
14
+ from ltx_core.text_encoders.gemma import encode_text
15
+ from ltx_core.types import LatentState, VideoPixelShape
16
+ from ltx_pipelines.utils import ModelLedger
17
+ from ltx_pipelines.utils.args import default_1_stage_arg_parser
18
+ from ltx_pipelines.utils.constants import AUDIO_SAMPLE_RATE
19
+ from ltx_pipelines.utils.helpers import (
20
+ assert_resolution,
21
+ cleanup_memory,
22
+ denoise_audio_video,
23
+ euler_denoising_loop,
24
+ generate_enhanced_prompt,
25
+ get_device,
26
+ guider_denoising_func,
27
+ image_conditionings_by_replacing_latent,
28
+ )
29
+ from ltx_pipelines.utils.media_io import encode_video
30
+ from ltx_pipelines.utils.types import PipelineComponents
31
+
32
+ device = get_device()
33
+
34
+
35
+ class TI2VidOneStagePipeline:
36
+ """
37
+ Single-stage text/image-to-video generation pipeline.
38
+ Generates video at the target resolution in a single diffusion pass with
39
+ classifier-free guidance (CFG). Supports optional image conditioning via
40
+ the images parameter.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ checkpoint_path: str,
46
+ gemma_root: str,
47
+ loras: list[LoraPathStrengthAndSDOps],
48
+ device: torch.device = device,
49
+ fp8transformer: bool = False,
50
+ ):
51
+ self.dtype = torch.bfloat16
52
+ self.device = device
53
+ self.model_ledger = ModelLedger(
54
+ dtype=self.dtype,
55
+ device=device,
56
+ checkpoint_path=checkpoint_path,
57
+ gemma_root_path=gemma_root,
58
+ loras=loras,
59
+ fp8transformer=fp8transformer,
60
+ )
61
+ self.pipeline_components = PipelineComponents(
62
+ dtype=self.dtype,
63
+ device=device,
64
+ )
65
+
66
+ def __call__( # noqa: PLR0913
67
+ self,
68
+ prompt: str,
69
+ negative_prompt: str,
70
+ seed: int,
71
+ height: int,
72
+ width: int,
73
+ num_frames: int,
74
+ frame_rate: float,
75
+ num_inference_steps: int,
76
+ cfg_guidance_scale: float,
77
+ images: list[tuple[str, int, float]],
78
+ enhance_prompt: bool = False,
79
+ ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
80
+ assert_resolution(height=height, width=width, is_two_stage=False)
81
+
82
+ generator = torch.Generator(device=self.device).manual_seed(seed)
83
+ noiser = GaussianNoiser(generator=generator)
84
+ stepper = EulerDiffusionStep()
85
+ cfg_guider = CFGGuider(cfg_guidance_scale)
86
+ dtype = torch.bfloat16
87
+
88
+ text_encoder = self.model_ledger.text_encoder()
89
+ if enhance_prompt:
90
+ prompt = generate_enhanced_prompt(
91
+ text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
92
+ )
93
+ context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
94
+ v_context_p, a_context_p = context_p
95
+ v_context_n, a_context_n = context_n
96
+
97
+ torch.cuda.synchronize()
98
+ del text_encoder
99
+ cleanup_memory()
100
+
101
+ # Stage 1: Initial low resolution video generation.
102
+ video_encoder = self.model_ledger.video_encoder()
103
+ transformer = self.model_ledger.transformer()
104
+ sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
105
+
106
+ def first_stage_denoising_loop(
107
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
108
+ ) -> tuple[LatentState, LatentState]:
109
+ return euler_denoising_loop(
110
+ sigmas=sigmas,
111
+ video_state=video_state,
112
+ audio_state=audio_state,
113
+ stepper=stepper,
114
+ denoise_fn=guider_denoising_func(
115
+ cfg_guider,
116
+ v_context_p,
117
+ v_context_n,
118
+ a_context_p,
119
+ a_context_n,
120
+ transformer=transformer, # noqa: F821
121
+ ),
122
+ )
123
+
124
+ stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
125
+ stage_1_conditionings = image_conditionings_by_replacing_latent(
126
+ images=images,
127
+ height=stage_1_output_shape.height,
128
+ width=stage_1_output_shape.width,
129
+ video_encoder=video_encoder,
130
+ dtype=dtype,
131
+ device=self.device,
132
+ )
133
+
134
+ video_state, audio_state = denoise_audio_video(
135
+ output_shape=stage_1_output_shape,
136
+ conditionings=stage_1_conditionings,
137
+ noiser=noiser,
138
+ sigmas=sigmas,
139
+ stepper=stepper,
140
+ denoising_loop_fn=first_stage_denoising_loop,
141
+ components=self.pipeline_components,
142
+ dtype=dtype,
143
+ device=self.device,
144
+ )
145
+
146
+ torch.cuda.synchronize()
147
+ del transformer
148
+ cleanup_memory()
149
+
150
+ decoded_video = vae_decode_video(video_state.latent, self.model_ledger.video_decoder())
151
+ decoded_audio = vae_decode_audio(
152
+ audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
153
+ )
154
+
155
+ return decoded_video, decoded_audio
156
+
157
+
158
+ @torch.inference_mode()
159
+ def main() -> None:
160
+ logging.getLogger().setLevel(logging.INFO)
161
+ parser = default_1_stage_arg_parser()
162
+ args = parser.parse_args()
163
+ pipeline = TI2VidOneStagePipeline(
164
+ checkpoint_path=args.checkpoint_path,
165
+ gemma_root=args.gemma_root,
166
+ loras=args.lora,
167
+ fp8transformer=args.enable_fp8,
168
+ )
169
+ video, audio = pipeline(
170
+ prompt=args.prompt,
171
+ negative_prompt=args.negative_prompt,
172
+ seed=args.seed,
173
+ height=args.height,
174
+ width=args.width,
175
+ num_frames=args.num_frames,
176
+ frame_rate=args.frame_rate,
177
+ num_inference_steps=args.num_inference_steps,
178
+ cfg_guidance_scale=args.cfg_guidance_scale,
179
+ images=args.images,
180
+ )
181
+
182
+ encode_video(
183
+ video=video,
184
+ fps=args.frame_rate,
185
+ audio=audio,
186
+ audio_sample_rate=AUDIO_SAMPLE_RATE,
187
+ output_path=args.output_path,
188
+ video_chunks_number=1,
189
+ )
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Iterator
3
+
4
+ import torch
5
+
6
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
7
+ from ltx_core.components.guiders import CFGGuider
8
+ from ltx_core.components.noisers import GaussianNoiser
9
+ from ltx_core.components.protocols import DiffusionStepProtocol
10
+ from ltx_core.components.schedulers import LTX2Scheduler
11
+ from ltx_core.loader import LoraPathStrengthAndSDOps
12
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
13
+ from ltx_core.model.upsampler import upsample_video
14
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
15
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
16
+ from ltx_core.text_encoders.gemma import encode_text
17
+ from ltx_core.types import LatentState, VideoPixelShape
18
+ from ltx_pipelines.utils import ModelLedger
19
+ from ltx_pipelines.utils.args import default_2_stage_arg_parser
20
+ from ltx_pipelines.utils.constants import (
21
+ AUDIO_SAMPLE_RATE,
22
+ STAGE_2_DISTILLED_SIGMA_VALUES,
23
+ )
24
+ from ltx_pipelines.utils.helpers import (
25
+ assert_resolution,
26
+ cleanup_memory,
27
+ denoise_audio_video,
28
+ euler_denoising_loop,
29
+ generate_enhanced_prompt,
30
+ get_device,
31
+ guider_denoising_func,
32
+ image_conditionings_by_replacing_latent,
33
+ simple_denoising_func,
34
+ )
35
+ from ltx_pipelines.utils.media_io import encode_video
36
+ from ltx_pipelines.utils.types import PipelineComponents
37
+
38
+ device = get_device()
39
+
40
+
41
+ class TI2VidTwoStagesPipeline:
42
+ """
43
+ Two-stage text/image-to-video generation pipeline.
44
+ Stage 1 generates video at the target resolution with CFG guidance, then
45
+ Stage 2 upsamples by 2x and refines using a distilled LoRA for higher
46
+ quality output. Supports optional image conditioning via the images parameter.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ checkpoint_path: str,
52
+ distilled_lora: list[LoraPathStrengthAndSDOps],
53
+ spatial_upsampler_path: str,
54
+ gemma_root: str,
55
+ loras: list[LoraPathStrengthAndSDOps],
56
+ device: str = device,
57
+ fp8transformer: bool = False,
58
+ ):
59
+ self.device = device
60
+ self.dtype = torch.bfloat16
61
+ self.stage_1_model_ledger = ModelLedger(
62
+ dtype=self.dtype,
63
+ device=device,
64
+ checkpoint_path=checkpoint_path,
65
+ gemma_root_path=gemma_root,
66
+ spatial_upsampler_path=spatial_upsampler_path,
67
+ loras=loras,
68
+ fp8transformer=fp8transformer,
69
+ )
70
+
71
+ self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
72
+ loras=distilled_lora,
73
+ )
74
+
75
+ self.pipeline_components = PipelineComponents(
76
+ dtype=self.dtype,
77
+ device=device,
78
+ )
79
+
80
+ @torch.inference_mode()
81
+ def __call__( # noqa: PLR0913
82
+ self,
83
+ prompt: str,
84
+ negative_prompt: str,
85
+ seed: int,
86
+ height: int,
87
+ width: int,
88
+ num_frames: int,
89
+ frame_rate: float,
90
+ num_inference_steps: int,
91
+ cfg_guidance_scale: float,
92
+ images: list[tuple[str, int, float]],
93
+ tiling_config: TilingConfig | None = None,
94
+ enhance_prompt: bool = False,
95
+ ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
96
+ assert_resolution(height=height, width=width, is_two_stage=True)
97
+
98
+ generator = torch.Generator(device=self.device).manual_seed(seed)
99
+ noiser = GaussianNoiser(generator=generator)
100
+ stepper = EulerDiffusionStep()
101
+ cfg_guider = CFGGuider(cfg_guidance_scale)
102
+ dtype = torch.bfloat16
103
+
104
+ text_encoder = self.stage_1_model_ledger.text_encoder()
105
+ if enhance_prompt:
106
+ prompt = generate_enhanced_prompt(
107
+ text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
108
+ )
109
+ context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
110
+ v_context_p, a_context_p = context_p
111
+ v_context_n, a_context_n = context_n
112
+
113
+ torch.cuda.synchronize()
114
+ del text_encoder
115
+ cleanup_memory()
116
+
117
+ # Stage 1: Initial low resolution video generation.
118
+ video_encoder = self.stage_1_model_ledger.video_encoder()
119
+ transformer = self.stage_1_model_ledger.transformer()
120
+ sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
121
+
122
+ def first_stage_denoising_loop(
123
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
124
+ ) -> tuple[LatentState, LatentState]:
125
+ return euler_denoising_loop(
126
+ sigmas=sigmas,
127
+ video_state=video_state,
128
+ audio_state=audio_state,
129
+ stepper=stepper,
130
+ denoise_fn=guider_denoising_func(
131
+ cfg_guider,
132
+ v_context_p,
133
+ v_context_n,
134
+ a_context_p,
135
+ a_context_n,
136
+ transformer=transformer, # noqa: F821
137
+ ),
138
+ )
139
+
140
+ stage_1_output_shape = VideoPixelShape(
141
+ batch=1,
142
+ frames=num_frames,
143
+ width=width // 2,
144
+ height=height // 2,
145
+ fps=frame_rate,
146
+ )
147
+ stage_1_conditionings = image_conditionings_by_replacing_latent(
148
+ images=images,
149
+ height=stage_1_output_shape.height,
150
+ width=stage_1_output_shape.width,
151
+ video_encoder=video_encoder,
152
+ dtype=dtype,
153
+ device=self.device,
154
+ )
155
+ video_state, audio_state = denoise_audio_video(
156
+ output_shape=stage_1_output_shape,
157
+ conditionings=stage_1_conditionings,
158
+ noiser=noiser,
159
+ sigmas=sigmas,
160
+ stepper=stepper,
161
+ denoising_loop_fn=first_stage_denoising_loop,
162
+ components=self.pipeline_components,
163
+ dtype=dtype,
164
+ device=self.device,
165
+ )
166
+
167
+ torch.cuda.synchronize()
168
+ del transformer
169
+ cleanup_memory()
170
+
171
+ # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
172
+ upscaled_video_latent = upsample_video(
173
+ latent=video_state.latent[:1],
174
+ video_encoder=video_encoder,
175
+ upsampler=self.stage_2_model_ledger.spatial_upsampler(),
176
+ )
177
+
178
+ torch.cuda.synchronize()
179
+ cleanup_memory()
180
+
181
+ transformer = self.stage_2_model_ledger.transformer()
182
+ distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
183
+
184
+ def second_stage_denoising_loop(
185
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
186
+ ) -> tuple[LatentState, LatentState]:
187
+ return euler_denoising_loop(
188
+ sigmas=sigmas,
189
+ video_state=video_state,
190
+ audio_state=audio_state,
191
+ stepper=stepper,
192
+ denoise_fn=simple_denoising_func(
193
+ video_context=v_context_p,
194
+ audio_context=a_context_p,
195
+ transformer=transformer, # noqa: F821
196
+ ),
197
+ )
198
+
199
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
200
+ stage_2_conditionings = image_conditionings_by_replacing_latent(
201
+ images=images,
202
+ height=stage_2_output_shape.height,
203
+ width=stage_2_output_shape.width,
204
+ video_encoder=video_encoder,
205
+ dtype=dtype,
206
+ device=self.device,
207
+ )
208
+ video_state, audio_state = denoise_audio_video(
209
+ output_shape=stage_2_output_shape,
210
+ conditionings=stage_2_conditionings,
211
+ noiser=noiser,
212
+ sigmas=distilled_sigmas,
213
+ stepper=stepper,
214
+ denoising_loop_fn=second_stage_denoising_loop,
215
+ components=self.pipeline_components,
216
+ dtype=dtype,
217
+ device=self.device,
218
+ noise_scale=distilled_sigmas[0],
219
+ initial_video_latent=upscaled_video_latent,
220
+ initial_audio_latent=audio_state.latent,
221
+ )
222
+
223
+ torch.cuda.synchronize()
224
+ del transformer
225
+ del video_encoder
226
+ cleanup_memory()
227
+
228
+ decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
229
+ decoded_audio = vae_decode_audio(
230
+ audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
231
+ )
232
+
233
+ return decoded_video, decoded_audio
234
+
235
+
236
+ @torch.inference_mode()
237
+ def main() -> None:
238
+ logging.getLogger().setLevel(logging.INFO)
239
+ parser = default_2_stage_arg_parser()
240
+ args = parser.parse_args()
241
+ pipeline = TI2VidTwoStagesPipeline(
242
+ checkpoint_path=args.checkpoint_path,
243
+ distilled_lora=args.distilled_lora,
244
+ spatial_upsampler_path=args.spatial_upsampler_path,
245
+ gemma_root=args.gemma_root,
246
+ loras=args.lora,
247
+ fp8transformer=args.enable_fp8,
248
+ )
249
+ tiling_config = TilingConfig.default()
250
+ video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
251
+ video, audio = pipeline(
252
+ prompt=args.prompt,
253
+ negative_prompt=args.negative_prompt,
254
+ seed=args.seed,
255
+ height=args.height,
256
+ width=args.width,
257
+ num_frames=args.num_frames,
258
+ frame_rate=args.frame_rate,
259
+ num_inference_steps=args.num_inference_steps,
260
+ cfg_guidance_scale=args.cfg_guidance_scale,
261
+ images=args.images,
262
+ tiling_config=tiling_config,
263
+ )
264
+
265
+ encode_video(
266
+ video=video,
267
+ fps=args.frame_rate,
268
+ audio=audio,
269
+ audio_sample_rate=AUDIO_SAMPLE_RATE,
270
+ output_path=args.output_path,
271
+ video_chunks_number=video_chunks_number,
272
+ )
273
+
274
+
275
+ if __name__ == "__main__":
276
+ main()