Mpavan45 commited on
Commit
b0a4e3e
·
verified ·
1 Parent(s): c3da79e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -183,37 +183,27 @@ st.markdown(
183
 
184
  st.markdown("<div class='prompt-box'>Paste the article content below to analyze its category with PressGuard🛡️</div>", unsafe_allow_html=True)
185
 
186
- # Check if NLTK resources are already downloaded
187
- nltk_data_path = os.path.expanduser('~/nltk_data')
188
  if not os.path.exists(nltk_data_path):
189
  os.makedirs(nltk_data_path)
190
 
191
- try:
192
- nltk.data.find('tokenizers/punkt')
193
- except LookupError:
194
- nltk.download('punkt', download_dir=nltk_data_path)
 
 
195
 
196
- try:
197
- nltk.data.find('corpora/stopwords')
198
- except LookupError:
199
- nltk.download('stopwords', download_dir=nltk_data_path)
200
-
201
- try:
202
- nltk.data.find('corpora/wordnet')
203
- except LookupError:
204
- nltk.download('wordnet', download_dir=nltk_data_path)
205
 
206
  # Initialize stopwords and lemmatizer
207
  stop_words = set(stopwords.words('english')).union({"pm"})
208
  lemmatizer = WordNetLemmatizer()
209
 
210
- import nltk
211
- nltk.download('punkt')
212
- nltk.download('stopwords')
213
- nltk.download('wordnet')
214
-
215
 
216
- # Preprocessing Function
217
  def pre_process(x):
218
  x = x.lower()
219
  x = re.sub("<.*?>", "", x)
@@ -224,32 +214,55 @@ def pre_process(x):
224
  x = emoji.demojize(x)
225
  x = re.sub(":.*?:", "", x)
226
  x = re.sub("[^a-zA-Z0-9\\s_]", "", x)
 
227
  words = word_tokenize(x)
228
  words = [word for word in words if word not in stop_words]
229
  x = " ".join([lemmatizer.lemmatize(word) for word in words])
230
  return x
231
 
232
- # Load
233
- rm -rf /home/user/nltk_data
234
  @st.cache_resource
235
  def load_model():
236
- model = keras.models.load_model("model_m3_new.keras")
237
- vectorizer = keras.models.load_model("vec_text_m3_new.keras")
 
 
 
 
 
 
238
  with open("label_encoder_m5.pkl", 'rb') as file:
239
  label_encoder = pickle.load(file)
 
240
  return model, vectorizer, label_encoder
241
 
 
 
242
  model, vectorizer, label_encoder = load_model()
243
 
244
- # Prediction Function
 
245
  def predict_category(text):
246
  processed_text = [pre_process(text)]
247
- text_vectorized = pad_sequences(vectorizer(processed_text).numpy().tolist(), padding='pre', maxlen=128)
 
 
 
 
 
 
 
248
  prediction = model.predict(text_vectorized)
249
  category_idx = np.argmax(prediction, axis=1)[0]
 
 
250
  return label_encoder.inverse_transform([category_idx])[0]
251
 
252
- # User Input
 
 
 
253
  input_text = st.text_area("Enter News Article:", height=200)
254
 
255
  if st.button("Analyze", key="analyze-btn", help="Click to classify the news article"):
@@ -257,4 +270,4 @@ if st.button("Analyze", key="analyze-btn", help="Click to classify the news arti
257
  category = predict_category(input_text)
258
  st.markdown(f"<div class='result-box'>Predicted Category: {category}</div>", unsafe_allow_html=True)
259
  else:
260
- st.warning("Please enter some text to analyze.")
 
183
 
184
  st.markdown("<div class='prompt-box'>Paste the article content below to analyze its category with PressGuard🛡️</div>", unsafe_allow_html=True)
185
 
186
+ # Ensure NLTK resources are downloaded in the correct directory
187
+ nltk_data_path = '/root/nltk_data' # Use the correct path in Hugging Face Spaces
188
  if not os.path.exists(nltk_data_path):
189
  os.makedirs(nltk_data_path)
190
 
191
+ # Download NLTK resources only if not already present
192
+ for resource in ['punkt', 'stopwords', 'wordnet']:
193
+ try:
194
+ nltk.data.find(f'tokenizers/{resource}' if resource == 'punkt' else f'corpora/{resource}')
195
+ except LookupError:
196
+ nltk.download(resource, download_dir=nltk_data_path)
197
 
198
+ # Set NLTK data path
199
+ nltk.data.path.append(nltk_data_path)
 
 
 
 
 
 
 
200
 
201
  # Initialize stopwords and lemmatizer
202
  stop_words = set(stopwords.words('english')).union({"pm"})
203
  lemmatizer = WordNetLemmatizer()
204
 
 
 
 
 
 
205
 
206
+ # Preprocessing Function
207
  def pre_process(x):
208
  x = x.lower()
209
  x = re.sub("<.*?>", "", x)
 
214
  x = emoji.demojize(x)
215
  x = re.sub(":.*?:", "", x)
216
  x = re.sub("[^a-zA-Z0-9\\s_]", "", x)
217
+
218
  words = word_tokenize(x)
219
  words = [word for word in words if word not in stop_words]
220
  x = " ".join([lemmatizer.lemmatize(word) for word in words])
221
  return x
222
 
223
+
224
+ # Load Model and Vectorizer
225
  @st.cache_resource
226
  def load_model():
227
+ # Load the model
228
+ model = tf.keras.models.load_model("model_m3_new.keras")
229
+
230
+ # Load vectorizer (use pickle or joblib for sklearn models)
231
+ with open("vec_text_m3_new.pkl", 'rb') as file:
232
+ vectorizer = pickle.load(file)
233
+
234
+ # Load label encoder
235
  with open("label_encoder_m5.pkl", 'rb') as file:
236
  label_encoder = pickle.load(file)
237
+
238
  return model, vectorizer, label_encoder
239
 
240
+
241
+ # Load models
242
  model, vectorizer, label_encoder = load_model()
243
 
244
+
245
+ # ✅ Prediction Function
246
  def predict_category(text):
247
  processed_text = [pre_process(text)]
248
+
249
+ # Vectorize the input
250
+ text_vectorized = vectorizer.transform(processed_text).toarray()
251
+
252
+ # Pad the sequence
253
+ text_vectorized = pad_sequences(text_vectorized, padding='pre', maxlen=128)
254
+
255
+ # Model prediction
256
  prediction = model.predict(text_vectorized)
257
  category_idx = np.argmax(prediction, axis=1)[0]
258
+
259
+ # Return the category label
260
  return label_encoder.inverse_transform([category_idx])[0]
261
 
262
+
263
+ # ✅ Streamlit UI
264
+ st.title("AI-Powered News Categorization")
265
+
266
  input_text = st.text_area("Enter News Article:", height=200)
267
 
268
  if st.button("Analyze", key="analyze-btn", help="Click to classify the news article"):
 
270
  category = predict_category(input_text)
271
  st.markdown(f"<div class='result-box'>Predicted Category: {category}</div>", unsafe_allow_html=True)
272
  else:
273
+ st.warning("Please enter some text to analyze.")