alex commited on
Commit
fc134c2
·
1 Parent(s): f8c25ce

allow loRA

Browse files
app.py CHANGED
@@ -30,6 +30,8 @@ from ltx_pipelines.utils.constants import (
30
  DEFAULT_FRAME_RATE,
31
  DEFAULT_LORA_STRENGTH,
32
  )
 
 
33
 
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
@@ -182,24 +184,48 @@ print("Loading LTX-2 Distilled pipeline...")
182
  print("=" * 80)
183
 
184
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
185
- distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
186
  spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
187
 
188
  print(f"Initializing pipeline with:")
189
  print(f" checkpoint_path={checkpoint_path}")
190
- print(f" distilled_lora_path={distilled_lora_path}")
191
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # Load distilled LoRA as a regular LoRA
195
  loras = [
 
196
  LoraPathStrengthAndSDOps(
197
  path=distilled_lora_path,
198
  strength=DEFAULT_LORA_STRENGTH,
199
  sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
200
- )
 
 
 
201
  ]
202
 
 
 
 
 
 
 
 
 
203
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
204
  # Text encoding will be done by external space
205
  pipeline = DistilledPipeline(
@@ -222,23 +248,6 @@ print("=" * 80)
222
  print("Pipeline fully loaded and ready!")
223
  print("=" * 80)
224
 
225
- def get_duration(
226
- input_image,
227
- prompt,
228
- duration,
229
- enhance_prompt,
230
- seed,
231
- randomize_seed,
232
- height,
233
- width,
234
- progress
235
- ):
236
- if duration <= 5:
237
- return 80
238
- elif duration <= 10:
239
- return 120
240
- else:
241
- return 180
242
 
243
  class RadioAnimated(gr.HTML):
244
  """
@@ -274,41 +283,254 @@ class RadioAnimated(gr.HTML):
274
 
275
  js_on_load = r"""
276
  (() => {
277
- const wrap = element.querySelector('.ra-wrap');
278
- const inner = element.querySelector('.ra-inner');
279
- const highlight = element.querySelector('.ra-highlight');
280
- const inputs = Array.from(element.querySelectorAll('.ra-input'));
281
-
282
- if (!inputs.length) return;
283
-
284
- const choices = inputs.map(i => i.value);
285
-
286
- function setHighlightByIndex(idx) {
287
- const n = choices.length;
288
- const pct = 100 / n;
289
- highlight.style.width = `calc(${pct}% - 6px)`;
290
- highlight.style.transform = `translateX(${idx * 100}%)`;
291
- }
292
-
293
- function setCheckedByValue(val, shouldTrigger=false) {
294
- const idx = Math.max(0, choices.indexOf(val));
295
- inputs.forEach((inp, i) => { inp.checked = (i === idx); });
296
- setHighlightByIndex(idx);
297
-
298
- props.value = choices[idx];
299
- if (shouldTrigger) trigger('change', props.value);
300
- }
301
-
302
- // Init from props.value
303
- setCheckedByValue(props.value ?? choices[0], false);
304
-
305
- // Input handlers
306
- inputs.forEach((inp) => {
307
- inp.addEventListener('change', () => {
308
- setCheckedByValue(inp.value, true);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  })();
 
312
  """
313
 
314
  super().__init__(
@@ -318,10 +540,42 @@ class RadioAnimated(gr.HTML):
318
  **kwargs
319
  )
320
 
 
321
  def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
322
- output_video, seed = generate_video(input_image, prompt, 5, True, 42, True, DEFAULT_1_STAGE_HEIGHT, DEFAULT_1_STAGE_WIDTH, progress)
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  return output_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  @spaces.GPU(duration=get_duration)
327
  def generate_video(
@@ -333,6 +587,7 @@ def generate_video(
333
  randomize_seed: bool = True,
334
  height: int = DEFAULT_1_STAGE_HEIGHT,
335
  width: int = DEFAULT_1_STAGE_WIDTH,
 
336
  progress=gr.Progress(track_tqdm=True),
337
  ):
338
  """
@@ -346,8 +601,10 @@ def generate_video(
346
  randomize_seed: If True, a random seed is generated for each run.
347
  height: Output video height in pixels.
348
  width: Output video width in pixels.
 
349
  progress: Gradio progress tracker.
350
  Returns:
 
351
  A tuple of:
352
  - output_path: Path to the generated MP4 video file.
353
  - seed: The seed used for generation.
@@ -396,6 +653,20 @@ def generate_video(
396
  del embeddings, final_prompt, status
397
  torch.cuda.empty_cache()
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  # Run inference - progress automatically tracks tqdm from pipeline
400
  pipeline(
401
  prompt=prompt,
@@ -431,7 +702,42 @@ def apply_duration(duration: str):
431
  duration_s = int(duration[:-1])
432
  return duration_s
433
 
 
434
  css = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  #col-container {
436
  margin: 0 auto;
437
  max-width: 1600px;
@@ -570,6 +876,176 @@ css += """
570
  }
571
  """
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
575
  gr.HTML(
@@ -605,12 +1081,19 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
605
  height=512
606
  )
607
 
 
 
 
 
 
 
608
  prompt = gr.Textbox(
609
  label="Prompt",
610
  value="Make this image come alive with cinematic motion, smooth animation",
611
  lines=3,
612
  max_lines=3,
613
- placeholder="Describe the motion and animation you want..."
 
614
  )
615
 
616
  enhance_prompt = gr.Checkbox(
@@ -633,10 +1116,9 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
633
 
634
  with gr.Column(elem_id="step-column"):
635
  output_video = gr.Video(label="Generated Video", autoplay=True, height=512)
636
-
637
- with gr.Row():
638
-
639
- with gr.Column():
640
  radioanimated_duration = RadioAnimated(
641
  choices=["3s", "5s", "10s", "15s"],
642
  value="3s",
@@ -651,8 +1133,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
651
  step=0.1,
652
  visible=False
653
  )
654
-
655
- with gr.Column():
656
  radioanimated_resolution = RadioAnimated(
657
  choices=["768x512", "512x512", "512x768"],
658
  value=f"{DEFAULT_1_STAGE_WIDTH}x{DEFAULT_1_STAGE_HEIGHT}",
@@ -661,10 +1142,30 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
661
 
662
  width = gr.Number(label="Width", value=DEFAULT_1_STAGE_WIDTH, precision=0, visible=False)
663
  height = gr.Number(label="Height", value=DEFAULT_1_STAGE_HEIGHT, precision=0, visible=False)
664
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  generate_btn = gr.Button("🤩 Generate Video", variant="primary", elem_classes="button-gradient")
667
 
 
 
 
 
 
 
668
 
669
  radioanimated_duration.change(
670
  fn=apply_duration,
@@ -678,6 +1179,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
678
  outputs=[width, height],
679
  api_visibility="private"
680
  )
 
 
 
 
 
 
 
681
 
682
  generate_btn.click(
683
  fn=generate_video,
@@ -690,6 +1198,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
690
  randomize_seed,
691
  height,
692
  width,
 
693
  ],
694
  outputs=[output_video,seed]
695
  )
@@ -716,7 +1225,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
716
 
717
  ],
718
  fn=generate_video_example,
719
- inputs=[input_image, prompt],
720
  outputs = [output_video],
721
  label="Example",
722
  cache_examples=True,
 
30
  DEFAULT_FRAME_RATE,
31
  DEFAULT_LORA_STRENGTH,
32
  )
33
+ from ltx_core.loader.single_gpu_model_builder import set_lora_enabled
34
+
35
 
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
 
184
  print("=" * 80)
185
 
186
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
 
187
  spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
188
 
189
  print(f"Initializing pipeline with:")
190
  print(f" checkpoint_path={checkpoint_path}")
 
191
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
192
 
193
+ distilled_lora_path = get_hub_or_local_checkpoint(
194
+ DEFAULT_REPO_ID,
195
+ DEFAULT_DISTILLED_LORA_FILENAME,
196
+ )
197
+
198
+ dolly_in_lora_path = get_hub_or_local_checkpoint(
199
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In",
200
+ "ltx-2-19b-lora-camera-control-dolly-in.safetensors",
201
+ )
202
+ dolly_out_lora_path = get_hub_or_local_checkpoint(
203
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out",
204
+ "ltx-2-19b-lora-camera-control-dolly-out.safetensors",
205
+ )
206
+
207
 
208
  # Load distilled LoRA as a regular LoRA
209
  loras = [
210
+ # --- fused / base behavior ---
211
  LoraPathStrengthAndSDOps(
212
  path=distilled_lora_path,
213
  strength=DEFAULT_LORA_STRENGTH,
214
  sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
215
+ ),
216
+ # # --- runtime-toggle camera controls ---#
217
+ LoraPathStrengthAndSDOps(dolly_in_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
218
+ LoraPathStrengthAndSDOps(dolly_out_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
219
  ]
220
 
221
+ # Runtime-toggle LoRAs (exclude fused distilled at index 0)
222
+ RUNTIME_LORA_CHOICES = [
223
+ ("No LoRA", -1),
224
+ ("Dolly In", 0),
225
+ ("Dolly Out", 1),
226
+ ]
227
+
228
+
229
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
230
  # Text encoding will be done by external space
231
  pipeline = DistilledPipeline(
 
248
  print("Pipeline fully loaded and ready!")
249
  print("=" * 80)
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  class RadioAnimated(gr.HTML):
253
  """
 
283
 
284
  js_on_load = r"""
285
  (() => {
286
+ const wrap = element.querySelector('.ra-wrap');
287
+ const inner = element.querySelector('.ra-inner');
288
+ const highlight = element.querySelector('.ra-highlight');
289
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
290
+ const labels = Array.from(element.querySelectorAll('.ra-label'));
291
+
292
+ if (!inputs.length || !labels.length) return;
293
+
294
+ const choices = inputs.map(i => i.value);
295
+ const PAD = 6; // must match .ra-inner padding and .ra-highlight top/left
296
+
297
+ let currentIdx = 0;
298
+
299
+ function setHighlightByIndex(idx) {
300
+ currentIdx = idx;
301
+
302
+ const lbl = labels[idx];
303
+ if (!lbl) return;
304
+
305
+ const innerRect = inner.getBoundingClientRect();
306
+ const lblRect = lbl.getBoundingClientRect();
307
+
308
+ // width matches the label exactly
309
+ highlight.style.width = `${lblRect.width}px`;
310
+
311
+ // highlight has left: 6px, so subtract PAD to align
312
+ const x = (lblRect.left - innerRect.left - PAD);
313
+ highlight.style.transform = `translateX(${x}px)`;
314
+ }
315
+
316
+ function setCheckedByValue(val, shouldTrigger=false) {
317
+ const idx = Math.max(0, choices.indexOf(val));
318
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
319
+
320
+ // Wait a frame in case fonts/layout settle (prevents rare drift)
321
+ requestAnimationFrame(() => setHighlightByIndex(idx));
322
+
323
+ props.value = choices[idx];
324
+ if (shouldTrigger) trigger('change', props.value);
325
+ }
326
+
327
+ // Init
328
+ setCheckedByValue(props.value ?? choices[0], false);
329
+
330
+ // Input handlers
331
+ inputs.forEach((inp) => {
332
+ inp.addEventListener('change', () => setCheckedByValue(inp.value, true));
333
  });
334
+
335
+ // Recalc on resize (important in Gradio layouts)
336
+ window.addEventListener('resize', () => setHighlightByIndex(currentIdx));
337
+ })();
338
+
339
+ """
340
+
341
+ super().__init__(
342
+ value=value,
343
+ html_template=html_template,
344
+ js_on_load=js_on_load,
345
+ **kwargs
346
+ )
347
+
348
+
349
+ class PromptBox(gr.HTML):
350
+ """
351
+ DeepSite-like prompt box (HTML textarea) that behaves like an input component.
352
+ Outputs: the current text value (string)
353
+ """
354
+ def __init__(self, value="", placeholder="Describe the video with audio you want to generate...", **kwargs):
355
+ uid = uuid.uuid4().hex[:8]
356
+
357
+ html_template = f"""
358
+ <div style="text-align:center; font-weight:600; margin-bottom:6px;">
359
+ Prompt
360
+ </div>
361
+ <div class="ds-prompt" data-ds="{uid}">
362
+ <textarea class="ds-textarea" rows="3"
363
+ placeholder="{placeholder}"></textarea>
364
+ </div>
365
+ """
366
+
367
+ js_on_load = r"""
368
+ (() => {
369
+ const textarea = element.querySelector(".ds-textarea");
370
+ if (!textarea) return;
371
+
372
+ // Auto-resize (optional, but nice)
373
+ const autosize = () => {
374
+ textarea.style.height = "0px";
375
+ textarea.style.height = Math.min(textarea.scrollHeight, 240) + "px";
376
+ };
377
+
378
+ // Set initial value from props.value
379
+ const setValue = (v, triggerChange=false) => {
380
+ const val = (v ?? "");
381
+ if (textarea.value !== val) textarea.value = val;
382
+ autosize();
383
+
384
+ props.value = textarea.value;
385
+ if (triggerChange) trigger("change", props.value);
386
+ };
387
+
388
+ setValue(props.value, false);
389
+
390
+ // Update Gradio value on input
391
+ textarea.addEventListener("input", () => {
392
+ autosize();
393
+ props.value = textarea.value;
394
+ trigger("change", props.value);
395
  });
396
+
397
+ let last = props.value;
398
+ const syncFromProps = () => {
399
+ if (props.value !== last) {
400
+ last = props.value;
401
+ setValue(last, false); // don't re-trigger change loop
402
+ }
403
+ requestAnimationFrame(syncFromProps);
404
+ };
405
+ requestAnimationFrame(syncFromProps);
406
+ })();
407
+ """
408
+
409
+ super().__init__(
410
+ value=value,
411
+ html_template=html_template,
412
+ js_on_load=js_on_load,
413
+ **kwargs
414
+ )
415
+
416
+ class CameraDropdown(gr.HTML):
417
+ """
418
+ Custom dropdown (More-style).
419
+ Outputs: selected option string, e.g. "Dolly Left"
420
+ """
421
+ def __init__(self, choices, value="None", title="Camera LoRA", **kwargs):
422
+ if not choices:
423
+ raise ValueError("CameraDropdown requires choices.")
424
+
425
+ uid = uuid.uuid4().hex[:8]
426
+ safe_choices = [str(c) for c in choices]
427
+
428
+ items_html = "\n".join(
429
+ f"""<button type="button" class="cd-item" data-value="{c}">{c}</button>"""
430
+ for c in safe_choices
431
+ )
432
+
433
+ html_template = f"""
434
+ <div class="cd-wrap" data-cd="{uid}">
435
+ <button type="button" class="cd-trigger" aria-haspopup="menu" aria-expanded="false">
436
+ <span class="cd-trigger-text">More</span>
437
+ <span class="cd-caret">▾</span>
438
+ </button>
439
+
440
+ <div class="cd-menu" role="menu" aria-hidden="true">
441
+ <div class="cd-title">{title}</div>
442
+ <div class="cd-items">
443
+ {items_html}
444
+ </div>
445
+ </div>
446
+ </div>
447
+ """
448
+
449
+ js_on_load = r"""
450
+ (() => {
451
+ const wrap = element.querySelector(".cd-wrap");
452
+ const trigger = element.querySelector(".cd-trigger");
453
+ const triggerText = element.querySelector(".cd-trigger-text");
454
+ const menu = element.querySelector(".cd-menu");
455
+ const items = Array.from(element.querySelectorAll(".cd-item"));
456
+
457
+ if (!wrap || !trigger || !menu || !items.length) return;
458
+
459
+ function closeMenu() {
460
+ menu.classList.remove("open");
461
+ trigger.setAttribute("aria-expanded", "false");
462
+ menu.setAttribute("aria-hidden", "true");
463
+ }
464
+
465
+ function openMenu() {
466
+ menu.classList.add("open");
467
+ trigger.setAttribute("aria-expanded", "true");
468
+ menu.setAttribute("aria-hidden", "false");
469
+ }
470
+
471
+ function setValue(val, shouldTrigger = false) {
472
+ const v = (val ?? "None");
473
+ props.value = v;
474
+ triggerText.textContent = v;
475
+
476
+ items.forEach(btn => {
477
+ btn.classList.toggle("selected", btn.dataset.value === v);
478
+ });
479
+
480
+ if (shouldTrigger) trigger("change", props.value);
481
+ }
482
+
483
+ // Toggle menu
484
+ trigger.addEventListener("pointerdown", (e) => {
485
+ e.preventDefault(); // prevents focus/blur weirdness
486
+ e.stopPropagation();
487
+ if (menu.classList.contains("open")) closeMenu();
488
+ else openMenu();
489
+ });
490
+
491
+ // Close on outside interaction (use capture so it wins)
492
+ document.addEventListener("pointerdown", (e) => {
493
+ if (!wrap.contains(e.target)) closeMenu();
494
+ }, true);
495
+
496
+ // Close on ESC
497
+ document.addEventListener("keydown", (e) => {
498
+ if (e.key === "Escape") closeMenu();
499
+ });
500
+
501
+ // Close when focus leaves the dropdown (keyboard users)
502
+ wrap.addEventListener("focusout", (e) => {
503
+ // if the newly-focused element isn't inside wrap, close
504
+ if (!wrap.contains(e.relatedTarget)) closeMenu();
505
+ });
506
+
507
+ // Item selection: use pointerdown so it closes immediately
508
+ items.forEach((btn) => {
509
+ btn.addEventListener("pointerdown", (e) => {
510
+ e.preventDefault();
511
+ e.stopPropagation();
512
+
513
+ // close first so it never "sticks" open
514
+ closeMenu();
515
+ setValue(btn.dataset.value, true);
516
+ });
517
+ });
518
+
519
+ // init
520
+ setValue((props.value ?? "None"), false);
521
+
522
+ // sync from Python updates
523
+ let last = props.value;
524
+ const syncFromProps = () => {
525
+ if (props.value !== last) {
526
+ last = props.value;
527
+ setValue(last, false);
528
+ }
529
+ requestAnimationFrame(syncFromProps);
530
+ };
531
+ requestAnimationFrame(syncFromProps);
532
  })();
533
+
534
  """
535
 
536
  super().__init__(
 
540
  **kwargs
541
  )
542
 
543
+
544
  def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
545
+
546
+ output_video, seed = generate_video(
547
+ input_image,
548
+ prompt,
549
+ 5, # duration seconds
550
+ True, # enhance_prompt
551
+ 42, # seed
552
+ True, # randomize_seed
553
+ DEFAULT_1_STAGE_HEIGHT, # height
554
+ DEFAULT_1_STAGE_WIDTH, # width
555
+ "No LoRA",
556
+ progress
557
+ )
558
 
559
  return output_video
560
+
561
+ def get_duration(
562
+ input_image,
563
+ prompt,
564
+ duration,
565
+ enhance_prompt,
566
+ seed,
567
+ randomize_seed,
568
+ height,
569
+ width,
570
+ camera_lora,
571
+ progress
572
+ ):
573
+ if duration <= 5:
574
+ return 80
575
+ elif duration <= 10:
576
+ return 120
577
+ else:
578
+ return 180
579
 
580
  @spaces.GPU(duration=get_duration)
581
  def generate_video(
 
587
  randomize_seed: bool = True,
588
  height: int = DEFAULT_1_STAGE_HEIGHT,
589
  width: int = DEFAULT_1_STAGE_WIDTH,
590
+ camera_lora: str = "No LoRA",
591
  progress=gr.Progress(track_tqdm=True),
592
  ):
593
  """
 
601
  randomize_seed: If True, a random seed is generated for each run.
602
  height: Output video height in pixels.
603
  width: Output video width in pixels.
604
+ camera_lora: Camera motion control LoRA to apply during generation (enables exactly one at runtime).
605
  progress: Gradio progress tracker.
606
  Returns:
607
+
608
  A tuple of:
609
  - output_path: Path to the generated MP4 video file.
610
  - seed: The seed used for generation.
 
653
  del embeddings, final_prompt, status
654
  torch.cuda.empty_cache()
655
 
656
+
657
+ # Map dropdown name -> adapter index
658
+ name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
659
+ selected_idx = name_to_idx.get(camera_lora, -1)
660
+
661
+ # Disable all runtime adapters first (0..N-1)
662
+ # N here is len(RUNTIME_LORA_CHOICES)-1 because "None" isn't an adapter
663
+ for i in range(len(RUNTIME_LORA_CHOICES) - 1):
664
+ set_lora_enabled(pipeline._transformer, i, False)
665
+
666
+ # Enable selected one (if any)
667
+ if selected_idx >= 0:
668
+ set_lora_enabled(pipeline._transformer, selected_idx, True)
669
+
670
  # Run inference - progress automatically tracks tqdm from pipeline
671
  pipeline(
672
  prompt=prompt,
 
702
  duration_s = int(duration[:-1])
703
  return duration_s
704
 
705
+
706
  css = """
707
+
708
+ /* Make the row behave nicely */
709
+ #controls-row {
710
+ display: flex;
711
+ align-items: center;
712
+ gap: 12px;
713
+ flex-wrap: nowrap; /* or wrap if you prefer on small screens */
714
+ }
715
+
716
+ /* Stop these components from stretching */
717
+ #controls-row > * {
718
+ flex: 0 0 auto !important;
719
+ width: auto !important;
720
+ min-width: 0 !important;
721
+ }
722
+
723
+ #controls-row #camera_lora_ui {
724
+ margin-left: auto !important;
725
+ }
726
+
727
+ /* Gradio HTML components often have an inner wrapper div that is width:100% */
728
+ #camera_lora_ui,
729
+ #camera_lora_ui > div {
730
+ width: fit-content !important;
731
+ }
732
+
733
+ /* Same idea for your radio HTML blocks (optional but helps) */
734
+ #radioanimated_duration,
735
+ #radioanimated_duration > div,
736
+ #radioanimated_resolution,
737
+ #radioanimated_resolution > div {
738
+ width: fit-content !important;
739
+ }
740
+
741
  #col-container {
742
  margin: 0 auto;
743
  max-width: 1600px;
 
876
  }
877
  """
878
 
879
+ css += """
880
+ /* --- prompt box --- */
881
+ .ds-prompt{
882
+ width: 100%;
883
+ max-width: 720px;
884
+ margin-top: 3px;
885
+ }
886
+
887
+ .ds-textarea{
888
+ width: 100%;
889
+ box-sizing: border-box;
890
+
891
+ background: #2b2b2b;
892
+ color: rgba(255,255,255,0.9);
893
+
894
+ border: 1px solid rgba(255,255,255,0.12);
895
+ border-radius: 14px;
896
+
897
+ padding: 14px 16px;
898
+ outline: none;
899
+
900
+ font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial;
901
+ font-size: 15px;
902
+ line-height: 1.35;
903
+
904
+ resize: none;
905
+ height: 94px;
906
+ min-height: 94px;
907
+ max-height: 94px;
908
+ overflow-y: auto;
909
+ }
910
+
911
+ .ds-textarea::placeholder{
912
+ color: rgba(255,255,255,0.55);
913
+ }
914
+
915
+ .ds-textarea:focus{
916
+ border-color: rgba(255,255,255,0.22);
917
+ box-shadow: 0 0 0 3px rgba(255,255,255,0.06);
918
+ }
919
+ """
920
+
921
+ css += """
922
+ /* ---- camera dropdown ---- */
923
+
924
+ /* 1) Fix overlap: make the Gradio HTML block shrink-to-fit when it contains a CameraDropdown.
925
+ Gradio uses .gr-html for HTML components in most versions; older themes sometimes use .gradio-html.
926
+ This keeps your big header HTML unaffected because it doesn't contain .cd-wrap.
927
+ */
928
+
929
+ /* 2) Actual dropdown layout */
930
+ .cd-wrap{
931
+ position: relative;
932
+ display: inline-block;
933
+ }
934
+
935
+ /* 3) Match RadioAnimated pill size/feel */
936
+ .cd-trigger{
937
+ margin-top: 2px;
938
+ display: inline-flex;
939
+ align-items: center;
940
+ justify-content: center;
941
+ gap: 10px;
942
+
943
+ border: none;
944
+
945
+ box-sizing: border-box;
946
+ padding: 10px 18px;
947
+ min-height: 52px;
948
+ line-height: 1.2;
949
+
950
+ border-radius: 9999px;
951
+ background: #0b0b0b;
952
+
953
+ font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial;
954
+ font-size: 14px;
955
+
956
+ /* ✅ match .ra-label exactly */
957
+ color: rgba(255,255,255,0.7) !important;
958
+ font-weight: 600 !important;
959
+
960
+ cursor: pointer;
961
+ user-select: none;
962
+ white-space: nowrap;
963
+ }
964
+
965
+ /* Ensure inner spans match too */
966
+ .cd-trigger .cd-trigger-text,
967
+ .cd-trigger .cd-caret{
968
+ color: rgba(255,255,255,0.7) !important;
969
+ }
970
+
971
+ /* keep caret styling */
972
+ .cd-caret{
973
+ opacity: 0.8;
974
+ font-weight: 900;
975
+ }
976
+
977
+ /* 4) Ensure menu overlays neighbors and isn't clipped */
978
+ .cd-menu{
979
+ position: absolute;
980
+ top: calc(100% + 10px);
981
+ left: 0;
982
+
983
+ min-width: 240px;
984
+ background: #2b2b2b;
985
+ border: 1px solid rgba(255,255,255,0.14);
986
+ border-radius: 14px;
987
+ box-shadow: 0 18px 40px rgba(0,0,0,0.35);
988
+ padding: 10px;
989
+
990
+ opacity: 0;
991
+ transform: translateY(-6px);
992
+ pointer-events: none;
993
+ transition: opacity 160ms ease, transform 160ms ease;
994
+
995
+ z-index: 9999; /* was 50 */
996
+ }
997
+
998
+ .cd-menu.open{
999
+ opacity: 1;
1000
+ transform: translateY(0);
1001
+ pointer-events: auto;
1002
+ }
1003
+
1004
+ .cd-title{
1005
+ padding: 6px 8px 10px 8px;
1006
+ font-size: 12px;
1007
+ font-weight: 800;
1008
+ letter-spacing: 0.02em;
1009
+ color: rgba(255,255,255,0.55);
1010
+ text-transform: none;
1011
+ }
1012
+
1013
+ .cd-items{
1014
+ display: flex;
1015
+ flex-direction: column;
1016
+ gap: 6px;
1017
+ }
1018
+
1019
+ .cd-item{
1020
+ width: 100%;
1021
+ text-align: left;
1022
+ border: none;
1023
+ background: rgba(255,255,255,0.06);
1024
+ color: rgba(255,255,255,0.92);
1025
+ padding: 10px 10px;
1026
+ border-radius: 12px;
1027
+ cursor: pointer;
1028
+ font-size: 14px;
1029
+ font-weight: 700;
1030
+ transition: background 120ms ease, transform 80ms ease;
1031
+ }
1032
+
1033
+ .cd-item:hover{
1034
+ background: rgba(255,255,255,0.10);
1035
+ }
1036
+
1037
+ .cd-item:active{
1038
+ transform: translateY(1px);
1039
+ }
1040
+
1041
+ .cd-item.selected{
1042
+ background: rgba(139,255,151,0.22);
1043
+ border: 1px solid rgba(139,255,151,0.35);
1044
+ }
1045
+
1046
+ """
1047
+
1048
+
1049
 
1050
  with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1051
  gr.HTML(
 
1081
  height=512
1082
  )
1083
 
1084
+
1085
+ prompt_ui = PromptBox(
1086
+ value="Make this image come alive with cinematic motion, smooth animation",
1087
+ elem_id="prompt_ui",
1088
+ )
1089
+
1090
  prompt = gr.Textbox(
1091
  label="Prompt",
1092
  value="Make this image come alive with cinematic motion, smooth animation",
1093
  lines=3,
1094
  max_lines=3,
1095
+ placeholder="Describe the motion and animation you want...",
1096
+ visible=False
1097
  )
1098
 
1099
  enhance_prompt = gr.Checkbox(
 
1116
 
1117
  with gr.Column(elem_id="step-column"):
1118
  output_video = gr.Video(label="Generated Video", autoplay=True, height=512)
1119
+
1120
+ with gr.Row(elem_id="controls-row"):
1121
+
 
1122
  radioanimated_duration = RadioAnimated(
1123
  choices=["3s", "5s", "10s", "15s"],
1124
  value="3s",
 
1133
  step=0.1,
1134
  visible=False
1135
  )
1136
+
 
1137
  radioanimated_resolution = RadioAnimated(
1138
  choices=["768x512", "512x512", "512x768"],
1139
  value=f"{DEFAULT_1_STAGE_WIDTH}x{DEFAULT_1_STAGE_HEIGHT}",
 
1142
 
1143
  width = gr.Number(label="Width", value=DEFAULT_1_STAGE_WIDTH, precision=0, visible=False)
1144
  height = gr.Number(label="Height", value=DEFAULT_1_STAGE_HEIGHT, precision=0, visible=False)
1145
+
1146
+ camera_lora_ui = CameraDropdown(
1147
+ choices=[name for name, _ in RUNTIME_LORA_CHOICES],
1148
+ value="No LoRA",
1149
+ title="Camera LoRA",
1150
+ elem_id="camera_lora_ui",
1151
+ )
1152
+
1153
+ # Hidden real dropdown (backend value)
1154
+ camera_lora = gr.Dropdown(
1155
+ label="Camera Control LoRA",
1156
+ choices=[name for name, _ in RUNTIME_LORA_CHOICES],
1157
+ value="No LoRA",
1158
+ visible=False
1159
+ )
1160
 
1161
  generate_btn = gr.Button("🤩 Generate Video", variant="primary", elem_classes="button-gradient")
1162
 
1163
+ camera_lora_ui.change(
1164
+ fn=lambda x: x,
1165
+ inputs=camera_lora_ui,
1166
+ outputs=camera_lora,
1167
+ api_visibility="private"
1168
+ )
1169
 
1170
  radioanimated_duration.change(
1171
  fn=apply_duration,
 
1179
  outputs=[width, height],
1180
  api_visibility="private"
1181
  )
1182
+ prompt_ui.change(
1183
+ fn=lambda x: x,
1184
+ inputs=prompt_ui,
1185
+ outputs=prompt,
1186
+ api_visibility="private"
1187
+ )
1188
+
1189
 
1190
  generate_btn.click(
1191
  fn=generate_video,
 
1198
  randomize_seed,
1199
  height,
1200
  width,
1201
+ camera_lora,
1202
  ],
1203
  outputs=[output_video,seed]
1204
  )
 
1225
 
1226
  ],
1227
  fn=generate_video_example,
1228
+ inputs=[input_image, prompt_ui],
1229
  outputs = [output_video],
1230
  label="Example",
1231
  cache_examples=True,
packages/ltx-core/src/ltx_core/loader/fuse_loras.py CHANGED
@@ -3,6 +3,7 @@ import triton
3
 
4
  from ltx_core.loader.kernels import fused_add_round_kernel
5
  from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
 
6
 
7
  BLOCK_SIZE = 1024
8
 
@@ -59,42 +60,59 @@ def _prepare_deltas(
59
  return deltas[0]
60
  return torch.sum(torch.stack(deltas, dim=0), dim=0)
61
 
62
-
63
  def apply_loras(
64
  model_sd: StateDict,
65
  lora_sd_and_strengths: list[LoraStateDictWithStrength],
66
  dtype: torch.dtype,
67
  destination_sd: StateDict | None = None,
68
- ) -> StateDict:
69
- sd = {}
70
- if destination_sd is not None:
71
- sd = destination_sd.sd
72
  size = 0
73
  device = torch.device("meta")
74
  inner_dtypes = set()
 
 
 
 
75
  for key, weight in model_sd.sd.items():
76
  if weight is None:
77
  continue
 
 
 
 
78
  device = weight.device
79
  target_dtype = dtype if dtype is not None else weight.dtype
80
- deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
 
81
  deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
 
 
 
 
 
 
82
  if deltas is None:
83
  if key in sd:
84
  continue
85
- deltas = weight.clone().to(dtype=target_dtype, device=device)
86
- elif weight.dtype == torch.float8_e4m3fn:
87
- if str(device).startswith("cuda"):
88
- deltas = calculate_weight_float8_(deltas, weight)
89
- else:
90
- deltas.add_(weight.to(dtype=deltas.dtype, device=device))
91
- elif weight.dtype == torch.bfloat16:
92
- deltas.add_(weight)
93
  else:
94
- raise ValueError(f"Unsupported dtype: {weight.dtype}")
95
- sd[key] = deltas.to(dtype=target_dtype)
 
 
 
 
96
  inner_dtypes.add(target_dtype)
97
- size += deltas.nbytes
98
- if destination_sd is not None:
99
- return destination_sd
100
- return StateDict(sd, device, size, inner_dtypes)
 
 
 
 
 
 
 
 
3
 
4
  from ltx_core.loader.kernels import fused_add_round_kernel
5
  from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
6
+ from typing import Iterable
7
 
8
  BLOCK_SIZE = 1024
9
 
 
60
  return deltas[0]
61
  return torch.sum(torch.stack(deltas, dim=0), dim=0)
62
 
 
63
  def apply_loras(
64
  model_sd: StateDict,
65
  lora_sd_and_strengths: list[LoraStateDictWithStrength],
66
  dtype: torch.dtype,
67
  destination_sd: StateDict | None = None,
68
+ return_affected: bool = False,
69
+ ) -> StateDict | tuple[StateDict, list[str]]:
70
+ sd = destination_sd.sd if destination_sd is not None else {}
 
71
  size = 0
72
  device = torch.device("meta")
73
  inner_dtypes = set()
74
+
75
+ affected_weight_keys: list[str] = []
76
+ affected_module_prefixes: set[str] = set()
77
+
78
  for key, weight in model_sd.sd.items():
79
  if weight is None:
80
  continue
81
+ if not key.endswith(".weight"):
82
+ # optional: skip non-weight tensors if your SD has them
83
+ continue
84
+
85
  device = weight.device
86
  target_dtype = dtype if dtype is not None else weight.dtype
87
+ deltas_dtype = target_dtype # you said ignore fp8 path
88
+
89
  deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
90
+
91
+ # Record which weights are actually modified by LoRA
92
+ if deltas is not None:
93
+ affected_weight_keys.append(key)
94
+ affected_module_prefixes.add(key[: -len(".weight")])
95
+
96
  if deltas is None:
97
  if key in sd:
98
  continue
99
+ out = weight.clone().to(dtype=target_dtype, device=device)
 
 
 
 
 
 
 
100
  else:
101
+ # normal add_ path (bf16 etc)
102
+ out = deltas.to(dtype=target_dtype)
103
+ # IMPORTANT: add base weight
104
+ out.add_(weight.to(dtype=out.dtype, device=device))
105
+
106
+ sd[key] = out
107
  inner_dtypes.add(target_dtype)
108
+ size += out.nbytes
109
+
110
+ result = destination_sd if destination_sd is not None else StateDict(sd, device, size, inner_dtypes)
111
+
112
+ if return_affected:
113
+ # sorted for stable output
114
+ affected = sorted(affected_module_prefixes)
115
+ return result, affected
116
+
117
+ return result
118
+
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py CHANGED
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
3
  from typing import Generic
4
 
5
  import torch
 
6
 
7
  from ltx_core.loader.fuse_loras import apply_loras
8
  from ltx_core.loader.module_ops import ModuleOps
@@ -22,6 +23,109 @@ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
22
  logger: logging.Logger = logging.getLogger(__name__)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @dataclass(frozen=True)
26
  class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
27
  """
@@ -93,9 +197,29 @@ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType],
93
  ]
94
  final_sd = apply_loras(
95
  model_sd=model_state_dict,
96
- lora_sd_and_strengths=lora_sd_and_strengths,
97
  dtype=dtype,
98
  destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
99
  )
100
  meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
101
- return self._return_model(meta_model, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import Generic
4
 
5
  import torch
6
+ import torch.nn as nn
7
 
8
  from ltx_core.loader.fuse_loras import apply_loras
9
  from ltx_core.loader.module_ops import ModuleOps
 
23
  logger: logging.Logger = logging.getLogger(__name__)
24
 
25
 
26
+ def get_submodule_and_parent(root: nn.Module, path: str):
27
+ """
28
+ Returns (parent_module, child_name, child_module)
29
+ where child_module is reachable at `path` from root.
30
+ Supports numeric segments for Sequential/ModuleList.
31
+ """
32
+ parts = path.split(".")
33
+ parent = root
34
+ for p in parts[:-1]:
35
+ if p.isdigit():
36
+ parent = parent[int(p)] # Sequential/ModuleList
37
+ else:
38
+ parent = getattr(parent, p)
39
+ last = parts[-1]
40
+ if last.isdigit():
41
+ child = parent[int(last)]
42
+ else:
43
+ child = getattr(parent, last)
44
+ return parent, last, child
45
+
46
+ def set_submodule(root: nn.Module, path: str, new_module: nn.Module):
47
+ parent, last, _ = get_submodule_and_parent(root, path)
48
+ if last.isdigit():
49
+ parent[int(last)] = new_module
50
+ else:
51
+ setattr(parent, last, new_module)
52
+
53
+
54
+ class MultiLoraLinear(nn.Module):
55
+ def __init__(self, base: nn.Linear):
56
+ super().__init__()
57
+ self.base = base
58
+ self.adapters: list[tuple[torch.Tensor, torch.Tensor, float]] = []
59
+ self.enabled: list[bool] = []
60
+
61
+ def add_adapter(self, A: torch.Tensor, B: torch.Tensor, scale: float, enabled: bool = True):
62
+ # store as buffers for inference (keeps them off .parameters())
63
+ idx = len(self.adapters)
64
+ self.register_buffer(f"lora_A_{idx}", A, persistent=False)
65
+ self.register_buffer(f"lora_B_{idx}", B, persistent=False)
66
+ self.adapters.append((A, B, float(scale)))
67
+ self.enabled.append(bool(enabled))
68
+
69
+ def set_enabled(self, idx: int, enabled: bool):
70
+ if 0 <= idx < len(self.enabled):
71
+ self.enabled[idx] = enabled
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ out = self.base(x)
75
+ # add enabled adapters
76
+ for i, on in enumerate(self.enabled):
77
+ if not on:
78
+ continue
79
+ A = getattr(self, f"lora_A_{i}")
80
+ B = getattr(self, f"lora_B_{i}")
81
+ scale = self.adapters[i][2]
82
+ out = out + ((x @ A.t()) @ B.t()) * scale
83
+ return out
84
+
85
+ def set_lora_enabled(model: nn.Module, adapter_idx: int, enabled: bool):
86
+ for m in model.modules():
87
+ if isinstance(m, MultiLoraLinear):
88
+ m.set_enabled(adapter_idx, enabled)
89
+
90
+ def patch_only_affected_linears(
91
+ model: nn.Module,
92
+ lora_sd: dict,
93
+ affected_modules: list[str],
94
+ strength: float,
95
+ adapter_idx: int,
96
+ default_enabled: bool = False,
97
+ ):
98
+ for prefix in affected_modules:
99
+ _, _, mod = get_submodule_and_parent(model, prefix)
100
+
101
+ # unwrap / wrap
102
+ if isinstance(mod, MultiLoraLinear):
103
+ wrapped = mod
104
+ else:
105
+ if not isinstance(mod, nn.Linear):
106
+ continue
107
+ wrapped = MultiLoraLinear(mod)
108
+ set_submodule(model, prefix, wrapped)
109
+
110
+ key_a = f"{prefix}.lora_A.weight"
111
+ key_b = f"{prefix}.lora_B.weight"
112
+ if key_a not in lora_sd or key_b not in lora_sd:
113
+ continue
114
+
115
+ base_device = wrapped.base.weight.device
116
+ base_dtype = wrapped.base.weight.dtype
117
+
118
+ A = lora_sd[key_a].to(device=base_device, dtype=base_dtype)
119
+ B = lora_sd[key_b].to(device=base_device, dtype=base_dtype)
120
+
121
+ # parity with your current merge behavior:
122
+ scale = strength
123
+
124
+ # Ensure adapter list indices align across layers
125
+ # If adapters are added sequentially per adapter_idx, this will line up.
126
+ wrapped.add_adapter(A, B, scale=scale, enabled=default_enabled)
127
+
128
+
129
  @dataclass(frozen=True)
130
  class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
131
  """
 
197
  ]
198
  final_sd = apply_loras(
199
  model_sd=model_state_dict,
200
+ lora_sd_and_strengths=[lora_sd_and_strengths[0]],
201
  dtype=dtype,
202
  destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
203
  )
204
  meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
205
+ model = self._return_model(meta_model, device)
206
+
207
+ _, affected_modules = apply_loras(
208
+ model_sd=model_state_dict,
209
+ lora_sd_and_strengths=lora_sd_and_strengths,
210
+ dtype=dtype,
211
+ destination_sd=None,
212
+ return_affected=True,
213
+ )
214
+
215
+ for runtime_idx, (lora_sd, strength) in enumerate(zip(lora_state_dicts[1:], lora_strengths[1:], strict=True)):
216
+ patch_only_affected_linears(
217
+ model,
218
+ lora_sd.sd,
219
+ affected_modules,
220
+ strength=strength,
221
+ adapter_idx=runtime_idx,
222
+ default_enabled=False, # start off
223
+ )
224
+
225
+ return model