mcherif commited on
Commit
61e2d01
·
1 Parent(s): 9625337

Fix cache init and copy full src

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -4
  2. src/app_gradio.py +34 -13
Dockerfile CHANGED
@@ -40,9 +40,8 @@ RUN pip uninstall -y multipart || true
40
  # Copy model and app code
41
  COPY models/vit-finetuned ./models/vit-finetuned
42
 
43
- # Copy FastAPI and Streamlit apps
44
- #COPY ./src/streamlit_app.py ./src/streamlit_app.py
45
- COPY ./src/app_gradio.py ./src/app_gradio.py
46
 
47
  # Copy assets
48
  COPY ./images/plant-disease-logo.png ./images/plant-disease-logo.png
@@ -63,4 +62,4 @@ EXPOSE 7860
63
  #CMD streamlit run src/streamlit_app.py --server.port=7860 --server.address=0.0.0.0
64
  # To run FastAPI, comment out the above CMD and uncomment below:
65
  #CMD python -c "import sys, os; print('Python:', sys.version); print('Env:', dict(os.environ)); os.system('ls -lR /app');" && uvicorn src.inference.app_fastapi:app --host 0.0.0.0 --port 7860
66
- CMD python src/app_gradio.py
 
40
  # Copy model and app code
41
  COPY models/vit-finetuned ./models/vit-finetuned
42
 
43
+ # Copy application source files
44
+ COPY ./src/ ./src/
 
45
 
46
  # Copy assets
47
  COPY ./images/plant-disease-logo.png ./images/plant-disease-logo.png
 
62
  #CMD streamlit run src/streamlit_app.py --server.port=7860 --server.address=0.0.0.0
63
  # To run FastAPI, comment out the above CMD and uncomment below:
64
  #CMD python -c "import sys, os; print('Python:', sys.version); print('Env:', dict(os.environ)); os.system('ls -lR /app');" && uvicorn src.inference.app_fastapi:app --host 0.0.0.0 --port 7860
65
+ CMD python src/app_gradio.py
src/app_gradio.py CHANGED
@@ -1,13 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
- import os
5
- import json
6
- import base64
7
  from transformers import AutoImageProcessor, AutoModelForImageClassification
8
- import sys
 
 
9
  from model_paths import resolve_model_dir
10
 
 
 
 
 
11
  # Debug flag: set to True to enable, False to disable
12
  DEBUG = False
13
 
@@ -133,12 +151,15 @@ with gr.Blocks() as demo:
133
  gr.Markdown(f"<sub>{logo_dbg}</sub>")
134
 
135
  if __name__ == "__main__":
136
- LOCAL = "--local" in sys.argv
137
-
138
- if LOCAL:
139
- # Local dev: bind to loopback so the browser check works on Windows
140
- demo.launch(server_name="127.0.0.1", server_port=7860)
141
- else:
142
- # HF Spaces / containers: bind to 0.0.0.0 and use provided PORT
143
- port = int(os.getenv("PORT", "7860"))
144
- demo.launch(server_name="0.0.0.0", server_port=port)
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import json
4
+ import os
5
+ import sys
6
+
7
+ CURR_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ if CURR_DIR not in sys.path:
9
+ sys.path.insert(0, CURR_DIR)
10
+
11
+ cache_root = os.path.abspath(os.path.join(CURR_DIR, "..", ".cache", "huggingface"))
12
+ os.makedirs(cache_root, exist_ok=True)
13
+ os.environ.setdefault("HF_HOME", cache_root)
14
+ os.environ.setdefault("TRANSFORMERS_CACHE", cache_root)
15
+
16
  import gradio as gr
17
  from PIL import Image
18
  import torch
 
 
 
19
  from transformers import AutoImageProcessor, AutoModelForImageClassification
20
+
21
+ SPACE_ENV_VARS = ("SPACE_ID", "HF_SPACE_ID", "REPO_ID")
22
+
23
  from model_paths import resolve_model_dir
24
 
25
+
26
+ def running_in_space() -> bool:
27
+ return any(os.getenv(var) for var in SPACE_ENV_VARS)
28
+
29
  # Debug flag: set to True to enable, False to disable
30
  DEBUG = False
31
 
 
151
  gr.Markdown(f"<sub>{logo_dbg}</sub>")
152
 
153
  if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser(description="Launch the Plant Disease Classifier Gradio interface.")
155
+ parser.add_argument("--local", action="store_true", help="Force binding to loopback for local dev.")
156
+ parser.add_argument("--share", action="store_true", help="Create a public Gradio share URL.")
157
+ parser.add_argument("--port", type=int, help="Port to bind the Gradio server to.")
158
+ parser.add_argument("--host", help="Override the server host/IP address.")
159
+ args = parser.parse_args()
160
+
161
+ port = args.port or int(os.getenv("PORT", "7860"))
162
+ default_host = "127.0.0.1" if args.local or not running_in_space() else "0.0.0.0"
163
+ host = args.host or os.getenv("GRADIO_SERVER_NAME") or default_host
164
+
165
+ demo.launch(server_name=host, server_port=port, share=args.share)