mcherif commited on
Commit
d7b1b94
·
1 Parent(s): efd28a7

Download shared model snapshot when missing

Browse files
Files changed (2) hide show
  1. src/app_gradio.py +2 -2
  2. src/model_paths.py +60 -10
src/app_gradio.py CHANGED
@@ -36,8 +36,8 @@ def _select_cache_root() -> str:
36
 
37
 
38
  cache_root = _select_cache_root()
39
- os.environ.setdefault("HF_HOME", cache_root)
40
- os.environ.setdefault("TRANSFORMERS_CACHE", cache_root)
41
 
42
  import gradio as gr
43
  from PIL import Image
 
36
 
37
 
38
  cache_root = _select_cache_root()
39
+ os.environ["HF_HOME"] = cache_root
40
+ os.environ["TRANSFORMERS_CACHE"] = cache_root
41
 
42
  import gradio as gr
43
  from PIL import Image
src/model_paths.py CHANGED
@@ -1,4 +1,9 @@
1
  import os
 
 
 
 
 
2
 
3
 
4
  def resolve_model_dir(require_weights: bool = True) -> str:
@@ -20,6 +25,30 @@ def resolve_model_dir(require_weights: bool = True) -> str:
20
 
21
  repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  shared_dir = os.path.abspath(
24
  os.path.join(
25
  repo_root,
@@ -32,16 +61,37 @@ def resolve_model_dir(require_weights: bool = True) -> str:
32
  shared_model = os.path.join(shared_dir, "model.safetensors")
33
  if os.path.exists(shared_model):
34
  return shared_dir
 
35
 
36
- local_dir = os.path.join(repo_root, "models", "vit-finetuned")
37
- if require_weights:
38
- local_model = os.path.join(local_dir, "model.safetensors")
39
- if os.path.exists(local_model):
40
- return local_dir
41
- raise FileNotFoundError(
42
- "model.safetensors not found in shared or local model directories. "
43
- "Set MODEL_DIR to point at a valid model location."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
 
 
45
 
46
- os.makedirs(local_dir, exist_ok=True)
47
- return local_dir
 
 
1
  import os
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ from huggingface_hub import snapshot_download
7
 
8
 
9
  def resolve_model_dir(require_weights: bool = True) -> str:
 
25
 
26
  repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
27
 
28
+ shared_dir = _resolve_shared_repo_path(repo_root)
29
+ if shared_dir:
30
+ return shared_dir
31
+
32
+ local_dir = os.path.join(repo_root, "models", "vit-finetuned")
33
+ if require_weights:
34
+ local_model = os.path.join(local_dir, "model.safetensors")
35
+ if os.path.exists(local_model):
36
+ return local_dir
37
+
38
+ remote_dir = _download_remote_model()
39
+ if remote_dir:
40
+ return remote_dir
41
+
42
+ raise FileNotFoundError(
43
+ "model.safetensors not found in shared or local model directories. "
44
+ "Set MODEL_DIR to point at a valid model location."
45
+ )
46
+
47
+ os.makedirs(local_dir, exist_ok=True)
48
+ return local_dir
49
+
50
+
51
+ def _resolve_shared_repo_path(repo_root: str) -> Optional[str]:
52
  shared_dir = os.path.abspath(
53
  os.path.join(
54
  repo_root,
 
61
  shared_model = os.path.join(shared_dir, "model.safetensors")
62
  if os.path.exists(shared_model):
63
  return shared_dir
64
+ return None
65
 
66
+
67
+ def _download_remote_model() -> Optional[str]:
68
+ repo_id = os.getenv("MODEL_REPO_ID", "mcherif/Plant-Disease-RAG-Assistant")
69
+ repo_type = os.getenv("MODEL_REPO_TYPE", "space")
70
+ subdir = os.getenv("MODEL_REPO_SUBDIR", "models/vit-finetuned")
71
+
72
+ if not repo_id:
73
+ return None
74
+
75
+ cache_dir = os.getenv("MODEL_REPO_CACHE_DIR")
76
+ if cache_dir:
77
+ cache_path = Path(cache_dir)
78
+ else:
79
+ cache_path = Path(os.getenv("HF_HOME", tempfile.gettempdir()))
80
+ cache_path = cache_path / "model_snapshots"
81
+
82
+ cache_path.mkdir(parents=True, exist_ok=True)
83
+
84
+ try:
85
+ snapshot_path = snapshot_download(
86
+ repo_id=repo_id,
87
+ repo_type=repo_type,
88
+ allow_patterns=[f"{subdir}/*"],
89
+ local_dir=str(cache_path),
90
+ local_dir_use_symlinks=False,
91
  )
92
+ except Exception:
93
+ return None
94
 
95
+ candidate = Path(snapshot_path) / subdir
96
+ model_file = candidate / "model.safetensors"
97
+ return str(candidate) if model_file.exists() else None