Mpavan45 commited on
Commit
db9518c
·
verified ·
1 Parent(s): ada421e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -30
app.py CHANGED
@@ -1,42 +1,191 @@
1
  import streamlit as st
2
- import tensorflow as tf
3
  import numpy as np
4
- import dill
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load the trained model with custom layers
7
- from tensorflow.keras.layers import TextVectorization
8
- model = tf.keras.models.load_model("news_classification_rnn1.h5",
9
- custom_objects={"TextVectorization": TextVectorization})
10
 
11
- # Load Preprocessing Function
12
- with open("preprocessing1.pkl", "rb") as f:
13
- clean_text = dill.load(f)
14
 
15
- # Load Text Vectorization Layer
16
- with open("vector.pkl", "rb") as f:
17
- vectorizer = dill.load(f)
18
 
19
- # Define News Categories
20
- news_categories = ["Business", "Sci/Tech", "Sports", "World"]
 
21
 
22
- # Streamlit UI
23
- st.title("📰 News Classification with Simple RNN")
24
- st.write("Enter a news headline to predict its category.")
25
 
26
- user_input = st.text_area("Enter News Text:", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- if st.button("Classify"):
29
- if user_input.strip():
30
- # Preprocess input
31
- processed_text = clean_text(user_input)
32
 
33
- # Vectorize input and convert to numpy array
34
- text_sequence = np.array(vectorizer([processed_text]))
 
35
 
36
- # Predict Category
37
- prediction = model.predict(text_sequence)
38
- category = np.argmax(prediction)
 
 
 
 
 
 
 
39
 
40
- st.success(f"Predicted Category: **{news_categories[category]}**")
41
- else:
42
- st.warning("⚠ Please enter a news headline.")
 
1
  import streamlit as st
 
2
  import numpy as np
3
+ import re
4
+ import emoji
5
+ from textblob import TextBlob
6
+ import spacy
7
+ import nltk
8
+ from nltk.corpus import stopwords
9
+ import tensorflow as tf
10
+ import keras
11
+ from keras.utils import pad_sequences
12
+ import pickle
13
+
14
+ # Page Config
15
+ st.set_page_config(page_title="Newsense AI", page_icon="📰", layout="wide")
16
+
17
+ # Download necessary resources
18
+ nltk.download('stopwords')
19
+
20
+ # Load SpaCy model
21
+ nlp = spacy.load("en_core_web_sm")
22
+
23
+ # Stopwords
24
+ stop_words = set(stopwords.words('english')).union({"pm"})
25
+
26
+ # Pre-processing function (without parentheses extraction)
27
+ def pre_process(x):
28
+ # Convert to lowercase
29
+ x = x.lower()
30
+
31
+ # Remove HTML tags
32
+ x = re.sub(r"<.*?>", "", x)
33
+
34
+ # Remove URLs
35
+ x = re.sub(r"http[s]?://\S+", "", x)
36
+
37
+ # Remove mentions (@, #)
38
+ x = re.sub(r"[@#]\S+", "", x)
39
+
40
+ # Remove emojis
41
+ x = emoji.replace_emoji(x, replace="")
42
+
43
+ # Remove special characters (-, ., :, \, ,)
44
+ x = re.sub(r"[-.:,\\]", " ", x)
45
+
46
+ # Remove single and double quotes
47
+ x = re.sub(r"['\"](.*?)['\"]", r'\1', x)
48
+
49
+ # Remove content inside parentheses
50
+ x = re.sub(r"\(.*?\)", "", x)
51
+
52
+ # Remove extra spaces
53
+ x = re.sub(r"\s+", " ", x).strip()
54
+
55
+ # Spell checking
56
+ x = str(TextBlob(x).correct())
57
+
58
+ # Lemmatization using SpaCy
59
+ x = " ".join([token.lemma_ for token in nlp(x)])
60
+
61
+ return " ".join(x)
62
+
63
+ @st.cache_resource
64
+ def load_model():
65
+ model = keras.models.load_model("model_m3_new.keras")
66
+ with open("label_encoder_m5.pkl", 'rb') as file:
67
+ label_encoder = pickle.load(file)
68
+ return model, label_encoder
69
+
70
+ model, label_encoder = load_model()
71
 
72
+ def predict_category(text):
73
+ cleaned_text = pre_process(text)
 
 
74
 
75
+ vectorizer = keras.models.load_model("vec_text_m3_new.keras")
 
 
76
 
77
+ # Vectorizing the pre-processed text
78
+ text_vectorized = pad_sequences(vectorizer.predict(np.array([cleaned_text])).numpy(), padding='pre', maxlen=128)
 
79
 
80
+ # Model prediction
81
+ prediction = model.predict(text_vectorized)
82
+ category_idx = np.argmax(prediction, axis=1)[0]
83
 
84
+ return label_encoder.inverse_transform([category_idx])[0], cleaned_text
 
 
85
 
86
+ # Custom CSS
87
+ st.markdown(
88
+ """
89
+ <style>
90
+ body {
91
+ background-image: url('https://cdn-uploads.huggingface.co/production/uploads/67441c51a784a9d15cb12871/4FFTjgkYjYUq6w-0gR15v.jpeg');
92
+ background-size: cover;
93
+ background-repeat: no-repeat;
94
+ background-attachment: fixed;
95
+ }
96
+ .title {
97
+ font-size: 60px;
98
+ font-weight: bold;
99
+ color: white;
100
+ background: linear-gradient(60deg, #880E4F, #4A235A, #311B92, #000000);
101
+ padding: 20px;
102
+ border-radius: 20px;
103
+ box-shadow: 0 8px 25px rgba(136, 14, 79, 0.5),
104
+ 0 4px 15px rgba(74, 35, 90, 0.6);
105
+ display: inline-block;
106
+ margin-bottom: 20px;
107
+ text-align: center;
108
+ animation: elegantFadeSlide 1.5s ease-out forwards;
109
+ }
110
+ .input-box {
111
+ display: flex;
112
+ flex-direction: column;
113
+ align-items: center;
114
+ gap: 20px;
115
+ margin: 0 auto;
116
+ width: 80%;
117
+ }
118
+ .input-prompt {
119
+ font-size: 22px;
120
+ font-weight: bold;
121
+ color: #ffffff;
122
+ text-align: center;
123
+ opacity: 0.8;
124
+ }
125
+ div.stTextArea textarea {
126
+ width: 100%;
127
+ height: 200px;
128
+ padding: 20px;
129
+ border-radius: 15px;
130
+ background-color: rgba(0, 0, 0, 0.7);
131
+ color: white;
132
+ font-size: 18px;
133
+ outline: none;
134
+ box-shadow: 0 6px 20px rgba(136, 14, 79, 0.3);
135
+ transition: all 0.5s ease;
136
+ }
137
+ div.stTextArea textarea:hover {
138
+ transform: scale(1.05);
139
+ box-shadow: 0 10px 30px rgba(136, 14, 79, 0.5);
140
+ }
141
+ .analyze-button {
142
+ width: 200px;
143
+ height: 60px;
144
+ border-radius: 30px;
145
+ background: linear-gradient(45deg, #880E4F, #4A235A, #311B92, #000000);
146
+ font-size: 20px;
147
+ font-weight: bold;
148
+ color: white;
149
+ border: none;
150
+ cursor: pointer;
151
+ transition: all 0.4s ease;
152
+ }
153
+ .analyze-button:hover {
154
+ transform: scale(1.1);
155
+ box-shadow: 0 12px 35px rgba(49, 27, 146, 0.8);
156
+ }
157
+ .result-box {
158
+ text-align: center;
159
+ font-size: 28px;
160
+ font-weight: bold;
161
+ color: white;
162
+ background: linear-gradient(60deg, #880E4F, #4A235A, #311B92, #000000);
163
+ padding: 30px;
164
+ border-radius: 20px;
165
+ box-shadow: 0 6px 20px rgba(74, 35, 90, 0.5);
166
+ margin-top: 30px;
167
+ }
168
+ </style>
169
+ """,
170
+ unsafe_allow_html=True
171
+ )
172
 
173
+ # Streamlit UI layout
174
+ st.markdown('<div class="title">📰 Newsense AI - News Classification</div>', unsafe_allow_html=True)
 
 
175
 
176
+ # Input and button section
177
+ st.markdown('<div class="input-box">', unsafe_allow_html=True)
178
+ user_input = st.text_area("Enter your news article:", height=200)
179
 
180
+ # Predict button
181
+ if st.button("Classify", key="analyze-button"):
182
+ if user_input:
183
+ category, cleaned_text = predict_category(user_input)
184
+
185
+ # Display the prediction and cleaned text
186
+ st.markdown(f'<div class="result-box">Prediction: {category}</div>', unsafe_allow_html=True)
187
+ st.markdown(f'<div class="result-box">Cleaned Text: {cleaned_text}</div>', unsafe_allow_html=True)
188
+ else:
189
+ st.warning("Please enter some text to classify!")
190
 
191
+ st.markdown('</div>', unsafe_allow_html=True)