linoyts HF Staff commited on
Commit
fb55646
·
verified ·
1 Parent(s): 6723455

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -12,12 +12,15 @@ from typing import Optional
12
  from huggingface_hub import hf_hub_download
13
  from ltx_pipelines.distilled import DistilledPipeline
14
  from ltx_core.tiling import TilingConfig
 
 
15
  from ltx_pipelines.constants import (
16
  DEFAULT_SEED,
17
  DEFAULT_HEIGHT,
18
  DEFAULT_WIDTH,
19
  DEFAULT_NUM_FRAMES,
20
  DEFAULT_FRAME_RATE,
 
21
  )
22
 
23
  # Default prompt from docstring example
@@ -26,7 +29,8 @@ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the
26
  # HuggingFace Hub defaults
27
  DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
28
  DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
29
- DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-distilled-rc1.safetensors"
 
30
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
31
 
32
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
@@ -52,18 +56,29 @@ print("Loading LTX-2 Distilled pipeline...")
52
  print("=" * 80)
53
 
54
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
 
55
  spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
56
 
57
  print(f"Initializing pipeline with:")
58
  print(f" checkpoint_path={checkpoint_path}")
 
59
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
60
  print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
61
 
 
 
 
 
 
 
 
 
 
62
  pipeline = DistilledPipeline(
63
  checkpoint_path=checkpoint_path,
64
  spatial_upsampler_path=spatial_upsampler_path,
65
  gemma_root=DEFAULT_GEMMA_REPO_ID,
66
- loras=[],
67
  fp8transformer=False,
68
  )
69
 
@@ -224,4 +239,4 @@ with gr.Blocks(title="LTX-2 Distilled Image-to-Video") as demo:
224
 
225
 
226
  if __name__ == "__main__":
227
- demo.launch(theme=gr.themes.Citrus(), share=True)
 
12
  from huggingface_hub import hf_hub_download
13
  from ltx_pipelines.distilled import DistilledPipeline
14
  from ltx_core.tiling import TilingConfig
15
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
16
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
17
  from ltx_pipelines.constants import (
18
  DEFAULT_SEED,
19
  DEFAULT_HEIGHT,
20
  DEFAULT_WIDTH,
21
  DEFAULT_NUM_FRAMES,
22
  DEFAULT_FRAME_RATE,
23
+ DEFAULT_LORA_STRENGTH,
24
  )
25
 
26
  # Default prompt from docstring example
 
29
  # HuggingFace Hub defaults
30
  DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
31
  DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
32
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-rc1.safetensors"
33
+ DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384-rc1.safetensors"
34
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
35
 
36
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
 
56
  print("=" * 80)
57
 
58
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
59
+ distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
60
  spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
61
 
62
  print(f"Initializing pipeline with:")
63
  print(f" checkpoint_path={checkpoint_path}")
64
+ print(f" distilled_lora_path={distilled_lora_path}")
65
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
66
  print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
67
 
68
+ # Load distilled LoRA as a regular LoRA
69
+ loras = [
70
+ LoraPathStrengthAndSDOps(
71
+ path=distilled_lora_path,
72
+ strength=DEFAULT_LORA_STRENGTH,
73
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
74
+ )
75
+ ]
76
+
77
  pipeline = DistilledPipeline(
78
  checkpoint_path=checkpoint_path,
79
  spatial_upsampler_path=spatial_upsampler_path,
80
  gemma_root=DEFAULT_GEMMA_REPO_ID,
81
+ loras=loras,
82
  fp8transformer=False,
83
  )
84
 
 
239
 
240
 
241
  if __name__ == "__main__":
242
+ demo.launch(theme=gr.themes.Citrus())