Mpavan45 commited on
Commit
fd77546
·
verified ·
1 Parent(s): 9835fcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -20
app.py CHANGED
@@ -1,23 +1,21 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
- import pickle
4
  import numpy as np
5
- import pandas as pd
6
 
7
-
8
- # Load Model
9
  model = tf.keras.models.load_model("news_classification_rnn.h5")
10
 
11
  # Load Preprocessing Function
12
  with open("preprocessing1.pkl", "rb") as f:
13
- clean_text = pickle.load(f)
14
 
15
- # Load TF-IDF Vectorizer
16
- with open("text_vectorizer(1)).pkl", "rb") as f:
17
- vectorizer = pickle.load(f)
18
 
19
  # Define News Categories
20
- news_categories = ["Business", "Sci/Tech","Sports","World"]
21
 
22
  # Streamlit UI
23
  st.title("📰 News Classification with Simple RNN")
@@ -26,20 +24,20 @@ st.write("Enter a news headline to predict its category.")
26
  user_input = st.text_area("Enter News Text:", "")
27
 
28
  if st.button("Classify"):
29
- if user_input:
30
- # # Preprocess Input
31
- # processed_text = clean_text(user_input)
32
 
33
- # Convert text to integer sequence
34
- text_sequence = tokenizer.texts_to_sequences([processed_text])
35
 
36
- # Pad the sequence to match model input size
37
- text_padded = tf.keras.preprocessing.sequence.pad_sequences(text_sequence, maxlen=100)
38
 
39
- # Prediction
40
- prediction = model.predict(text_padded)
41
  category = np.argmax(prediction)
42
 
43
- st.success(f"Predicted Category: {news_categories[category]}")
44
  else:
45
- st.warning("Please enter a news headline.")
 
1
  import streamlit as st
2
  import tensorflow as tf
3
+ import dill # Use dill instead of pickle for preprocessing function
4
  import numpy as np
 
5
 
6
+ # Load Trained Model
 
7
  model = tf.keras.models.load_model("news_classification_rnn.h5")
8
 
9
  # Load Preprocessing Function
10
  with open("preprocessing1.pkl", "rb") as f:
11
+ clean_text = dill.load(f)
12
 
13
+ # Load Text Vectorization Layer
14
+ with open("text_vectorizer.pkl", "rb") as f:
15
+ vectorizer = dill.load(f)
16
 
17
  # Define News Categories
18
+ news_categories = ["Business", "Sci/Tech", "Sports", "World"]
19
 
20
  # Streamlit UI
21
  st.title("📰 News Classification with Simple RNN")
 
24
  user_input = st.text_area("Enter News Text:", "")
25
 
26
  if st.button("Classify"):
27
+ if user_input.strip():
28
+ # Preprocess Input
29
+ processed_text = clean_text(user_input)
30
 
31
+ # Vectorize Input (Convert text to integer sequence)
32
+ text_sequence = vectorizer([processed_text]) # Directly vectorizes text
33
 
34
+ # Ensure correct shape (model expects batch input)
35
+ text_sequence = np.array(text_sequence)
36
 
37
+ # Make Prediction
38
+ prediction = model.predict(text_sequence)
39
  category = np.argmax(prediction)
40
 
41
+ st.success(f"Predicted Category: **{news_categories[category]}**")
42
  else:
43
+ st.warning("Please enter a news headline.")