dagloop5 commited on
Commit
bfed689
·
verified ·
1 Parent(s): 287d66d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +611 -0
app.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Installation and Setup
3
+ # =============================================================================
4
+ import os
5
+ import subprocess
6
+ import sys
7
+
8
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
9
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
10
+
11
+ # Clone LTX-2 repo at the commit with ModelLedger-based HQ pipeline
12
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
13
+ LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
14
+ LTX_COMMIT_SHA = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
15
+
16
+ if not os.path.exists(LTX_REPO_DIR):
17
+ print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
18
+ os.makedirs(LTX_REPO_DIR)
19
+ subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
20
+ subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
21
+ subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
22
+ subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
+
24
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
25
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
26
+
27
+ # =============================================================================
28
+ # Imports
29
+ # =============================================================================
30
+ import logging
31
+ import random
32
+ import tempfile
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ torch._dynamo.config.suppress_errors = True
37
+ torch._dynamo.config.disable = True
38
+
39
+ import gradio as gr
40
+ import spaces
41
+ import numpy as np
42
+ from huggingface_hub import hf_hub_download, snapshot_download
43
+
44
+ # Core LTX imports from the ModelLedger-based HQ pipeline
45
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
46
+ from ltx_core.quantization import QuantizationPolicy
47
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
48
+ from ltx_core.components.guiders import MultiModalGuiderParams
49
+
50
+ # Import the HQ pipeline with ModelLedger
51
+ from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
52
+
53
+ # Import constants
54
+ from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS
55
+
56
+ logging.getLogger().setLevel(logging.INFO)
57
+
58
+ # =============================================================================
59
+ # Constants
60
+ # =============================================================================
61
+
62
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
63
+ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
64
+ DEFAULT_FRAME_RATE = 24.0
65
+ MIN_DIM, MAX_DIM, STEP = 256, 1280, 64
66
+ MIN_FRAMES, MAX_FRAMES = 9, 257
67
+ MAX_SEED = np.iinfo(np.int32).max
68
+
69
+ DEFAULT_PROMPT = (
70
+ "A majestic eagle soaring over mountain peaks at sunset, "
71
+ "wings spread wide against the orange sky, feathers catching the light, "
72
+ "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
73
+ )
74
+ DEFAULT_NEGATIVE_PROMPT = (
75
+ "worst quality, inconsistent motion, blurry, jittery, distorted, "
76
+ "deformed, artifacts, text, watermark, logo, frame, border, "
77
+ "low resolution, pixelated, unnatural, fake, CGI, cartoon"
78
+ )
79
+
80
+ # =============================================================================
81
+ # Model Download
82
+ # =============================================================================
83
+
84
+ print("=" * 80)
85
+ print("Downloading LTX-2.3 models...")
86
+ print("=" * 80)
87
+
88
+ checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-dev.safetensors")
89
+ spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
90
+ distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors")
91
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
92
+
93
+ print(f"Dev checkpoint: {checkpoint_path}")
94
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
95
+ print(f"Distilled LoRA: {distilled_lora_path}")
96
+ print(f"Gemma root: {gemma_root}")
97
+ print("=" * 80)
98
+
99
+ # =============================================================================
100
+ # Pipeline Initialization
101
+ # =============================================================================
102
+
103
+ print("Initializing TI2VidTwoStagesHQPipeline with ModelLedger...")
104
+
105
+ # Create LoRA configuration for distilled model
106
+ distilled_lora = [
107
+ LoraPathStrengthAndSDOps(
108
+ path=distilled_lora_path,
109
+ strength=1.0, # Will be set per-stage (0.25, 0.5)
110
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
111
+ )
112
+ ]
113
+
114
+ # Initialize the ModelLedger-based HQ pipeline
115
+ pipeline = TI2VidTwoStagesHQPipeline(
116
+ checkpoint_path=checkpoint_path,
117
+ distilled_lora=distilled_lora,
118
+ distilled_lora_strength_stage_1=0.25, # From HQ params
119
+ distilled_lora_strength_stage_2=0.50, # From HQ params
120
+ spatial_upsampler_path=spatial_upsampler_path,
121
+ gemma_root=gemma_root,
122
+ loras=(), # No additional custom LoRAs
123
+ quantization=QuantizationPolicy.fp8_cast(),
124
+ )
125
+
126
+ print("Pipeline initialized successfully!")
127
+ print("=" * 80)
128
+
129
+ # =============================================================================
130
+ # ZeroGPU Tensor Preloading - ModelLedger Pattern
131
+ # =============================================================================
132
+ # This pipeline has TWO model ledgers: stage_1_model_ledger and stage_2_model_ledger
133
+ # We must preload both for ZeroGPU tensor packing to work
134
+
135
+ print("Preloading all models for ZeroGPU tensor packing...")
136
+ print("This may take a few minutes...")
137
+
138
+ # ===== Stage 1 Model Ledger =====
139
+ print(" Loading Stage 1 models...")
140
+
141
+ # Stage 1 transformer
142
+ ledger1 = pipeline.stage_1_model_ledger
143
+ _transformer_s1 = ledger1.transformer()
144
+ ledger1.transformer = lambda: _transformer_s1
145
+
146
+ # Stage 1 video encoder
147
+ _video_encoder_s1 = ledger1.video_encoder()
148
+ ledger1.video_encoder = lambda: _video_encoder_s1
149
+
150
+ # Stage 1 video decoder
151
+ _video_decoder_s1 = ledger1.video_decoder()
152
+ ledger1.video_decoder = lambda: _video_decoder_s1
153
+
154
+ # Stage 1 audio decoder
155
+ _audio_decoder_s1 = ledger1.audio_decoder()
156
+ ledger1.audio_decoder = lambda: _audio_decoder_s1
157
+
158
+ # Stage 1 vocoder
159
+ _vocoder_s1 = ledger1.vocoder()
160
+ ledger1.vocoder = lambda: _vocoder_s1
161
+
162
+ # Stage 1 spatial upsampler
163
+ _spatial_upsampler_s1 = ledger1.spatial_upsampler()
164
+ ledger1.spatial_upsampler = lambda: _spatial_upsampler_s1
165
+
166
+ # Stage 1 text encoder (Gemma)
167
+ _text_encoder_s1 = ledger1.text_encoder()
168
+ ledger1.text_encoder = lambda: _text_encoder_s1
169
+
170
+ # Stage 1 embeddings processor
171
+ _embeddings_processor_s1 = ledger1.embeddings_processor()
172
+ ledger1.embeddings_processor = lambda: _embeddings_processor_s1
173
+
174
+ print(" Stage 1 models loaded")
175
+
176
+ # ===== Stage 2 Model Ledger =====
177
+ print(" Loading Stage 2 models...")
178
+
179
+ # Stage 2 transformer
180
+ ledger2 = pipeline.stage_2_model_ledger
181
+ _transformer_s2 = ledger2.transformer()
182
+ ledger2.transformer = lambda: _transformer_s2
183
+
184
+ # Stage 2 video encoder
185
+ _video_encoder_s2 = ledger2.video_encoder()
186
+ ledger2.video_encoder = lambda: _video_encoder_s2
187
+
188
+ # Stage 2 video decoder
189
+ _video_decoder_s2 = ledger2.video_decoder()
190
+ ledger2.video_decoder = lambda: _video_decoder_s2
191
+
192
+ # Stage 2 audio decoder
193
+ _audio_decoder_s2 = ledger2.audio_decoder()
194
+ ledger2.audio_decoder = lambda: _audio_decoder_s2
195
+
196
+ # Stage 2 vocoder
197
+ _vocoder_s2 = ledger2.vocoder()
198
+ ledger2.vocoder = lambda: _vocoder_s2
199
+
200
+ # Stage 2 spatial upsampler
201
+ _spatial_upsampler_s2 = ledger2.spatial_upsampler()
202
+ ledger2.spatial_upsampler = lambda: _spatial_upsampler_s2
203
+
204
+ # Stage 2 text encoder (Gemma) - can reuse from stage 1
205
+ _text_encoder_s2 = ledger2.text_encoder()
206
+ ledger2.text_encoder = lambda: _text_encoder_s2
207
+
208
+ # Stage 2 embeddings processor - can reuse from stage 1
209
+ _embeddings_processor_s2 = ledger2.embeddings_processor()
210
+ ledger2.embeddings_processor = lambda: _embeddings_processor_s2
211
+
212
+ print(" Stage 2 models loaded")
213
+
214
+ # Create global references to prevent garbage collection
215
+ # Module-level variables persist for ZeroGPU to pack
216
+ _transformer_stage1 = _transformer_s1
217
+ _transformer_stage2 = _transformer_s2
218
+ _video_encoder_stage1 = _video_encoder_s1
219
+ _video_encoder_stage2 = _video_encoder_s2
220
+ _video_decoder_stage1 = _video_decoder_s1
221
+ _video_decoder_stage2 = _video_decoder_s2
222
+ _audio_decoder_stage1 = _audio_decoder_s1
223
+ _audio_decoder_stage2 = _audio_decoder_s2
224
+ _vocoder_stage1 = _vocoder_s1
225
+ _vocoder_stage2 = _vocoder_s2
226
+ _spatial_upsampler_stage1 = _spatial_upsampler_s1
227
+ _spatial_upsampler_stage2 = _spatial_upsampler_s2
228
+ _text_encoder_stage1 = _text_encoder_s1
229
+ _text_encoder_stage2 = _text_encoder_s2
230
+ _embeddings_processor_stage1 = _embeddings_processor_s1
231
+ _embeddings_processor_stage2 = _embeddings_processor_s2
232
+
233
+ print("All models preloaded for ZeroGPU tensor packing!")
234
+ print("=" * 80)
235
+
236
+ # =============================================================================
237
+ # Helper Functions
238
+ # =============================================================================
239
+
240
+ def log_memory(tag: str):
241
+ if torch.cuda.is_available():
242
+ allocated = torch.cuda.memory_allocated() / 1024**3
243
+ peak = torch.cuda.max_memory_allocated() / 1024**3
244
+ free, total = torch.cuda.mem_get_info()
245
+ print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
246
+
247
+
248
+ def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
249
+ ideal_frames = int(duration * frame_rate)
250
+ ideal_frames = max(ideal_frames, MIN_FRAMES)
251
+ k = round((ideal_frames - 1) / 8)
252
+ frames = k * 8 + 1
253
+ return min(frames, MAX_FRAMES)
254
+
255
+
256
+ def validate_resolution(height: int, width: int) -> tuple[int, int]:
257
+ height = round(height / STEP) * STEP
258
+ width = round(width / STEP) * STEP
259
+ height = max(MIN_DIM, min(height, MAX_DIM))
260
+ width = max(MIN_DIM, min(width, MAX_DIM))
261
+ return height, width
262
+
263
+
264
+ def detect_aspect_ratio(image) -> str:
265
+ if image is None:
266
+ return "16:9"
267
+ if hasattr(image, "size"):
268
+ w, h = image.size
269
+ elif hasattr(image, "shape"):
270
+ h, w = image.shape[:2]
271
+ else:
272
+ return "16:9"
273
+ ratio = w / h
274
+ candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
275
+ return min(candidates, key=lambda k: abs(ratio - candidates[k]))
276
+
277
+
278
+ RESOLUTIONS = {
279
+ "16:9": {"width": 1280, "height": 704},
280
+ "9:16": {"width": 704, "height": 1280},
281
+ "1:1": {"width": 960, "height": 960},
282
+ }
283
+
284
+
285
+ def get_duration(
286
+ prompt: str,
287
+ negative_prompt: str,
288
+ input_image,
289
+ duration: float,
290
+ seed: int,
291
+ randomize_seed: bool,
292
+ height: int,
293
+ width: int,
294
+ enhance_prompt: bool,
295
+ video_cfg_scale: float,
296
+ video_stg_scale: float,
297
+ video_rescale_scale: float,
298
+ video_a2v_scale: float,
299
+ audio_cfg_scale: float,
300
+ audio_stg_scale: float,
301
+ audio_rescale_scale: float,
302
+ audio_v2a_scale: float,
303
+ progress,
304
+ ) -> int:
305
+ base = 60
306
+ if duration > 4:
307
+ base += 15
308
+ if duration > 6:
309
+ base += 15
310
+ if height > 700 or width > 1000:
311
+ base += 15
312
+ frames_from_duration = int(duration * DEFAULT_FRAME_RATE)
313
+ if frames_from_duration > 81:
314
+ base += 10
315
+ return min(base, 90)
316
+
317
+
318
+ @spaces.GPU(duration=get_duration)
319
+ @torch.inference_mode()
320
+ def generate_video(
321
+ prompt: str,
322
+ negative_prompt: str,
323
+ input_image,
324
+ duration: float,
325
+ seed: int,
326
+ randomize_seed: bool,
327
+ height: int,
328
+ width: int,
329
+ enhance_prompt: bool,
330
+ video_cfg_scale: float,
331
+ video_stg_scale: float,
332
+ video_rescale_scale: float,
333
+ video_a2v_scale: float,
334
+ audio_cfg_scale: float,
335
+ audio_stg_scale: float,
336
+ audio_rescale_scale: float,
337
+ audio_v2a_scale: float,
338
+ progress=gr.Progress(track_tqdm=True),
339
+ ):
340
+ """
341
+ Generate high-quality video using the ModelLedger-based HQ pipeline.
342
+
343
+ This pipeline uses ModelLedger (like DistilledPipeline) for ZeroGPU compatibility,
344
+ while supporting CFG/negative prompts (like original TI2VidTwoStagesHQPipeline).
345
+ """
346
+ try:
347
+ torch.cuda.reset_peak_memory_stats()
348
+ log_memory("start")
349
+
350
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
351
+ print(f"Using seed: {current_seed}")
352
+
353
+ height, width = validate_resolution(int(height), int(width))
354
+ print(f"Resolution: {width}x{height}")
355
+
356
+ num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
357
+ print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
358
+
359
+ # Prepare image conditioning if provided
360
+ images = []
361
+ if input_image is not None:
362
+ output_dir = Path("outputs")
363
+ output_dir.mkdir(exist_ok=True)
364
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
365
+ if hasattr(input_image, "save"):
366
+ input_image.save(temp_image_path)
367
+ else:
368
+ import shutil
369
+ shutil.copy(input_image, temp_image_path)
370
+ # Use ImageConditioningInput format
371
+ images = [(str(temp_image_path), 1.0)] # (path, strength)
372
+
373
+ tiling_config = TilingConfig.default()
374
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
375
+
376
+ # Configure MultiModalGuider parameters
377
+ video_guider_params = MultiModalGuiderParams(
378
+ cfg_scale=video_cfg_scale,
379
+ stg_scale=video_stg_scale,
380
+ rescale_scale=video_rescale_scale,
381
+ modality_scale=video_a2v_scale,
382
+ skip_step=0,
383
+ stg_blocks=[],
384
+ )
385
+
386
+ audio_guider_params = MultiModalGuiderParams(
387
+ cfg_scale=audio_cfg_scale,
388
+ stg_scale=audio_stg_scale,
389
+ rescale_scale=audio_rescale_scale,
390
+ modality_scale=audio_v2a_scale,
391
+ skip_step=0,
392
+ stg_blocks=[],
393
+ )
394
+
395
+ log_memory("before pipeline call")
396
+
397
+ # Call the pipeline
398
+ video, audio = pipeline(
399
+ prompt=prompt,
400
+ negative_prompt=negative_prompt,
401
+ seed=current_seed,
402
+ height=height,
403
+ width=width,
404
+ num_frames=num_frames,
405
+ frame_rate=DEFAULT_FRAME_RATE,
406
+ num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
407
+ video_guider_params=video_guider_params,
408
+ audio_guider_params=audio_guider_params,
409
+ images=images,
410
+ tiling_config=tiling_config,
411
+ enhance_prompt=enhance_prompt,
412
+ )
413
+
414
+ log_memory("after pipeline call")
415
+
416
+ # Encode video with audio
417
+ from ltx_pipelines.utils.media_io import encode_video
418
+ output_path = tempfile.mktemp(suffix=".mp4")
419
+ encode_video(
420
+ video=video,
421
+ fps=DEFAULT_FRAME_RATE,
422
+ audio=audio,
423
+ output_path=output_path,
424
+ video_chunks_number=video_chunks_number,
425
+ )
426
+
427
+ log_memory("after encode_video")
428
+ return str(output_path), current_seed
429
+
430
+ except Exception as e:
431
+ import traceback
432
+ log_memory("on error")
433
+ print(f"Error: {str(e)}\n{traceback.format_exc()}")
434
+ return None, current_seed
435
+
436
+
437
+ # =============================================================================
438
+ # Gradio UI
439
+ # =============================================================================
440
+
441
+ css = """
442
+ .fillable {max-width: 1200px !important}
443
+ .progress-text {color: white}
444
+ """
445
+
446
+ with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo:
447
+ gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
448
+ gr.Markdown(
449
+ "High-quality text/image-to-video generation using ModelLedger. "
450
+ "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
451
+ "[[GitHub]](https://github.com/Lightricks/LTX-2)"
452
+ )
453
+
454
+ with gr.Row():
455
+ with gr.Column():
456
+ input_image = gr.Image(
457
+ label="Input Image (Optional - for image-to-video)",
458
+ type="pil",
459
+ sources=["upload", "webcam", "clipboard"]
460
+ )
461
+
462
+ prompt = gr.Textbox(
463
+ label="Prompt",
464
+ info="Describe the video you want to generate",
465
+ value=DEFAULT_PROMPT,
466
+ lines=3,
467
+ )
468
+
469
+ negative_prompt = gr.Textbox(
470
+ label="Negative Prompt",
471
+ info="What to avoid in the generated video",
472
+ value=DEFAULT_NEGATIVE_PROMPT,
473
+ lines=2,
474
+ )
475
+
476
+ duration = gr.Slider(
477
+ label="Duration (seconds)",
478
+ minimum=0.5,
479
+ maximum=8.0,
480
+ value=2.0,
481
+ step=0.1,
482
+ )
483
+
484
+ enhance_prompt = gr.Checkbox(
485
+ label="Enhance Prompt",
486
+ value=False,
487
+ )
488
+
489
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
490
+
491
+ with gr.Column():
492
+ output_video = gr.Video(
493
+ label="Generated Video",
494
+ autoplay=True,
495
+ interactive=False
496
+ )
497
+
498
+ with gr.Accordion("Advanced Settings", open=False):
499
+ with gr.Row():
500
+ width = gr.Number(label="Width", value=1280, precision=0)
501
+ height = gr.Number(label="Height", value=704, precision=0)
502
+
503
+ with gr.Row():
504
+ seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED)
505
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
506
+
507
+ gr.Markdown("### Video Guidance Parameters")
508
+ gr.Markdown("Control how strongly the model follows the video prompt.")
509
+
510
+ with gr.Row():
511
+ video_cfg_scale = gr.Slider(
512
+ label="Video CFG Scale",
513
+ minimum=1.0,
514
+ maximum=10.0,
515
+ value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale,
516
+ step=0.1,
517
+ )
518
+ video_stg_scale = gr.Slider(
519
+ label="Video STG Scale",
520
+ minimum=0.0,
521
+ maximum=2.0,
522
+ value=0.0,
523
+ step=0.1,
524
+ )
525
+
526
+ with gr.Row():
527
+ video_rescale_scale = gr.Slider(
528
+ label="Video Rescale",
529
+ minimum=0.0,
530
+ maximum=2.0,
531
+ value=0.45,
532
+ step=0.1,
533
+ )
534
+ video_a2v_scale = gr.Slider(
535
+ label="A2V Scale",
536
+ minimum=0.0,
537
+ maximum=5.0,
538
+ value=3.0,
539
+ step=0.1,
540
+ )
541
+
542
+ gr.Markdown("### Audio Guidance Parameters")
543
+ gr.Markdown("Control audio generation quality and sync.")
544
+
545
+ with gr.Row():
546
+ audio_cfg_scale = gr.Slider(
547
+ label="Audio CFG Scale",
548
+ minimum=1.0,
549
+ maximum=15.0,
550
+ value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale,
551
+ step=0.1,
552
+ )
553
+ audio_stg_scale = gr.Slider(
554
+ label="Audio STG Scale",
555
+ minimum=0.0,
556
+ maximum=2.0,
557
+ value=0.0,
558
+ step=0.1,
559
+ )
560
+
561
+ with gr.Row():
562
+ audio_rescale_scale = gr.Slider(
563
+ label="Audio Rescale",
564
+ minimum=0.0,
565
+ maximum=2.0,
566
+ value=1.0,
567
+ step=0.1,
568
+ )
569
+ audio_v2a_scale = gr.Slider(
570
+ label="V2A Scale",
571
+ minimum=0.0,
572
+ maximum=5.0,
573
+ value=3.0,
574
+ step=0.1,
575
+ )
576
+
577
+ def on_image_upload(image, current_h, current_w):
578
+ if image is None:
579
+ return gr.update(), gr.update()
580
+ aspect = detect_aspect_ratio(image)
581
+ if aspect in RESOLUTIONS:
582
+ return (
583
+ gr.update(value=RESOLUTIONS[aspect]["width"]),
584
+ gr.update(value=RESOLUTIONS[aspect]["height"])
585
+ )
586
+ return gr.update(), gr.update()
587
+
588
+ input_image.change(
589
+ fn=on_image_upload,
590
+ inputs=[input_image, height, width],
591
+ outputs=[width, height],
592
+ )
593
+
594
+ generate_btn.click(
595
+ fn=generate_video,
596
+ inputs=[
597
+ prompt, negative_prompt, input_image, duration,
598
+ seed, randomize_seed, height, width, enhance_prompt,
599
+ video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale,
600
+ audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale,
601
+ ],
602
+ outputs=[output_video, seed],
603
+ )
604
+
605
+
606
+ if __name__ == "__main__":
607
+ demo.queue().launch(
608
+ theme=gr.themes.Citrus(),
609
+ css=css,
610
+ mcp_server=True,
611
+ )