dagloop5 commited on
Commit
2061db8
·
verified ·
1 Parent(s): 51195c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +930 -0
app.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Installation and Setup
3
+ # =============================================================================
4
+ import os
5
+ import subprocess
6
+ import sys
7
+
8
+ # Disable torch.compile / dynamo before any torch import
9
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
10
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
11
+
12
+ # Clone LTX-2 repo at specific commit
13
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
+ LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
+ LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45"
16
+
17
+ if not os.path.exists(LTX_REPO_DIR):
18
+ print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
19
+ os.makedirs(LTX_REPO_DIR)
20
+ subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
21
+ subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
22
+ subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
+ subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
24
+
25
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
26
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
27
+
28
+ # =============================================================================
29
+ # Imports
30
+ # =============================================================================
31
+ import logging
32
+ import random
33
+ import tempfile
34
+ from pathlib import Path
35
+ from typing import Optional, Any
36
+
37
+ import torch
38
+ torch._dynamo.config.suppress_errors = True
39
+ torch._dynamo.config.disable = True
40
+
41
+ import gradio as gr
42
+ import spaces
43
+ import numpy as np
44
+ from huggingface_hub import hf_hub_download, snapshot_download
45
+
46
+ # Core LTX imports
47
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
48
+ from ltx_core.quantization import QuantizationPolicy
49
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
50
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
51
+ from ltx_core.components.noisers import GaussianNoiser
52
+ from ltx_core.components.schedulers import LTX2Scheduler
53
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
54
+ from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
55
+
56
+ # Pipeline utilities
57
+ from ltx_pipelines.utils.args import ImageConditioningInput
58
+ from ltx_pipelines.utils.media_io import encode_video
59
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
60
+ from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
61
+ from ltx_pipelines.utils.types import ModalitySpec
62
+ from ltx_pipelines.utils.helpers import assert_resolution, combined_image_conditionings, get_device
63
+ from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS, STAGE_2_DISTILLED_SIGMA_VALUES
64
+
65
+ # Model builders
66
+ from ltx_core.loader.single_gpu_model_builder import (
67
+ TransformerBuilder,
68
+ VideoEncoderBuilder,
69
+ VideoDecoderBuilder,
70
+ AudioDecoderBuilder,
71
+ VocoderBuilder,
72
+ UpsamplerBuilder,
73
+ TextEncoderBuilder,
74
+ )
75
+ from ltx_core.model.transformer import X0Model
76
+ from ltx_core.model.video_vae import VideoEncoder, VideoDecoder
77
+ from ltx_core.model.audio_vae import AudioDecoder as AVAudioDecoder, Vocoder
78
+ from ltx_core.model.upsampler import LatentUpsampler
79
+ from ltx_core.text_encoders.gemma import GemmaTextEncoder
80
+ from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessorBuilder
81
+
82
+ logging.getLogger().setLevel(logging.INFO)
83
+
84
+ # =============================================================================
85
+ # Constants and Configuration
86
+ # =============================================================================
87
+
88
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
89
+ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
90
+ DEFAULT_FRAME_RATE = 24.0
91
+ MIN_DIM, MAX_DIM, STEP = 256, 1280, 64
92
+ MIN_FRAMES, MAX_FRAMES = 9, 257
93
+ MAX_SEED = np.iinfo(np.int32).max
94
+
95
+ DEFAULT_PROMPT = (
96
+ "A majestic eagle soaring over mountain peaks at sunset, "
97
+ "wings spread wide against the orange sky, feathers catching the light, "
98
+ "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
99
+ )
100
+ DEFAULT_NEGATIVE_PROMPT = (
101
+ "worst quality, inconsistent motion, blurry, jittery, distorted, "
102
+ "deformed, artifacts, text, watermark, logo, frame, border, "
103
+ "low resolution, pixelated, unnatural, fake, CGI, cartoon"
104
+ )
105
+
106
+ # =============================================================================
107
+ # HQ Pipeline with model_ledger - Custom Implementation
108
+ # =============================================================================
109
+
110
+ class HQModelLedger:
111
+ """
112
+ Model ledger that stores preloaded models for ZeroGPU tensor packing.
113
+ Mimics the pattern used in DistilledPipeline's official Space.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ checkpoint_path: str,
119
+ distilled_lora_path: str,
120
+ distilled_lora_strength_stage_1: float,
121
+ distilled_lora_strength_stage_2: float,
122
+ spatial_upsampler_path: str,
123
+ gemma_root: str,
124
+ loras: tuple,
125
+ device: torch.device,
126
+ dtype: torch.dtype,
127
+ quantization: Optional[QuantizationPolicy] = None,
128
+ ):
129
+ self.device = device
130
+ self.dtype = dtype
131
+ self._target_device = device
132
+ self._checkpoint_path = checkpoint_path
133
+ self._spatial_upsampler_path = spatial_upsampler_path
134
+ self._gemma_root = gemma_root
135
+ self._quantization = quantization
136
+
137
+ # Cached models (set to None initially)
138
+ self._transformer_stage1 = None
139
+ self._transformer_stage2 = None
140
+ self._video_encoder = None
141
+ self._video_decoder = None
142
+ self._audio_decoder = None
143
+ self._vocoder = None
144
+ self._spatial_upsampler = None
145
+ self._text_encoder = None
146
+ self._embeddings_processor = None
147
+
148
+ # LoRA configurations
149
+ self._distilled_lora_path = distilled_lora_path
150
+ self._distilled_lora_strength_stage_1 = distilled_lora_strength_stage_1
151
+ self._distilled_lora_strength_stage_2 = distilled_lora_strength_stage_2
152
+ self._loras = loras
153
+
154
+ # Build configurations
155
+ self._build_configs()
156
+
157
+ def _build_configs(self):
158
+ """Create builder configurations with LoRAs."""
159
+ # Stage 1 LoRA list
160
+ stage1_loras = [
161
+ LoraPathStrengthAndSDOps(
162
+ path=self._distilled_lora_path,
163
+ strength=self._distilled_lora_strength_stage_1,
164
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
165
+ )
166
+ ]
167
+ # Add custom loras
168
+ for lora in self._loras:
169
+ stage1_loras.append(lora)
170
+
171
+ # Stage 2 LoRA list (different strength)
172
+ stage2_loras = [
173
+ LoraPathStrengthAndSDOps(
174
+ path=self._distilled_lora_path,
175
+ strength=self._distilled_lora_strength_stage_2,
176
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
177
+ )
178
+ ]
179
+ for lora in self._loras:
180
+ stage2_loras.append(lora)
181
+
182
+ # Transformer builder for stage 1
183
+ self._transformer_builder_stage1 = (
184
+ TransformerBuilder.from_checkpoint(self._checkpoint_path)
185
+ .with_loras(stage1_loras)
186
+ )
187
+
188
+ # Transformer builder for stage 2
189
+ self._transformer_builder_stage2 = (
190
+ TransformerBuilder.from_checkpoint(self._checkpoint_path)
191
+ .with_loras(stage2_loras)
192
+ )
193
+
194
+ # Other builders (no LoRAs)
195
+ self._video_encoder_builder = VideoEncoderBuilder.from_checkpoint(self._checkpoint_path)
196
+ self._video_decoder_builder = VideoDecoderBuilder.from_checkpoint(self._checkpoint_path)
197
+ self._audio_decoder_builder = AudioDecoderBuilder.from_checkpoint(self._checkpoint_path)
198
+ self._vocoder_builder = VocoderBuilder.from_checkpoint(self._checkpoint_path)
199
+ self._spatial_upsampler_builder = UpsamplerBuilder.from_checkpoint(self._spatial_upsampler_path)
200
+ self._text_encoder_builder = TextEncoderBuilder.from_gemma(self._gemma_root)
201
+ self._embeddings_processor_builder = EmbeddingsProcessorBuilder.from_checkpoint(self._checkpoint_path)
202
+
203
+ def transformer(self, stage: int = 1):
204
+ """Get or build transformer model."""
205
+ if stage == 1:
206
+ if self._transformer_stage1 is None:
207
+ print(" Building transformer (stage 1)...")
208
+ model = self._transformer_builder_stage1.build(
209
+ device=self._target_device,
210
+ dtype=self.dtype,
211
+ )
212
+ self._transformer_stage1 = X0Model(model).to(self.device).eval()
213
+ return self._transformer_stage1
214
+ else:
215
+ if self._transformer_stage2 is None:
216
+ print(" Building transformer (stage 2)...")
217
+ model = self._transformer_builder_stage2.build(
218
+ device=self._target_device,
219
+ dtype=self.dtype,
220
+ )
221
+ self._transformer_stage2 = X0Model(model).to(self.device).eval()
222
+ return self._transformer_stage2
223
+
224
+ def video_encoder(self):
225
+ """Get or build video encoder."""
226
+ if self._video_encoder is None:
227
+ print(" Building video encoder...")
228
+ self._video_encoder = self._video_encoder_builder.build(
229
+ device=self._target_device,
230
+ dtype=self.dtype,
231
+ ).to(self.device).eval()
232
+ return self._video_encoder
233
+
234
+ def video_decoder(self):
235
+ """Get or build video decoder."""
236
+ if self._video_decoder is None:
237
+ print(" Building video decoder...")
238
+ self._video_decoder = self._video_decoder_builder.build(
239
+ device=self._target_device,
240
+ dtype=self.dtype,
241
+ ).to(self.device).eval()
242
+ return self._video_decoder
243
+
244
+ def audio_decoder(self):
245
+ """Get or build audio decoder."""
246
+ if self._audio_decoder is None:
247
+ print(" Building audio decoder...")
248
+ self._audio_decoder = self._audio_decoder_builder.build(
249
+ device=self._target_device,
250
+ dtype=self.dtype,
251
+ ).to(self.device).eval()
252
+ return self._audio_decoder
253
+
254
+ def vocoder(self):
255
+ """Get or build vocoder."""
256
+ if self._vocoder is None:
257
+ print(" Building vocoder...")
258
+ self._vocoder = self._vocoder_builder.build(
259
+ device=self._target_device,
260
+ dtype=self.dtype,
261
+ ).to(self.device).eval()
262
+ return self._vocoder
263
+
264
+ def spatial_upsampler(self):
265
+ """Get or build spatial upsampler."""
266
+ if self._spatial_upsampler is None:
267
+ print(" Building spatial upsampler...")
268
+ self._spatial_upsampler = self._spatial_upsampler_builder.build(
269
+ device=self._target_device,
270
+ dtype=self.dtype,
271
+ ).to(self.device).eval()
272
+ return self._spatial_upsampler
273
+
274
+ def text_encoder(self):
275
+ """Get or build text encoder."""
276
+ if self._text_encoder is None:
277
+ print(" Building text encoder (Gemma)...")
278
+ self._text_encoder = self._text_encoder_builder.build(
279
+ device=self._target_device,
280
+ dtype=self.dtype,
281
+ ).to(self.device).eval()
282
+ return self._text_encoder
283
+
284
+ def embeddings_processor(self):
285
+ """Get or build embeddings processor."""
286
+ if self._embeddings_processor is None:
287
+ print(" Building embeddings processor...")
288
+ self._embeddings_processor = self._embeddings_processor_builder.build(
289
+ device=self._target_device,
290
+ dtype=self.dtype,
291
+ ).to(self.device).eval()
292
+ return self._embeddings_processor
293
+
294
+
295
+ class TI2VidTwoStagesHQPipelineWithLedger:
296
+ """
297
+ Two-stage text/image-to-video generation pipeline using model_ledger.
298
+ Same as TI2VidTwoStagesHQPipeline but uses model_ledger for ZeroGPU compatibility.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ checkpoint_path: str,
304
+ distilled_lora_path: str,
305
+ distilled_lora_strength_stage_1: float,
306
+ distilled_lora_strength_stage_2: float,
307
+ spatial_upsampler_path: str,
308
+ gemma_root: str,
309
+ loras: tuple = (),
310
+ device: Optional[torch.device] = None,
311
+ quantization: Optional[QuantizationPolicy] = None,
312
+ torch_compile: bool = False,
313
+ ):
314
+ self.device = device or get_device()
315
+ self.dtype = torch.bfloat16
316
+ self._torch_compile = torch_compile
317
+
318
+ # Create model ledger
319
+ self.model_ledger = HQModelLedger(
320
+ checkpoint_path=checkpoint_path,
321
+ distilled_lora_path=distilled_lora_path,
322
+ distilled_lora_strength_stage_1=distilled_lora_strength_stage_1,
323
+ distilled_lora_strength_stage_2=distilled_lora_strength_stage_2,
324
+ spatial_upsampler_path=spatial_upsampler_path,
325
+ gemma_root=gemma_root,
326
+ loras=loras,
327
+ device=self.device,
328
+ dtype=self.dtype,
329
+ quantization=quantization,
330
+ )
331
+
332
+ # Scheduler and stepper
333
+ self._scheduler = LTX2Scheduler()
334
+ self._stepper = Res2sDiffusionStep()
335
+
336
+ @torch.inference_mode()
337
+ def __call__(
338
+ self,
339
+ prompt: str,
340
+ negative_prompt: str,
341
+ seed: int,
342
+ height: int,
343
+ width: int,
344
+ num_frames: int,
345
+ frame_rate: float,
346
+ num_inference_steps: int,
347
+ video_guider_params: MultiModalGuiderParams,
348
+ audio_guider_params: MultiModalGuiderParams,
349
+ images: list[ImageConditioningInput],
350
+ tiling_config: Optional[TilingConfig] = None,
351
+ enhance_prompt: bool = False,
352
+ streaming_prefetch_count: Optional[int] = None,
353
+ max_batch_size: int = 1,
354
+ ):
355
+ assert_resolution(height=height, width=width, is_two_stage=True)
356
+
357
+ generator = torch.Generator(device=self.device).manual_seed(seed)
358
+ noiser = GaussianNoiser(generator=generator)
359
+
360
+ # Get models from ledger
361
+ text_encoder = self.model_ledger.text_encoder()
362
+ embeddings_processor = self.model_ledger.embeddings_processor()
363
+ video_encoder = self.model_ledger.video_encoder()
364
+
365
+ # Encode prompts
366
+ # Encode positive prompt
367
+ ctx_p = embeddings_processor.create_embeddings(
368
+ text_encoder([prompt]),
369
+ video_encoder,
370
+ images[0].path if len(images) > 0 and enhance_prompt else None,
371
+ seed if enhance_prompt else None,
372
+ )
373
+
374
+ # Encode negative prompt
375
+ ctx_n = embeddings_processor.create_embeddings(
376
+ text_encoder([negative_prompt]),
377
+ video_encoder,
378
+ None,
379
+ None,
380
+ )
381
+
382
+ v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
383
+ v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
384
+
385
+ # Stage 1: Generate at half resolution with CFG
386
+ stage_1_output_shape = VideoPixelShape(
387
+ batch=1,
388
+ frames=num_frames,
389
+ width=width // 2,
390
+ height=height // 2,
391
+ fps=frame_rate,
392
+ )
393
+
394
+ stage_1_conditionings = combined_image_conditionings(
395
+ images=images,
396
+ height=stage_1_output_shape.height,
397
+ width=stage_1_output_shape.width,
398
+ video_encoder=video_encoder,
399
+ dtype=self.dtype,
400
+ device=self.device,
401
+ )
402
+
403
+ empty_latent = torch.empty(
404
+ VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape(),
405
+ dtype=self.dtype,
406
+ device=self.device,
407
+ )
408
+ sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
409
+ sigmas = sigmas.to(dtype=torch.float32, device=self.device)
410
+
411
+ transformer = self.model_ledger.transformer(stage=1)
412
+
413
+ video_state, audio_state = res2s_audio_video_denoising_loop(
414
+ transformer=transformer,
415
+ denoiser=GuidedDenoiser(
416
+ v_context=v_context_p,
417
+ a_context=a_context_p,
418
+ video_guider=MultiModalGuider(params=video_guider_params, negative_context=v_context_n),
419
+ audio_guider=MultiModalGuider(params=audio_guider_params, negative_context=a_context_n),
420
+ ),
421
+ sigmas=sigmas,
422
+ noiser=noiser,
423
+ stepper=self._stepper,
424
+ width=stage_1_output_shape.width,
425
+ height=stage_1_output_shape.height,
426
+ frames=num_frames,
427
+ fps=frame_rate,
428
+ video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
429
+ audio=ModalitySpec(context=a_context_p),
430
+ streaming_prefetch_count=streaming_prefetch_count,
431
+ max_batch_size=max_batch_size,
432
+ )
433
+
434
+ # Stage 2: Upscale and refine
435
+ upscaled_video_latent = self.model_ledger.spatial_upsampler()(video_state.latent[:1])
436
+
437
+ distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device, dtype=torch.float32)
438
+
439
+ stage_2_conditionings = combined_image_conditionings(
440
+ images=images,
441
+ height=height,
442
+ width=width,
443
+ video_encoder=video_encoder,
444
+ dtype=self.dtype,
445
+ device=self.device,
446
+ )
447
+
448
+ transformer = self.model_ledger.transformer(stage=2)
449
+
450
+ video_state, audio_state = res2s_audio_video_denoising_loop(
451
+ transformer=transformer,
452
+ denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
453
+ sigmas=distilled_sigmas,
454
+ noiser=noiser,
455
+ stepper=self._stepper,
456
+ width=width,
457
+ height=height,
458
+ frames=num_frames,
459
+ fps=frame_rate,
460
+ video=ModalitySpec(
461
+ context=v_context_p,
462
+ conditionings=stage_2_conditionings,
463
+ noise_scale=distilled_sigmas[0].item(),
464
+ initial_latent=upscaled_video_latent,
465
+ ),
466
+ audio=ModalitySpec(
467
+ context=a_context_p,
468
+ noise_scale=distilled_sigmas[0].item(),
469
+ initial_latent=audio_state.latent,
470
+ ),
471
+ streaming_prefetch_count=streaming_prefetch_count,
472
+ max_batch_size=max_batch_size,
473
+ )
474
+
475
+ # Decode
476
+ video_decoder = self.model_ledger.video_decoder()
477
+ audio_decoder = self.model_ledger.audio_decoder()
478
+
479
+ decoded_video = video_decoder(video_state.latent, tiling_config, generator)
480
+ decoded_audio = audio_decoder(audio_state.latent)
481
+
482
+ return decoded_video, decoded_audio
483
+
484
+
485
+ # =============================================================================
486
+ # Model Download
487
+ # =============================================================================
488
+
489
+ print("=" * 80)
490
+ print("Downloading LTX-2.3 models...")
491
+ print("=" * 80)
492
+
493
+ checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-dev.safetensors")
494
+ spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
495
+ distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors")
496
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
497
+
498
+ print(f"Checkpoint: {checkpoint_path}")
499
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
500
+ print(f"Distilled LoRA: {distilled_lora_path}")
501
+ print(f"Gemma root: {gemma_root}")
502
+
503
+ print("=" * 80)
504
+ print("All models downloaded!")
505
+ print("=" * 80)
506
+
507
+ # =============================================================================
508
+ # Pipeline Initialization
509
+ # =============================================================================
510
+
511
+ print("Initializing TI2VidTwoStagesHQPipelineWithLedger...")
512
+
513
+ pipeline = TI2VidTwoStagesHQPipelineWithLedger(
514
+ checkpoint_path=checkpoint_path,
515
+ distilled_lora_path=distilled_lora_path,
516
+ distilled_lora_strength_stage_1=0.25,
517
+ distilled_lora_strength_stage_2=0.50,
518
+ spatial_upsampler_path=spatial_upsampler_path,
519
+ gemma_root=gemma_root,
520
+ loras=(),
521
+ quantization=QuantizationPolicy.fp8_cast(),
522
+ torch_compile=False,
523
+ )
524
+
525
+ print("Pipeline initialized successfully!")
526
+ print("=" * 80)
527
+
528
+ # =============================================================================
529
+ # ZeroGPU Tensor Preloading - model_ledger Pattern
530
+ # =============================================================================
531
+ print("Preloading all models for ZeroGPU tensor packing...")
532
+ print("This may take a few minutes...")
533
+
534
+ # Access model ledger
535
+ ledger = pipeline.model_ledger
536
+
537
+ # Preload all models - this mimics the official Space's pattern
538
+ print(" Loading transformer (stage 1)...")
539
+ _transformer_s1 = ledger.transformer(stage=1)
540
+ ledger._transformer_stage1 = _transformer_s1
541
+
542
+ print(" Loading transformer (stage 2)...")
543
+ _transformer_s2 = ledger.transformer(stage=2)
544
+ ledger._transformer_stage2 = _transformer_s2
545
+
546
+ print(" Loading video encoder...")
547
+ _ve = ledger.video_encoder()
548
+ ledger._video_encoder = _ve
549
+
550
+ print(" Loading video decoder...")
551
+ _vd = ledger.video_decoder()
552
+ ledger._video_decoder = _vd
553
+
554
+ print(" Loading audio decoder...")
555
+ _ad = ledger.audio_decoder()
556
+ ledger._audio_decoder = _ad
557
+
558
+ print(" Loading vocoder...")
559
+ _voc = ledger.vocoder()
560
+ ledger._vocoder = _voc
561
+
562
+ print(" Loading spatial upsampler...")
563
+ _su = ledger.spatial_upsampler()
564
+ ledger._spatial_upsampler = _su
565
+
566
+ print(" Loading text encoder (Gemma)...")
567
+ _te = ledger.text_encoder()
568
+ ledger._text_encoder = _te
569
+
570
+ print(" Loading embeddings processor...")
571
+ _ep = ledger.embeddings_processor()
572
+ ledger._embeddings_processor = _ep
573
+
574
+ # Replace methods with lambdas to prevent garbage collection
575
+ # This is the CRITICAL step that makes ZeroGPU tensor packing work
576
+ def ledger_transformer(stage=1):
577
+ return ledger._transformer_stage1 if stage == 1 else ledger._transformer_stage2
578
+
579
+ ledger.transformer = ledger_transformer
580
+ ledger.video_encoder = lambda: ledger._video_encoder
581
+ ledger.video_decoder = lambda: ledger._video_decoder
582
+ ledger.audio_decoder = lambda: ledger._audio_decoder
583
+ ledger.vocoder = lambda: ledger._vocoder
584
+ ledger.spatial_upsampler = lambda: ledger._spatial_upsampler
585
+ ledger.text_encoder = lambda: ledger._text_encoder
586
+ ledger.embeddings_processor = lambda: ledger._embeddings_processor
587
+
588
+ # Create global references to prevent garbage collection
589
+ global _transformer_s1, _transformer_s2, _ve, _vd, _ad, _voc, _su, _te, _ep
590
+
591
+ print("All models preloaded for ZeroGPU tensor packing!")
592
+ print("=" * 80)
593
+
594
+ # =============================================================================
595
+ # Helper Functions
596
+ # =============================================================================
597
+
598
+ def log_memory(tag: str):
599
+ if torch.cuda.is_available():
600
+ allocated = torch.cuda.memory_allocated() / 1024**3
601
+ peak = torch.cuda.max_memory_allocated() / 1024**3
602
+ free, total = torch.cuda.mem_get_info()
603
+ print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
604
+
605
+
606
+ def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
607
+ ideal_frames = int(duration * frame_rate)
608
+ ideal_frames = max(ideal_frames, MIN_FRAMES)
609
+ k = round((ideal_frames - 1) / 8)
610
+ frames = k * 8 + 1
611
+ return min(frames, MAX_FRAMES)
612
+
613
+
614
+ def validate_resolution(height: int, width: int) -> tuple[int, int]:
615
+ height = round(height / STEP) * STEP
616
+ width = round(width / STEP) * STEP
617
+ height = max(MIN_DIM, min(height, MAX_DIM))
618
+ width = max(MIN_DIM, min(width, MAX_DIM))
619
+ return height, width
620
+
621
+
622
+ def detect_aspect_ratio(image) -> str:
623
+ if image is None:
624
+ return "16:9"
625
+ if hasattr(image, "size"):
626
+ w, h = image.size
627
+ elif hasattr(image, "shape"):
628
+ h, w = image.shape[:2]
629
+ else:
630
+ return "16:9"
631
+ ratio = w / h
632
+ candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
633
+ return min(candidates, key=lambda k: abs(ratio - candidates[k]))
634
+
635
+
636
+ RESOLUTIONS = {
637
+ "16:9": {"width": 1280, "height": 704},
638
+ "9:16": {"width": 704, "height": 1280},
639
+ "1:1": {"width": 960, "height": 960},
640
+ }
641
+
642
+
643
+ def get_duration(
644
+ prompt: str,
645
+ negative_prompt: str,
646
+ input_image,
647
+ duration: float,
648
+ seed: int,
649
+ randomize_seed: bool,
650
+ height: int,
651
+ width: int,
652
+ enhance_prompt: bool,
653
+ video_cfg_scale: float,
654
+ video_stg_scale: float,
655
+ video_rescale_scale: float,
656
+ video_a2v_scale: float,
657
+ audio_cfg_scale: float,
658
+ audio_stg_scale: float,
659
+ audio_rescale_scale: float,
660
+ audio_v2a_scale: float,
661
+ progress,
662
+ ) -> int:
663
+ base = 60
664
+ if duration > 4:
665
+ base += 15
666
+ if duration > 6:
667
+ base += 15
668
+ if height > 700 or width > 1000:
669
+ base += 15
670
+ frames_from_duration = int(duration * DEFAULT_FRAME_RATE)
671
+ if frames_from_duration > 81:
672
+ base += 10
673
+ return min(base, 90)
674
+
675
+
676
+ @spaces.GPU(duration=get_duration)
677
+ @torch.inference_mode()
678
+ def generate_video(
679
+ prompt: str,
680
+ negative_prompt: str,
681
+ input_image,
682
+ duration: float,
683
+ seed: int,
684
+ randomize_seed: bool,
685
+ height: int,
686
+ width: int,
687
+ enhance_prompt: bool,
688
+ video_cfg_scale: float,
689
+ video_stg_scale: float,
690
+ video_rescale_scale: float,
691
+ video_a2v_scale: float,
692
+ audio_cfg_scale: float,
693
+ audio_stg_scale: float,
694
+ audio_rescale_scale: float,
695
+ audio_v2a_scale: float,
696
+ progress=gr.Progress(track_tqdm=True),
697
+ ):
698
+ try:
699
+ torch.cuda.reset_peak_memory_stats()
700
+ log_memory("start")
701
+
702
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
703
+ print(f"Using seed: {current_seed}")
704
+
705
+ height, width = validate_resolution(int(height), int(width))
706
+ print(f"Resolution: {width}x{height}")
707
+
708
+ num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
709
+ print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
710
+
711
+ images = []
712
+ if input_image is not None:
713
+ output_dir = Path("outputs")
714
+ output_dir.mkdir(exist_ok=True)
715
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
716
+ if hasattr(input_image, "save"):
717
+ input_image.save(temp_image_path)
718
+ else:
719
+ import shutil
720
+ shutil.copy(input_image, temp_image_path)
721
+ images = [ImageConditioningInput(
722
+ path=str(temp_image_path),
723
+ frame_idx=0,
724
+ strength=1.0
725
+ )]
726
+
727
+ tiling_config = TilingConfig.default()
728
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
729
+
730
+ video_guider_params = MultiModalGuiderParams(
731
+ cfg_scale=video_cfg_scale,
732
+ stg_scale=video_stg_scale,
733
+ rescale_scale=video_rescale_scale,
734
+ modality_scale=video_a2v_scale,
735
+ skip_step=0,
736
+ stg_blocks=[],
737
+ )
738
+
739
+ audio_guider_params = MultiModalGuiderParams(
740
+ cfg_scale=audio_cfg_scale,
741
+ stg_scale=audio_stg_scale,
742
+ rescale_scale=audio_rescale_scale,
743
+ modality_scale=audio_v2a_scale,
744
+ skip_step=0,
745
+ stg_blocks=[],
746
+ )
747
+
748
+ log_memory("before pipeline call")
749
+
750
+ video, audio = pipeline(
751
+ prompt=prompt,
752
+ negative_prompt=negative_prompt,
753
+ seed=current_seed,
754
+ height=height,
755
+ width=width,
756
+ num_frames=num_frames,
757
+ frame_rate=DEFAULT_FRAME_RATE,
758
+ num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
759
+ video_guider_params=video_guider_params,
760
+ audio_guider_params=audio_guider_params,
761
+ images=images,
762
+ tiling_config=tiling_config,
763
+ enhance_prompt=enhance_prompt,
764
+ )
765
+
766
+ log_memory("after pipeline call")
767
+
768
+ output_path = tempfile.mktemp(suffix=".mp4")
769
+ encode_video(
770
+ video=video,
771
+ fps=DEFAULT_FRAME_RATE,
772
+ audio=audio,
773
+ output_path=output_path,
774
+ video_chunks_number=video_chunks_number,
775
+ )
776
+
777
+ log_memory("after encode_video")
778
+ return str(output_path), current_seed
779
+
780
+ except Exception as e:
781
+ import traceback
782
+ log_memory("on error")
783
+ print(f"Error: {str(e)}\n{traceback.format_exc()}")
784
+ return None, current_seed
785
+
786
+
787
+ # =============================================================================
788
+ # Gradio UI
789
+ # =============================================================================
790
+
791
+ css = """
792
+ .fillable {max-width: 1200px !important}
793
+ .progress-text {color: white}
794
+ """
795
+
796
+ with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo:
797
+ gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
798
+ gr.Markdown(
799
+ "High-quality text/image-to-video generation using the dev model + distilled LoRA. "
800
+ "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
801
+ "[[GitHub]](https://github.com/Lightricks/LTX-2)"
802
+ )
803
+
804
+ with gr.Row():
805
+ with gr.Column():
806
+ input_image = gr.Image(
807
+ label="Input Image (Optional - for image-to-video)",
808
+ type="pil",
809
+ sources=["upload", "webcam", "clipboard"]
810
+ )
811
+
812
+ prompt = gr.Textbox(
813
+ label="Prompt",
814
+ info="Describe the video you want to generate",
815
+ value=DEFAULT_PROMPT,
816
+ lines=3,
817
+ placeholder="Enter your prompt here..."
818
+ )
819
+
820
+ negative_prompt = gr.Textbox(
821
+ label="Negative Prompt",
822
+ info="What to avoid in the generated video",
823
+ value=DEFAULT_NEGATIVE_PROMPT,
824
+ lines=2,
825
+ )
826
+
827
+ duration = gr.Slider(
828
+ label="Duration (seconds)",
829
+ minimum=0.5,
830
+ maximum=8.0,
831
+ value=2.0,
832
+ step=0.1,
833
+ )
834
+
835
+ enhance_prompt = gr.Checkbox(
836
+ label="Enhance Prompt",
837
+ value=False,
838
+ )
839
+
840
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
841
+
842
+ with gr.Column():
843
+ output_video = gr.Video(
844
+ label="Generated Video",
845
+ autoplay=True,
846
+ interactive=False
847
+ )
848
+
849
+ with gr.Accordion("Advanced Settings", open=False):
850
+ with gr.Row():
851
+ width = gr.Number(label="Width", value=1280, precision=0)
852
+ height = gr.Number(label="Height", value=704, precision=0)
853
+
854
+ with gr.Row():
855
+ seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED)
856
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
857
+
858
+ gr.Markdown("### Video Guidance Parameters")
859
+
860
+ with gr.Row():
861
+ video_cfg_scale = gr.Slider(
862
+ label="Video CFG Scale", minimum=1.0, maximum=10.0,
863
+ value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale, step=0.1
864
+ )
865
+ video_stg_scale = gr.Slider(
866
+ label="Video STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
867
+ )
868
+
869
+ with gr.Row():
870
+ video_rescale_scale = gr.Slider(
871
+ label="Video Rescale", minimum=0.0, maximum=2.0, value=0.45, step=0.1
872
+ )
873
+ video_a2v_scale = gr.Slider(
874
+ label="A2V Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
875
+ )
876
+
877
+ gr.Markdown("### Audio Guidance Parameters")
878
+
879
+ with gr.Row():
880
+ audio_cfg_scale = gr.Slider(
881
+ label="Audio CFG Scale", minimum=1.0, maximum=15.0,
882
+ value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale, step=0.1
883
+ )
884
+ audio_stg_scale = gr.Slider(
885
+ label="Audio STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
886
+ )
887
+
888
+ with gr.Row():
889
+ audio_rescale_scale = gr.Slider(
890
+ label="Audio Rescale", minimum=0.0, maximum=2.0, value=1.0, step=0.1
891
+ )
892
+ audio_v2a_scale = gr.Slider(
893
+ label="V2A Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
894
+ )
895
+
896
+ def on_image_upload(image, current_h, current_w):
897
+ if image is None:
898
+ return gr.update(), gr.update()
899
+ aspect = detect_aspect_ratio(image)
900
+ if aspect in RESOLUTIONS:
901
+ return (
902
+ gr.update(value=RESOLUTIONS[aspect]["width"]),
903
+ gr.update(value=RESOLUTIONS[aspect]["height"])
904
+ )
905
+ return gr.update(), gr.update()
906
+
907
+ input_image.change(
908
+ fn=on_image_upload,
909
+ inputs=[input_image, height, width],
910
+ outputs=[width, height],
911
+ )
912
+
913
+ generate_btn.click(
914
+ fn=generate_video,
915
+ inputs=[
916
+ prompt, negative_prompt, input_image, duration,
917
+ seed, randomize_seed, height, width, enhance_prompt,
918
+ video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale,
919
+ audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale,
920
+ ],
921
+ outputs=[output_video, seed],
922
+ )
923
+
924
+
925
+ if __name__ == "__main__":
926
+ demo.queue().launch(
927
+ theme=gr.themes.Citrus(),
928
+ css=css,
929
+ mcp_server=True,
930
+ )