stanlys96 commited on
Commit
d91171b
·
verified ·
1 Parent(s): 21a9c98

Upload 6 files

Browse files
Files changed (1) hide show
  1. prediction.py +7 -3
prediction.py CHANGED
@@ -17,13 +17,17 @@ def get_feeling(number):
17
  feeling = number_to_feeling.get(str(number), "Unknown feeling")
18
  return feeling
19
 
20
- with st.spinner("Loading the model, please wait..."):
21
- the_model = tf.keras.models.load_model('model.keras', custom_objects={'KerasLayer': tf_hub.KerasLayer})
 
22
 
23
  def app():
24
  st.header('Prediction', divider='rainbow')
25
 
26
  user_input = st.text_input("Enter your text here:")
 
 
 
27
  if st.button('Predict', type="secondary"):
28
  data = {
29
  "text_processed": [
@@ -33,7 +37,7 @@ def app():
33
  df = pd.DataFrame(data)
34
  with st.spinner("Making prediction..."):
35
  # Replace with your preprocessing and prediction code
36
- predictions = the_model.predict(df)
37
  predicted_class = np.argmax(predictions, axis=1)
38
  the_sentiment = predicted_class[0]
39
 
 
17
  feeling = number_to_feeling.get(str(number), "Unknown feeling")
18
  return feeling
19
 
20
+ # Load the model function
21
+ def load_model():
22
+ return tf.keras.models.load_model('model.keras', custom_objects={'KerasLayer': tf_hub.KerasLayer})
23
 
24
  def app():
25
  st.header('Prediction', divider='rainbow')
26
 
27
  user_input = st.text_input("Enter your text here:")
28
+ if 'the_model' not in st.session_state:
29
+ with st.spinner("Loading the model, please wait..."):
30
+ st.session_state.the_model = load_model()
31
  if st.button('Predict', type="secondary"):
32
  data = {
33
  "text_processed": [
 
37
  df = pd.DataFrame(data)
38
  with st.spinner("Making prediction..."):
39
  # Replace with your preprocessing and prediction code
40
+ predictions = st.session_state.the_model.predict(df)
41
  predicted_class = np.argmax(predictions, axis=1)
42
  the_sentiment = predicted_class[0]
43