linoyts HF Staff commited on
Commit
95a227e
Β·
verified Β·
1 Parent(s): 21c3a50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +566 -380
app.py CHANGED
@@ -10,18 +10,12 @@ os.environ["TORCHDYNAMO_DISABLE"] = "1"
10
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # Clone LTX-2 repo and install packages
13
- LTX_REPO_URL = "https://github.com/linoytsaban/LTX-2.git"
14
- LTX_REPO_BRANCH = "patch-1"
15
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
16
 
17
- if os.path.exists(LTX_REPO_DIR):
18
- subprocess.run(["rm", "-rf", LTX_REPO_DIR], check=True)
19
-
20
- print(f"Cloning {LTX_REPO_URL}@{LTX_REPO_BRANCH}...")
21
- subprocess.run(
22
- ["git", "clone", "--depth", "1", "--branch", LTX_REPO_BRANCH, LTX_REPO_URL, LTX_REPO_DIR],
23
- check=True,
24
- )
25
 
26
  print("Installing ltx-core and ltx-pipelines from cloned repo...")
27
  subprocess.run(
@@ -31,7 +25,6 @@ subprocess.run(
31
  check=True,
32
  )
33
 
34
-
35
  sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
36
  sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
37
 
@@ -48,21 +41,42 @@ import spaces
48
  import gradio as gr
49
  import numpy as np
50
  from huggingface_hub import hf_hub_download, snapshot_download
 
51
 
52
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
53
- from ltx_core.quantization import QuantizationPolicy
54
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
55
  from ltx_core.components.noisers import GaussianNoiser
 
 
 
 
 
 
 
56
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
57
  from ltx_core.model.upsampler import upsample_video
 
58
  from ltx_core.model.video_vae import decode_video as vae_decode_video
59
- from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
60
- from ltx_pipelines.ic_lora import ICLoraPipeline
61
- from ltx_pipelines.utils.helpers import combined_image_conditionings
62
- from ltx_pipelines.utils import cleanup_memory, denoise_audio_video, encode_prompts, euler_denoising_loop, simple_denoising_func
63
  from ltx_pipelines.utils.args import ImageConditioningInput
64
  from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
65
- from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Force-patch xformers attention into the LTX attention module.
68
  from ltx_core.model.transformer import attention as _attn_mod
@@ -76,55 +90,139 @@ except Exception as e:
76
 
77
  logging.getLogger().setLevel(logging.INFO)
78
 
79
- MAX_SEED = np.iinfo(np.int32).max
80
- DEFAULT_FRAME_RATE = 24.0
81
-
82
- # Resolution presets: (width, height)
83
- RESOLUTIONS = {
84
- "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
85
- "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
86
- }
87
 
88
- # Model repos
89
- LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
90
- GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # Available IC-LoRAs for LTX-2.3 (22B)
93
- IC_LORA_OPTIONS = {
94
- "Union Control (Depth + Canny)": {
95
- "repo": "Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control",
96
- "filename": "ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors",
97
- },
98
- "Motion Track Control": {
99
- "repo": "Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control",
100
- "filename": "ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors",
101
- },
102
- }
103
 
104
- # Download model checkpoints
105
- print("=" * 80)
106
- print("Downloading LTX-2.3 distilled model + Gemma + IC-LoRAs...")
107
- print("=" * 80)
 
 
 
 
 
 
108
 
109
- # checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
110
- checkpoint_path = hf_hub_download(repo_id="linoyts/ltx-2.3-22b-distilled-motion-track-control-fused", filename="ltx-2.3-22b-distilled-motion-track-control-fused.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
113
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
 
114
 
115
- print(f"Checkpoint: {checkpoint_path}")
116
- print(f"Spatial upsampler: {spatial_upsampler_path}")
117
- print(f"Gemma root: {gemma_root}")
 
 
 
 
 
 
 
 
118
 
119
- # Build initial pipeline with the first IC-LoRA
120
- default_lora_name = "Union Control (Depth + Canny)"
121
- current_pipeline = None
122
- current_lora_name = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- class AudioConditionedICLoraPipeline(ICLoraPipeline):
126
- """IC-LoRA pipeline with optional audio conditioning, adapted from multimodalart's audio-input Space."""
 
127
 
 
128
  def __call__(
129
  self,
130
  prompt: str,
@@ -134,85 +232,111 @@ class AudioConditionedICLoraPipeline(ICLoraPipeline):
134
  num_frames: int,
135
  frame_rate: float,
136
  images: list[ImageConditioningInput],
137
- video_conditioning: list[tuple[str, float]],
138
  audio_path: str | None = None,
139
- enhance_prompt: bool = False,
140
  tiling_config: TilingConfig | None = None,
141
- conditioning_attention_strength: float = 1.0,
142
- skip_stage_2: bool = False,
143
- conditioning_attention_mask: torch.Tensor | None = None,
144
  ):
145
- if audio_path is None:
146
- return super().__call__(
147
- prompt=prompt,
148
- seed=seed,
149
- height=height,
150
- width=width,
151
- num_frames=num_frames,
152
- frame_rate=frame_rate,
153
- images=images,
154
- video_conditioning=video_conditioning,
155
- enhance_prompt=enhance_prompt,
156
- tiling_config=tiling_config,
157
- conditioning_attention_strength=conditioning_attention_strength,
158
- skip_stage_2=skip_stage_2,
159
- conditioning_attention_mask=conditioning_attention_mask,
160
- )
161
 
162
  generator = torch.Generator(device=self.device).manual_seed(seed)
163
  noiser = GaussianNoiser(generator=generator)
164
  stepper = EulerDiffusionStep()
165
  dtype = torch.bfloat16
166
 
 
 
167
  (ctx_p,) = encode_prompts(
168
  [prompt],
169
  self.stage_1_model_ledger,
170
  enhance_first_prompt=enhance_prompt,
171
  enhance_prompt_image=images[0].path if len(images) > 0 else None,
172
- enhance_prompt_seed=seed,
173
  )
174
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
175
 
176
- video_duration = num_frames / frame_rate
177
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
178
- if decoded_audio is None:
179
- raise ValueError(f"Could not extract audio stream from {audio_path}")
180
-
181
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.stage_1_model_ledger.audio_encoder())
182
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
183
- expected_frames = audio_shape.frames
184
- actual_frames = encoded_audio_latent.shape[2]
185
-
186
- if actual_frames > expected_frames:
187
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
188
- elif actual_frames < expected_frames:
189
- pad = torch.zeros(
190
- encoded_audio_latent.shape[0],
191
- encoded_audio_latent.shape[1],
192
- expected_frames - actual_frames,
193
- encoded_audio_latent.shape[3],
194
- device=encoded_audio_latent.device,
195
- dtype=encoded_audio_latent.dtype,
 
 
 
 
 
 
 
 
 
 
 
196
  )
197
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
198
 
199
- stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate)
 
200
  video_encoder = self.stage_1_model_ledger.video_encoder()
201
- stage_1_conditionings = self._create_conditionings(
 
 
 
 
 
 
 
202
  images=images,
203
- video_conditioning=video_conditioning,
204
  height=stage_1_output_shape.height,
205
  width=stage_1_output_shape.width,
206
  video_encoder=video_encoder,
207
- num_frames=num_frames,
208
- conditioning_attention_strength=conditioning_attention_strength,
209
- conditioning_attention_mask=conditioning_attention_mask,
210
  )
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  transformer = self.stage_1_model_ledger.transformer()
213
  stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
