Mpavan45 commited on
Commit
ada421e
·
verified ·
1 Parent(s): 4eeaf18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -1,21 +1,20 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
 
4
 
5
- # Load the trained model
6
- model = tf.keras.models.load_model("news_classification_rnn1.h5")
 
 
7
 
8
  # Load Preprocessing Function
9
- import dill
10
  with open("preprocessing1.pkl", "rb") as f:
11
  clean_text = dill.load(f)
12
 
13
- # Load Text Vectorization Layer from SavedModel
14
- import pickle
15
-
16
  with open("vector.pkl", "rb") as f:
17
- vectorizer = pickle.load(f)
18
-
19
 
20
  # Define News Categories
21
  news_categories = ["Business", "Sci/Tech", "Sports", "World"]
@@ -31,11 +30,8 @@ if st.button("Classify"):
31
  # Preprocess input
32
  processed_text = clean_text(user_input)
33
 
34
- # Convert text to integer sequence
35
- text_sequence = vectorizer([processed_text])
36
-
37
- # Convert to numpy array (model expects batch input)
38
-
39
 
40
  # Predict Category
41
  prediction = model.predict(text_sequence)
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
4
+ import dill
5
 
6
+ # Load the trained model with custom layers
7
+ from tensorflow.keras.layers import TextVectorization
8
+ model = tf.keras.models.load_model("news_classification_rnn1.h5",
9
+ custom_objects={"TextVectorization": TextVectorization})
10
 
11
  # Load Preprocessing Function
 
12
  with open("preprocessing1.pkl", "rb") as f:
13
  clean_text = dill.load(f)
14
 
15
+ # Load Text Vectorization Layer
 
 
16
  with open("vector.pkl", "rb") as f:
17
+ vectorizer = dill.load(f)
 
18
 
19
  # Define News Categories
20
  news_categories = ["Business", "Sci/Tech", "Sports", "World"]
 
30
  # Preprocess input
31
  processed_text = clean_text(user_input)
32
 
33
+ # Vectorize input and convert to numpy array
34
+ text_sequence = np.array(vectorizer([processed_text]))
 
 
 
35
 
36
  # Predict Category
37
  prediction = model.predict(text_sequence)