dagloop5 commited on
Commit
9ada948
·
verified ·
1 Parent(s): 4c3fdc0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +723 -0
app.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Installation and Setup
3
+ # =============================================================================
4
+ import os
5
+ import subprocess
6
+ import sys
7
+
8
+ # Disable torch.compile / dynamo before any torch import
9
+ # This prevents CUDA initialization issues in the Space environment
10
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
11
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
12
+
13
+ # Clone LTX-2 repo at specific commit for reproducibility
14
+ # The commit ensures we have the exact pipeline code matching our analysis
15
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
16
+ LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
17
+ # Using specific commit for stability - can be updated to main later
18
+ LTX_COMMIT_SHA = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
19
+
20
+ if not os.path.exists(LTX_REPO_DIR):
21
+ print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
22
+ os.makedirs(LTX_REPO_DIR)
23
+ subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
24
+ subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
25
+ subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
26
+ subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
27
+
28
+ # Add repo packages to Python path
29
+ # This allows us to import from ltx-core and ltx-pipelines
30
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
31
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
32
+
33
+ # =============================================================================
34
+ # Imports
35
+ # =============================================================================
36
+ import logging
37
+ import random
38
+ import tempfile
39
+ from pathlib import Path
40
+
41
+ import torch
42
+ # Disable torch.compile/dynamo at runtime level
43
+ torch._dynamo.config.suppress_errors = True
44
+ torch._dynamo.config.disable = True
45
+
46
+ import spaces
47
+ import gradio as gr
48
+ import numpy as np
49
+ from huggingface_hub import hf_hub_download, snapshot_download
50
+
51
+ # Import from the cloned LTX-2 pipeline
52
+ # These imports come from ti2vid_two_stages_hq.py
53
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
54
+ from ltx_core.quantization import QuantizationPolicy
55
+ from ltx_core.loader import LoraPathStrengthAndSDOps
56
+ from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
57
+ from ltx_pipelines.utils.args import ImageConditioningInput
58
+ from ltx_pipelines.utils.media_io import encode_video
59
+ from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS
60
+ from ltx_core.components.guiders import MultiModalGuiderParams
61
+
62
+ # =============================================================================
63
+ # Constants and Configuration
64
+ # =============================================================================
65
+
66
+ # Model repository on Hugging Face
67
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
68
+ GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
69
+
70
+ # Default parameters from LTX_2_3_HQ_PARAMS
71
+ DEFAULT_FRAME_RATE = 24.0
72
+
73
+ # Resolution constraints (must be divisible by 64 for two-stage pipeline)
74
+ # The pipeline generates at half-resolution in Stage 1, so input must be divisible by 2
75
+ MIN_DIM = 256
76
+ MAX_DIM = 1280
77
+ STEP = 64 # Both width and height must be divisible by 64
78
+
79
+ # Duration constraints (frames must be 8*K + 1)
80
+ MIN_FRAMES = 9 # 8*1 + 1
81
+ MAX_FRAMES = 257 # 8*32 + 1
82
+
83
+ # Seed range
84
+ MAX_SEED = np.iinfo(np.int32).max
85
+
86
+ # Default prompts
87
+ DEFAULT_PROMPT = (
88
+ "A majestic eagle soaring over mountain peaks at sunset, "
89
+ "wings spread wide against the orange sky, feathers catching the light, "
90
+ "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
91
+ )
92
+ DEFAULT_NEGATIVE_PROMPT = (
93
+ "worst quality, inconsistent motion, blurry, jittery, distorted, "
94
+ "deformed, artifacts, text, watermark, logo, frame, border, "
95
+ "low resolution, pixelated, unnatural, fake, CGI, cartoon"
96
+ )
97
+
98
+ # =============================================================================
99
+ # Model Download and Initialization
100
+ # =============================================================================
101
+
102
+ print("=" * 80)
103
+ print("Downloading LTX-2.3 models...")
104
+ print("=" * 80)
105
+
106
+ # Download all required model files
107
+ # 1. Dev checkpoint - full trainable 22B model
108
+ checkpoint_path = hf_hub_download(
109
+ repo_id=LTX_MODEL_REPO,
110
+ filename="ltx-2.3-22b-dev.safetensors"
111
+ )
112
+ print(f"Dev checkpoint: {checkpoint_path}")
113
+
114
+ # 2. Spatial upscaler - x2 upscaler for latent space
115
+ spatial_upsampler_path = hf_hub_download(
116
+ repo_id=LTX_MODEL_REPO,
117
+ filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
118
+ )
119
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
120
+
121
+ # 3. Distilled LoRA - distilled knowledge in LoRA format (rank 384)
122
+ # This LoRA is specifically trained to work with the dev model
123
+ distilled_lora_path = hf_hub_download(
124
+ repo_id=LTX_MODEL_REPO,
125
+ filename="ltx-2.3-22b-distilled-lora-384.safetensors"
126
+ )
127
+ print(f"Distilled LoRA: {distilled_lora_path}")
128
+
129
+ # 4. Gemma text encoder - required for prompt encoding
130
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
131
+ print(f"Gemma root: {gemma_root}")
132
+
133
+ print("=" * 80)
134
+ print("All models downloaded!")
135
+ print("=" * 80)
136
+
137
+ # =============================================================================
138
+ # Pipeline Initialization
139
+ # =============================================================================
140
+
141
+ # Create the LoraPathStrengthAndSDOps for distilled LoRA
142
+ # The sd_ops parameter uses the ComfyUI renaming map for compatibility
143
+ from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP
144
+
145
+ distilled_lora = [
146
+ LoraPathStrengthAndSDOps(
147
+ path=distilled_lora_path,
148
+ strength=1.0, # Will be set per-stage (0.25 for stage 1, 0.5 for stage 2)
149
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
150
+ )
151
+ ]
152
+
153
+ # Initialize the Two-Stage HQ Pipeline
154
+ # Key parameters:
155
+ # - checkpoint_path: Full dev model (trainable)
156
+ # - distilled_lora: LoRA containing distilled knowledge
157
+ # - distilled_lora_strength_stage_1: 0.25 (lighter application at half-res)
158
+ # - distilled_lora_strength_stage_2: 0.5 (stronger application after upscaling)
159
+ # - spatial_upsampler_path: Required for two-stage upscaling
160
+ # - gemma_root: Gemma text encoder for prompt encoding
161
+ print("Initializing LTX-2.3 Two-Stage HQ Pipeline...")
162
+
163
+ pipeline = TI2VidTwoStagesHQPipeline(
164
+ checkpoint_path=checkpoint_path,
165
+ distilled_lora=distilled_lora,
166
+ distilled_lora_strength_stage_1=0.25, # From HQ params
167
+ distilled_lora_strength_stage_2=0.50, # From HQ params
168
+ spatial_upsampler_path=spatial_upsampler_path,
169
+ gemma_root=gemma_root,
170
+ loras=(), # No additional custom LoRAs for this Space
171
+ quantization=QuantizationPolicy.fp8_cast(), # FP8 for memory efficiency
172
+ torch_compile=False, # Disable for Space compatibility
173
+ )
174
+
175
+ print("Pipeline initialized successfully!")
176
+ print("=" * 80)
177
+
178
+ # =============================================================================
179
+ # Helper Functions
180
+ # =============================================================================
181
+
182
+ def log_memory(tag: str):
183
+ """Log current GPU memory usage for debugging."""
184
+ if torch.cuda.is_available():
185
+ allocated = torch.cuda.memory_allocated() / 1024**3
186
+ peak = torch.cuda.max_memory_allocated() / 1024**3
187
+ free, total = torch.cuda.mem_get_info()
188
+ print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
189
+
190
+
191
+ def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
192
+ """
193
+ Calculate number of frames from duration.
194
+
195
+ Frame count must be 8*K + 1 (K is a non-negative integer) for the LTX model.
196
+ This constraint comes from the temporal upsampling architecture.
197
+
198
+ Args:
199
+ duration: Duration in seconds
200
+ frame_rate: Frames per second
201
+
202
+ Returns:
203
+ Frame count that satisfies the 8*K + 1 constraint
204
+ """
205
+ ideal_frames = int(duration * frame_rate)
206
+ # Ensure it's at least MIN_FRAMES
207
+ ideal_frames = max(ideal_frames, MIN_FRAMES)
208
+ # Round to nearest 8*K + 1
209
+ k = round((ideal_frames - 1) / 8)
210
+ frames = k * 8 + 1
211
+ # Clamp to max
212
+ return min(frames, MAX_FRAMES)
213
+
214
+
215
+ def validate_resolution(height: int, width: int) -> tuple[int, int]:
216
+ """
217
+ Ensure resolution is valid for two-stage pipeline.
218
+
219
+ The two-stage pipeline requires:
220
+ - Both dimensions divisible by 64 (for final resolution)
221
+ - Stage 1 operates at half resolution (divisible by 32)
222
+
223
+ Args:
224
+ height: Target height
225
+ width: Target width
226
+
227
+ Returns:
228
+ Validated (height, width) tuple
229
+ """
230
+ # Round to nearest multiple of 64
231
+ height = round(height / STEP) * STEP
232
+ width = round(width / STEP) * STEP
233
+
234
+ # Clamp to valid range
235
+ height = max(MIN_DIM, min(height, MAX_DIM))
236
+ width = max(MIN_DIM, min(width, MAX_DIM))
237
+
238
+ return height, width
239
+
240
+
241
+ def detect_aspect_ratio(image) -> str:
242
+ """Detect the closest aspect ratio from an image for resolution presets."""
243
+ if image is None:
244
+ return "16:9"
245
+
246
+ if hasattr(image, "size"):
247
+ w, h = image.size
248
+ elif hasattr(image, "shape"):
249
+ h, w = image.shape[:2]
250
+ else:
251
+ return "16:9"
252
+
253
+ ratio = w / h
254
+ candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
255
+ return min(candidates, key=lambda k: abs(ratio - candidates[k]))
256
+
257
+
258
+ # Resolution presets based on aspect ratio
259
+ RESOLUTIONS = {
260
+ "16:9": {"width": 1280, "height": 704}, # 960x540 * 1.33 = 1280x720, halved = 640x360 -> 1280x720
261
+ "9:16": {"width": 704, "height": 1280},
262
+ "1:1": {"width": 960, "height": 960},
263
+ }
264
+
265
+
266
+ def get_duration(
267
+ prompt: str,
268
+ negative_prompt: str,
269
+ duration: float,
270
+ height: int,
271
+ width: int,
272
+ num_frames: int,
273
+ enhance_prompt: bool,
274
+ video_cfg: float,
275
+ audio_cfg: float,
276
+ progress,
277
+ ) -> int:
278
+ """
279
+ Dynamically calculate GPU duration based on generation parameters.
280
+
281
+ This is used by @spaces.GPU to set the appropriate time limit.
282
+ Longer videos and higher resolution require more time.
283
+
284
+ Args:
285
+ duration: Video duration in seconds
286
+ height, width: Resolution
287
+ num_frames: Number of frames (indicates complexity)
288
+
289
+ Returns:
290
+ Duration in seconds for the GPU allocation
291
+ """
292
+ base = 60
293
+
294
+ # Longer videos need more time
295
+ if duration > 4:
296
+ base += 15
297
+ if duration > 6:
298
+ base += 15
299
+
300
+ # Higher resolution needs more time
301
+ if height > 700 or width > 1000:
302
+ base += 15
303
+
304
+ # More frames means more processing
305
+ if num_frames > 81:
306
+ base += 10
307
+
308
+ return min(base, 90)
309
+
310
+
311
+ @spaces.GPU(duration=get_duration)
312
+ @torch.inference_mode()
313
+ def generate_video(
314
+ prompt: str,
315
+ negative_prompt: str,
316
+ input_image,
317
+ duration: float,
318
+ seed: int,
319
+ randomize_seed: bool,
320
+ height: int,
321
+ width: int,
322
+ enhance_prompt: bool,
323
+ # Guidance parameters
324
+ video_cfg_scale: float,
325
+ video_stg_scale: float,
326
+ video_rescale_scale: float,
327
+ video_a2v_scale: float,
328
+ audio_cfg_scale: float,
329
+ audio_stg_scale: float,
330
+ audio_rescale_scale: float,
331
+ audio_v2a_scale: float,
332
+ progress=gr.Progress(track_tqdm=True),
333
+ ):
334
+ """
335
+ Generate high-quality video using the Two-Stage HQ Pipeline.
336
+
337
+ This function implements a two-stage generation process:
338
+
339
+ Stage 1 (Half Resolution + CFG):
340
+ - Generates video at half the target resolution
341
+ - Uses GuidedDenoiser with CFG (positive + negative prompts)
342
+ - Applies distilled LoRA at strength 0.25
343
+ - Res2s sampler for efficient second-order denoising
344
+
345
+ Stage 2 (Upscale + Refine):
346
+ - Upscales latent representation 2x using spatial upsampler
347
+ - Refines using SimpleDenoiser (no CFG, distilled approach)
348
+ - Applies distilled LoRA at strength 0.5
349
+ - 4-step refined denoising schedule
350
+
351
+ Args:
352
+ prompt: Text description of desired video content
353
+ negative_prompt: What to avoid in the video
354
+ input_image: Optional input image for image-to-video
355
+ duration: Video duration in seconds
356
+ seed: Random seed for reproducibility
357
+ randomize_seed: Whether to use a random seed
358
+ height, width: Target resolution (must be divisible by 64)
359
+ enhance_prompt: Whether to use prompt enhancement
360
+ video_cfg_scale: Video CFG (prompt adherence)
361
+ video_stg_scale: Video STG (spatio-temporal guidance)
362
+ video_rescale_scale: Video rescaling factor
363
+ video_a2v_scale: Audio-to-video cross-attention scale
364
+ audio_cfg_scale: Audio CFG (prompt adherence)
365
+ audio_stg_scale: Audio STG (spatio-temporal guidance)
366
+ audio_rescale_scale: Audio rescaling factor
367
+ audio_v2a_scale: Video-to-audio cross-attention scale
368
+
369
+ Returns:
370
+ Tuple of (output_video_path, used_seed)
371
+ """
372
+ try:
373
+ torch.cuda.reset_peak_memory_stats()
374
+ log_memory("start")
375
+
376
+ # Handle random seed
377
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
378
+ print(f"Using seed: {current_seed}")
379
+
380
+ # Validate and adjust resolution
381
+ height, width = validate_resolution(int(height), int(width))
382
+ print(f"Resolution: {width}x{height}")
383
+
384
+ # Calculate frames (must be 8*K + 1)
385
+ num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
386
+ print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
387
+
388
+ # Prepare image conditioning if provided
389
+ images = []
390
+ if input_image is not None:
391
+ # Save input image temporarily
392
+ output_dir = Path("outputs")
393
+ output_dir.mkdir(exist_ok=True)
394
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
395
+
396
+ if hasattr(input_image, "save"):
397
+ input_image.save(temp_image_path)
398
+ else:
399
+ import shutil
400
+ shutil.copy(input_image, temp_image_path)
401
+
402
+ # Create ImageConditioningInput
403
+ # path: image file path
404
+ # frame_idx: target frame to condition on (0 = first frame)
405
+ # strength: conditioning strength (1.0 = full influence)
406
+ images = [ImageConditioningInput(
407
+ path=str(temp_image_path),
408
+ frame_idx=0,
409
+ strength=1.0
410
+ )]
411
+
412
+ # Create tiling config for VAE decoding
413
+ # Tiling is necessary to avoid OOM errors during decoding
414
+ tiling_config = TilingConfig.default()
415
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
416
+
417
+ # Configure MultiModalGuider parameters
418
+ # These control how the model adheres to prompts and handles modality guidance
419
+
420
+ # Video guider parameters
421
+ # cfg_scale: Classifier-free guidance scale (higher = stronger prompt adherence)
422
+ # stg_scale: Spatio-temporal guidance scale (0 = disabled)
423
+ # rescale_scale: Rescaling factor for oversaturation prevention
424
+ # modality_scale: Cross-attention scale (audio-to-video)
425
+ # skip_step: Step skipping for faster inference (0 = no skipping)
426
+ # stg_blocks: Which transformer blocks to perturb for STG
427
+ video_guider_params = MultiModalGuiderParams(
428
+ cfg_scale=video_cfg_scale,
429
+ stg_scale=video_stg_scale,
430
+ rescale_scale=video_rescale_scale,
431
+ modality_scale=video_a2v_scale,
432
+ skip_step=0,
433
+ stg_blocks=[], # Empty for LTX 2.3 HQ
434
+ )
435
+
436
+ # Audio guider parameters
437
+ audio_guider_params = MultiModalGuiderParams(
438
+ cfg_scale=audio_cfg_scale,
439
+ stg_scale=audio_stg_scale,
440
+ rescale_scale=audio_rescale_scale,
441
+ modality_scale=audio_v2a_scale,
442
+ skip_step=0,
443
+ stg_blocks=[], # Empty for LTX 2.3 HQ
444
+ )
445
+
446
+ log_memory("before pipeline call")
447
+
448
+ # Call the pipeline
449
+ # The pipeline uses Res2sDiffusionStep for second-order sampling
450
+ # Stage 1: num_inference_steps from LTX_2_3_HQ_PARAMS (15 steps)
451
+ # Stage 2: Fixed 4-step schedule from STAGE_2_DISTILLED_SIGMAS
452
+ video, audio = pipeline(
453
+ prompt=prompt,
454
+ negative_prompt=negative_prompt,
455
+ seed=current_seed,
456
+ height=height,
457
+ width=width,
458
+ num_frames=num_frames,
459
+ frame_rate=DEFAULT_FRAME_RATE,
460
+ num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps, # 15 steps
461
+ video_guider_params=video_guider_params,
462
+ audio_guider_params=audio_guider_params,
463
+ images=images,
464
+ tiling_config=tiling_config,
465
+ enhance_prompt=enhance_prompt,
466
+ )
467
+
468
+ log_memory("after pipeline call")
469
+
470
+ # Encode video with audio
471
+ output_path = tempfile.mktemp(suffix=".mp4")
472
+ encode_video(
473
+ video=video,
474
+ fps=DEFAULT_FRAME_RATE,
475
+ audio=audio,
476
+ output_path=output_path,
477
+ video_chunks_number=video_chunks_number,
478
+ )
479
+
480
+ log_memory("after encode_video")
481
+ return str(output_path), current_seed
482
+
483
+ except Exception as e:
484
+ import traceback
485
+ log_memory("on error")
486
+ print(f"Error: {str(e)}\n{traceback.format_exc()}")
487
+ return None, current_seed
488
+
489
+
490
+ # =============================================================================
491
+ # Gradio UI
492
+ # =============================================================================
493
+
494
+ css = """
495
+ /* Custom styling for LTX-2.3 Space */
496
+ .fillable {max-width: 1200px !important}
497
+ .progress-text {color: white}
498
+ """
499
+
500
+ with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation", css=css) as demo:
501
+ gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
502
+ gr.Markdown(
503
+ "High-quality text/image-to-video generation using the dev model + distilled LoRA. "
504
+ "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
505
+ "[[GitHub]](https://github.com/Lightricks/LTX-2)"
506
+ )
507
+
508
+ with gr.Row():
509
+ # Input Column
510
+ with gr.Column():
511
+ # Input image (optional)
512
+ input_image = gr.Image(
513
+ label="Input Image (Optional - for image-to-video)",
514
+ type="pil",
515
+ sources=["upload", "webcam", "clipboard"]
516
+ )
517
+
518
+ # Prompt inputs
519
+ prompt = gr.Textbox(
520
+ label="Prompt",
521
+ info="Describe the video you want to generate",
522
+ value=DEFAULT_PROMPT,
523
+ lines=3,
524
+ placeholder="Enter your prompt here..."
525
+ )
526
+
527
+ negative_prompt = gr.Textbox(
528
+ label="Negative Prompt",
529
+ info="What to avoid in the generated video",
530
+ value=DEFAULT_NEGATIVE_PROMPT,
531
+ lines=2,
532
+ placeholder="Enter negative prompt here..."
533
+ )
534
+
535
+ # Duration slider
536
+ duration = gr.Slider(
537
+ label="Duration (seconds)",
538
+ minimum=0.5,
539
+ maximum=8.0,
540
+ value=2.0,
541
+ step=0.1,
542
+ info="Video duration (clamped to 8K+1 frames)"
543
+ )
544
+
545
+ # Enhance prompt toggle
546
+ enhance_prompt = gr.Checkbox(
547
+ label="Enhance Prompt",
548
+ value=False,
549
+ info="Use Gemma to enhance the prompt for better results"
550
+ )
551
+
552
+ # Generate button
553
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
554
+
555
+ # Output Column
556
+ with gr.Column():
557
+ output_video = gr.Video(
558
+ label="Generated Video",
559
+ autoplay=True,
560
+ interactive=False
561
+ )
562
+
563
+ # Advanced Settings Accordion
564
+ with gr.Accordion("Advanced Settings", open=False):
565
+ with gr.Row():
566
+ # Resolution inputs
567
+ width = gr.Number(
568
+ label="Width",
569
+ value=1280,
570
+ precision=0,
571
+ info="Must be divisible by 64"
572
+ )
573
+ height = gr.Number(
574
+ label="Height",
575
+ value=704,
576
+ precision=0,
577
+ info="Must be divisible by 64"
578
+ )
579
+
580
+ with gr.Row():
581
+ # Seed controls
582
+ seed = gr.Number(
583
+ label="Seed",
584
+ value=42,
585
+ precision=0,
586
+ minimum=0,
587
+ maximum=MAX_SEED
588
+ )
589
+ randomize_seed = gr.Checkbox(
590
+ label="Randomize Seed",
591
+ value=True
592
+ )
593
+
594
+ gr.Markdown("### Video Guidance Parameters")
595
+ gr.Markdown("Control how strongly the model follows the video prompt and handles guidance.")
596
+
597
+ with gr.Row():
598
+ video_cfg_scale = gr.Slider(
599
+ label="Video CFG Scale",
600
+ minimum=1.0,
601
+ maximum=10.0,
602
+ value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale,
603
+ step=0.1,
604
+ info="Classifier-free guidance for video (higher = stronger prompt adherence)"
605
+ )
606
+ video_stg_scale = gr.Slider(
607
+ label="Video STG Scale",
608
+ minimum=0.0,
609
+ maximum=2.0,
610
+ value=0.0,
611
+ step=0.1,
612
+ info="Spatio-temporal guidance (0 = disabled)"
613
+ )
614
+
615
+ with gr.Row():
616
+ video_rescale_scale = gr.Slider(
617
+ label="Video Rescale",
618
+ minimum=0.0,
619
+ maximum=2.0,
620
+ value=0.45,
621
+ step=0.1,
622
+ info="Rescaling factor for oversaturation prevention"
623
+ )
624
+ video_a2v_scale = gr.Slider(
625
+ label="A2V Scale",
626
+ minimum=0.0,
627
+ maximum=5.0,
628
+ value=3.0,
629
+ step=0.1,
630
+ info="Audio-to-video cross-attention scale"
631
+ )
632
+
633
+ gr.Markdown("### Audio Guidance Parameters")
634
+ gr.Markdown("Control audio generation quality and sync.")
635
+
636
+ with gr.Row():
637
+ audio_cfg_scale = gr.Slider(
638
+ label="Audio CFG Scale",
639
+ minimum=1.0,
640
+ maximum=15.0,
641
+ value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale,
642
+ step=0.1,
643
+ info="Classifier-free guidance for audio"
644
+ )
645
+ audio_stg_scale = gr.Slider(
646
+ label="Audio STG Scale",
647
+ minimum=0.0,
648
+ maximum=2.0,
649
+ value=0.0,
650
+ step=0.1,
651
+ info="Spatio-temporal guidance for audio (0 = disabled)"
652
+ )
653
+
654
+ with gr.Row():
655
+ audio_rescale_scale = gr.Slider(
656
+ label="Audio Rescale",
657
+ minimum=0.0,
658
+ maximum=2.0,
659
+ value=1.0,
660
+ step=0.1,
661
+ info="Audio rescaling factor"
662
+ )
663
+ audio_v2a_scale = gr.Slider(
664
+ label="V2A Scale",
665
+ minimum=0.0,
666
+ maximum=5.0,
667
+ value=3.0,
668
+ step=0.1,
669
+ info="Video-to-audio cross-attention scale"
670
+ )
671
+
672
+ # Event handlers
673
+ def on_image_upload(image, current_h, current_w):
674
+ """Update resolution based on uploaded image aspect ratio."""
675
+ if image is None:
676
+ return gr.update(), gr.update()
677
+
678
+ aspect = detect_aspect_ratio(image)
679
+ if aspect in RESOLUTIONS:
680
+ return (
681
+ gr.update(value=RESOLUTIONS[aspect]["width"]),
682
+ gr.update(value=RESOLUTIONS[aspect]["height"])
683
+ )
684
+ return gr.update(), gr.update()
685
+
686
+ input_image.change(
687
+ fn=on_image_upload,
688
+ inputs=[input_image, height, width],
689
+ outputs=[width, height],
690
+ )
691
+
692
+ # Generate button click handler
693
+ generate_btn.click(
694
+ fn=generate_video,
695
+ inputs=[
696
+ prompt,
697
+ negative_prompt,
698
+ input_image,
699
+ duration,
700
+ seed,
701
+ randomize_seed,
702
+ height,
703
+ width,
704
+ enhance_prompt,
705
+ video_cfg_scale,
706
+ video_stg_scale,
707
+ video_rescale_scale,
708
+ video_a2v_scale,
709
+ audio_cfg_scale,
710
+ audio_stg_scale,
711
+ audio_rescale_scale,
712
+ audio_v2a_scale,
713
+ ],
714
+ outputs=[output_video, seed],
715
+ )
716
+
717
+
718
+ # =============================================================================
719
+ # Main Entry Point
720
+ # =============================================================================
721
+
722
+ if __name__ == "__main__":
723
+ demo.queue().launch(mcp_server=True)