Files changed (1) hide show
  1. app.py +0 -542
app.py DELETED
@@ -1,542 +0,0 @@
1
- import os
2
- import subprocess
3
- import sys
4
-
5
- # Disable torch.compile / dynamo before any torch import
6
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
-
9
- # Install xformers for memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
-
12
- # Clone LTX-2 repo at a pinned compatible commit and install packages
13
- LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
- LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
- LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
16
-
17
- if os.path.exists(LTX_REPO_DIR):
18
- print(f"Removing existing repo at {LTX_REPO_DIR}...")
19
- subprocess.run(["rm", "-rf", LTX_REPO_DIR], check=True)
20
-
21
- print(f"Cloning {LTX_REPO_URL}...")
22
- subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)
23
-
24
- print(f"Checking out commit {LTX_COMMIT}...")
25
- subprocess.run(["git", "-C", LTX_REPO_DIR, "checkout", LTX_COMMIT], check=True)
26
-
27
- print("Installing ltx-core and ltx-pipelines from pinned repo commit...")
28
- subprocess.run(
29
- [
30
- sys.executable, "-m", "pip", "install",
31
- "--force-reinstall", "--no-deps",
32
- "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
33
- "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines"),
34
- ],
35
- check=True,
36
- )
37
-
38
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
39
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
40
-
41
- import logging
42
- import random
43
- import tempfile
44
- from pathlib import Path
45
-
46
- import torch
47
- torch._dynamo.config.suppress_errors = True
48
- torch._dynamo.config.disable = True
49
-
50
- import spaces
51
- import gradio as gr
52
- import numpy as np
53
- from huggingface_hub import hf_hub_download, snapshot_download
54
-
55
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
56
- from ltx_core.components.noisers import GaussianNoiser
57
- from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
58
- from ltx_core.model.upsampler import upsample_video
59
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
60
- from ltx_core.quantization import QuantizationPolicy
61
- from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
62
- from ltx_pipelines.distilled import DistilledPipeline
63
- from ltx_pipelines.utils import euler_denoising_loop
64
- from ltx_pipelines.utils.args import ImageConditioningInput
65
- from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
66
- from ltx_pipelines.utils.helpers import (
67
- cleanup_memory,
68
- combined_image_conditionings,
69
- denoise_video_only,
70
- encode_prompts,
71
- simple_denoising_func,
72
- )
73
- from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
74
-
75
- # Force-patch xformers attention into the LTX attention module.
76
- from ltx_core.model.transformer import attention as _attn_mod
77
- print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
78
- try:
79
- from xformers.ops import memory_efficient_attention as _mea
80
- _attn_mod.memory_efficient_attention = _mea
81
- print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
82
- except Exception as e:
83
- print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
84
-
85
- logging.getLogger().setLevel(logging.INFO)
86
-
87
- MAX_SEED = np.iinfo(np.int32).max
88
- DEFAULT_PROMPT = (
89
- "An astronaut hatches from a fragile egg on the surface of the Moon, "
90
- "the shell cracking and peeling apart in gentle low-gravity motion. "
91
- "Fine lunar dust lifts and drifts outward with each movement, floating "
92
- "in slow arcs before settling back onto the ground."
93
- )
94
- DEFAULT_FRAME_RATE = 24.0
95
-
96
- # Resolution presets: (width, height)
97
- RESOLUTIONS = {
98
- "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
99
- "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
100
- }
101
-
102
-
103
- class LTX23DistilledA2VPipeline(DistilledPipeline):
104
- """DistilledPipeline with optional audio conditioning."""
105
-
106
- def __call__(
107
- self,
108
- prompt: str,
109
- seed: int,
110
- height: int,
111
- width: int,
112
- num_frames: int,
113
- frame_rate: float,
114
- images: list[ImageConditioningInput],
115
- audio_path: str | None = None,
116
- tiling_config: TilingConfig | None = None,
117
- enhance_prompt: bool = False,
118
- ):
119
- # Standard path when no audio input is provided.
120
- if audio_path is None:
121
- return super().__call__(
122
- prompt=prompt,
123
- seed=seed,
124
- height=height,
125
- width=width,
126
- num_frames=num_frames,
127
- frame_rate=frame_rate,
128
- images=images,
129
- tiling_config=tiling_config,
130
- enhance_prompt=enhance_prompt,
131
- )
132
-
133
- generator = torch.Generator(device=self.device).manual_seed(seed)
134
- noiser = GaussianNoiser(generator=generator)
135
- stepper = EulerDiffusionStep()
136
- dtype = torch.bfloat16
137
-
138
- (ctx_p,) = encode_prompts(
139
- [prompt],
140
- self.model_ledger,
141
- enhance_first_prompt=enhance_prompt,
142
- enhance_prompt_image=images[0].path if len(images) > 0 else None,
143
- )
144
- video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
145
-
146
- video_duration = num_frames / frame_rate
147
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
148
- if decoded_audio is None:
149
- raise ValueError(f"Could not extract audio stream from {audio_path}")
150
-
151
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
152
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
153
- expected_frames = audio_shape.frames
154
- actual_frames = encoded_audio_latent.shape[2]
155
-
156
- if actual_frames > expected_frames:
157
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
158
- elif actual_frames < expected_frames:
159
- pad = torch.zeros(
160
- encoded_audio_latent.shape[0],
161
- encoded_audio_latent.shape[1],
162
- expected_frames - actual_frames,
163
- encoded_audio_latent.shape[3],
164
- device=encoded_audio_latent.device,
165
- dtype=encoded_audio_latent.dtype,
166
- )
167
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
168
-
169
- video_encoder = self.model_ledger.video_encoder()
170
- transformer = self.model_ledger.transformer()
171
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
172
-
173
- def denoising_loop(sigmas, video_state, audio_state, stepper):
174
- return euler_denoising_loop(
175
- sigmas=sigmas,
176
- video_state=video_state,
177
- audio_state=audio_state,
178
- stepper=stepper,
179
- denoise_fn=simple_denoising_func(
180
- video_context=video_context,
181
- audio_context=audio_context,
182
- transformer=transformer,
183
- ),
184
- )
185
-
186
- stage_1_output_shape = VideoPixelShape(
187
- batch=1,
188
- frames=num_frames,
189
- width=width // 2,
190
- height=height // 2,
191
- fps=frame_rate,
192
- )
193
- stage_1_conditionings = combined_image_conditionings(
194
- images=images,
195
- height=stage_1_output_shape.height,
196
- width=stage_1_output_shape.width,
197
- video_encoder=video_encoder,
198
- dtype=dtype,
199
- device=self.device,
200
- )
201
- video_state = denoise_video_only(
202
- output_shape=stage_1_output_shape,
203
- conditionings=stage_1_conditionings,
204
- noiser=noiser,
205
- sigmas=stage_1_sigmas,
206
- stepper=stepper,
207
- denoising_loop_fn=denoising_loop,
208
- components=self.pipeline_components,
209
- dtype=dtype,
210
- device=self.device,
211
- initial_audio_latent=encoded_audio_latent,
212
- )
213
-
214
- torch.cuda.synchronize()
215
- cleanup_memory()
216
-
217
- upscaled_video_latent = upsample_video(
218
- latent=video_state.latent[:1],
219
- video_encoder=video_encoder,
220
- upsampler=self.model_ledger.spatial_upsampler(),
221
- )
222
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
223
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
224
- stage_2_conditionings = combined_image_conditionings(
225
- images=images,
226
- height=stage_2_output_shape.height,
227
- width=stage_2_output_shape.width,
228
- video_encoder=video_encoder,
229
- dtype=dtype,
230
- device=self.device,
231
- )
232
- video_state = denoise_video_only(
233
- output_shape=stage_2_output_shape,
234
- conditionings=stage_2_conditionings,
235
- noiser=noiser,
236
- sigmas=stage_2_sigmas,
237
- stepper=stepper,
238
- denoising_loop_fn=denoising_loop,
239
- components=self.pipeline_components,
240
- dtype=dtype,
241
- device=self.device,
242
- noise_scale=stage_2_sigmas[0],
243
- initial_video_latent=upscaled_video_latent,
244
- initial_audio_latent=encoded_audio_latent,
245
- )
246
-
247
- torch.cuda.synchronize()
248
- del transformer
249
- del video_encoder
250
- cleanup_memory()
251
-
252
- decoded_video = vae_decode_video(
253
- video_state.latent,
254
- self.model_ledger.video_decoder(),
255
- tiling_config,
256
- generator,
257
- )
258
- original_audio = Audio(
259
- waveform=decoded_audio.waveform.squeeze(0),
260
- sampling_rate=decoded_audio.sampling_rate,
261
- )
262
- return decoded_video, original_audio
263
-
264
-
265
- # Model repos
266
- LTX_MODEL_REPO = "Lightricks/LTX-2.3"
267
- GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
268
-
269
- # Download model checkpoints
270
- print("=" * 80)
271
- print("Downloading LTX-2.3 distilled model + Gemma...")
272
- print("=" * 80)
273
-
274
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-1.1.safetensors")
275
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
276
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
277
-
278
- print(f"Checkpoint: {checkpoint_path}")
279
- print(f"Spatial upsampler: {spatial_upsampler_path}")
280
- print(f"Gemma root: {gemma_root}")
281
-
282
- # Initialize pipeline WITH text encoder and optional audio support
283
- pipeline = LTX23DistilledA2VPipeline(
284
- distilled_checkpoint_path=checkpoint_path,
285
- spatial_upsampler_path=spatial_upsampler_path,
286
- gemma_root=gemma_root,
287
- loras=[],
288
- quantization=QuantizationPolicy.fp8_cast(),
289
- )
290
-
291
- # Preload all models for ZeroGPU tensor packing.
292
- print("Preloading all models (including Gemma and audio components)...")
293
- ledger = pipeline.model_ledger
294
- _transformer = ledger.transformer()
295
- _video_encoder = ledger.video_encoder()
296
- _video_decoder = ledger.video_decoder()
297
- _audio_encoder = ledger.audio_encoder()
298
- _audio_decoder = ledger.audio_decoder()
299
- _vocoder = ledger.vocoder()
300
- _spatial_upsampler = ledger.spatial_upsampler()
301
- _text_encoder = ledger.text_encoder()
302
- _embeddings_processor = ledger.gemma_embeddings_processor()
303
-
304
- ledger.transformer = lambda: _transformer
305
- ledger.video_encoder = lambda: _video_encoder
306
- ledger.video_decoder = lambda: _video_decoder
307
- ledger.audio_encoder = lambda: _audio_encoder
308
- ledger.audio_decoder = lambda: _audio_decoder
309
- ledger.vocoder = lambda: _vocoder
310
- ledger.spatial_upsampler = lambda: _spatial_upsampler
311
- ledger.text_encoder = lambda: _text_encoder
312
- ledger.gemma_embeddings_processor = lambda: _embeddings_processor
313
- print("All models preloaded (including Gemma text encoder and audio encoder)!")
314
-
315
- print("=" * 80)
316
- print("Pipeline ready!")
317
- print("=" * 80)
318
-
319
-
320
- def log_memory(tag: str):
321
- if torch.cuda.is_available():
322
- allocated = torch.cuda.memory_allocated() / 1024**3
323
- peak = torch.cuda.max_memory_allocated() / 1024**3
324
- free, total = torch.cuda.mem_get_info()
325
- print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
326
-
327
-
328
- def detect_aspect_ratio(image) -> str:
329
- if image is None:
330
- return "16:9"
331
- if hasattr(image, "size"):
332
- w, h = image.size
333
- elif hasattr(image, "shape"):
334
- h, w = image.shape[:2]
335
- else:
336
- return "16:9"
337
- ratio = w / h
338
- candidates = {"16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0}
339
- return min(candidates, key=lambda k: abs(ratio - candidates[k]))
340
-
341
-
342
- def on_image_upload(first_image, last_image, high_res):
343
- ref_image = first_image if first_image is not None else last_image
344
- aspect = detect_aspect_ratio(ref_image)
345
- tier = "high" if high_res else "low"
346
- w, h = RESOLUTIONS[tier][aspect]
347
- return gr.update(value=w), gr.update(value=h)
348
-
349
-
350
- def on_highres_toggle(first_image, last_image, high_res):
351
- ref_image = first_image if first_image is not None else last_image
352
- aspect = detect_aspect_ratio(ref_image)
353
- tier = "high" if high_res else "low"
354
- w, h = RESOLUTIONS[tier][aspect]
355
- return gr.update(value=w), gr.update(value=h)
356
-
357
-
358
- @spaces.GPU(duration=75)
359
- @torch.inference_mode()
360
- def generate_video(
361
- first_image,
362
- last_image,
363
- input_audio,
364
- prompt: str,
365
- duration: float,
366
- enhance_prompt: bool = True,
367
- seed: int = 42,
368
- randomize_seed: bool = True,
369
- height: int = 1024,
370
- width: int = 1536,
371
- progress=gr.Progress(track_tqdm=True),
372
- ):
373
- try:
374
- torch.cuda.reset_peak_memory_stats()
375
- log_memory("start")
376
-
377
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
378
-
379
- frame_rate = DEFAULT_FRAME_RATE
380
- num_frames = int(duration * frame_rate) + 1
381
- num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
382
-
383
- print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
384
-
385
- images = []
386
- output_dir = Path("outputs")
387
- output_dir.mkdir(exist_ok=True)
388
-
389
- if first_image is not None:
390
- temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
391
- if hasattr(first_image, "save"):
392
- first_image.save(temp_first_path)
393
- else:
394
- temp_first_path = Path(first_image)
395
- images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))
396
-
397
- if last_image is not None:
398
- temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
399
- if hasattr(last_image, "save"):
400
- last_image.save(temp_last_path)
401
- else:
402
- temp_last_path = Path(last_image)
403
- images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
404
-
405
- tiling_config = TilingConfig.default()
406
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
407
-
408
- log_memory("before pipeline call")
409
-
410
- video, audio = pipeline(
411
- prompt=prompt,
412
- seed=current_seed,
413
- height=int(height),
414
- width=int(width),
415
- num_frames=num_frames,
416
- frame_rate=frame_rate,
417
- images=images,
418
- audio_path=input_audio,
419
- tiling_config=tiling_config,
420
- enhance_prompt=enhance_prompt,
421
- )
422
-
423
- log_memory("after pipeline call")
424
-
425
- output_path = tempfile.mktemp(suffix=".mp4")
426
- encode_video(
427
- video=video,
428
- fps=frame_rate,
429
- audio=audio,
430
- output_path=output_path,
431
- video_chunks_number=video_chunks_number,
432
- )
433
-
434
- log_memory("after encode_video")
435
- return str(output_path), current_seed
436
-
437
- except Exception as e:
438
- import traceback
439
- log_memory("on error")
440
- print(f"Error: {str(e)}\n{traceback.format_exc()}")
441
- return None, current_seed
442
-
443
-
444
- with gr.Blocks(title="LTX-2.3 Distilled") as demo:
445
- gr.Markdown("# LTX-2.3 F2LF: Fast Audio-Video Generation with Frame Conditioning")
446
- gr.Markdown(
447
- "Fast and high quality video + audio generation with first and last frame conditioning and optional audio input "
448
- "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
449
- "[[code]](https://github.com/Lightricks/LTX-2)"
450
- )
451
-
452
- with gr.Row():
453
- with gr.Column():
454
- with gr.Row():
455
- first_image = gr.Image(label="First Frame (Optional)", type="pil")
456
- last_image = gr.Image(label="Last Frame (Optional)", type="pil")
457
- input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
458
- prompt = gr.Textbox(
459
- label="Prompt",
460
- info="for best results - make it as elaborate as possible",
461
- value="Make this image come alive with cinematic motion, smooth animation",
462
- lines=3,
463
- placeholder="Describe the motion and animation you want...",
464
- )
465
- duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
466
-
467
-
468
- generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
469
-
470
- with gr.Accordion("Advanced Settings", open=False):
471
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1)
472
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
473
- with gr.Row():
474
- width = gr.Number(label="Width", value=1536, precision=0)
475
- height = gr.Number(label="Height", value=1024, precision=0)
476
- with gr.Row():
477
- enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
478
- high_res = gr.Checkbox(label="High Resolution", value=True)
479
-
480
- with gr.Column():
481
- output_video = gr.Video(label="Generated Video", autoplay=True)
482
-
483
- gr.Examples(
484
- examples=[
485
- [
486
- None,
487
- "pinkknit.jpg",
488
- None,
489
- "The camera falls downward through darkness as if dropped into a tunnel. "
490
- "As it slows, five friends wearing pink knitted hats and sunglasses lean "
491
- "over and look down toward the camera with curious expressions. The lens "
492
- "has a strong fisheye effect, creating a circular frame around them. They "
493
- "crowd together closely, forming a symmetrical cluster while staring "
494
- "directly into the lens.",
495
- 3.0,
496
- False,
497
- 42,
498
- True,
499
- 1024,
500
- 1024,
501
- ],
502
- ],
503
- inputs=[
504
- first_image, last_image, input_audio, prompt, duration,
505
- enhance_prompt, seed, randomize_seed, height, width,
506
- ],
507
- )
508
-
509
- first_image.change(
510
- fn=on_image_upload,
511
- inputs=[first_image, last_image, high_res],
512
- outputs=[width, height],
513
- )
514
-
515
- last_image.change(
516
- fn=on_image_upload,
517
- inputs=[first_image, last_image, high_res],
518
- outputs=[width, height],
519
- )
520
-
521
- high_res.change(
522
- fn=on_highres_toggle,
523
- inputs=[first_image, last_image, high_res],
524
- outputs=[width, height],
525
- )
526
-
527
- generate_btn.click(
528
- fn=generate_video,
529
- inputs=[
530
- first_image, last_image, input_audio, prompt, duration, enhance_prompt,
531
- seed, randomize_seed, height, width,
532
- ],
533
- outputs=[output_video, seed],
534
- )
535
-
536
-
537
- css = """
538
- .fillable{max-width: 1200px !important}
539
- """
540
-
541
- if __name__ == "__main__":
542
- demo.launch(theme=gr.themes.Citrus(), css=css)