Adityaganesh commited on
Commit
294f97f
Β·
verified Β·
1 Parent(s): c81c93f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -3,18 +3,22 @@ import emoji
3
  import nltk
4
  import numpy as np
5
  import streamlit as st
 
 
6
  from nltk.tokenize import word_tokenize
7
  from nltk.stem import WordNetLemmatizer
8
  from tensorflow.keras.models import load_model
9
  from tensorflow.keras.preprocessing.sequence import pad_sequences
10
- from tensorflow.keras.preprocessing.text import Tokenizer
11
 
12
  # Ensure necessary downloads
13
  nltk.download("punkt")
14
  nltk.download("wordnet")
 
 
15
 
16
  lemmatizer = WordNetLemmatizer()
17
 
 
18
  def pre_process(x):
19
  x = x.lower()
20
  x = re.sub("<.*?>", "", x)
@@ -32,13 +36,17 @@ def pre_process(x):
32
  # Load trained model
33
  model = load_model("best_rnn_model.h5")
34
 
35
- # Tokenizer (Ensure this matches the one used during training)
36
- MAX_LENGTH = 100 # Set this to the same max length used in training
37
- tokenizer = Tokenizer() # Load your trained tokenizer here
 
 
 
38
 
39
  # Class labels
40
  class_labels = ['Sports', 'Business', 'SciTech', 'World']
41
 
 
42
  def predict_category(text):
43
  processed_text = pre_process(text)
44
  seq = tokenizer.texts_to_sequences([processed_text])
@@ -48,14 +56,14 @@ def predict_category(text):
48
  return predicted_label
49
 
50
  # Streamlit UI
51
- st.title("News Category Classifier")
52
  st.write("Enter a news headline or article snippet, and the model will predict its category.")
53
 
54
- user_input = st.text_area("Enter text here:")
55
 
56
- if st.button("Predict"):
57
  if user_input.strip():
58
  prediction = predict_category(user_input)
59
- st.success(f"Predicted Category: {prediction}")
60
  else:
61
- st.warning("Please enter some text to classify.")
 
3
  import nltk
4
  import numpy as np
5
  import streamlit as st
6
+ import pickle # To load the tokenizer
7
+
8
  from nltk.tokenize import word_tokenize
9
  from nltk.stem import WordNetLemmatizer
10
  from tensorflow.keras.models import load_model
11
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
12
 
13
  # Ensure necessary downloads
14
  nltk.download("punkt")
15
  nltk.download("wordnet")
16
+ nltk.download("omw-1.4")
17
+ nltk.download("averaged_perceptron_tagger")
18
 
19
  lemmatizer = WordNetLemmatizer()
20
 
21
+ # Function to preprocess text
22
  def pre_process(x):
23
  x = x.lower()
24
  x = re.sub("<.*?>", "", x)
 
36
  # Load trained model
37
  model = load_model("best_rnn_model.h5")
38
 
39
+ # Load the same tokenizer used during training
40
+ with open("tokenizer.pickle", "rb") as handle:
41
+ tokenizer = pickle.load(handle)
42
+
43
+ # Maximum length (must match training settings)
44
+ MAX_LENGTH = 100
45
 
46
  # Class labels
47
  class_labels = ['Sports', 'Business', 'SciTech', 'World']
48
 
49
+ # Function to predict category
50
  def predict_category(text):
51
  processed_text = pre_process(text)
52
  seq = tokenizer.texts_to_sequences([processed_text])
 
56
  return predicted_label
57
 
58
  # Streamlit UI
59
+ st.title("πŸ“° News Category Classifier")
60
  st.write("Enter a news headline or article snippet, and the model will predict its category.")
61
 
62
+ user_input = st.text_area("✍ Enter text here:")
63
 
64
+ if st.button("πŸ” Predict"):
65
  if user_input.strip():
66
  prediction = predict_category(user_input)
67
+ st.success(f"πŸ“Œ Predicted Category: **{prediction}**")
68
  else:
69
+ st.warning("⚠️ Please enter some text to classify.")