zferd commited on
Commit
ca8354c
·
verified ·
1 Parent(s): 371065d

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +15 -28
streamlit_app.py CHANGED
@@ -11,7 +11,6 @@ import pandas as pd
11
  import streamlit as st
12
  from PIL import Image, ImageOps
13
  import tensorflow as tf
14
- from huggingface_hub import hf_hub_download
15
  from tensorflow.keras.applications.resnet50 import preprocess_input
16
 
17
  tf.get_logger().setLevel("ERROR")
@@ -28,10 +27,9 @@ if "upload" not in st.session_state:
28
  if "probs" not in st.session_state:
29
  st.session_state.probs = None
30
 
31
- # ---- Hugging Face Repo Details --- #
32
- REPO_ID = "zferd/welding-defect"
33
- MODEL_FILENAME = "model/final_single_phase.h5"
34
- CONFIG_FILENAME = "model/training_config.json"
35
  IMG_SIZE = (224, 224)
36
 
37
  # ---- Pretty display labels
@@ -75,22 +73,12 @@ class CastLayer(tf.keras.layers.Layer):
75
 
76
 
77
  @st.cache_resource
78
- def load_model_and_config_from_hub():
79
- """Downloads files from the Hub and loads model and config."""
80
- # Get token from environment (set in HF Space secrets if repo is private)
81
- token = os.environ.get("HF_TOKEN")
82
-
83
- # Download model and config files with token
84
- model_path = hf_hub_download(
85
- repo_id=REPO_ID,
86
- filename=MODEL_FILENAME,
87
- token=token,
88
- )
89
- config_path = hf_hub_download(
90
- repo_id=REPO_ID,
91
- filename=CONFIG_FILENAME,
92
- token=token,
93
- )
94
 
95
  # Load the Keras model with custom_objects so that 'Cast' is known
96
  custom_objects = {
@@ -98,13 +86,13 @@ def load_model_and_config_from_hub():
98
  }
99
 
100
  model = tf.keras.models.load_model(
101
- model_path,
102
  compile=False,
103
  custom_objects=custom_objects,
104
  )
105
 
106
  # Load class names from the config file
107
- with open(config_path, "r") as f:
108
  cfg = json.load(f)
109
  class_names = cfg.get("class_names", ["CR", "LP", "ND", "PO"]) # Fallback
110
 
@@ -145,14 +133,13 @@ def upload_cb():
145
  def weld():
146
  st.title("🔎 Weld Defect Classifier")
147
 
148
- # Load resources from the Hub
149
- model, class_names = None, None
150
  try:
151
- model, class_names = load_model_and_config_from_hub()
152
  except Exception as e:
153
- st.error(f"Error loading model from Hugging Face Hub: {str(e)}")
154
- st.error("Make sure the model repository is accessible and HF_TOKEN is set if needed.")
155
  st.stop()
 
156
 
157
  st.file_uploader(
158
  "Upload an image",
 
11
  import streamlit as st
12
  from PIL import Image, ImageOps
13
  import tensorflow as tf
 
14
  from tensorflow.keras.applications.resnet50 import preprocess_input
15
 
16
  tf.get_logger().setLevel("ERROR")
 
27
  if "probs" not in st.session_state:
28
  st.session_state.probs = None
29
 
30
+ # ---- Local model file paths (inside THIS Space repo) --- #
31
+ MODEL_PATH = "model/final_single_phase.h5"
32
+ CONFIG_PATH = "model/training_config.json"
 
33
  IMG_SIZE = (224, 224)
34
 
35
  # ---- Pretty display labels
 
73
 
74
 
75
  @st.cache_resource
76
+ def load_model_and_config():
77
+ """Loads model + config from local files inside the Space."""
78
+ if not os.path.exists(MODEL_PATH):
79
+ raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}")
80
+ if not os.path.exists(CONFIG_PATH):
81
+ raise FileNotFoundError(f"Config file not found at: {CONFIG_PATH}")
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Load the Keras model with custom_objects so that 'Cast' is known
84
  custom_objects = {
 
86
  }
87
 
88
  model = tf.keras.models.load_model(
89
+ MODEL_PATH,
90
  compile=False,
91
  custom_objects=custom_objects,
92
  )
93
 
94
  # Load class names from the config file
95
+ with open(CONFIG_PATH, "r") as f:
96
  cfg = json.load(f)
97
  class_names = cfg.get("class_names", ["CR", "LP", "ND", "PO"]) # Fallback
98
 
 
133
  def weld():
134
  st.title("🔎 Weld Defect Classifier")
135
 
136
+ # Load resources from local files
 
137
  try:
138
+ model, class_names = load_model_and_config()
139
  except Exception as e:
140
+ st.error(f"Error loading model/config: {str(e)}")
 
141
  st.stop()
142
+ return
143
 
144
  st.file_uploader(
145
  "Upload an image",