Robys01 commited on
Commit
6492f9e
·
1 Parent(s): e259963

Try this header file

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -1,4 +1,29 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import time
3
  import torch
4
  from models import UNet
@@ -10,15 +35,8 @@ from huggingface_hub import hf_hub_download
10
  import tempfile
11
  import requests
12
 
13
- os.environ["HF_HOME"] = "/tmp/huggingface"
14
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
15
- os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
16
- os.environ["GRADIO_CACHE_DIR"] = "/tmp/.gradio"
17
-
18
- MODEL_DIR = "/tmp/model"
19
- os.makedirs(MODEL_DIR, exist_ok=True)
20
  MODEL_PATH = os.path.join(MODEL_DIR, "best_unet_model.pth")
21
-
22
  # Download model if missing
23
  def download_model():
24
  print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S"))
 
1
  import os
2
+ import pathlib
3
+
4
+ # Hugging Face Hub cache → /tmp
5
+ os.environ["HF_HOME"] = "/tmp/huggingface"
6
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
7
+
8
+ # XDG cache (for any other libraries) → /tmp
9
+ os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
10
+
11
+ # Gradio general cache and examples cache → /tmp
12
+ os.environ["GRADIO_CACHE_DIR"] = "/tmp/.gradio"
13
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "/tmp/.gradio/cached_examples"
14
+
15
+ # Pre-create all those directories so no one falls back to ./ .gradio or model/
16
+ for d in [
17
+ "/tmp/huggingface",
18
+ "/tmp/.cache",
19
+ "/tmp/.gradio",
20
+ os.environ["GRADIO_EXAMPLES_CACHE"],
21
+ ]:
22
+ pathlib.Path(d).mkdir(parents=True, exist_ok=True)
23
+
24
+ MODEL_DIR = "/tmp/model"
25
+ pathlib.Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
26
+
27
  import time
28
  import torch
29
  from models import UNet
 
35
  import tempfile
36
  import requests
37
 
38
+ # Path to your downloaded model
 
 
 
 
 
 
39
  MODEL_PATH = os.path.join(MODEL_DIR, "best_unet_model.pth")
 
40
  # Download model if missing
41
  def download_model():
42
  print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S"))