linoyts HF Staff commited on
Commit
a9a3bf1
·
verified ·
1 Parent(s): 0a8b4db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  import spaces
10
  import gradio as gr
11
  import torch
12
- from huggingface_hub import hf_hub_download
13
 
14
  # Import from public LTX-2 package
15
  # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
@@ -21,6 +21,7 @@ DEFAULT_REPO_ID = "Lightricks/LTX-2"
21
  DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
22
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
23
 
 
24
  def get_hub_or_local_checkpoint(repo_id: str, filename: str):
25
  """Download from HuggingFace Hub."""
26
  print(f"Downloading {filename} from {repo_id}...")
@@ -28,19 +29,28 @@ def get_hub_or_local_checkpoint(repo_id: str, filename: str):
28
  print(f"Downloaded to {ckpt_path}")
29
  return ckpt_path
30
 
 
 
 
 
 
 
 
31
  # Initialize model ledger and text encoder at startup (load once, keep in memory)
32
  print("=" * 80)
33
  print("Loading Gemma Text Encoder...")
34
  print("=" * 80)
35
 
36
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
 
37
  device = "cuda"
38
 
39
  print(f"Initializing text encoder with:")
40
  print(f" checkpoint_path={checkpoint_path}")
41
- print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
42
  print(f" device={device}")
43
 
 
44
  model_ledger = ModelLedger(
45
  dtype=torch.bfloat16,
46
  device=device,
 
9
  import spaces
10
  import gradio as gr
11
  import torch
12
+ from huggingface_hub import hf_hub_download,snapshot_download
13
 
14
  # Import from public LTX-2 package
15
  # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
 
21
  DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
22
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
23
 
24
+
25
  def get_hub_or_local_checkpoint(repo_id: str, filename: str):
26
  """Download from HuggingFace Hub."""
27
  print(f"Downloading {filename} from {repo_id}...")
 
29
  print(f"Downloaded to {ckpt_path}")
30
  return ckpt_path
31
 
32
+ def download_gemma_model(repo_id: str):
33
+ """Download the full Gemma model directory."""
34
+ print(f"Downloading Gemma model from {repo_id}...")
35
+ local_dir = snapshot_download(repo_id=repo_id)
36
+ print(f"Gemma model downloaded to {local_dir}")
37
+ return local_dir
38
+
39
  # Initialize model ledger and text encoder at startup (load once, keep in memory)
40
  print("=" * 80)
41
  print("Loading Gemma Text Encoder...")
42
  print("=" * 80)
43
 
44
  checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
45
+ gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
46
  device = "cuda"
47
 
48
  print(f"Initializing text encoder with:")
49
  print(f" checkpoint_path={checkpoint_path}")
50
+ print(f" gemma_root={gemma_local_path}")
51
  print(f" device={device}")
52
 
53
+
54
  model_ledger = ModelLedger(
55
  dtype=torch.bfloat16,
56
  device=device,