Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -183,37 +183,27 @@ st.markdown(
|
|
| 183 |
|
| 184 |
st.markdown("<div class='prompt-box'>Paste the article content below to analyze its category with PressGuard🛡️</div>", unsafe_allow_html=True)
|
| 185 |
|
| 186 |
-
#
|
| 187 |
-
nltk_data_path =
|
| 188 |
if not os.path.exists(nltk_data_path):
|
| 189 |
os.makedirs(nltk_data_path)
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
except LookupError:
|
| 199 |
-
nltk.download('stopwords', download_dir=nltk_data_path)
|
| 200 |
-
|
| 201 |
-
try:
|
| 202 |
-
nltk.data.find('corpora/wordnet')
|
| 203 |
-
except LookupError:
|
| 204 |
-
nltk.download('wordnet', download_dir=nltk_data_path)
|
| 205 |
|
| 206 |
# Initialize stopwords and lemmatizer
|
| 207 |
stop_words = set(stopwords.words('english')).union({"pm"})
|
| 208 |
lemmatizer = WordNetLemmatizer()
|
| 209 |
|
| 210 |
-
import nltk
|
| 211 |
-
nltk.download('punkt')
|
| 212 |
-
nltk.download('stopwords')
|
| 213 |
-
nltk.download('wordnet')
|
| 214 |
-
|
| 215 |
|
| 216 |
-
# Preprocessing Function
|
| 217 |
def pre_process(x):
|
| 218 |
x = x.lower()
|
| 219 |
x = re.sub("<.*?>", "", x)
|
|
@@ -224,32 +214,55 @@ def pre_process(x):
|
|
| 224 |
x = emoji.demojize(x)
|
| 225 |
x = re.sub(":.*?:", "", x)
|
| 226 |
x = re.sub("[^a-zA-Z0-9\\s_]", "", x)
|
|
|
|
| 227 |
words = word_tokenize(x)
|
| 228 |
words = [word for word in words if word not in stop_words]
|
| 229 |
x = " ".join([lemmatizer.lemmatize(word) for word in words])
|
| 230 |
return x
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
@st.cache_resource
|
| 235 |
def load_model():
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
with open("label_encoder_m5.pkl", 'rb') as file:
|
| 239 |
label_encoder = pickle.load(file)
|
|
|
|
| 240 |
return model, vectorizer, label_encoder
|
| 241 |
|
|
|
|
|
|
|
| 242 |
model, vectorizer, label_encoder = load_model()
|
| 243 |
|
| 244 |
-
|
|
|
|
| 245 |
def predict_category(text):
|
| 246 |
processed_text = [pre_process(text)]
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
prediction = model.predict(text_vectorized)
|
| 249 |
category_idx = np.argmax(prediction, axis=1)[0]
|
|
|
|
|
|
|
| 250 |
return label_encoder.inverse_transform([category_idx])[0]
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
| 253 |
input_text = st.text_area("Enter News Article:", height=200)
|
| 254 |
|
| 255 |
if st.button("Analyze", key="analyze-btn", help="Click to classify the news article"):
|
|
@@ -257,4 +270,4 @@ if st.button("Analyze", key="analyze-btn", help="Click to classify the news arti
|
|
| 257 |
category = predict_category(input_text)
|
| 258 |
st.markdown(f"<div class='result-box'>Predicted Category: {category}</div>", unsafe_allow_html=True)
|
| 259 |
else:
|
| 260 |
-
st.warning("Please enter some text to analyze.")
|
|
|
|
| 183 |
|
| 184 |
st.markdown("<div class='prompt-box'>Paste the article content below to analyze its category with PressGuard🛡️</div>", unsafe_allow_html=True)
|
| 185 |
|
| 186 |
+
# Ensure NLTK resources are downloaded in the correct directory
|
| 187 |
+
nltk_data_path = '/root/nltk_data' # Use the correct path in Hugging Face Spaces
|
| 188 |
if not os.path.exists(nltk_data_path):
|
| 189 |
os.makedirs(nltk_data_path)
|
| 190 |
|
| 191 |
+
# Download NLTK resources only if not already present
|
| 192 |
+
for resource in ['punkt', 'stopwords', 'wordnet']:
|
| 193 |
+
try:
|
| 194 |
+
nltk.data.find(f'tokenizers/{resource}' if resource == 'punkt' else f'corpora/{resource}')
|
| 195 |
+
except LookupError:
|
| 196 |
+
nltk.download(resource, download_dir=nltk_data_path)
|
| 197 |
|
| 198 |
+
# Set NLTK data path
|
| 199 |
+
nltk.data.path.append(nltk_data_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
# Initialize stopwords and lemmatizer
|
| 202 |
stop_words = set(stopwords.words('english')).union({"pm"})
|
| 203 |
lemmatizer = WordNetLemmatizer()
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
# ✅ Preprocessing Function
|
| 207 |
def pre_process(x):
|
| 208 |
x = x.lower()
|
| 209 |
x = re.sub("<.*?>", "", x)
|
|
|
|
| 214 |
x = emoji.demojize(x)
|
| 215 |
x = re.sub(":.*?:", "", x)
|
| 216 |
x = re.sub("[^a-zA-Z0-9\\s_]", "", x)
|
| 217 |
+
|
| 218 |
words = word_tokenize(x)
|
| 219 |
words = [word for word in words if word not in stop_words]
|
| 220 |
x = " ".join([lemmatizer.lemmatize(word) for word in words])
|
| 221 |
return x
|
| 222 |
|
| 223 |
+
|
| 224 |
+
# ✅ Load Model and Vectorizer
|
| 225 |
@st.cache_resource
|
| 226 |
def load_model():
|
| 227 |
+
# Load the model
|
| 228 |
+
model = tf.keras.models.load_model("model_m3_new.keras")
|
| 229 |
+
|
| 230 |
+
# Load vectorizer (use pickle or joblib for sklearn models)
|
| 231 |
+
with open("vec_text_m3_new.pkl", 'rb') as file:
|
| 232 |
+
vectorizer = pickle.load(file)
|
| 233 |
+
|
| 234 |
+
# Load label encoder
|
| 235 |
with open("label_encoder_m5.pkl", 'rb') as file:
|
| 236 |
label_encoder = pickle.load(file)
|
| 237 |
+
|
| 238 |
return model, vectorizer, label_encoder
|
| 239 |
|
| 240 |
+
|
| 241 |
+
# Load models
|
| 242 |
model, vectorizer, label_encoder = load_model()
|
| 243 |
|
| 244 |
+
|
| 245 |
+
# ✅ Prediction Function
|
| 246 |
def predict_category(text):
|
| 247 |
processed_text = [pre_process(text)]
|
| 248 |
+
|
| 249 |
+
# Vectorize the input
|
| 250 |
+
text_vectorized = vectorizer.transform(processed_text).toarray()
|
| 251 |
+
|
| 252 |
+
# Pad the sequence
|
| 253 |
+
text_vectorized = pad_sequences(text_vectorized, padding='pre', maxlen=128)
|
| 254 |
+
|
| 255 |
+
# Model prediction
|
| 256 |
prediction = model.predict(text_vectorized)
|
| 257 |
category_idx = np.argmax(prediction, axis=1)[0]
|
| 258 |
+
|
| 259 |
+
# Return the category label
|
| 260 |
return label_encoder.inverse_transform([category_idx])[0]
|
| 261 |
|
| 262 |
+
|
| 263 |
+
# ✅ Streamlit UI
|
| 264 |
+
st.title("AI-Powered News Categorization")
|
| 265 |
+
|
| 266 |
input_text = st.text_area("Enter News Article:", height=200)
|
| 267 |
|
| 268 |
if st.button("Analyze", key="analyze-btn", help="Click to classify the news article"):
|
|
|
|
| 270 |
category = predict_category(input_text)
|
| 271 |
st.markdown(f"<div class='result-box'>Predicted Category: {category}</div>", unsafe_allow_html=True)
|
| 272 |
else:
|
| 273 |
+
st.warning("Please enter some text to analyze.")
|