StatusReport commited on
Commit
593bb6a
·
verified ·
1 Parent(s): 423ba6e

App: overcome 50GB limit by using mounted storage

Browse files
Files changed (1) hide show
  1. app.py +7 -19
app.py CHANGED
@@ -6,11 +6,6 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- # Place caches in persistent storage
10
- os.environ["HF_HOME"] = "/data/.huggingface"
11
- os.environ["HF_HUB_CACHE"] = "/data/.cache/huggingface/hub"
12
- os.environ["HF_DATASETS_CACHE"] = "/data/.cache/huggingface/datasets"
13
-
14
  # Install xformers for memory-efficient attention
15
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
16
 
@@ -86,22 +81,15 @@ RESOLUTIONS = {
86
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
87
  }
88
 
89
- # Model repos
90
- LTX_MODEL_REPO = "Lightricks/LTX-2.3"
91
- GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
92
-
93
- # Download model checkpoints
94
- print("=" * 80)
95
- print("Downloading LTX-2.3 distilled model + Gemma...")
96
- print("=" * 80)
97
 
98
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-1.1.safetensors")
99
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
100
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
101
 
102
- print(f"Checkpoint: {checkpoint_path}")
103
- print(f"Spatial upsampler: {spatial_upsampler_path}")
104
- print(f"Gemma root: {gemma_root}")
105
 
106
  # Initialize pipeline WITH text encoder
107
  pipeline = DistilledPipeline(
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
 
 
 
 
 
9
  # Install xformers for memory-efficient attention
10
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
 
81
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
82
  }
83
 
84
+ LTX_MOUNT = "/models/ltx"
85
+ GEMMA_MOUNT = "/models/gemma"
 
 
 
 
 
 
86
 
87
+ DISTILLED_FILENAME = "ltx-2.3-22b-distilled.safetensors"
88
+ UPSCALER_FILENAME = "ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
 
89
 
90
+ ltx_checkpoint = os.path.join(LTX_MOUNT, DISTILLED_FILENAME)
91
+ ltx_upscaler = os.path.join(LTX_MOUNT, UPSCALER_FILENAME)
92
+ gemma_root = GEMMA_MOUNT
93
 
94
  # Initialize pipeline WITH text encoder
95
  pipeline = DistilledPipeline(