Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from pathlib import Path
|
|
| 9 |
import spaces
|
| 10 |
import gradio as gr
|
| 11 |
import torch
|
| 12 |
-
from huggingface_hub import hf_hub_download
|
| 13 |
|
| 14 |
# Import from public LTX-2 package
|
| 15 |
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
|
|
@@ -21,6 +21,7 @@ DEFAULT_REPO_ID = "Lightricks/LTX-2"
|
|
| 21 |
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 22 |
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
|
| 23 |
|
|
|
|
| 24 |
def get_hub_or_local_checkpoint(repo_id: str, filename: str):
|
| 25 |
"""Download from HuggingFace Hub."""
|
| 26 |
print(f"Downloading {filename} from {repo_id}...")
|
|
@@ -28,19 +29,28 @@ def get_hub_or_local_checkpoint(repo_id: str, filename: str):
|
|
| 28 |
print(f"Downloaded to {ckpt_path}")
|
| 29 |
return ckpt_path
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Initialize model ledger and text encoder at startup (load once, keep in memory)
|
| 32 |
print("=" * 80)
|
| 33 |
print("Loading Gemma Text Encoder...")
|
| 34 |
print("=" * 80)
|
| 35 |
|
| 36 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
|
|
|
| 37 |
device = "cuda"
|
| 38 |
|
| 39 |
print(f"Initializing text encoder with:")
|
| 40 |
print(f" checkpoint_path={checkpoint_path}")
|
| 41 |
-
print(f" gemma_root={
|
| 42 |
print(f" device={device}")
|
| 43 |
|
|
|
|
| 44 |
model_ledger = ModelLedger(
|
| 45 |
dtype=torch.bfloat16,
|
| 46 |
device=device,
|
|
|
|
| 9 |
import spaces
|
| 10 |
import gradio as gr
|
| 11 |
import torch
|
| 12 |
+
from huggingface_hub import hf_hub_download,snapshot_download
|
| 13 |
|
| 14 |
# Import from public LTX-2 package
|
| 15 |
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
|
|
|
|
| 21 |
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 22 |
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
|
| 23 |
|
| 24 |
+
|
| 25 |
def get_hub_or_local_checkpoint(repo_id: str, filename: str):
|
| 26 |
"""Download from HuggingFace Hub."""
|
| 27 |
print(f"Downloading {filename} from {repo_id}...")
|
|
|
|
| 29 |
print(f"Downloaded to {ckpt_path}")
|
| 30 |
return ckpt_path
|
| 31 |
|
| 32 |
+
def download_gemma_model(repo_id: str):
|
| 33 |
+
"""Download the full Gemma model directory."""
|
| 34 |
+
print(f"Downloading Gemma model from {repo_id}...")
|
| 35 |
+
local_dir = snapshot_download(repo_id=repo_id)
|
| 36 |
+
print(f"Gemma model downloaded to {local_dir}")
|
| 37 |
+
return local_dir
|
| 38 |
+
|
| 39 |
# Initialize model ledger and text encoder at startup (load once, keep in memory)
|
| 40 |
print("=" * 80)
|
| 41 |
print("Loading Gemma Text Encoder...")
|
| 42 |
print("=" * 80)
|
| 43 |
|
| 44 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
| 45 |
+
gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
|
| 46 |
device = "cuda"
|
| 47 |
|
| 48 |
print(f"Initializing text encoder with:")
|
| 49 |
print(f" checkpoint_path={checkpoint_path}")
|
| 50 |
+
print(f" gemma_root={gemma_local_path}")
|
| 51 |
print(f" device={device}")
|
| 52 |
|
| 53 |
+
|
| 54 |
model_ledger = ModelLedger(
|
| 55 |
dtype=torch.bfloat16,
|
| 56 |
device=device,
|