mcherif commited on
Commit
9625337
·
1 Parent(s): 98ae70f

Share vit checkpoint with RAG assistant project

Browse files
models/vit-finetuned/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b5249456310ed9b3884807fc9a172980c3950ec008aa89dadea0d97e5d9da5e4
3
- size 343334720
 
 
 
 
src/app_gradio.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import base64
7
  from transformers import AutoImageProcessor, AutoModelForImageClassification
8
  import sys
 
9
 
10
  # Debug flag: set to True to enable, False to disable
11
  DEBUG = False
@@ -16,7 +17,7 @@ def dbg(msg: str):
16
  print(f"[DEBUG] {msg}")
17
 
18
 
19
- MODEL_DIR = "models/vit-finetuned"
20
 
21
  # Build a safe path to the logo and expose it via Gradio's file= URL
22
  LOGO_REL = "images/plant-disease-logo.png"
 
6
  import base64
7
  from transformers import AutoImageProcessor, AutoModelForImageClassification
8
  import sys
9
+ from model_paths import resolve_model_dir
10
 
11
  # Debug flag: set to True to enable, False to disable
12
  DEBUG = False
 
17
  print(f"[DEBUG] {msg}")
18
 
19
 
20
+ MODEL_DIR = resolve_model_dir()
21
 
22
  # Build a safe path to the logo and expose it via Gradio's file= URL
23
  LOGO_REL = "images/plant-disease-logo.png"
src/model_paths.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def resolve_model_dir() -> str:
5
+ """
6
+ Determine which vit-finetuned folder to use.
7
+ Prefers the shared copy in HF-Plant-Disease-RAG-Assistant if present,
8
+ otherwise falls back to this repository's models directory.
9
+ """
10
+ override = os.getenv("MODEL_DIR")
11
+ if override:
12
+ return os.path.abspath(override)
13
+
14
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
15
+
16
+ shared_dir = os.path.abspath(
17
+ os.path.join(
18
+ repo_root,
19
+ "..",
20
+ "HF-Plant-Disease-RAG-Assistant",
21
+ "models",
22
+ "vit-finetuned",
23
+ )
24
+ )
25
+ shared_model = os.path.join(shared_dir, "model.safetensors")
26
+ if os.path.exists(shared_model):
27
+ return shared_dir
28
+
29
+ local_dir = os.path.join(repo_root, "models", "vit-finetuned")
30
+ local_model = os.path.join(local_dir, "model.safetensors")
31
+ if os.path.exists(local_model):
32
+ return local_dir
33
+
34
+ raise FileNotFoundError(
35
+ "model.safetensors not found in shared or local model directories."
36
+ )
src/streamlit_app.py CHANGED
@@ -14,9 +14,10 @@ else:
14
  import os
15
  import json
16
  from transformers import AutoImageProcessor, AutoModelForImageClassification
 
17
 
18
  # Model directory
19
- MODEL_DIR = "models/vit-finetuned"
20
 
21
  st.title("Plant Disease Classifier")
22
  st.write("App is starting...")
 
14
  import os
15
  import json
16
  from transformers import AutoImageProcessor, AutoModelForImageClassification
17
+ from model_paths import resolve_model_dir
18
 
19
  # Model directory
20
+ MODEL_DIR = resolve_model_dir()
21
 
22
  st.title("Plant Disease Classifier")
23
  st.write("App is starting...")