JasonYinnnn commited on
Commit
51fb385
·
1 Parent(s): 0e9bb13

asign cache to /data

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -7,6 +7,19 @@ import gradio as gr
7
  import spaces
8
 
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import uuid
11
  from typing import Any, List, Optional, Union
12
  import cv2
@@ -46,7 +59,7 @@ generated_object_map = {}
46
 
47
  ############## 3D-Fixer model
48
  model_dir = 'HorizonRobotics/3D-Fixer'
49
- local_dir = 'checkpoints/3D-Fixer'
50
  snapshot_download(repo_id=model_dir, local_dir=local_dir)
51
  ############## 3D-Fixer model
52
 
@@ -192,12 +205,6 @@ def run_depth_estimation(
192
  ) -> Image.Image:
193
  rgb_image = image_prompts["image"].convert("RGB")
194
 
195
- import torch
196
- print(torch.__version__)
197
- print(torch.version.cuda)
198
- print(torch.__file__)
199
- print(torch.cuda.is_available())
200
-
201
  from threeDFixer.datasets.utils import (
202
  normalize_vertices,
203
  project2ply
 
7
  import spaces
8
 
9
  import os
10
+ # Hugging Face 相关缓存
11
+ os.environ["HF_HOME"] = "/data/.huggingface"
12
+ os.environ["HF_HUB_CACHE"] = "/data/.cache/huggingface/hub"
13
+ os.environ["HF_DATASETS_CACHE"] = "/data/.cache/huggingface/datasets"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/data/.cache/huggingface/hub"
15
+
16
+ # PyTorch / Torch Hub / TorchVision / 部分 timm 权重
17
+ os.environ["TORCH_HOME"] = "/data/.cache/torch"
18
+ os.environ["XDG_CACHE_HOME"] = "/data/.cache"
19
+
20
+ # 关键:PyTorch JIT CUDA/C++ 扩展编译缓存目录
21
+ os.environ["TORCH_EXTENSIONS_DIR"] = "/data/.cache/torch_extensions"
22
+
23
  import uuid
24
  from typing import Any, List, Optional, Union
25
  import cv2
 
59
 
60
  ############## 3D-Fixer model
61
  model_dir = 'HorizonRobotics/3D-Fixer'
62
+ local_dir = "/data/checkpoints/3D-Fixer"
63
  snapshot_download(repo_id=model_dir, local_dir=local_dir)
64
  ############## 3D-Fixer model
65
 
 
205
  ) -> Image.Image:
206
  rgb_image = image_prompts["image"].convert("RGB")
207
 
 
 
 
 
 
 
208
  from threeDFixer.datasets.utils import (
209
  normalize_vertices,
210
  project2ply