dagloop5 commited on
Commit
287d66d
·
verified ·
1 Parent(s): 9d8a7db

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -927
app.py DELETED
@@ -1,927 +0,0 @@
1
- # =============================================================================
2
- # Installation and Setup
3
- # =============================================================================
4
- import os
5
- import subprocess
6
- import sys
7
-
8
- # Disable torch.compile / dynamo before any torch import
9
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
10
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
11
-
12
- # Clone LTX-2 repo at specific commit
13
- LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
- LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
- LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45"
16
-
17
- if not os.path.exists(LTX_REPO_DIR):
18
- print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
19
- os.makedirs(LTX_REPO_DIR)
20
- subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
21
- subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
22
- subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
- subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
24
-
25
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
26
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
27
-
28
- # =============================================================================
29
- # Imports
30
- # =============================================================================
31
- import logging
32
- import random
33
- import tempfile
34
- from pathlib import Path
35
- from typing import Optional, Any
36
-
37
- import torch
38
- torch._dynamo.config.suppress_errors = True
39
- torch._dynamo.config.disable = True
40
-
41
- import gradio as gr
42
- import spaces
43
- import numpy as np
44
- from huggingface_hub import hf_hub_download, snapshot_download
45
-
46
- # Core LTX imports
47
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
48
- from ltx_core.quantization import QuantizationPolicy
49
- from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
50
- from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
51
- from ltx_core.components.noisers import GaussianNoiser
52
- from ltx_core.components.schedulers import LTX2Scheduler
53
- from ltx_core.components.diffusion_steps import Res2sDiffusionStep
54
- from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
55
-
56
- # Pipeline utilities
57
- from ltx_pipelines.utils.args import ImageConditioningInput
58
- from ltx_pipelines.utils.media_io import encode_video
59
- from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
60
- from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
61
- from ltx_pipelines.utils.types import ModalitySpec
62
- from ltx_pipelines.utils.helpers import assert_resolution, combined_image_conditionings, get_device
63
- from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS, STAGE_2_DISTILLED_SIGMA_VALUES
64
-
65
- # Model builders
66
- from ltx_core.loader.single_gpu_model_builder import (
67
- TransformerBuilder,
68
- VideoEncoderBuilder,
69
- VideoDecoderBuilder,
70
- AudioDecoderBuilder,
71
- VocoderBuilder,
72
- UpsamplerBuilder,
73
- TextEncoderBuilder,
74
- )
75
- from ltx_core.model.transformer import X0Model
76
- from ltx_core.model.video_vae import VideoEncoder, VideoDecoder
77
- from ltx_core.model.audio_vae import AudioDecoder as AVAudioDecoder, Vocoder
78
- from ltx_core.model.upsampler import LatentUpsampler
79
- from ltx_core.text_encoders.gemma import GemmaTextEncoder
80
- from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessorBuilder
81
-
82
- logging.getLogger().setLevel(logging.INFO)
83
-
84
- # =============================================================================
85
- # Constants and Configuration
86
- # =============================================================================
87
-
88
- LTX_MODEL_REPO = "Lightricks/LTX-2.3"
89
- GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
90
- DEFAULT_FRAME_RATE = 24.0
91
- MIN_DIM, MAX_DIM, STEP = 256, 1280, 64
92
- MIN_FRAMES, MAX_FRAMES = 9, 257
93
- MAX_SEED = np.iinfo(np.int32).max
94
-
95
- DEFAULT_PROMPT = (
96
- "A majestic eagle soaring over mountain peaks at sunset, "
97
- "wings spread wide against the orange sky, feathers catching the light, "
98
- "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
99
- )
100
- DEFAULT_NEGATIVE_PROMPT = (
101
- "worst quality, inconsistent motion, blurry, jittery, distorted, "
102
- "deformed, artifacts, text, watermark, logo, frame, border, "
103
- "low resolution, pixelated, unnatural, fake, CGI, cartoon"
104
- )
105
-
106
- # =============================================================================
107
- # HQ Pipeline with model_ledger - Custom Implementation
108
- # =============================================================================
109
-
110
- class HQModelLedger:
111
- """
112
- Model ledger that stores preloaded models for ZeroGPU tensor packing.
113
- Mimics the pattern used in DistilledPipeline's official Space.
114
- """
115
-
116
- def __init__(
117
- self,
118
- checkpoint_path: str,
119
- distilled_lora_path: str,
120
- distilled_lora_strength_stage_1: float,
121
- distilled_lora_strength_stage_2: float,
122
- spatial_upsampler_path: str,
123
- gemma_root: str,
124
- loras: tuple,
125
- device: torch.device,
126
- dtype: torch.dtype,
127
- quantization: Optional[QuantizationPolicy] = None,
128
- ):
129
- self.device = device
130
- self.dtype = dtype
131
- self._target_device = device
132
- self._checkpoint_path = checkpoint_path
133
- self._spatial_upsampler_path = spatial_upsampler_path
134
- self._gemma_root = gemma_root
135
- self._quantization = quantization
136
-
137
- # Cached models (set to None initially)
138
- self._transformer_stage1 = None
139
- self._transformer_stage2 = None
140
- self._video_encoder = None
141
- self._video_decoder = None
142
- self._audio_decoder = None
143
- self._vocoder = None
144
- self._spatial_upsampler = None
145
- self._text_encoder = None
146
- self._embeddings_processor = None
147
-
148
- # LoRA configurations
149
- self._distilled_lora_path = distilled_lora_path
150
- self._distilled_lora_strength_stage_1 = distilled_lora_strength_stage_1
151
- self._distilled_lora_strength_stage_2 = distilled_lora_strength_stage_2
152
- self._loras = loras
153
-
154
- # Build configurations
155
- self._build_configs()
156
-
157
- def _build_configs(self):
158
- """Create builder configurations with LoRAs."""
159
- # Stage 1 LoRA list
160
- stage1_loras = [
161
- LoraPathStrengthAndSDOps(
162
- path=self._distilled_lora_path,
163
- strength=self._distilled_lora_strength_stage_1,
164
- sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
165
- )
166
- ]
167
- # Add custom loras
168
- for lora in self._loras:
169
- stage1_loras.append(lora)
170
-
171
- # Stage 2 LoRA list (different strength)
172
- stage2_loras = [
173
- LoraPathStrengthAndSDOps(
174
- path=self._distilled_lora_path,
175
- strength=self._distilled_lora_strength_stage_2,
176
- sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
177
- )
178
- ]
179
- for lora in self._loras:
180
- stage2_loras.append(lora)
181
-
182
- # Transformer builder for stage 1
183
- self._transformer_builder_stage1 = (
184
- TransformerBuilder.from_checkpoint(self._checkpoint_path)
185
- .with_loras(stage1_loras)
186
- )
187
-
188
- # Transformer builder for stage 2
189
- self._transformer_builder_stage2 = (
190
- TransformerBuilder.from_checkpoint(self._checkpoint_path)
191
- .with_loras(stage2_loras)
192
- )
193
-
194
- # Other builders (no LoRAs)
195
- self._video_encoder_builder = VideoEncoderBuilder.from_checkpoint(self._checkpoint_path)
196
- self._video_decoder_builder = VideoDecoderBuilder.from_checkpoint(self._checkpoint_path)
197
- self._audio_decoder_builder = AudioDecoderBuilder.from_checkpoint(self._checkpoint_path)
198
- self._vocoder_builder = VocoderBuilder.from_checkpoint(self._checkpoint_path)
199
- self._spatial_upsampler_builder = UpsamplerBuilder.from_checkpoint(self._spatial_upsampler_path)
200
- self._text_encoder_builder = TextEncoderBuilder.from_gemma(self._gemma_root)
201
- self._embeddings_processor_builder = EmbeddingsProcessorBuilder.from_checkpoint(self._checkpoint_path)
202
-
203
- def transformer(self, stage: int = 1):
204
- """Get or build transformer model."""
205
- if stage == 1:
206
- if self._transformer_stage1 is None:
207
- print(" Building transformer (stage 1)...")
208
- model = self._transformer_builder_stage1.build(
209
- device=self._target_device,
210
- dtype=self.dtype,
211
- )
212
- self._transformer_stage1 = X0Model(model).to(self.device).eval()
213
- return self._transformer_stage1
214
- else:
215
- if self._transformer_stage2 is None:
216
- print(" Building transformer (stage 2)...")
217
- model = self._transformer_builder_stage2.build(
218
- device=self._target_device,
219
- dtype=self.dtype,
220
- )
221
- self._transformer_stage2 = X0Model(model).to(self.device).eval()
222
- return self._transformer_stage2
223
-
224
- def video_encoder(self):
225
- """Get or build video encoder."""
226
- if self._video_encoder is None:
227
- print(" Building video encoder...")
228
- self._video_encoder = self._video_encoder_builder.build(
229
- device=self._target_device,
230
- dtype=self.dtype,
231
- ).to(self.device).eval()
232
- return self._video_encoder
233
-
234
- def video_decoder(self):
235
- """Get or build video decoder."""
236
- if self._video_decoder is None:
237
- print(" Building video decoder...")
238
- self._video_decoder = self._video_decoder_builder.build(
239
- device=self._target_device,
240
- dtype=self.dtype,
241
- ).to(self.device).eval()
242
- return self._video_decoder
243
-
244
- def audio_decoder(self):
245
- """Get or build audio decoder."""
246
- if self._audio_decoder is None:
247
- print(" Building audio decoder...")
248
- self._audio_decoder = self._audio_decoder_builder.build(
249
- device=self._target_device,
250
- dtype=self.dtype,
251
- ).to(self.device).eval()
252
- return self._audio_decoder
253
-
254
- def vocoder(self):
255
- """Get or build vocoder."""
256
- if self._vocoder is None:
257
- print(" Building vocoder...")
258
- self._vocoder = self._vocoder_builder.build(
259
- device=self._target_device,
260
- dtype=self.dtype,
261
- ).to(self.device).eval()
262
- return self._vocoder
263
-
264
- def spatial_upsampler(self):
265
- """Get or build spatial upsampler."""
266
- if self._spatial_upsampler is None:
267
- print(" Building spatial upsampler...")
268
- self._spatial_upsampler = self._spatial_upsampler_builder.build(
269
- device=self._target_device,
270
- dtype=self.dtype,
271
- ).to(self.device).eval()
272
- return self._spatial_upsampler
273
-
274
- def text_encoder(self):
275
- """Get or build text encoder."""
276
- if self._text_encoder is None:
277
- print(" Building text encoder (Gemma)...")
278
- self._text_encoder = self._text_encoder_builder.build(
279
- device=self._target_device,
280
- dtype=self.dtype,
281
- ).to(self.device).eval()
282
- return self._text_encoder
283
-
284
- def embeddings_processor(self):
285
- """Get or build embeddings processor."""
286
- if self._embeddings_processor is None:
287
- print(" Building embeddings processor...")
288
- self._embeddings_processor = self._embeddings_processor_builder.build(
289
- device=self._target_device,
290
- dtype=self.dtype,
291
- ).to(self.device).eval()
292
- return self._embeddings_processor
293
-
294
-
295
- class TI2VidTwoStagesHQPipelineWithLedger:
296
- """
297
- Two-stage text/image-to-video generation pipeline using model_ledger.
298
- Same as TI2VidTwoStagesHQPipeline but uses model_ledger for ZeroGPU compatibility.
299
- """
300
-
301
- def __init__(
302
- self,
303
- checkpoint_path: str,
304
- distilled_lora_path: str,
305
- distilled_lora_strength_stage_1: float,
306
- distilled_lora_strength_stage_2: float,
307
- spatial_upsampler_path: str,
308
- gemma_root: str,
309
- loras: tuple = (),
310
- device: Optional[torch.device] = None,
311
- quantization: Optional[QuantizationPolicy] = None,
312
- torch_compile: bool = False,
313
- ):
314
- self.device = device or get_device()
315
- self.dtype = torch.bfloat16
316
- self._torch_compile = torch_compile
317
-
318
- # Create model ledger
319
- self.model_ledger = HQModelLedger(
320
- checkpoint_path=checkpoint_path,
321
- distilled_lora_path=distilled_lora_path,
322
- distilled_lora_strength_stage_1=distilled_lora_strength_stage_1,
323
- distilled_lora_strength_stage_2=distilled_lora_strength_stage_2,
324
- spatial_upsampler_path=spatial_upsampler_path,
325
- gemma_root=gemma_root,
326
- loras=loras,
327
- device=self.device,
328
- dtype=self.dtype,
329
- quantization=quantization,
330
- )
331
-
332
- # Scheduler and stepper
333
- self._scheduler = LTX2Scheduler()
334
- self._stepper = Res2sDiffusionStep()
335
-
336
- @torch.inference_mode()
337
- def __call__(
338
- self,
339
- prompt: str,
340
- negative_prompt: str,
341
- seed: int,
342
- height: int,
343
- width: int,
344
- num_frames: int,
345
- frame_rate: float,
346
- num_inference_steps: int,
347
- video_guider_params: MultiModalGuiderParams,
348
- audio_guider_params: MultiModalGuiderParams,
349
- images: list[ImageConditioningInput],
350
- tiling_config: Optional[TilingConfig] = None,
351
- enhance_prompt: bool = False,
352
- streaming_prefetch_count: Optional[int] = None,
353
- max_batch_size: int = 1,
354
- ):
355
- assert_resolution(height=height, width=width, is_two_stage=True)
356
-
357
- generator = torch.Generator(device=self.device).manual_seed(seed)
358
- noiser = GaussianNoiser(generator=generator)
359
-
360
- # Get models from ledger
361
- text_encoder = self.model_ledger.text_encoder()
362
- embeddings_processor = self.model_ledger.embeddings_processor()
363
- video_encoder = self.model_ledger.video_encoder()
364
-
365
- # Encode prompts
366
- # Encode positive prompt
367
- ctx_p = embeddings_processor.create_embeddings(
368
- text_encoder([prompt]),
369
- video_encoder,
370
- images[0].path if len(images) > 0 and enhance_prompt else None,
371
- seed if enhance_prompt else None,
372
- )
373
-
374
- # Encode negative prompt
375
- ctx_n = embeddings_processor.create_embeddings(
376
- text_encoder([negative_prompt]),
377
- video_encoder,
378
- None,
379
- None,
380
- )
381
-
382
- v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
383
- v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
384
-
385
- # Stage 1: Generate at half resolution with CFG
386
- stage_1_output_shape = VideoPixelShape(
387
- batch=1,
388
- frames=num_frames,
389
- width=width // 2,
390
- height=height // 2,
391
- fps=frame_rate,
392
- )
393
-
394
- stage_1_conditionings = combined_image_conditionings(
395
- images=images,
396
- height=stage_1_output_shape.height,
397
- width=stage_1_output_shape.width,
398
- video_encoder=video_encoder,
399
- dtype=self.dtype,
400
- device=self.device,
401
- )
402
-
403
- empty_latent = torch.empty(
404
- VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape(),
405
- dtype=self.dtype,
406
- device=self.device,
407
- )
408
- sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
409
- sigmas = sigmas.to(dtype=torch.float32, device=self.device)
410
-
411
- transformer = self.model_ledger.transformer(stage=1)
412
-
413
- video_state, audio_state = res2s_audio_video_denoising_loop(
414
- transformer=transformer,
415
- denoiser=GuidedDenoiser(
416
- v_context=v_context_p,
417
- a_context=a_context_p,
418
- video_guider=MultiModalGuider(params=video_guider_params, negative_context=v_context_n),
419
- audio_guider=MultiModalGuider(params=audio_guider_params, negative_context=a_context_n),
420
- ),
421
- sigmas=sigmas,
422
- noiser=noiser,
423
- stepper=self._stepper,
424
- width=stage_1_output_shape.width,
425
- height=stage_1_output_shape.height,
426
- frames=num_frames,
427
- fps=frame_rate,
428
- video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
429
- audio=ModalitySpec(context=a_context_p),
430
- streaming_prefetch_count=streaming_prefetch_count,
431
- max_batch_size=max_batch_size,
432
- )
433
-
434
- # Stage 2: Upscale and refine
435
- upscaled_video_latent = self.model_ledger.spatial_upsampler()(video_state.latent[:1])
436
-
437
- distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device, dtype=torch.float32)
438
-
439
- stage_2_conditionings = combined_image_conditionings(
440
- images=images,
441
- height=height,
442
- width=width,
443
- video_encoder=video_encoder,
444
- dtype=self.dtype,
445
- device=self.device,
446
- )
447
-
448
- transformer = self.model_ledger.transformer(stage=2)
449
-
450
- video_state, audio_state = res2s_audio_video_denoising_loop(
451
- transformer=transformer,
452
- denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
453
- sigmas=distilled_sigmas,
454
- noiser=noiser,
455
- stepper=self._stepper,
456
- width=width,
457
- height=height,
458
- frames=num_frames,
459
- fps=frame_rate,
460
- video=ModalitySpec(
461
- context=v_context_p,
462
- conditionings=stage_2_conditionings,
463
- noise_scale=distilled_sigmas[0].item(),
464
- initial_latent=upscaled_video_latent,
465
- ),
466
- audio=ModalitySpec(
467
- context=a_context_p,
468
- noise_scale=distilled_sigmas[0].item(),
469
- initial_latent=audio_state.latent,
470
- ),
471
- streaming_prefetch_count=streaming_prefetch_count,
472
- max_batch_size=max_batch_size,
473
- )
474
-
475
- # Decode
476
- video_decoder = self.model_ledger.video_decoder()
477
- audio_decoder = self.model_ledger.audio_decoder()
478
-
479
- decoded_video = video_decoder(video_state.latent, tiling_config, generator)
480
- decoded_audio = audio_decoder(audio_state.latent)
481
-
482
- return decoded_video, decoded_audio
483
-
484
-
485
- # =============================================================================
486
- # Model Download
487
- # =============================================================================
488
-
489
- print("=" * 80)
490
- print("Downloading LTX-2.3 models...")
491
- print("=" * 80)
492
-
493
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-dev.safetensors")
494
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
495
- distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors")
496
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
497
-
498
- print(f"Checkpoint: {checkpoint_path}")
499
- print(f"Spatial upsampler: {spatial_upsampler_path}")
500
- print(f"Distilled LoRA: {distilled_lora_path}")
501
- print(f"Gemma root: {gemma_root}")
502
-
503
- print("=" * 80)
504
- print("All models downloaded!")
505
- print("=" * 80)
506
-
507
- # =============================================================================
508
- # Pipeline Initialization
509
- # =============================================================================
510
-
511
- print("Initializing TI2VidTwoStagesHQPipelineWithLedger...")
512
-
513
- pipeline = TI2VidTwoStagesHQPipelineWithLedger(
514
- checkpoint_path=checkpoint_path,
515
- distilled_lora_path=distilled_lora_path,
516
- distilled_lora_strength_stage_1=0.25,
517
- distilled_lora_strength_stage_2=0.50,
518
- spatial_upsampler_path=spatial_upsampler_path,
519
- gemma_root=gemma_root,
520
- loras=(),
521
- quantization=QuantizationPolicy.fp8_cast(),
522
- torch_compile=False,
523
- )
524
-
525
- print("Pipeline initialized successfully!")
526
- print("=" * 80)
527
-
528
- # =============================================================================
529
- # ZeroGPU Tensor Preloading - model_ledger Pattern
530
- # =============================================================================
531
- print("Preloading all models for ZeroGPU tensor packing...")
532
- print("This may take a few minutes...")
533
-
534
- # Access model ledger
535
- ledger = pipeline.model_ledger
536
-
537
- # Preload all models - this mimics the official Space's pattern
538
- print(" Loading transformer (stage 1)...")
539
- _transformer_s1 = ledger.transformer(stage=1)
540
- ledger._transformer_stage1 = _transformer_s1
541
-
542
- print(" Loading transformer (stage 2)...")
543
- _transformer_s2 = ledger.transformer(stage=2)
544
- ledger._transformer_stage2 = _transformer_s2
545
-
546
- print(" Loading video encoder...")
547
- _ve = ledger.video_encoder()
548
- ledger._video_encoder = _ve
549
-
550
- print(" Loading video decoder...")
551
- _vd = ledger.video_decoder()
552
- ledger._video_decoder = _vd
553
-
554
- print(" Loading audio decoder...")
555
- _ad = ledger.audio_decoder()
556
- ledger._audio_decoder = _ad
557
-
558
- print(" Loading vocoder...")
559
- _voc = ledger.vocoder()
560
- ledger._vocoder = _voc
561
-
562
- print(" Loading spatial upsampler...")
563
- _su = ledger.spatial_upsampler()
564
- ledger._spatial_upsampler = _su
565
-
566
- print(" Loading text encoder (Gemma)...")
567
- _te = ledger.text_encoder()
568
- ledger._text_encoder = _te
569
-
570
- print(" Loading embeddings processor...")
571
- _ep = ledger.embeddings_processor()
572
- ledger._embeddings_processor = _ep
573
-
574
- # Replace methods with lambdas to prevent garbage collection
575
- # This is the CRITICAL step that makes ZeroGPU tensor packing work
576
- def ledger_transformer(stage=1):
577
- return ledger._transformer_stage1 if stage == 1 else ledger._transformer_stage2
578
-
579
- ledger.transformer = ledger_transformer
580
- ledger.video_encoder = lambda: ledger._video_encoder
581
- ledger.video_decoder = lambda: ledger._video_decoder
582
- ledger.audio_decoder = lambda: ledger._audio_decoder
583
- ledger.vocoder = lambda: ledger._vocoder
584
- ledger.spatial_upsampler = lambda: ledger._spatial_upsampler
585
- ledger.text_encoder = lambda: ledger._text_encoder
586
- ledger.embeddings_processor = lambda: ledger._embeddings_processor
587
-
588
- print("All models preloaded for ZeroGPU tensor packing!")
589
- print("=" * 80)
590
-
591
- # =============================================================================
592
- # Helper Functions
593
- # =============================================================================
594
-
595
- def log_memory(tag: str):
596
- if torch.cuda.is_available():
597
- allocated = torch.cuda.memory_allocated() / 1024**3
598
- peak = torch.cuda.max_memory_allocated() / 1024**3
599
- free, total = torch.cuda.mem_get_info()
600
- print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
601
-
602
-
603
- def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
604
- ideal_frames = int(duration * frame_rate)
605
- ideal_frames = max(ideal_frames, MIN_FRAMES)
606
- k = round((ideal_frames - 1) / 8)
607
- frames = k * 8 + 1
608
- return min(frames, MAX_FRAMES)
609
-
610
-
611
- def validate_resolution(height: int, width: int) -> tuple[int, int]:
612
- height = round(height / STEP) * STEP
613
- width = round(width / STEP) * STEP
614
- height = max(MIN_DIM, min(height, MAX_DIM))
615
- width = max(MIN_DIM, min(width, MAX_DIM))
616
- return height, width
617
-
618
-
619
- def detect_aspect_ratio(image) -> str:
620
- if image is None:
621
- return "16:9"
622
- if hasattr(image, "size"):
623
- w, h = image.size
624
- elif hasattr(image, "shape"):
625
- h, w = image.shape[:2]
626
- else:
627
- return "16:9"
628
- ratio = w / h
629
- candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
630
- return min(candidates, key=lambda k: abs(ratio - candidates[k]))
631
-
632
-
633
- RESOLUTIONS = {
634
- "16:9": {"width": 1280, "height": 704},
635
- "9:16": {"width": 704, "height": 1280},
636
- "1:1": {"width": 960, "height": 960},
637
- }
638
-
639
-
640
- def get_duration(
641
- prompt: str,
642
- negative_prompt: str,
643
- input_image,
644
- duration: float,
645
- seed: int,
646
- randomize_seed: bool,
647
- height: int,
648
- width: int,
649
- enhance_prompt: bool,
650
- video_cfg_scale: float,
651
- video_stg_scale: float,
652
- video_rescale_scale: float,
653
- video_a2v_scale: float,
654
- audio_cfg_scale: float,
655
- audio_stg_scale: float,
656
- audio_rescale_scale: float,
657
- audio_v2a_scale: float,
658
- progress,
659
- ) -> int:
660
- base = 60
661
- if duration > 4:
662
- base += 15
663
- if duration > 6:
664
- base += 15
665
- if height > 700 or width > 1000:
666
- base += 15
667
- frames_from_duration = int(duration * DEFAULT_FRAME_RATE)
668
- if frames_from_duration > 81:
669
- base += 10
670
- return min(base, 90)
671
-
672
-
673
- @spaces.GPU(duration=get_duration)
674
- @torch.inference_mode()
675
- def generate_video(
676
- prompt: str,
677
- negative_prompt: str,
678
- input_image,
679
- duration: float,
680
- seed: int,
681
- randomize_seed: bool,
682
- height: int,
683
- width: int,
684
- enhance_prompt: bool,
685
- video_cfg_scale: float,
686
- video_stg_scale: float,
687
- video_rescale_scale: float,
688
- video_a2v_scale: float,
689
- audio_cfg_scale: float,
690
- audio_stg_scale: float,
691
- audio_rescale_scale: float,
692
- audio_v2a_scale: float,
693
- progress=gr.Progress(track_tqdm=True),
694
- ):
695
- try:
696
- torch.cuda.reset_peak_memory_stats()
697
- log_memory("start")
698
-
699
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
700
- print(f"Using seed: {current_seed}")
701
-
702
- height, width = validate_resolution(int(height), int(width))
703
- print(f"Resolution: {width}x{height}")
704
-
705
- num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
706
- print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
707
-
708
- images = []
709
- if input_image is not None:
710
- output_dir = Path("outputs")
711
- output_dir.mkdir(exist_ok=True)
712
- temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
713
- if hasattr(input_image, "save"):
714
- input_image.save(temp_image_path)
715
- else:
716
- import shutil
717
- shutil.copy(input_image, temp_image_path)
718
- images = [ImageConditioningInput(
719
- path=str(temp_image_path),
720
- frame_idx=0,
721
- strength=1.0
722
- )]
723
-
724
- tiling_config = TilingConfig.default()
725
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
726
-
727
- video_guider_params = MultiModalGuiderParams(
728
- cfg_scale=video_cfg_scale,
729
- stg_scale=video_stg_scale,
730
- rescale_scale=video_rescale_scale,
731
- modality_scale=video_a2v_scale,
732
- skip_step=0,
733
- stg_blocks=[],
734
- )
735
-
736
- audio_guider_params = MultiModalGuiderParams(
737
- cfg_scale=audio_cfg_scale,
738
- stg_scale=audio_stg_scale,
739
- rescale_scale=audio_rescale_scale,
740
- modality_scale=audio_v2a_scale,
741
- skip_step=0,
742
- stg_blocks=[],
743
- )
744
-
745
- log_memory("before pipeline call")
746
-
747
- video, audio = pipeline(
748
- prompt=prompt,
749
- negative_prompt=negative_prompt,
750
- seed=current_seed,
751
- height=height,
752
- width=width,
753
- num_frames=num_frames,
754
- frame_rate=DEFAULT_FRAME_RATE,
755
- num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
756
- video_guider_params=video_guider_params,
757
- audio_guider_params=audio_guider_params,
758
- images=images,
759
- tiling_config=tiling_config,
760
- enhance_prompt=enhance_prompt,
761
- )
762
-
763
- log_memory("after pipeline call")
764
-
765
- output_path = tempfile.mktemp(suffix=".mp4")
766
- encode_video(
767
- video=video,
768
- fps=DEFAULT_FRAME_RATE,
769
- audio=audio,
770
- output_path=output_path,
771
- video_chunks_number=video_chunks_number,
772
- )
773
-
774
- log_memory("after encode_video")
775
- return str(output_path), current_seed
776
-
777
- except Exception as e:
778
- import traceback
779
- log_memory("on error")
780
- print(f"Error: {str(e)}\n{traceback.format_exc()}")
781
- return None, current_seed
782
-
783
-
784
- # =============================================================================
785
- # Gradio UI
786
- # =============================================================================
787
-
788
- css = """
789
- .fillable {max-width: 1200px !important}
790
- .progress-text {color: white}
791
- """
792
-
793
- with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo:
794
- gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
795
- gr.Markdown(
796
- "High-quality text/image-to-video generation using the dev model + distilled LoRA. "
797
- "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
798
- "[[GitHub]](https://github.com/Lightricks/LTX-2)"
799
- )
800
-
801
- with gr.Row():
802
- with gr.Column():
803
- input_image = gr.Image(
804
- label="Input Image (Optional - for image-to-video)",
805
- type="pil",
806
- sources=["upload", "webcam", "clipboard"]
807
- )
808
-
809
- prompt = gr.Textbox(
810
- label="Prompt",
811
- info="Describe the video you want to generate",
812
- value=DEFAULT_PROMPT,
813
- lines=3,
814
- placeholder="Enter your prompt here..."
815
- )
816
-
817
- negative_prompt = gr.Textbox(
818
- label="Negative Prompt",
819
- info="What to avoid in the generated video",
820
- value=DEFAULT_NEGATIVE_PROMPT,
821
- lines=2,
822
- )
823
-
824
- duration = gr.Slider(
825
- label="Duration (seconds)",
826
- minimum=0.5,
827
- maximum=8.0,
828
- value=2.0,
829
- step=0.1,
830
- )
831
-
832
- enhance_prompt = gr.Checkbox(
833
- label="Enhance Prompt",
834
- value=False,
835
- )
836
-
837
- generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
838
-
839
- with gr.Column():
840
- output_video = gr.Video(
841
- label="Generated Video",
842
- autoplay=True,
843
- interactive=False
844
- )
845
-
846
- with gr.Accordion("Advanced Settings", open=False):
847
- with gr.Row():
848
- width = gr.Number(label="Width", value=1280, precision=0)
849
- height = gr.Number(label="Height", value=704, precision=0)
850
-
851
- with gr.Row():
852
- seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED)
853
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
854
-
855
- gr.Markdown("### Video Guidance Parameters")
856
-
857
- with gr.Row():
858
- video_cfg_scale = gr.Slider(
859
- label="Video CFG Scale", minimum=1.0, maximum=10.0,
860
- value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale, step=0.1
861
- )
862
- video_stg_scale = gr.Slider(
863
- label="Video STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
864
- )
865
-
866
- with gr.Row():
867
- video_rescale_scale = gr.Slider(
868
- label="Video Rescale", minimum=0.0, maximum=2.0, value=0.45, step=0.1
869
- )
870
- video_a2v_scale = gr.Slider(
871
- label="A2V Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
872
- )
873
-
874
- gr.Markdown("### Audio Guidance Parameters")
875
-
876
- with gr.Row():
877
- audio_cfg_scale = gr.Slider(
878
- label="Audio CFG Scale", minimum=1.0, maximum=15.0,
879
- value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale, step=0.1
880
- )
881
- audio_stg_scale = gr.Slider(
882
- label="Audio STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
883
- )
884
-
885
- with gr.Row():
886
- audio_rescale_scale = gr.Slider(
887
- label="Audio Rescale", minimum=0.0, maximum=2.0, value=1.0, step=0.1
888
- )
889
- audio_v2a_scale = gr.Slider(
890
- label="V2A Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
891
- )
892
-
893
- def on_image_upload(image, current_h, current_w):
894
- if image is None:
895
- return gr.update(), gr.update()
896
- aspect = detect_aspect_ratio(image)
897
- if aspect in RESOLUTIONS:
898
- return (
899
- gr.update(value=RESOLUTIONS[aspect]["width"]),
900
- gr.update(value=RESOLUTIONS[aspect]["height"])
901
- )
902
- return gr.update(), gr.update()
903
-
904
- input_image.change(
905
- fn=on_image_upload,
906
- inputs=[input_image, height, width],
907
- outputs=[width, height],
908
- )
909
-
910
- generate_btn.click(
911
- fn=generate_video,
912
- inputs=[
913
- prompt, negative_prompt, input_image, duration,
914
- seed, randomize_seed, height, width, enhance_prompt,
915
- video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale,
916
- audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale,
917
- ],
918
- outputs=[output_video, seed],
919
- )
920
-
921
-
922
- if __name__ == "__main__":
923
- demo.queue().launch(
924
- theme=gr.themes.Citrus(),
925
- css=css,
926
- mcp_server=True,
927
- )