Update app.py
Browse files
app.py
CHANGED
|
@@ -12,12 +12,15 @@ from typing import Optional
|
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 14 |
from ltx_core.tiling import TilingConfig
|
|
|
|
|
|
|
| 15 |
from ltx_pipelines.constants import (
|
| 16 |
DEFAULT_SEED,
|
| 17 |
DEFAULT_HEIGHT,
|
| 18 |
DEFAULT_WIDTH,
|
| 19 |
DEFAULT_NUM_FRAMES,
|
| 20 |
DEFAULT_FRAME_RATE,
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
# Default prompt from docstring example
|
|
@@ -26,7 +29,8 @@ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the
|
|
| 26 |
# HuggingFace Hub defaults
|
| 27 |
DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
|
| 28 |
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 29 |
-
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-
|
|
|
|
| 30 |
DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
|
| 31 |
|
| 32 |
def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
|
|
@@ -52,18 +56,29 @@ print("Loading LTX-2 Distilled pipeline...")
|
|
| 52 |
print("=" * 80)
|
| 53 |
|
| 54 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
|
|
|
| 55 |
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
|
| 56 |
|
| 57 |
print(f"Initializing pipeline with:")
|
| 58 |
print(f" checkpoint_path={checkpoint_path}")
|
|
|
|
| 59 |
print(f" spatial_upsampler_path={spatial_upsampler_path}")
|
| 60 |
print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
pipeline = DistilledPipeline(
|
| 63 |
checkpoint_path=checkpoint_path,
|
| 64 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 65 |
gemma_root=DEFAULT_GEMMA_REPO_ID,
|
| 66 |
-
loras=
|
| 67 |
fp8transformer=False,
|
| 68 |
)
|
| 69 |
|
|
@@ -224,4 +239,4 @@ with gr.Blocks(title="LTX-2 Distilled Image-to-Video") as demo:
|
|
| 224 |
|
| 225 |
|
| 226 |
if __name__ == "__main__":
|
| 227 |
-
demo.launch(theme=gr.themes.Citrus()
|
|
|
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 14 |
from ltx_core.tiling import TilingConfig
|
| 15 |
+
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 16 |
+
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 17 |
from ltx_pipelines.constants import (
|
| 18 |
DEFAULT_SEED,
|
| 19 |
DEFAULT_HEIGHT,
|
| 20 |
DEFAULT_WIDTH,
|
| 21 |
DEFAULT_NUM_FRAMES,
|
| 22 |
DEFAULT_FRAME_RATE,
|
| 23 |
+
DEFAULT_LORA_STRENGTH,
|
| 24 |
)
|
| 25 |
|
| 26 |
# Default prompt from docstring example
|
|
|
|
| 29 |
# HuggingFace Hub defaults
|
| 30 |
DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
|
| 31 |
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 32 |
+
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-rc1.safetensors"
|
| 33 |
+
DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384-rc1.safetensors"
|
| 34 |
DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
|
| 35 |
|
| 36 |
def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
|
|
|
|
| 56 |
print("=" * 80)
|
| 57 |
|
| 58 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
| 59 |
+
distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
|
| 60 |
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
|
| 61 |
|
| 62 |
print(f"Initializing pipeline with:")
|
| 63 |
print(f" checkpoint_path={checkpoint_path}")
|
| 64 |
+
print(f" distilled_lora_path={distilled_lora_path}")
|
| 65 |
print(f" spatial_upsampler_path={spatial_upsampler_path}")
|
| 66 |
print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
|
| 67 |
|
| 68 |
+
# Load distilled LoRA as a regular LoRA
|
| 69 |
+
loras = [
|
| 70 |
+
LoraPathStrengthAndSDOps(
|
| 71 |
+
path=distilled_lora_path,
|
| 72 |
+
strength=DEFAULT_LORA_STRENGTH,
|
| 73 |
+
sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
|
| 74 |
+
)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
pipeline = DistilledPipeline(
|
| 78 |
checkpoint_path=checkpoint_path,
|
| 79 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 80 |
gemma_root=DEFAULT_GEMMA_REPO_ID,
|
| 81 |
+
loras=loras,
|
| 82 |
fp8transformer=False,
|
| 83 |
)
|
| 84 |
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
if __name__ == "__main__":
|
| 242 |
+
demo.launch(theme=gr.themes.Citrus())
|