linoyts HF Staff commited on
Commit
b6775a7
·
verified ·
1 Parent(s): 85c5116

add audio input support

Browse files
Files changed (1) hide show
  1. app.py +202 -26
app.py CHANGED
@@ -42,11 +42,25 @@ import gradio as gr
42
  import numpy as np
43
  from huggingface_hub import hf_hub_download, snapshot_download
44
 
45
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
 
 
 
 
46
  from ltx_core.quantization import QuantizationPolicy
 
47
  from ltx_pipelines.distilled import DistilledPipeline
 
48
  from ltx_pipelines.utils.args import ImageConditioningInput
49
- from ltx_pipelines.utils.media_io import encode_video
 
 
 
 
 
 
 
 
50
 
51
  # Force-patch xformers attention into the LTX attention module.
52
  from ltx_core.model.transformer import attention as _attn_mod
@@ -75,6 +89,169 @@ RESOLUTIONS = {
75
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
76
  }
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Model repos
79
  LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
80
  GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
@@ -92,8 +269,8 @@ print(f"Checkpoint: {checkpoint_path}")
92
  print(f"Spatial upsampler: {spatial_upsampler_path}")
93
  print(f"Gemma root: {gemma_root}")
94
 
95
- # Initialize pipeline WITH text encoder
96
- pipeline = DistilledPipeline(
97
  distilled_checkpoint_path=checkpoint_path,
98
  spatial_upsampler_path=spatial_upsampler_path,
99
  gemma_root=gemma_root,
@@ -102,11 +279,12 @@ pipeline = DistilledPipeline(
102
  )
103
 
104
  # Preload all models for ZeroGPU tensor packing.
105
- print("Preloading all models (including Gemma)...")
106
  ledger = pipeline.model_ledger
107
  _transformer = ledger.transformer()
108
  _video_encoder = ledger.video_encoder()
109
  _video_decoder = ledger.video_decoder()
 
110
  _audio_decoder = ledger.audio_decoder()
111
  _vocoder = ledger.vocoder()
112
  _spatial_upsampler = ledger.spatial_upsampler()
@@ -116,12 +294,13 @@ _embeddings_processor = ledger.gemma_embeddings_processor()
116
  ledger.transformer = lambda: _transformer
117
  ledger.video_encoder = lambda: _video_encoder
118
  ledger.video_decoder = lambda: _video_decoder
 
119
  ledger.audio_decoder = lambda: _audio_decoder
120
  ledger.vocoder = lambda: _vocoder
121
  ledger.spatial_upsampler = lambda: _spatial_upsampler
122
  ledger.text_encoder = lambda: _text_encoder
123
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
124
- print("All models preloaded (including Gemma text encoder)!")
125
 
126
  print("=" * 80)
127
  print("Pipeline ready!")
@@ -137,7 +316,6 @@ def log_memory(tag: str):
137
 
138
 
139
  def detect_aspect_ratio(image) -> str:
140
- """Detect the closest aspect ratio (16:9, 9:16, or 1:1) from an image."""
141
  if image is None:
142
  return "16:9"
143
  if hasattr(image, "size"):
@@ -152,8 +330,6 @@ def detect_aspect_ratio(image) -> str:
152
 
153
 
154
  def on_image_upload(first_image, last_image, high_res):
155
- """Auto-set resolution when an image is uploaded."""
156
- # Use first image for aspect ratio detection, fall back to last image
157
  ref_image = first_image if first_image is not None else last_image
158
  aspect = detect_aspect_ratio(ref_image)
159
  tier = "high" if high_res else "low"
@@ -162,7 +338,6 @@ def on_image_upload(first_image, last_image, high_res):
162
 
163
 
164
  def on_highres_toggle(first_image, last_image, high_res):
165
- """Update resolution when high-res toggle changes."""
166
  ref_image = first_image if first_image is not None else last_image
167
  aspect = detect_aspect_ratio(ref_image)
168
  tier = "high" if high_res else "low"
@@ -175,6 +350,7 @@ def on_highres_toggle(first_image, last_image, high_res):
175
  def generate_video(
176
  first_image,
177
  last_image,
 
178
  prompt: str,
179
  duration: float,
180
  enhance_prompt: bool = True,
@@ -229,6 +405,7 @@ def generate_video(
229
  num_frames=num_frames,
230
  frame_rate=frame_rate,
231
  images=images,
 
232
  tiling_config=tiling_config,
233
  enhance_prompt=enhance_prompt,
234
  )
@@ -257,7 +434,7 @@ def generate_video(
257
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
258
  gr.Markdown("# LTX-2.3 F2LF: Fast Audio-Video Generation with Frame Conditioning")
259
  gr.Markdown(
260
- "Fast and high quality video + audio generation with first and last frame conditioing "
261
  "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
262
  "[[code]](https://github.com/Lightricks/LTX-2)"
263
  )
@@ -267,6 +444,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
267
  with gr.Row():
268
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
269
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
 
270
  prompt = gr.Textbox(
271
  label="Prompt",
272
  info="for best results - make it as elaborate as possible",
@@ -274,7 +452,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
274
  lines=3,
275
  placeholder="Describe the motion and animation you want...",
276
  )
277
-
278
  with gr.Row():
279
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
280
  with gr.Column():
@@ -292,34 +470,33 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
292
 
293
  with gr.Column():
294
  output_video = gr.Video(label="Generated Video", autoplay=True)
295
-
296
  gr.Examples(
297
  examples=[
298
  [
299
- None, # first_image
300
- "pinkknit.jpg", # last_image
 
301
  "The camera falls downward through darkness as if dropped into a tunnel. "
302
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
303
  "over and look down toward the camera with curious expressions. The lens "
304
  "has a strong fisheye effect, creating a circular frame around them. They "
305
  "crowd together closely, forming a symmetrical cluster while staring "
306
  "directly into the lens.",
307
- 3.0, # duration
308
- False, # enhance_prompt
309
- 42, # seed
310
- True, # randomize_seed
 
311
  1024,
312
- 1024
313
-
314
  ],
315
  ],
316
  inputs=[
317
- first_image, last_image, prompt, duration,
318
  enhance_prompt, seed, randomize_seed, height, width,
319
  ],
320
  )
321
 
322
- # Auto-detect aspect ratio from uploaded image and set resolution
323
  first_image.change(
324
  fn=on_image_upload,
325
  inputs=[first_image, last_image, high_res],
@@ -332,7 +509,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
332
  outputs=[width, height],
333
  )
334
 
335
- # Update resolution when high-res toggle changes
336
  high_res.change(
337
  fn=on_highres_toggle,
338
  inputs=[first_image, last_image, high_res],
@@ -342,7 +518,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
342
  generate_btn.click(
343
  fn=generate_video,
344
  inputs=[
345
- first_image, last_image, prompt, duration, enhance_prompt,
346
  seed, randomize_seed, height, width,
347
  ],
348
  outputs=[output_video, seed],
@@ -354,4 +530,4 @@ css = """
354
  """
355
 
356
  if __name__ == "__main__":
357
- demo.launch(theme=gr.themes.Citrus(), css=css)
 
42
  import numpy as np
43
  from huggingface_hub import hf_hub_download, snapshot_download
44
 
45
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
46
+ from ltx_core.components.noisers import GaussianNoiser
47
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
48
+ from ltx_core.model.upsampler import upsample_video
49
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
50
  from ltx_core.quantization import QuantizationPolicy
51
+ from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
52
  from ltx_pipelines.distilled import DistilledPipeline
53
+ from ltx_pipelines.utils import euler_denoising_loop
54
  from ltx_pipelines.utils.args import ImageConditioningInput
55
+ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
56
+ from ltx_pipelines.utils.helpers import (
57
+ cleanup_memory,
58
+ combined_image_conditionings,
59
+ denoise_video_only,
60
+ encode_prompts,
61
+ simple_denoising_func,
62
+ )
63
+ from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
64
 
65
  # Force-patch xformers attention into the LTX attention module.
66
  from ltx_core.model.transformer import attention as _attn_mod
 
89
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
90
  }
91
 
92
+
93
+ class LTX23DistilledA2VPipeline(DistilledPipeline):
94
+ """DistilledPipeline with optional audio conditioning."""
95
+
96
+ def __call__(
97
+ self,
98
+ prompt: str,
99
+ seed: int,
100
+ height: int,
101
+ width: int,
102
+ num_frames: int,
103
+ frame_rate: float,
104
+ images: list[ImageConditioningInput],
105
+ audio_path: str | None = None,
106
+ tiling_config: TilingConfig | None = None,
107
+ enhance_prompt: bool = False,
108
+ ):
109
+ # Standard path when no audio input is provided.
110
+ if audio_path is None:
111
+ return super().__call__(
112
+ prompt=prompt,
113
+ seed=seed,
114
+ height=height,
115
+ width=width,
116
+ num_frames=num_frames,
117
+ frame_rate=frame_rate,
118
+ images=images,
119
+ tiling_config=tiling_config,
120
+ enhance_prompt=enhance_prompt,
121
+ )
122
+
123
+ generator = torch.Generator(device=self.device).manual_seed(seed)
124
+ noiser = GaussianNoiser(generator=generator)
125
+ stepper = EulerDiffusionStep()
126
+ dtype = torch.bfloat16
127
+
128
+ (ctx_p,) = encode_prompts(
129
+ [prompt],
130
+ self.model_ledger,
131
+ enhance_first_prompt=enhance_prompt,
132
+ enhance_prompt_image=images[0].path if len(images) > 0 else None,
133
+ )
134
+ video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
135
+
136
+ video_duration = num_frames / frame_rate
137
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
138
+ if decoded_audio is None:
139
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
140
+
141
+ encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
142
+ audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
143
+ expected_frames = audio_shape.frames
144
+ actual_frames = encoded_audio_latent.shape[2]
145
+
146
+ if actual_frames > expected_frames:
147
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
148
+ elif actual_frames < expected_frames:
149
+ pad = torch.zeros(
150
+ encoded_audio_latent.shape[0],
151
+ encoded_audio_latent.shape[1],
152
+ expected_frames - actual_frames,
153
+ encoded_audio_latent.shape[3],
154
+ device=encoded_audio_latent.device,
155
+ dtype=encoded_audio_latent.dtype,
156
+ )
157
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
158
+
159
+ video_encoder = self.model_ledger.video_encoder()
160
+ transformer = self.model_ledger.transformer()
161
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
162
+
163
+ def denoising_loop(sigmas, video_state, audio_state, stepper):
164
+ return euler_denoising_loop(
165
+ sigmas=sigmas,
166
+ video_state=video_state,
167
+ audio_state=audio_state,
168
+ stepper=stepper,
169
+ denoise_fn=simple_denoising_func(
170
+ video_context=video_context,
171
+ audio_context=audio_context,
172
+ transformer=transformer,
173
+ ),
174
+ )
175
+
176
+ stage_1_output_shape = VideoPixelShape(
177
+ batch=1,
178
+ frames=num_frames,
179
+ width=width // 2,
180
+ height=height // 2,
181
+ fps=frame_rate,
182
+ )
183
+ stage_1_conditionings = combined_image_conditionings(
184
+ images=images,
185
+ height=stage_1_output_shape.height,
186
+ width=stage_1_output_shape.width,
187
+ video_encoder=video_encoder,
188
+ dtype=dtype,
189
+ device=self.device,
190
+ )
191
+ video_state = denoise_video_only(
192
+ output_shape=stage_1_output_shape,
193
+ conditionings=stage_1_conditionings,
194
+ noiser=noiser,
195
+ sigmas=stage_1_sigmas,
196
+ stepper=stepper,
197
+ denoising_loop_fn=denoising_loop,
198
+ components=self.pipeline_components,
199
+ dtype=dtype,
200
+ device=self.device,
201
+ initial_audio_latent=encoded_audio_latent,
202
+ )
203
+
204
+ torch.cuda.synchronize()
205
+ cleanup_memory()
206
+
207
+ upscaled_video_latent = upsample_video(
208
+ latent=video_state.latent[:1],
209
+ video_encoder=video_encoder,
210
+ upsampler=self.model_ledger.spatial_upsampler(),
211
+ )
212
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
213
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
214
+ stage_2_conditionings = combined_image_conditionings(
215
+ images=images,
216
+ height=stage_2_output_shape.height,
217
+ width=stage_2_output_shape.width,
218
+ video_encoder=video_encoder,
219
+ dtype=dtype,
220
+ device=self.device,
221
+ )
222
+ video_state = denoise_video_only(
223
+ output_shape=stage_2_output_shape,
224
+ conditionings=stage_2_conditionings,
225
+ noiser=noiser,
226
+ sigmas=stage_2_sigmas,
227
+ stepper=stepper,
228
+ denoising_loop_fn=denoising_loop,
229
+ components=self.pipeline_components,
230
+ dtype=dtype,
231
+ device=self.device,
232
+ noise_scale=stage_2_sigmas[0],
233
+ initial_video_latent=upscaled_video_latent,
234
+ initial_audio_latent=encoded_audio_latent,
235
+ )
236
+
237
+ torch.cuda.synchronize()
238
+ del transformer
239
+ del video_encoder
240
+ cleanup_memory()
241
+
242
+ decoded_video = vae_decode_video(
243
+ video_state.latent,
244
+ self.model_ledger.video_decoder(),
245
+ tiling_config,
246
+ generator,
247
+ )
248
+ original_audio = Audio(
249
+ waveform=decoded_audio.waveform.squeeze(0),
250
+ sampling_rate=decoded_audio.sampling_rate,
251
+ )
252
+ return decoded_video, original_audio
253
+
254
+
255
  # Model repos
256
  LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
257
  GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
 
269
  print(f"Spatial upsampler: {spatial_upsampler_path}")
270
  print(f"Gemma root: {gemma_root}")
271
 
272
+ # Initialize pipeline WITH text encoder and optional audio support
273
+ pipeline = LTX23DistilledA2VPipeline(
274
  distilled_checkpoint_path=checkpoint_path,
275
  spatial_upsampler_path=spatial_upsampler_path,
276
  gemma_root=gemma_root,
 
279
  )
280
 
281
  # Preload all models for ZeroGPU tensor packing.
282
+ print("Preloading all models (including Gemma and audio components)...")
283
  ledger = pipeline.model_ledger
284
  _transformer = ledger.transformer()
285
  _video_encoder = ledger.video_encoder()
286
  _video_decoder = ledger.video_decoder()
287
+ _audio_encoder = ledger.audio_encoder()
288
  _audio_decoder = ledger.audio_decoder()
289
  _vocoder = ledger.vocoder()
290
  _spatial_upsampler = ledger.spatial_upsampler()
 
294
  ledger.transformer = lambda: _transformer
295
  ledger.video_encoder = lambda: _video_encoder
296
  ledger.video_decoder = lambda: _video_decoder
297
+ ledger.audio_encoder = lambda: _audio_encoder
298
  ledger.audio_decoder = lambda: _audio_decoder
299
  ledger.vocoder = lambda: _vocoder
300
  ledger.spatial_upsampler = lambda: _spatial_upsampler
301
  ledger.text_encoder = lambda: _text_encoder
302
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
303
+ print("All models preloaded (including Gemma text encoder and audio encoder)!")
304
 
305
  print("=" * 80)
306
  print("Pipeline ready!")
 
316
 
317
 
318
  def detect_aspect_ratio(image) -> str:
 
319
  if image is None:
320
  return "16:9"
321
  if hasattr(image, "size"):
 
330
 
331
 
332
  def on_image_upload(first_image, last_image, high_res):
 
 
333
  ref_image = first_image if first_image is not None else last_image
334
  aspect = detect_aspect_ratio(ref_image)
335
  tier = "high" if high_res else "low"
 
338
 
339
 
340
  def on_highres_toggle(first_image, last_image, high_res):
 
341
  ref_image = first_image if first_image is not None else last_image
342
  aspect = detect_aspect_ratio(ref_image)
343
  tier = "high" if high_res else "low"
 
350
  def generate_video(
351
  first_image,
352
  last_image,
353
+ input_audio,
354
  prompt: str,
355
  duration: float,
356
  enhance_prompt: bool = True,
 
405
  num_frames=num_frames,
406
  frame_rate=frame_rate,
407
  images=images,
408
+ audio_path=input_audio,
409
  tiling_config=tiling_config,
410
  enhance_prompt=enhance_prompt,
411
  )
 
434
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
435
  gr.Markdown("# LTX-2.3 F2LF: Fast Audio-Video Generation with Frame Conditioning")
436
  gr.Markdown(
437
+ "Fast and high quality video + audio generation with first and last frame conditioning and optional audio input "
438
  "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
439
  "[[code]](https://github.com/Lightricks/LTX-2)"
440
  )
 
444
  with gr.Row():
445
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
446
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
447
+ input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
448
  prompt = gr.Textbox(
449
  label="Prompt",
450
  info="for best results - make it as elaborate as possible",
 
452
  lines=3,
453
  placeholder="Describe the motion and animation you want...",
454
  )
455
+
456
  with gr.Row():
457
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
458
  with gr.Column():
 
470
 
471
  with gr.Column():
472
  output_video = gr.Video(label="Generated Video", autoplay=True)
473
+
474
  gr.Examples(
475
  examples=[
476
  [
477
+ None,
478
+ "pinkknit.jpg",
479
+ None,
480
  "The camera falls downward through darkness as if dropped into a tunnel. "
481
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
482
  "over and look down toward the camera with curious expressions. The lens "
483
  "has a strong fisheye effect, creating a circular frame around them. They "
484
  "crowd together closely, forming a symmetrical cluster while staring "
485
  "directly into the lens.",
486
+ 3.0,
487
+ False,
488
+ 42,
489
+ True,
490
+ 1024,
491
  1024,
 
 
492
  ],
493
  ],
494
  inputs=[
495
+ first_image, last_image, input_audio, prompt, duration,
496
  enhance_prompt, seed, randomize_seed, height, width,
497
  ],
498
  )
499
 
 
500
  first_image.change(
501
  fn=on_image_upload,
502
  inputs=[first_image, last_image, high_res],
 
509
  outputs=[width, height],
510
  )
511
 
 
512
  high_res.change(
513
  fn=on_highres_toggle,
514
  inputs=[first_image, last_image, high_res],
 
518
  generate_btn.click(
519
  fn=generate_video,
520
  inputs=[
521
+ first_image, last_image, input_audio, prompt, duration, enhance_prompt,
522
  seed, randomize_seed, height, width,
523
  ],
524
  outputs=[output_video, seed],
 
530
  """
531
 
532
  if __name__ == "__main__":
533
+ demo.launch(theme=gr.themes.Citrus(), css=css)