dagloop5 commited on
Commit
e48c5ee
·
verified ·
1 Parent(s): bfed689

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -611
app.py DELETED
@@ -1,611 +0,0 @@
1
- # =============================================================================
2
- # Installation and Setup
3
- # =============================================================================
4
- import os
5
- import subprocess
6
- import sys
7
-
8
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
9
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
10
-
11
- # Clone LTX-2 repo at the commit with ModelLedger-based HQ pipeline
12
- LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
13
- LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
14
- LTX_COMMIT_SHA = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
15
-
16
- if not os.path.exists(LTX_REPO_DIR):
17
- print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
18
- os.makedirs(LTX_REPO_DIR)
19
- subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
20
- subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
21
- subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
22
- subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
-
24
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
25
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
26
-
27
- # =============================================================================
28
- # Imports
29
- # =============================================================================
30
- import logging
31
- import random
32
- import tempfile
33
- from pathlib import Path
34
-
35
- import torch
36
- torch._dynamo.config.suppress_errors = True
37
- torch._dynamo.config.disable = True
38
-
39
- import gradio as gr
40
- import spaces
41
- import numpy as np
42
- from huggingface_hub import hf_hub_download, snapshot_download
43
-
44
- # Core LTX imports from the ModelLedger-based HQ pipeline
45
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
46
- from ltx_core.quantization import QuantizationPolicy
47
- from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
48
- from ltx_core.components.guiders import MultiModalGuiderParams
49
-
50
- # Import the HQ pipeline with ModelLedger
51
- from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
52
-
53
- # Import constants
54
- from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS
55
-
56
- logging.getLogger().setLevel(logging.INFO)
57
-
58
- # =============================================================================
59
- # Constants
60
- # =============================================================================
61
-
62
- LTX_MODEL_REPO = "Lightricks/LTX-2.3"
63
- GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
64
- DEFAULT_FRAME_RATE = 24.0
65
- MIN_DIM, MAX_DIM, STEP = 256, 1280, 64
66
- MIN_FRAMES, MAX_FRAMES = 9, 257
67
- MAX_SEED = np.iinfo(np.int32).max
68
-
69
- DEFAULT_PROMPT = (
70
- "A majestic eagle soaring over mountain peaks at sunset, "
71
- "wings spread wide against the orange sky, feathers catching the light, "
72
- "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
73
- )
74
- DEFAULT_NEGATIVE_PROMPT = (
75
- "worst quality, inconsistent motion, blurry, jittery, distorted, "
76
- "deformed, artifacts, text, watermark, logo, frame, border, "
77
- "low resolution, pixelated, unnatural, fake, CGI, cartoon"
78
- )
79
-
80
- # =============================================================================
81
- # Model Download
82
- # =============================================================================
83
-
84
- print("=" * 80)
85
- print("Downloading LTX-2.3 models...")
86
- print("=" * 80)
87
-
88
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-dev.safetensors")
89
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
90
- distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors")
91
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
92
-
93
- print(f"Dev checkpoint: {checkpoint_path}")
94
- print(f"Spatial upsampler: {spatial_upsampler_path}")
95
- print(f"Distilled LoRA: {distilled_lora_path}")
96
- print(f"Gemma root: {gemma_root}")
97
- print("=" * 80)
98
-
99
- # =============================================================================
100
- # Pipeline Initialization
101
- # =============================================================================
102
-
103
- print("Initializing TI2VidTwoStagesHQPipeline with ModelLedger...")
104
-
105
- # Create LoRA configuration for distilled model
106
- distilled_lora = [
107
- LoraPathStrengthAndSDOps(
108
- path=distilled_lora_path,
109
- strength=1.0, # Will be set per-stage (0.25, 0.5)
110
- sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
111
- )
112
- ]
113
-
114
- # Initialize the ModelLedger-based HQ pipeline
115
- pipeline = TI2VidTwoStagesHQPipeline(
116
- checkpoint_path=checkpoint_path,
117
- distilled_lora=distilled_lora,
118
- distilled_lora_strength_stage_1=0.25, # From HQ params
119
- distilled_lora_strength_stage_2=0.50, # From HQ params
120
- spatial_upsampler_path=spatial_upsampler_path,
121
- gemma_root=gemma_root,
122
- loras=(), # No additional custom LoRAs
123
- quantization=QuantizationPolicy.fp8_cast(),
124
- )
125
-
126
- print("Pipeline initialized successfully!")
127
- print("=" * 80)
128
-
129
- # =============================================================================
130
- # ZeroGPU Tensor Preloading - ModelLedger Pattern
131
- # =============================================================================
132
- # This pipeline has TWO model ledgers: stage_1_model_ledger and stage_2_model_ledger
133
- # We must preload both for ZeroGPU tensor packing to work
134
-
135
- print("Preloading all models for ZeroGPU tensor packing...")
136
- print("This may take a few minutes...")
137
-
138
- # ===== Stage 1 Model Ledger =====
139
- print(" Loading Stage 1 models...")
140
-
141
- # Stage 1 transformer
142
- ledger1 = pipeline.stage_1_model_ledger
143
- _transformer_s1 = ledger1.transformer()
144
- ledger1.transformer = lambda: _transformer_s1
145
-
146
- # Stage 1 video encoder
147
- _video_encoder_s1 = ledger1.video_encoder()
148
- ledger1.video_encoder = lambda: _video_encoder_s1
149
-
150
- # Stage 1 video decoder
151
- _video_decoder_s1 = ledger1.video_decoder()
152
- ledger1.video_decoder = lambda: _video_decoder_s1
153
-
154
- # Stage 1 audio decoder
155
- _audio_decoder_s1 = ledger1.audio_decoder()
156
- ledger1.audio_decoder = lambda: _audio_decoder_s1
157
-
158
- # Stage 1 vocoder
159
- _vocoder_s1 = ledger1.vocoder()
160
- ledger1.vocoder = lambda: _vocoder_s1
161
-
162
- # Stage 1 spatial upsampler
163
- _spatial_upsampler_s1 = ledger1.spatial_upsampler()
164
- ledger1.spatial_upsampler = lambda: _spatial_upsampler_s1
165
-
166
- # Stage 1 text encoder (Gemma)
167
- _text_encoder_s1 = ledger1.text_encoder()
168
- ledger1.text_encoder = lambda: _text_encoder_s1
169
-
170
- # Stage 1 embeddings processor
171
- _embeddings_processor_s1 = ledger1.embeddings_processor()
172
- ledger1.embeddings_processor = lambda: _embeddings_processor_s1
173
-
174
- print(" Stage 1 models loaded")
175
-
176
- # ===== Stage 2 Model Ledger =====
177
- print(" Loading Stage 2 models...")
178
-
179
- # Stage 2 transformer
180
- ledger2 = pipeline.stage_2_model_ledger
181
- _transformer_s2 = ledger2.transformer()
182
- ledger2.transformer = lambda: _transformer_s2
183
-
184
- # Stage 2 video encoder
185
- _video_encoder_s2 = ledger2.video_encoder()
186
- ledger2.video_encoder = lambda: _video_encoder_s2
187
-
188
- # Stage 2 video decoder
189
- _video_decoder_s2 = ledger2.video_decoder()
190
- ledger2.video_decoder = lambda: _video_decoder_s2
191
-
192
- # Stage 2 audio decoder
193
- _audio_decoder_s2 = ledger2.audio_decoder()
194
- ledger2.audio_decoder = lambda: _audio_decoder_s2
195
-
196
- # Stage 2 vocoder
197
- _vocoder_s2 = ledger2.vocoder()
198
- ledger2.vocoder = lambda: _vocoder_s2
199
-
200
- # Stage 2 spatial upsampler
201
- _spatial_upsampler_s2 = ledger2.spatial_upsampler()
202
- ledger2.spatial_upsampler = lambda: _spatial_upsampler_s2
203
-
204
- # Stage 2 text encoder (Gemma) - can reuse from stage 1
205
- _text_encoder_s2 = ledger2.text_encoder()
206
- ledger2.text_encoder = lambda: _text_encoder_s2
207
-
208
- # Stage 2 embeddings processor - can reuse from stage 1
209
- _embeddings_processor_s2 = ledger2.embeddings_processor()
210
- ledger2.embeddings_processor = lambda: _embeddings_processor_s2
211
-
212
- print(" Stage 2 models loaded")
213
-
214
- # Create global references to prevent garbage collection
215
- # Module-level variables persist for ZeroGPU to pack
216
- _transformer_stage1 = _transformer_s1
217
- _transformer_stage2 = _transformer_s2
218
- _video_encoder_stage1 = _video_encoder_s1
219
- _video_encoder_stage2 = _video_encoder_s2
220
- _video_decoder_stage1 = _video_decoder_s1
221
- _video_decoder_stage2 = _video_decoder_s2
222
- _audio_decoder_stage1 = _audio_decoder_s1
223
- _audio_decoder_stage2 = _audio_decoder_s2
224
- _vocoder_stage1 = _vocoder_s1
225
- _vocoder_stage2 = _vocoder_s2
226
- _spatial_upsampler_stage1 = _spatial_upsampler_s1
227
- _spatial_upsampler_stage2 = _spatial_upsampler_s2
228
- _text_encoder_stage1 = _text_encoder_s1
229
- _text_encoder_stage2 = _text_encoder_s2
230
- _embeddings_processor_stage1 = _embeddings_processor_s1
231
- _embeddings_processor_stage2 = _embeddings_processor_s2
232
-
233
- print("All models preloaded for ZeroGPU tensor packing!")
234
- print("=" * 80)
235
-
236
- # =============================================================================
237
- # Helper Functions
238
- # =============================================================================
239
-
240
- def log_memory(tag: str):
241
- if torch.cuda.is_available():
242
- allocated = torch.cuda.memory_allocated() / 1024**3
243
- peak = torch.cuda.max_memory_allocated() / 1024**3
244
- free, total = torch.cuda.mem_get_info()
245
- print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
246
-
247
-
248
- def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
249
- ideal_frames = int(duration * frame_rate)
250
- ideal_frames = max(ideal_frames, MIN_FRAMES)
251
- k = round((ideal_frames - 1) / 8)
252
- frames = k * 8 + 1
253
- return min(frames, MAX_FRAMES)
254
-
255
-
256
- def validate_resolution(height: int, width: int) -> tuple[int, int]:
257
- height = round(height / STEP) * STEP
258
- width = round(width / STEP) * STEP
259
- height = max(MIN_DIM, min(height, MAX_DIM))
260
- width = max(MIN_DIM, min(width, MAX_DIM))
261
- return height, width
262
-
263
-
264
- def detect_aspect_ratio(image) -> str:
265
- if image is None:
266
- return "16:9"
267
- if hasattr(image, "size"):
268
- w, h = image.size
269
- elif hasattr(image, "shape"):
270
- h, w = image.shape[:2]
271
- else:
272
- return "16:9"
273
- ratio = w / h
274
- candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
275
- return min(candidates, key=lambda k: abs(ratio - candidates[k]))
276
-
277
-
278
- RESOLUTIONS = {
279
- "16:9": {"width": 1280, "height": 704},
280
- "9:16": {"width": 704, "height": 1280},
281
- "1:1": {"width": 960, "height": 960},
282
- }
283
-
284
-
285
- def get_duration(
286
- prompt: str,
287
- negative_prompt: str,
288
- input_image,
289
- duration: float,
290
- seed: int,
291
- randomize_seed: bool,
292
- height: int,
293
- width: int,
294
- enhance_prompt: bool,
295
- video_cfg_scale: float,
296
- video_stg_scale: float,
297
- video_rescale_scale: float,
298
- video_a2v_scale: float,
299
- audio_cfg_scale: float,
300
- audio_stg_scale: float,
301
- audio_rescale_scale: float,
302
- audio_v2a_scale: float,
303
- progress,
304
- ) -> int:
305
- base = 60
306
- if duration > 4:
307
- base += 15
308
- if duration > 6:
309
- base += 15
310
- if height > 700 or width > 1000:
311
- base += 15
312
- frames_from_duration = int(duration * DEFAULT_FRAME_RATE)
313
- if frames_from_duration > 81:
314
- base += 10
315
- return min(base, 90)
316
-
317
-
318
- @spaces.GPU(duration=get_duration)
319
- @torch.inference_mode()
320
- def generate_video(
321
- prompt: str,
322
- negative_prompt: str,
323
- input_image,
324
- duration: float,
325
- seed: int,
326
- randomize_seed: bool,
327
- height: int,
328
- width: int,
329
- enhance_prompt: bool,
330
- video_cfg_scale: float,
331
- video_stg_scale: float,
332
- video_rescale_scale: float,
333
- video_a2v_scale: float,
334
- audio_cfg_scale: float,
335
- audio_stg_scale: float,
336
- audio_rescale_scale: float,
337
- audio_v2a_scale: float,
338
- progress=gr.Progress(track_tqdm=True),
339
- ):
340
- """
341
- Generate high-quality video using the ModelLedger-based HQ pipeline.
342
-
343
- This pipeline uses ModelLedger (like DistilledPipeline) for ZeroGPU compatibility,
344
- while supporting CFG/negative prompts (like original TI2VidTwoStagesHQPipeline).
345
- """
346
- try:
347
- torch.cuda.reset_peak_memory_stats()
348
- log_memory("start")
349
-
350
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
351
- print(f"Using seed: {current_seed}")
352
-
353
- height, width = validate_resolution(int(height), int(width))
354
- print(f"Resolution: {width}x{height}")
355
-
356
- num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
357
- print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
358
-
359
- # Prepare image conditioning if provided
360
- images = []
361
- if input_image is not None:
362
- output_dir = Path("outputs")
363
- output_dir.mkdir(exist_ok=True)
364
- temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
365
- if hasattr(input_image, "save"):
366
- input_image.save(temp_image_path)
367
- else:
368
- import shutil
369
- shutil.copy(input_image, temp_image_path)
370
- # Use ImageConditioningInput format
371
- images = [(str(temp_image_path), 1.0)] # (path, strength)
372
-
373
- tiling_config = TilingConfig.default()
374
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
375
-
376
- # Configure MultiModalGuider parameters
377
- video_guider_params = MultiModalGuiderParams(
378
- cfg_scale=video_cfg_scale,
379
- stg_scale=video_stg_scale,
380
- rescale_scale=video_rescale_scale,
381
- modality_scale=video_a2v_scale,
382
- skip_step=0,
383
- stg_blocks=[],
384
- )
385
-
386
- audio_guider_params = MultiModalGuiderParams(
387
- cfg_scale=audio_cfg_scale,
388
- stg_scale=audio_stg_scale,
389
- rescale_scale=audio_rescale_scale,
390
- modality_scale=audio_v2a_scale,
391
- skip_step=0,
392
- stg_blocks=[],
393
- )
394
-
395
- log_memory("before pipeline call")
396
-
397
- # Call the pipeline
398
- video, audio = pipeline(
399
- prompt=prompt,
400
- negative_prompt=negative_prompt,
401
- seed=current_seed,
402
- height=height,
403
- width=width,
404
- num_frames=num_frames,
405
- frame_rate=DEFAULT_FRAME_RATE,
406
- num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
407
- video_guider_params=video_guider_params,
408
- audio_guider_params=audio_guider_params,
409
- images=images,
410
- tiling_config=tiling_config,
411
- enhance_prompt=enhance_prompt,
412
- )
413
-
414
- log_memory("after pipeline call")
415
-
416
- # Encode video with audio
417
- from ltx_pipelines.utils.media_io import encode_video
418
- output_path = tempfile.mktemp(suffix=".mp4")
419
- encode_video(
420
- video=video,
421
- fps=DEFAULT_FRAME_RATE,
422
- audio=audio,
423
- output_path=output_path,
424
- video_chunks_number=video_chunks_number,
425
- )
426
-
427
- log_memory("after encode_video")
428
- return str(output_path), current_seed
429
-
430
- except Exception as e:
431
- import traceback
432
- log_memory("on error")
433
- print(f"Error: {str(e)}\n{traceback.format_exc()}")
434
- return None, current_seed
435
-
436
-
437
- # =============================================================================
438
- # Gradio UI
439
- # =============================================================================
440
-
441
- css = """
442
- .fillable {max-width: 1200px !important}
443
- .progress-text {color: white}
444
- """
445
-
446
- with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo:
447
- gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
448
- gr.Markdown(
449
- "High-quality text/image-to-video generation using ModelLedger. "
450
- "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
451
- "[[GitHub]](https://github.com/Lightricks/LTX-2)"
452
- )
453
-
454
- with gr.Row():
455
- with gr.Column():
456
- input_image = gr.Image(
457
- label="Input Image (Optional - for image-to-video)",
458
- type="pil",
459
- sources=["upload", "webcam", "clipboard"]
460
- )
461
-
462
- prompt = gr.Textbox(
463
- label="Prompt",
464
- info="Describe the video you want to generate",
465
- value=DEFAULT_PROMPT,
466
- lines=3,
467
- )
468
-
469
- negative_prompt = gr.Textbox(
470
- label="Negative Prompt",
471
- info="What to avoid in the generated video",
472
- value=DEFAULT_NEGATIVE_PROMPT,
473
- lines=2,
474
- )
475
-
476
- duration = gr.Slider(
477
- label="Duration (seconds)",
478
- minimum=0.5,
479
- maximum=8.0,
480
- value=2.0,
481
- step=0.1,
482
- )
483
-
484
- enhance_prompt = gr.Checkbox(
485
- label="Enhance Prompt",
486
- value=False,
487
- )
488
-
489
- generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
490
-
491
- with gr.Column():
492
- output_video = gr.Video(
493
- label="Generated Video",
494
- autoplay=True,
495
- interactive=False
496
- )
497
-
498
- with gr.Accordion("Advanced Settings", open=False):
499
- with gr.Row():
500
- width = gr.Number(label="Width", value=1280, precision=0)
501
- height = gr.Number(label="Height", value=704, precision=0)
502
-
503
- with gr.Row():
504
- seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED)
505
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
506
-
507
- gr.Markdown("### Video Guidance Parameters")
508
- gr.Markdown("Control how strongly the model follows the video prompt.")
509
-
510
- with gr.Row():
511
- video_cfg_scale = gr.Slider(
512
- label="Video CFG Scale",
513
- minimum=1.0,
514
- maximum=10.0,
515
- value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale,
516
- step=0.1,
517
- )
518
- video_stg_scale = gr.Slider(
519
- label="Video STG Scale",
520
- minimum=0.0,
521
- maximum=2.0,
522
- value=0.0,
523
- step=0.1,
524
- )
525
-
526
- with gr.Row():
527
- video_rescale_scale = gr.Slider(
528
- label="Video Rescale",
529
- minimum=0.0,
530
- maximum=2.0,
531
- value=0.45,
532
- step=0.1,
533
- )
534
- video_a2v_scale = gr.Slider(
535
- label="A2V Scale",
536
- minimum=0.0,
537
- maximum=5.0,
538
- value=3.0,
539
- step=0.1,
540
- )
541
-
542
- gr.Markdown("### Audio Guidance Parameters")
543
- gr.Markdown("Control audio generation quality and sync.")
544
-
545
- with gr.Row():
546
- audio_cfg_scale = gr.Slider(
547
- label="Audio CFG Scale",
548
- minimum=1.0,
549
- maximum=15.0,
550
- value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale,
551
- step=0.1,
552
- )
553
- audio_stg_scale = gr.Slider(
554
- label="Audio STG Scale",
555
- minimum=0.0,
556
- maximum=2.0,
557
- value=0.0,
558
- step=0.1,
559
- )
560
-
561
- with gr.Row():
562
- audio_rescale_scale = gr.Slider(
563
- label="Audio Rescale",
564
- minimum=0.0,
565
- maximum=2.0,
566
- value=1.0,
567
- step=0.1,
568
- )
569
- audio_v2a_scale = gr.Slider(
570
- label="V2A Scale",
571
- minimum=0.0,
572
- maximum=5.0,
573
- value=3.0,
574
- step=0.1,
575
- )
576
-
577
- def on_image_upload(image, current_h, current_w):
578
- if image is None:
579
- return gr.update(), gr.update()
580
- aspect = detect_aspect_ratio(image)
581
- if aspect in RESOLUTIONS:
582
- return (
583
- gr.update(value=RESOLUTIONS[aspect]["width"]),
584
- gr.update(value=RESOLUTIONS[aspect]["height"])
585
- )
586
- return gr.update(), gr.update()
587
-
588
- input_image.change(
589
- fn=on_image_upload,
590
- inputs=[input_image, height, width],
591
- outputs=[width, height],
592
- )
593
-
594
- generate_btn.click(
595
- fn=generate_video,
596
- inputs=[
597
- prompt, negative_prompt, input_image, duration,
598
- seed, randomize_seed, height, width, enhance_prompt,
599
- video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale,
600
- audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale,
601
- ],
602
- outputs=[output_video, seed],
603
- )
604
-
605
-
606
- if __name__ == "__main__":
607
- demo.queue().launch(
608
- theme=gr.themes.Citrus(),
609
- css=css,
610
- mcp_server=True,
611
- )