214
 
215
- def first_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
216
  return euler_denoising_loop(
217
  sigmas=sigmas,
218
  video_state=video_state,
@@ -225,32 +349,39 @@ class AudioConditionedICLoraPipeline(ICLoraPipeline):
225
  ),
226
  )
227
 
228
- video_state, audio_state = denoise_audio_video(
229
- output_shape=stage_1_output_shape,
230
- conditionings=stage_1_conditionings,
231
- noiser=noiser,
232
- sigmas=stage_1_sigmas,
233
- stepper=stepper,
234
- denoising_loop_fn=first_stage_denoising_loop,
235
- components=self.pipeline_components,
236
- dtype=dtype,
237
- device=self.device,
238
- initial_audio_latent=encoded_audio_latent,
239
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  torch.cuda.synchronize()
242
- del transformer
243
  cleanup_memory()
244
 
245
- if skip_stage_2:
246
- decoded_video = vae_decode_video(
247
- video_state.latent, self.stage_1_model_ledger.video_decoder(), tiling_config, generator
248
- )
249
- original_audio = Audio(waveform=decoded_audio.waveform.squeeze(0), sampling_rate=decoded_audio.sampling_rate)
250
- del video_encoder
251
- cleanup_memory()
252
- return decoded_video, original_audio
253
-
254
  upscaled_video_latent = upsample_video(
255
  latent=video_state.latent[:1],
256
  video_encoder=video_encoder,
@@ -260,10 +391,11 @@ class AudioConditionedICLoraPipeline(ICLoraPipeline):
260
  torch.cuda.synchronize()
261
  cleanup_memory()
262
 
263
- transformer = self.stage_2_model_ledger.transformer()
 
264
  stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
265
 
266
- def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
267
  return euler_denoising_loop(
268
  sigmas=sigmas,
269
  video_state=video_state,
@@ -272,275 +404,334 @@ class AudioConditionedICLoraPipeline(ICLoraPipeline):
272
  denoise_fn=simple_denoising_func(
273
  video_context=video_context,
274
  audio_context=audio_context,
275
- transformer=transformer,
276
  ),
277
  )
278
 
279
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
 
 
 
280
  stage_2_conditionings = combined_image_conditionings(
281
  images=images,
282
- #video_conditioning=video_conditioning,
283
  height=stage_2_output_shape.height,
284
  width=stage_2_output_shape.width,
285
  video_encoder=video_encoder,
286
  dtype=dtype,
287
  device=self.device,
288
- # num_frames=num_frames,
289
- # conditioning_attention_strength=conditioning_attention_strength,
290
- # conditioning_attention_mask=conditioning_attention_mask,
291
  )
292
 
293
- video_state, audio_state = denoise_audio_video(
294
- output_shape=stage_2_output_shape,
295
- conditionings=stage_2_conditionings,
296
- noiser=noiser,
297
- sigmas=stage_2_sigmas,
298
- stepper=stepper,
299
- denoising_loop_fn=second_stage_denoising_loop,
300
- components=self.pipeline_components,
301
- dtype=dtype,
302
- device=self.device,
303
- noise_scale=stage_2_sigmas[0],
304
- initial_video_latent=upscaled_video_latent,
305
- initial_audio_latent=encoded_audio_latent,
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  torch.cuda.synchronize()
309
- del transformer
310
- del video_encoder
311
  cleanup_memory()
312
 
 
313
  decoded_video = vae_decode_video(
314
- video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config, generator
 
 
 
315
  )
316
- original_audio = Audio(waveform=decoded_audio.waveform.squeeze(0), sampling_rate=decoded_audio.sampling_rate)
317
- return decoded_video, original_audio
318
-
319
-
320
- def build_pipeline(lora_name: str) -> AudioConditionedICLoraPipeline:
321
- """Build the fused IC-LoRA pipeline with optional audio conditioning."""
322
- pipe = AudioConditionedICLoraPipeline(
323
- distilled_checkpoint_path=checkpoint_path,
324
- spatial_upsampler_path=spatial_upsampler_path,
325
- gemma_root=gemma_root,
326
- loras=[],
327
- quantization=QuantizationPolicy.fp8_cast(),
328
- )
329
- return pipe
330
-
331
-
332
- def preload_pipeline(pipe: ICLoraPipeline) -> None:
333
- """Preload all models from both ledgers for ZeroGPU tensor packing."""
334
- print("Preloading stage 1 models (with IC-LoRA)...")
335
- import gc
336
-
337
- def cleanup_vram():
338
- gc.collect()
339
- if torch.cuda.is_available():
340
- torch.cuda.empty_cache()
341
-
342
- # Preload all models for ZeroGPU tensor packing.
343
- print("Preloading stage 1 models (with IC-LoRA)...")
344
- s1 = pipe.stage_1_model_ledger
345
-
346
- _s1_text_encoder = s1.text_encoder()
347
- _s1_embeddings_processor = s1.gemma_embeddings_processor()
348
- _s1_video_encoder = s1.video_encoder()
349
- cleanup_vram()
350
-
351
- _s1_transformer = s1.transformer()
352
- cleanup_vram()
353
-
354
- s1.transformer = lambda: _s1_transformer
355
- s1.video_encoder = lambda: _s1_video_encoder
356
- s1.text_encoder = lambda: _s1_text_encoder
357
- s1.gemma_embeddings_processor = lambda: _s1_embeddings_processor
358
-
359
- # Free stage 1 builders β€” we've cached the built models via lambdas
360
- if hasattr(s1, 'transformer_builder'): del s1.transformer_builder
361
- if hasattr(s1, 'vae_encoder_builder'): del s1.vae_encoder_builder
362
- if hasattr(s1, 'text_encoder_builder'): del s1.text_encoder_builder
363
- if hasattr(s1, 'embeddings_processor_builder'): del s1.embeddings_processor_builder
364
- cleanup_vram()
365
-
366
- print("Preloading stage 2 models (without IC-LoRA)...")
367
- s2 = pipe.stage_2_model_ledger
368
-
369
- _s2_video_encoder = s2.video_encoder()
370
- _s2_video_decoder = s2.video_decoder()
371
- _s2_audio_decoder = s2.audio_decoder()
372
- _s2_vocoder = s2.vocoder()
373
- _s2_spatial_upsampler = s2.spatial_upsampler()
374
- cleanup_vram()
375
-
376
- _s2_transformer = s2.transformer()
377
- cleanup_vram()
378
-
379
- s2.transformer = lambda: _s2_transformer
380
- s2.video_encoder = lambda: _s2_video_encoder
381
- s2.video_decoder = lambda: _s2_video_decoder
382
- s2.audio_decoder = lambda: _s2_audio_decoder
383
- s2.vocoder = lambda: _s2_vocoder
384
- s2.spatial_upsampler = lambda: _s2_spatial_upsampler
385
-
386
- # Free stage 2 builders
387
- if hasattr(s2, 'transformer_builder'): del s2.transformer_builder
388
- if hasattr(s2, 'vae_encoder_builder'): del s2.vae_encoder_builder
389
- if hasattr(s2, 'vae_decoder_builder'): del s2.vae_decoder_builder
390
- if hasattr(s2, 'audio_decoder_builder'): del s2.audio_decoder_builder
391
- if hasattr(s2, 'vocoder_builder'): del s2.vocoder_builder
392
- if hasattr(s2, 'upsampler_builder'): del s2.upsampler_builder
393
- cleanup_vram()
394
-
395
- print("All models preloaded!")
396
-
397
-
398
- print(f"Building initial pipeline with IC-LoRA: {default_lora_name}")
399
- current_pipeline = build_pipeline(default_lora_name)
400
- current_lora_name = default_lora_name
401
- preload_pipeline(current_pipeline)
402
 
403
  print("=" * 80)
404
- print("Pipeline ready!")
405
  print("=" * 80)
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- def log_memory(tag: str):
409
- if torch.cuda.is_available():
410
- allocated = torch.cuda.memory_allocated() / 1024**3
411
- peak = torch.cuda.max_memory_allocated() / 1024**3
412
- free, total = torch.cuda.mem_get_info()
413
- print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
 
416
- def detect_aspect_ratio(media) -> str:
 
 
 
417
  """Detect the closest aspect ratio from an image or video."""
418
- if media is None:
419
  return "16:9"
420
- if hasattr(media, "size"):
421
- w, h = media.size
422
- elif hasattr(media, "shape"):
423
- h, w = media.shape[:2]
 
 
 
 
 
 
 
424
  else:
425
- return "16:9"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  ratio = w / h
427
  candidates = {"16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0}
428
  return min(candidates, key=lambda k: abs(ratio - candidates[k]))
429
 
430
 
431
- def on_media_upload(input_image, high_res):
432
  """Auto-set resolution when image is uploaded."""
433
- aspect = detect_aspect_ratio(input_image)
 
 
 
 
 
 
 
 
 
 
434
  tier = "high" if high_res else "low"
435
  w, h = RESOLUTIONS[tier][aspect]
436
  return gr.update(value=w), gr.update(value=h)
437
 
438
 
439
- def on_highres_toggle(input_image, high_res):
440
  """Update resolution when high-res toggle changes."""
441
- aspect = detect_aspect_ratio(input_image)
 
442
  tier = "high" if high_res else "low"
443
  w, h = RESOLUTIONS[tier][aspect]
444
  return gr.update(value=w), gr.update(value=h)
445
 
446
 
447
- def extract_audio_from_video(video_path: str, output_dir: Path, seed: int) -> str:
448
- """Extract the audio track from a video into a wav file for audio conditioning."""
449
- audio_output = output_dir / f"extracted_audio_{seed}.wav"
450
- cmd = [
451
- "ffmpeg", "-y",
452
- "-i", str(video_path),
453
- "-vn",
454
- "-acodec", "pcm_s16le",
455
- "-ar", "48000",
456
- "-ac", "2",
457
- str(audio_output),
458
- ]
459
- result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
460
- if result.returncode != 0 or not audio_output.exists() or audio_output.stat().st_size == 0:
461
- raise gr.Error("Could not extract audio from the conditioning video. Make sure the uploaded video contains an audio track.")
462
- return str(audio_output)
463
-
464
-
465
- @spaces.GPU(duration=120, size="xlarge")
466
  @torch.inference_mode()
467
  def generate_video(
468
  input_image,
469
- conditioning_video,
470
  input_audio,
471
- use_video_audio,
472
  prompt: str,
473
  duration: float,
474
- ic_lora_choice: str,
475
  conditioning_strength: float,
476
  enhance_prompt: bool,
477
- skip_stage_2: bool,
478
  seed: int,
479
  randomize_seed: bool,
480
  height: int,
481
  width: int,
482
  progress=gr.Progress(track_tqdm=True),
483
  ):
484
- global current_pipeline, current_lora_name
485
-
486
  try:
487
  torch.cuda.reset_peak_memory_stats()
488
- log_memory("start")
489
-
490
- # Rebuild pipeline if IC-LoRA changed
491
- if ic_lora_choice != current_lora_name:
492
- print(f"Switching IC-LoRA: {current_lora_name} β†’ {ic_lora_choice}")
493
- current_pipeline = build_pipeline(ic_lora_choice)
494
- current_lora_name = ic_lora_choice
495
- preload_pipeline(current_pipeline)
496
-
497
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
498
 
499
  frame_rate = DEFAULT_FRAME_RATE
500
  num_frames = int(duration * frame_rate) + 1
501
  num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
502
 
503
- print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
504
- print(f"IC-LoRA: {ic_lora_choice}, conditioning_strength: {conditioning_strength}")
 
 
 
505
  if input_audio is not None:
506
- print(f"Audio conditioning: {input_audio}")
 
 
 
507
 
508
- output_dir = Path("outputs")
509
- output_dir.mkdir(exist_ok=True)
510
 
511
- # Image conditioning (optional, for I2V)
512
  images = []
513
  if input_image is not None:
514
- temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
515
- if hasattr(input_image, "save"):
516
- input_image.save(temp_image_path)
517
- else:
518
- temp_image_path = Path(input_image)
519
- images.append(ImageConditioningInput(path=str(temp_image_path), frame_idx=0, strength=1.0))
520
-
521
- # Video conditioning for IC-LoRA (reference video)
522
- video_conditioning = []
523
- video_path = None
524
- if conditioning_video is not None:
525
- video_path = str(conditioning_video)
526
- video_conditioning.append((video_path, conditioning_strength))
527
- print(f"Video conditioning: {video_path} (strength={conditioning_strength})")
528
-
529
- audio_path = input_audio
530
- if use_video_audio:
531
- if video_path is None:
532
- raise gr.Error("Enable 'Use audio from conditioning video' only when a conditioning video is uploaded.")
533
- audio_path = extract_audio_from_video(video_path, output_dir, current_seed)
534
- print(f"Extracted audio from conditioning video: {audio_path}")
535
- elif audio_path is not None:
536
- print(f"Using uploaded audio file: {audio_path}")
537
 
538
  tiling_config = TilingConfig.default()
539
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
540
 
541
- log_memory("before pipeline call")
542
-
543
- video, audio = current_pipeline(
544
  prompt=prompt,
545
  seed=current_seed,
546
  height=int(height),
@@ -548,16 +739,13 @@ def generate_video(
548
  num_frames=num_frames,
549
  frame_rate=frame_rate,
550
  images=images,
 
551
  video_conditioning=video_conditioning,
552
- audio_path=audio_path,
553
  tiling_config=tiling_config,
554
  enhance_prompt=enhance_prompt,
555
- conditioning_attention_strength=conditioning_strength,
556
- skip_stage_2=skip_stage_2,
557
  )
558
 
559
- log_memory("after pipeline call")
560
-
561
  output_path = tempfile.mktemp(suffix=".mp4")
562
  encode_video(
563
  video=video,
@@ -567,74 +755,73 @@ def generate_video(
567
  video_chunks_number=video_chunks_number,
568
  )
569
 
570
- log_memory("after encode_video")
571
  return str(output_path), current_seed
572
 
573
- except gr.Error:
574
- raise
575
  except Exception as e:
576
  import traceback
577
- log_memory("on error")
578
  print(f"Error: {str(e)}\n{traceback.format_exc()}")
579
  return None, current_seed
580
 
581
 
582
- with gr.Blocks(title="LTX-2.3 IC-LoRA") as demo:
583
- gr.Markdown("# LTX-2.3 IC-LoRA: Video-to-Video + Audio Conditioning")
 
 
 
584
  gr.Markdown(
585
- "Video-to-video transformations using IC-LoRA conditioning with optional audio-driven generation. "
586
- "Upload a **conditioning video** as the IC-LoRA reference signal, optionally add an **input audio** file "
587
- "to preserve soundtrack or lip-sync timing, optionally provide an input image for I2V, and describe the desired output. "
588
  "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
589
  "[[code]](https://github.com/Lightricks/LTX-2)"
590
  )
591
 
592
  with gr.Row():
593
  with gr.Column():
 
594
  with gr.Row():
595
- conditioning_video = gr.Video(
596
- label="Conditioning Video (IC-LoRA Reference)",
597
-
 
 
 
 
598
  )
599
- input_image = gr.Image(label="Input Image (Optional)", type="pil")
600
- input_audio = gr.Audio(label="Input Audio (Optional)", type="filepath")
601
- use_video_audio = gr.Checkbox(
602
- label="Use audio from reference video",
603
- value=False,
604
- info="extracts the audio track from the uploaded video",
605
  )
 
606
  prompt = gr.Textbox(
607
  label="Prompt",
608
- info="Describe the desired output β€” the IC-LoRA controls structure from the reference",
609
- value="A cinematic scene with dramatic lighting and rich detail, smooth motion",
610
  lines=3,
611
- placeholder="Describe the video you want to generate...",
612
  )
613
 
614
  with gr.Row():
615
- duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
616
- ic_lora_choice = gr.Dropdown(
617
- label="IC-LoRA",
618
- choices=list(IC_LORA_OPTIONS.keys()),
619
- value=default_lora_name,
620
- visible=False
 
 
621
  )
622
 
623
-
 
 
 
624
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
625
 
626
  with gr.Accordion("Advanced Settings", open=False):
627
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
 
 
628
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
629
- with gr.Row():
630
- conditioning_strength = gr.Slider(
631
- label="Conditioning Strength", minimum=0.1, maximum=1.0, value=1.0, step=0.05,
632
- )
633
- with gr.Column():
634
- enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
635
- high_res = gr.Checkbox(label="High Resolution", value=True)
636
- skip_stage_2 = gr.Checkbox(label="Skip Stage 2 (faster, half res)", value=False)
637
-
638
  with gr.Row():
639
  width = gr.Number(label="Width", value=1536, precision=0)
640
  height = gr.Number(label="Height", value=1024, precision=0)
@@ -642,33 +829,32 @@ with gr.Blocks(title="LTX-2.3 IC-LoRA") as demo:
642
  with gr.Column():
643
  output_video = gr.Video(label="Generated Video", autoplay=True)
644
 
645
- # Auto-detect aspect ratio from uploaded image
646
  input_image.change(
647
- fn=on_media_upload,
648
- inputs=[input_image, high_res],
 
 
 
 
 
649
  outputs=[width, height],
650
  )
651
  high_res.change(
652
  fn=on_highres_toggle,
653
- inputs=[input_image, high_res],
654
  outputs=[width, height],
655
  )
656
-
657
  generate_btn.click(
658
  fn=generate_video,
659
  inputs=[
660
- input_image, conditioning_video,
661
- input_audio, use_video_audio, prompt, duration, ic_lora_choice, conditioning_strength,
662
- enhance_prompt, skip_stage_2,
663
  seed, randomize_seed, height, width,
664
  ],
665
  outputs=[output_video, seed],
666
  )
667
 
668
 
669
- css = """
670
- .fillable{max-width: 1200px !important}
671
- """
672
-
673
  if __name__ == "__main__":
674
- demo.launch(theme=gr.themes.Citrus(), css=css)
 
10
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # Clone LTX-2 repo and install packages
13
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
 
14
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
 
16
+ if not os.path.exists(LTX_REPO_DIR):
17
+ print(f"Cloning {LTX_REPO_URL}...")
18
+ subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True)
 
 
 
 
 
19
 
20
  print("Installing ltx-core and ltx-pipelines from cloned repo...")
21
  subprocess.run(
 
25
  check=True,
26
  )
27
 
 
28
  sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
29
  sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
30
 
 
41
  import gradio as gr
42
  import numpy as np
43
  from huggingface_hub import hf_hub_download, snapshot_download
44
+ from safetensors import safe_open
45
 
 
 
46
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
47
  from ltx_core.components.noisers import GaussianNoiser
48
+ from ltx_core.conditioning import (
49
+ ConditioningItem,
50
+ ConditioningItemAttentionStrengthWrapper,
51
+ VideoConditionByReferenceLatent,
52
+ )
53
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
54
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
55
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
56
  from ltx_core.model.upsampler import upsample_video
57
+ from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
58
  from ltx_core.model.video_vae import decode_video as vae_decode_video
59
+ from ltx_core.quantization import QuantizationPolicy
60
+ from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape
61
+ from ltx_pipelines.utils import ModelLedger, euler_denoising_loop
 
62
  from ltx_pipelines.utils.args import ImageConditioningInput
63
  from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
64
+ from ltx_pipelines.utils.helpers import (
65
+ assert_resolution,
66
+ cleanup_memory,
67
+ combined_image_conditionings,
68
+ denoise_audio_video,
69
+ denoise_video_only,
70
+ encode_prompts,
71
+ get_device,
72
+ simple_denoising_func,
73
+ )
74
+ from ltx_pipelines.utils.media_io import (
75
+ decode_audio_from_file,
76
+ encode_video,
77
+ load_video_conditioning,
78
+ )
79
+ from ltx_pipelines.utils.types import PipelineComponents
80
 
81
  # Force-patch xformers attention into the LTX attention module.
82
  from ltx_core.model.transformer import attention as _attn_mod
 
90
 
91
  logging.getLogger().setLevel(logging.INFO)
92
 
 
 
 
 
 
 
 
 
93
 
94
+ # ─────────────────────────────────────────────────────────────────────────────
95
+ # Helper: read reference downscale factor from IC-LoRA metadata
96
+ # ─────────────────────────────────────────────────────────────────────────────
97
+ def _read_lora_reference_downscale_factor(lora_path: str) -> int:
98
+ try:
99
+ with safe_open(lora_path, framework="pt") as f:
100
+ metadata = f.metadata() or {}
101
+ return int(metadata.get("reference_downscale_factor", 1))
102
+ except Exception as e:
103
+ logging.warning(f"Failed to read metadata from LoRA file '{lora_path}': {e}")
104
+ return 1
105
+
106
+
107
+ # ─────────────────────────────────────────────────────────────────────────────
108
+ # Unified Pipeline: Distilled + Audio + IC-LoRA Video-to-Video
109
+ # ─────────────────────────────────────────────────────────────────────────────
110
+ class LTX23UnifiedPipeline:
111
+ """
112
+ Unified LTX-2.3 pipeline supporting all generation modes:
113
+ β€’ Text-to-Video
114
+ β€’ Image-to-Video (first-frame conditioning)
115
+ β€’ Audio-to-Video (lip-sync / BGM conditioning with external audio)
116
+ β€’ Video-to-Video (IC-LoRA reference video conditioning)
117
+ β€’ Any combination of the above
118
+
119
+ Architecture:
120
+ - stage_1_model_ledger: transformer WITH IC-LoRA fused (used for Stage 1)
121
+ - stage_2_model_ledger: transformer WITHOUT IC-LoRA (used for Stage 2 upsampling)
122
+ - When no IC-LoRA is provided, both stages use the same base model.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ distilled_checkpoint_path: str,
128
+ spatial_upsampler_path: str,
129
+ gemma_root: str,
130
+ ic_loras: list[LoraPathStrengthAndSDOps] | None = None,
131
+ device: torch.device | None = None,
132
+ quantization: QuantizationPolicy | None = None,
133
+ ):
134
+ self.device = device or get_device()
135
+ self.dtype = torch.bfloat16
136
 
137
+ ic_loras = ic_loras or []
138
+ self.has_ic_lora = len(ic_loras) > 0
 
 
 
 
 
 
 
 
 
139
 
140
+ # Stage 1: transformer with IC-LoRA (if provided)
141
+ self.stage_1_model_ledger = ModelLedger(
142
+ dtype=self.dtype,
143
+ device=self.device,
144
+ checkpoint_path=distilled_checkpoint_path,
145
+ spatial_upsampler_path=spatial_upsampler_path,
146
+ gemma_root_path=gemma_root,
147
+ loras=ic_loras,
148
+ quantization=quantization,
149
+ )
150
 
151
+ if self.has_ic_lora:
152
+ # Stage 2 needs a separate transformer WITHOUT IC-LoRA
153
+ self.stage_2_model_ledger = ModelLedger(
154
+ dtype=self.dtype,
155
+ device=self.device,
156
+ checkpoint_path=distilled_checkpoint_path,
157
+ spatial_upsampler_path=spatial_upsampler_path,
158
+ gemma_root_path=gemma_root,
159
+ loras=[],
160
+ quantization=quantization,
161
+ )
162
+ else:
163
+ # No IC-LoRA: share a single ledger for both stages (saves ~half VRAM)
164
+ self.stage_2_model_ledger = self.stage_1_model_ledger
165
 
166
+ self.pipeline_components = PipelineComponents(
167
+ dtype=self.dtype,
168
+ device=self.device,
169
+ )
170
 
171
+ # Read reference downscale factor from IC-LoRA metadata
172
+ self.reference_downscale_factor = 1
173
+ for lora in ic_loras:
174
+ scale = _read_lora_reference_downscale_factor(lora.path)
175
+ if scale != 1:
176
+ if self.reference_downscale_factor not in (1, scale):
177
+ raise ValueError(
178
+ f"Conflicting reference_downscale_factor: "
179
+ f"already {self.reference_downscale_factor}, got {scale}"
180
+ )
181
+ self.reference_downscale_factor = scale
182
 
183
+ # ── Video reference conditioning (from ICLoraPipeline) ───────────────
184
+ def _create_ic_conditionings(
185
+ self,
186
+ video_conditioning: list[tuple[str, float]],
187
+ height: int,
188
+ width: int,
189
+ num_frames: int,
190
+ video_encoder: VideoEncoder,
191
+ conditioning_strength: float = 1.0,
192
+ ) -> list[ConditioningItem]:
193
+ """Create IC-LoRA video reference conditioning items."""
194
+ conditionings: list[ConditioningItem] = []
195
+ scale = self.reference_downscale_factor
196
+ ref_height = height // scale
197
+ ref_width = width // scale
198
+
199
+ for video_path, strength in video_conditioning:
200
+ video = load_video_conditioning(
201
+ video_path=video_path,
202
+ height=ref_height,
203
+ width=ref_width,
204
+ frame_cap=num_frames,
205
+ dtype=self.dtype,
206
+ device=self.device,
207
+ )
208
+ encoded_video = video_encoder(video)
209
 
210
+ cond = VideoConditionByReferenceLatent(
211
+ latent=encoded_video,
212
+ downscale_factor=scale,
213
+ strength=strength,
214
+ )
215
+ if conditioning_strength < 1.0:
216
+ cond = ConditioningItemAttentionStrengthWrapper(
217
+ cond, attention_mask=conditioning_strength
218
+ )
219
+ conditionings.append(cond)
220
 
221
+ if conditionings:
222
+ logging.info(f"[IC-LoRA] Added {len(conditionings)} video conditioning(s)")
223
+ return conditionings
224
 
225
+ # ── Main generation entry point ──────────────────────────────────────
226
  def __call__(
227
  self,
228
  prompt: str,
 
232
  num_frames: int,
233
  frame_rate: float,
234
  images: list[ImageConditioningInput],
 
235
  audio_path: str | None = None,
236
+ video_conditioning: list[tuple[str, float]] | None = None,
237
  tiling_config: TilingConfig | None = None,
238
+ enhance_prompt: bool = False,
239
+ conditioning_strength: float = 1.0,
 
240
  ):
241
+ """
242
+ Generate video with any combination of conditioning.
243
+
244
+ Args:
245
+ audio_path: Path to external audio file for lipsync/BGM conditioning.
246
+ video_conditioning: List of (path, strength) tuples for IC-LoRA V2V.
247
+ conditioning_strength: Scale for IC-LoRA attention influence [0, 1].
248
+ Returns:
249
+ Tuple of (decoded_video_iterator, Audio).
250
+ """
251
+ assert_resolution(height=height, width=width, is_two_stage=True)
252
+
253
+ has_audio = audio_path is not None
254
+ has_video_cond = bool(video_conditioning)
 
 
255
 
256
  generator = torch.Generator(device=self.device).manual_seed(seed)
257
  noiser = GaussianNoiser(generator=generator)
258
  stepper = EulerDiffusionStep()
259
  dtype = torch.bfloat16
260
 
261
+ # ── Encode text prompt ───────────────────────────────────────────
262
+ # Use stage_1 ledger for prompt encoding (has text encoder)
263
  (ctx_p,) = encode_prompts(
264
  [prompt],
265
  self.stage_1_model_ledger,
266
  enhance_first_prompt=enhance_prompt,
267
  enhance_prompt_image=images[0].path if len(images) > 0 else None,
 
268
  )
269
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
270
 
271
+ # ── Encode external audio (if provided) ─────────────────────────
272
+ encoded_audio_latent = None
273
+ decoded_audio_for_output = None
274
+ if has_audio:
275
+ video_duration = num_frames / frame_rate
276
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
277
+ if decoded_audio is None:
278
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
279
+
280
+ encoded_audio_latent = vae_encode_audio(
281
+ decoded_audio, self.stage_1_model_ledger.audio_encoder()
282
+ )
283
+ audio_shape = AudioLatentShape.from_duration(
284
+ batch=1, duration=video_duration, channels=8, mel_bins=16
285
+ )
286
+ expected_frames = audio_shape.frames
287
+ actual_frames = encoded_audio_latent.shape[2]
288
+
289
+ if actual_frames > expected_frames:
290
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
291
+ elif actual_frames < expected_frames:
292
+ pad = torch.zeros(
293
+ encoded_audio_latent.shape[0], encoded_audio_latent.shape[1],
294
+ expected_frames - actual_frames, encoded_audio_latent.shape[3],
295
+ device=encoded_audio_latent.device, dtype=encoded_audio_latent.dtype,
296
+ )
297
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
298
+
299
+ decoded_audio_for_output = Audio(
300
+ waveform=decoded_audio.waveform.squeeze(0),
301
+ sampling_rate=decoded_audio.sampling_rate,
302
  )
 
303
 
304
+ # ── Build conditionings for Stage 1 ──────────────────────────────
305
+ # Use stage_1 video encoder (has IC-LoRA context)
306
  video_encoder = self.stage_1_model_ledger.video_encoder()
307
+
308
+ stage_1_output_shape = VideoPixelShape(
309
+ batch=1, frames=num_frames,
310
+ width=width // 2, height=height // 2, fps=frame_rate,
311
+ )
312
+
313
+ # Image conditionings
314
+ stage_1_conditionings = combined_image_conditionings(
315
  images=images,
 
316
  height=stage_1_output_shape.height,
317
  width=stage_1_output_shape.width,
318
  video_encoder=video_encoder,
319
+ dtype=dtype,
320
+ device=self.device,
 
321
  )
322
 
323
+ # IC-LoRA video reference conditionings
324
+ if has_video_cond:
325
+ ic_conds = self._create_ic_conditionings(
326
+ video_conditioning=video_conditioning,
327
+ height=stage_1_output_shape.height,
328
+ width=stage_1_output_shape.width,
329
+ num_frames=num_frames,
330
+ video_encoder=video_encoder,
331
+ conditioning_strength=conditioning_strength,
332
+ )
333
+ stage_1_conditionings.extend(ic_conds)
334
+
335
+ # ── Stage 1: Low-res generation ──────────────────────────────────
336
  transformer = self.stage_1_model_ledger.transformer()
337
  stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
338
 
339
+ def denoising_loop(sigmas, video_state, audio_state, stepper):
340
  return euler_denoising_loop(
341
  sigmas=sigmas,
342
  video_state=video_state,
 
349
  ),
350
  )
351
 
352
+ if has_audio:
353
+ # Audio mode: denoise video only, use external audio latent
354
+ video_state = denoise_video_only(
355
+ output_shape=stage_1_output_shape,
356
+ conditionings=stage_1_conditionings,
357
+ noiser=noiser,
358
+ sigmas=stage_1_sigmas,
359
+ stepper=stepper,
360
+ denoising_loop_fn=denoising_loop,
361
+ components=self.pipeline_components,
362
+ dtype=dtype,
363
+ device=self.device,
364
+ initial_audio_latent=encoded_audio_latent,
365
+ )
366
+ audio_state = None # we'll use the original audio for output
367
+ else:
368
+ # Standard / IC-only mode: denoise both audio and video
369
+ video_state, audio_state = denoise_audio_video(
370
+ output_shape=stage_1_output_shape,
371
+ conditionings=stage_1_conditionings,
372
+ noiser=noiser,
373
+ sigmas=stage_1_sigmas,
374
+ stepper=stepper,
375
+ denoising_loop_fn=denoising_loop,
376
+ components=self.pipeline_components,
377
+ dtype=dtype,
378
+ device=self.device,
379
+ )
380
 
381
  torch.cuda.synchronize()
 
382
  cleanup_memory()
383
 
384
+ # ── Stage 2: Upsample + Refine ──────────────────────────────────
 
 
 
 
 
 
 
 
385
  upscaled_video_latent = upsample_video(
386
  latent=video_state.latent[:1],
387
  video_encoder=video_encoder,
 
391
  torch.cuda.synchronize()
392
  cleanup_memory()
393
 
394
+ # Stage 2 uses the transformer WITHOUT IC-LoRA
395
+ transformer_s2 = self.stage_2_model_ledger.transformer()
396
  stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
397
 
398
+ def denoising_loop_s2(sigmas, video_state, audio_state, stepper):
399
  return euler_denoising_loop(
400
  sigmas=sigmas,
401
  video_state=video_state,
 
404
  denoise_fn=simple_denoising_func(
405
  video_context=video_context,
406
  audio_context=audio_context,
407
+ transformer=transformer_s2,
408
  ),
409
  )
410
 
411
+ stage_2_output_shape = VideoPixelShape(
412
+ batch=1, frames=num_frames,
413
+ width=width, height=height, fps=frame_rate,
414
+ )
415
  stage_2_conditionings = combined_image_conditionings(
416
  images=images,
 
417
  height=stage_2_output_shape.height,
418
  width=stage_2_output_shape.width,
419
  video_encoder=video_encoder,
420
  dtype=dtype,
421
  device=self.device,
 
 
 
422
  )
423
 
424
+ if has_audio:
425
+ video_state = denoise_video_only(
426
+ output_shape=stage_2_output_shape,
427
+ conditionings=stage_2_conditionings,
428
+ noiser=noiser,
429
+ sigmas=stage_2_sigmas,
430
+ stepper=stepper,
431
+ denoising_loop_fn=denoising_loop_s2,
432
+ components=self.pipeline_components,
433
+ dtype=dtype,
434
+ device=self.device,
435
+ noise_scale=stage_2_sigmas[0],
436
+ initial_video_latent=upscaled_video_latent,
437
+ initial_audio_latent=encoded_audio_latent,
438
+ )
439
+ audio_state = None
440
+ else:
441
+ video_state, audio_state = denoise_audio_video(
442
+ output_shape=stage_2_output_shape,
443
+ conditionings=stage_2_conditionings,
444
+ noiser=noiser,
445
+ sigmas=stage_2_sigmas,
446
+ stepper=stepper,
447
+ denoising_loop_fn=denoising_loop_s2,
448
+ components=self.pipeline_components,
449
+ dtype=dtype,
450
+ device=self.device,
451
+ noise_scale=stage_2_sigmas[0],
452
+ initial_video_latent=upscaled_video_latent,
453
+ initial_audio_latent=audio_state.latent,
454
+ )
455
 
456
  torch.cuda.synchronize()
457
+ del transformer, transformer_s2, video_encoder
 
458
  cleanup_memory()
459
 
460
+ # ── Decode ───────────────────────────────────────────────────────
461
  decoded_video = vae_decode_video(
462
+ video_state.latent,
463
+ self.stage_2_model_ledger.video_decoder(),
464
+ tiling_config,
465
+ generator,
466
  )
467
+
468
+ if has_audio:
469
+ output_audio = decoded_audio_for_output
470
+ else:
471
+ output_audio = vae_decode_audio(
472
+ audio_state.latent,
473
+ self.stage_2_model_ledger.audio_decoder(),
474
+ self.stage_2_model_ledger.vocoder(),
475
+ )
476
+
477
+ return decoded_video, output_audio
478
+
479
+
480
+ # ─────────────────────────────────────────────────────────────────────────────
481
+ # Constants
482
+ # ─────────────────────────���───────────────────────────────────────────────────
483
+ MAX_SEED = np.iinfo(np.int32).max
484
+ DEFAULT_PROMPT = (
485
+ "An astronaut hatches from a fragile egg on the surface of the Moon, "
486
+ "the shell cracking and peeling apart in gentle low-gravity motion."
487
+ )
488
+ DEFAULT_FRAME_RATE = 24.0
489
+
490
+ RESOLUTIONS = {
491
+ "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
492
+ "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
493
+ }
494
+
495
+ # Available IC-LoRA models
496
+ IC_LORA_OPTIONS = {
497
+ "Union Control (Depth + Edge)": {
498
+ "repo": "Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control",
499
+ "filename": "ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors",
500
+ },
501
+ "Motion Track Control": {
502
+ "repo": "Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control",
503
+ "filename": "ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors",
504
+ },
505
+ }
506
+ DEFAULT_IC_LORA = "Union Control (Depth + Edge)"
507
+
508
+
509
+ # ─────────────────────────────────────────────────────────────────────────────
510
+ # Download Models
511
+ # ─────────────────────────────────────────────────────────────────────────────
512
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
513
+ GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  print("=" * 80)
516
+ print("Downloading LTX-2.3 distilled model + Gemma + IC-LoRA...")
517
  print("=" * 80)
518
 
519
+ checkpoint_path = hf_hub_download(
520
+ repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors"
521
+ )
522
+ spatial_upsampler_path = hf_hub_download(
523
+ repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
524
+ )
525
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
526
+
527
+ # Download default IC-LoRA
528
+ default_lora_info = IC_LORA_OPTIONS[DEFAULT_IC_LORA]
529
+ default_ic_lora_path = hf_hub_download(
530
+ repo_id=default_lora_info["repo"], filename=default_lora_info["filename"]
531
+ )
532
+
533
+ print(f"Checkpoint: {checkpoint_path}")
534
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
535
+ print(f"Gemma root: {gemma_root}")
536
+ print(f"IC-LoRA: {default_ic_lora_path}")
537
+
538
+
539
+ # ─────────────────────────────────────────────────────────────────────────────
540
+ # Initialize Pipeline
541
+ # ─────────────────────────────────────────────────────────────────────────────
542
+ ic_loras = [
543
+ LoraPathStrengthAndSDOps(default_ic_lora_path, 1.0, LTXV_LORA_COMFY_RENAMING_MAP)
544
+ ]
545
+
546
+ pipeline = LTX23UnifiedPipeline(
547
+ distilled_checkpoint_path=checkpoint_path,
548
+ spatial_upsampler_path=spatial_upsampler_path,
549
+ gemma_root=gemma_root,
550
+ ic_loras=ic_loras,
551
+ quantization=QuantizationPolicy.fp8_cast(),
552
+ )
553
 
554
+ # Preload all models for ZeroGPU tensor packing.
555
+ print("Preloading all models (including Gemma, Audio encoders)...")
556
+
557
+ # Shared ledger: preload once. Separate ledgers (IC-LoRA): preload both.
558
+ _ledger_1 = pipeline.stage_1_model_ledger
559
+ _ledger_2 = pipeline.stage_2_model_ledger
560
+ _shared = _ledger_1 is _ledger_2
561
+
562
+ # Stage 1 models (with IC-LoRA if loaded)
563
+ _s1_transformer = _ledger_1.transformer()
564
+ _s1_video_encoder = _ledger_1.video_encoder()
565
+ _s1_text_encoder = _ledger_1.text_encoder()
566
+ _s1_embeddings = _ledger_1.gemma_embeddings_processor()
567
+ _s1_audio_encoder = _ledger_1.audio_encoder()
568
+
569
+ _ledger_1.transformer = lambda: _s1_transformer
570
+ _ledger_1.video_encoder = lambda: _s1_video_encoder
571
+ _ledger_1.text_encoder = lambda: _s1_text_encoder
572
+ _ledger_1.gemma_embeddings_processor = lambda: _s1_embeddings
573
+ _ledger_1.audio_encoder = lambda: _s1_audio_encoder
574
+
575
+ if _shared:
576
+ # Single ledger β€” also preload decoder/upsampler/vocoder on the same object
577
+ _video_decoder = _ledger_1.video_decoder()
578
+ _audio_decoder = _ledger_1.audio_decoder()
579
+ _vocoder = _ledger_1.vocoder()
580
+ _spatial_upsampler = _ledger_1.spatial_upsampler()
581
+
582
+ _ledger_1.video_decoder = lambda: _video_decoder
583
+ _ledger_1.audio_decoder = lambda: _audio_decoder
584
+ _ledger_1.vocoder = lambda: _vocoder
585
+ _ledger_1.spatial_upsampler = lambda: _spatial_upsampler
586
+ print(" (single shared ledger β€” no IC-LoRA)")
587
+ else:
588
+ # Stage 2 models (separate transformer without IC-LoRA)
589
+ _s2_transformer = _ledger_2.transformer()
590
+ _s2_video_encoder = _ledger_2.video_encoder()
591
+ _s2_video_decoder = _ledger_2.video_decoder()
592
+ _s2_audio_decoder = _ledger_2.audio_decoder()
593
+ _s2_vocoder = _ledger_2.vocoder()
594
+ _s2_spatial_upsampler = _ledger_2.spatial_upsampler()
595
+ _s2_text_encoder = _ledger_2.text_encoder()
596
+ _s2_embeddings = _ledger_2.gemma_embeddings_processor()
597
+ _s2_audio_encoder = _ledger_2.audio_encoder()
598
+
599
+ _ledger_2.transformer = lambda: _s2_transformer
600
+ _ledger_2.video_encoder = lambda: _s2_video_encoder
601
+ _ledger_2.video_decoder = lambda: _s2_video_decoder
602
+ _ledger_2.audio_decoder = lambda: _s2_audio_decoder
603
+ _ledger_2.vocoder = lambda: _s2_vocoder
604
+ _ledger_2.spatial_upsampler = lambda: _s2_spatial_upsampler
605
+ _ledger_2.text_encoder = lambda: _s2_text_encoder
606
+ _ledger_2.gemma_embeddings_processor = lambda: _s2_embeddings
607
+ _ledger_2.audio_encoder = lambda: _s2_audio_encoder
608
+ print(" (two separate ledgers β€” IC-LoRA active)")
609
+
610
+ print("All models preloaded!")
611
+ print("=" * 80)
612
 
613
 
614
+ # ─────────────────────────────────────────────────────────────────────────────
615
+ # UI Helpers
616
+ # ─────────────────────────────────────────────────────────────────────────────
617
+ def detect_aspect_ratio(media_path) -> str:
618
  """Detect the closest aspect ratio from an image or video."""
619
+ if media_path is None:
620
  return "16:9"
621
+
622
+ ext = str(media_path).lower().rsplit(".", 1)[-1] if "." in str(media_path) else ""
623
+
624
+ # Try as image first
625
+ if ext in ("jpg", "jpeg", "png", "bmp", "webp", "gif", "tiff"):
626
+ import PIL.Image
627
+ try:
628
+ with PIL.Image.open(media_path) as img:
629
+ w, h = img.size
630
+ except Exception:
631
+ return "16:9"
632
  else:
633
+ # Try as video
634
+ try:
635
+ import av
636
+ with av.open(str(media_path)) as container:
637
+ stream = container.streams.video[0]
638
+ w, h = stream.codec_context.width, stream.codec_context.height
639
+ except Exception:
640
+ # Fallback: try as image anyway
641
+ import PIL.Image
642
+ try:
643
+ with PIL.Image.open(media_path) as img:
644
+ w, h = img.size
645
+ except Exception:
646
+ return "16:9"
647
+
648
  ratio = w / h
649
  candidates = {"16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0}
650
  return min(candidates, key=lambda k: abs(ratio - candidates[k]))
651
 
652
 
653
+ def on_image_upload(image, video, high_res):
654
  """Auto-set resolution when image is uploaded."""
655
+ media = image if image is not None else video
656
+ aspect = detect_aspect_ratio(media)
657
+ tier = "high" if high_res else "low"
658
+ w, h = RESOLUTIONS[tier][aspect]
659
+ return gr.update(value=w), gr.update(value=h)
660
+
661
+
662
+ def on_video_upload(video, image, high_res):
663
+ """Auto-set resolution when video is uploaded."""
664
+ media = video if video is not None else image
665
+ aspect = detect_aspect_ratio(media)
666
  tier = "high" if high_res else "low"
667
  w, h = RESOLUTIONS[tier][aspect]
668
  return gr.update(value=w), gr.update(value=h)
669
 
670
 
671
+ def on_highres_toggle(image, video, high_res):
672
  """Update resolution when high-res toggle changes."""
673
+ media = image if image is not None else video
674
+ aspect = detect_aspect_ratio(media)
675
  tier = "high" if high_res else "low"
676
  w, h = RESOLUTIONS[tier][aspect]
677
  return gr.update(value=w), gr.update(value=h)
678
 
679
 
680
+ # ───────────────────��─────────────────────────────────────────────────────────
681
+ # Generation
682
+ # ─────────────────────────────────────────────────────────────────────────────
683
+ @spaces.GPU(duration=180)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  @torch.inference_mode()
685
  def generate_video(
686
  input_image,
687
+ input_video,
688
  input_audio,
 
689
  prompt: str,
690
  duration: float,
 
691
  conditioning_strength: float,
692
  enhance_prompt: bool,
 
693
  seed: int,
694
  randomize_seed: bool,
695
  height: int,
696
  width: int,
697
  progress=gr.Progress(track_tqdm=True),
698
  ):
 
 
699
  try:
700
  torch.cuda.reset_peak_memory_stats()
 
 
 
 
 
 
 
 
 
701
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
702
 
703
  frame_rate = DEFAULT_FRAME_RATE
704
  num_frames = int(duration * frame_rate) + 1
705
  num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
706
 
707
+ mode_parts = []
708
+ if input_image is not None:
709
+ mode_parts.append("Image")
710
+ if input_video is not None:
711
+ mode_parts.append("Video(IC-LoRA)")
712
  if input_audio is not None:
713
+ mode_parts.append("Audio")
714
+ if not mode_parts:
715
+ mode_parts.append("Text")
716
+ mode_str = " + ".join(mode_parts)
717
 
718
+ print(f"[{mode_str}] Generating: {height}x{width}, {num_frames} frames "
719
+ f"({duration}s), seed={current_seed}")
720
 
721
+ # Build image conditionings
722
  images = []
723
  if input_image is not None:
724
+ images = [ImageConditioningInput(path=str(input_image), frame_idx=0, strength=1.0)]
725
+
726
+ # Build video conditionings for IC-LoRA / V2V
727
+ video_conditioning = None
728
+ if input_video is not None:
729
+ video_conditioning = [(str(input_video), 1.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
  tiling_config = TilingConfig.default()
732
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
733
 
734
+ video, audio = pipeline(
 
 
735
  prompt=prompt,
736
  seed=current_seed,
737
  height=int(height),
 
739
  num_frames=num_frames,
740
  frame_rate=frame_rate,
741
  images=images,
742
+ audio_path=input_audio,
743
  video_conditioning=video_conditioning,
 
744
  tiling_config=tiling_config,
745
  enhance_prompt=enhance_prompt,
746
+ conditioning_strength=conditioning_strength,
 
747
  )
748
 
 
 
749
  output_path = tempfile.mktemp(suffix=".mp4")
750
  encode_video(
751
  video=video,
 
755
  video_chunks_number=video_chunks_number,
756
  )
757
 
 
758
  return str(output_path), current_seed
759
 
 
 
760
  except Exception as e:
761
  import traceback
 
762
  print(f"Error: {str(e)}\n{traceback.format_exc()}")
763
  return None, current_seed
764
 
765
 
766
+ # ─────────────────────────────────────────────────────────────────────────────
767
+ # Gradio UI
768
+ # ─────────────────────────────────────────────────────────────────────────────
769
+ with gr.Blocks(title="LTX-2.3 Unified: V2V + I2V + A2V") as demo:
770
+ gr.Markdown("# LTX-2.3 Unified: Video/Image/Audio β†’ Video")
771
  gr.Markdown(
772
+ "Unified pipeline for **video-to-video** (IC-LoRA), **image-to-video**, "
773
+ "and **audio-conditioned** generation with LTX-2.3 β€” use any combination of inputs. "
 
774
  "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
775
  "[[code]](https://github.com/Lightricks/LTX-2)"
776
  )
777
 
778
  with gr.Row():
779
  with gr.Column():
780
+ # All three inputs visible at once
781
  with gr.Row():
782
+ input_image = gr.Image(
783
+ label="πŸ–ΌοΈ Input Image (I2V β€” first frame)",
784
+ type="filepath",
785
+ )
786
+ input_video = gr.Video(
787
+ label="🎬 Reference Video (V2V β€” IC-LoRA)",
788
+ sources=["upload"],
789
  )
790
+ input_audio = gr.Audio(
791
+ label="πŸ”Š Input Audio (A2V β€” lipsync / BGM)",
792
+ type="filepath",
 
 
 
793
  )
794
+
795
  prompt = gr.Textbox(
796
  label="Prompt",
797
+ info="Describe the desired output β€” be as detailed as possible",
798
+ value="Make this come alive with cinematic motion, smooth animation",
799
  lines=3,
800
+ placeholder="Describe the motion, style, and content you want...",
801
  )
802
 
803
  with gr.Row():
804
+ duration = gr.Slider(
805
+ label="Duration (seconds)",
806
+ minimum=1.0, maximum=10.0, value=3.0, step=0.1,
807
+ )
808
+ conditioning_strength = gr.Slider(
809
+ label="V2V Conditioning Strength",
810
+ info="How closely to follow the reference video",
811
+ minimum=0.0, maximum=1.0, value=1.0, step=0.05,
812
  )
813
 
814
+ with gr.Row():
815
+ enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=True)
816
+ high_res = gr.Checkbox(label="High Resolution", value=True)
817
+
818
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
819
 
820
  with gr.Accordion("Advanced Settings", open=False):
821
+ seed = gr.Slider(
822
+ label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1,
823
+ )
824
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
825
  with gr.Row():
826
  width = gr.Number(label="Width", value=1536, precision=0)
827
  height = gr.Number(label="Height", value=1024, precision=0)
 
829
  with gr.Column():
830
  output_video = gr.Video(label="Generated Video", autoplay=True)
831
 
832
+ # ── Event handlers ───────────────────────────────────────────────────
833
  input_image.change(
834
+ fn=on_image_upload,
835
+ inputs=[input_image, input_video, high_res],
836
+ outputs=[width, height],
837
+ )
838
+ input_video.change(
839
+ fn=on_video_upload,
840
+ inputs=[input_video, input_image, high_res],
841
  outputs=[width, height],
842
  )
843
  high_res.change(
844
  fn=on_highres_toggle,
845
+ inputs=[input_image, input_video, high_res],
846
  outputs=[width, height],
847
  )
 
848
  generate_btn.click(
849
  fn=generate_video,
850
  inputs=[
851
+ input_image, input_video, input_audio, prompt, duration,
852
+ conditioning_strength, enhance_prompt,
 
853
  seed, randomize_seed, height, width,
854
  ],
855
  outputs=[output_video, seed],
856
  )
857
 
858
 
 
 
 
 
859
  if __name__ == "__main__":
860
+ demo.launch(theme=gr.themes.Citrus())