Mpavan45 commited on
Commit
c8c9751
·
verified ·
1 Parent(s): 697dda9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
- import dill # Use dill instead of pickle for preprocessing function
4
  import numpy as np
5
 
6
- # Load Trained Model
7
  model = tf.keras.models.load_model("news_classification_rnn.h5")
8
 
9
  # Load Preprocessing Function
 
10
  with open("preprocessing1.pkl", "rb") as f:
11
  clean_text = dill.load(f)
12
 
13
- # Load Text Vectorization Layer
14
- with open("text_vectorizer (1).pkl", "rb") as f:
15
- vectorizer = dill.load(f)
16
 
17
  # Define News Categories
18
  news_categories = ["Business", "Sci/Tech", "Sports", "World"]
@@ -25,16 +24,16 @@ user_input = st.text_area("Enter News Text:", "")
25
 
26
  if st.button("Classify"):
27
  if user_input.strip():
28
- # Preprocess Input
29
  processed_text = clean_text(user_input)
30
 
31
- # Vectorize Input (Convert text to integer sequence)
32
- text_sequence = vectorizer([processed_text]) # Directly vectorizes text
33
 
34
- # Ensure correct shape (model expects batch input)
35
  text_sequence = np.array(text_sequence)
36
 
37
- # Make Prediction
38
  prediction = model.predict(text_sequence)
39
  category = np.argmax(prediction)
40
 
 
1
  import streamlit as st
2
  import tensorflow as tf
 
3
  import numpy as np
4
 
5
+ # Load the trained model
6
  model = tf.keras.models.load_model("news_classification_rnn.h5")
7
 
8
  # Load Preprocessing Function
9
+ import dill
10
  with open("preprocessing1.pkl", "rb") as f:
11
  clean_text = dill.load(f)
12
 
13
+ # Load Text Vectorization Layer from SavedModel
14
+ vectorizer = tf.saved_model.load("vectorizer")
 
15
 
16
  # Define News Categories
17
  news_categories = ["Business", "Sci/Tech", "Sports", "World"]
 
24
 
25
  if st.button("Classify"):
26
  if user_input.strip():
27
+ # Preprocess input
28
  processed_text = clean_text(user_input)
29
 
30
+ # Convert text to integer sequence
31
+ text_sequence = vectorizer([processed_text])
32
 
33
+ # Convert to numpy array (model expects batch input)
34
  text_sequence = np.array(text_sequence)
35
 
36
+ # Predict Category
37
  prediction = model.predict(text_sequence)
38
  category = np.argmax(prediction)
39