Thanut003 commited on
Commit
76c0de9
·
verified ·
1 Parent(s): 05f284d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -28
app.py CHANGED
@@ -130,8 +130,12 @@ import re
130
  import nltk
131
  import numpy as np
132
  import traceback
 
133
 
134
  # --- 1. SETUP ---
 
 
 
135
  from khmernltk import word_tokenize
136
 
137
  # NLTK Setup
@@ -174,18 +178,11 @@ def khmer_tokenize(text):
174
 
175
  return " ".join(processed_tokens)
176
 
177
- # --- 2. LOAD MODELS ---
178
- print("Loading processors...")
179
- try:
180
- vectorizer = joblib.load("tfidf_vectorizer.joblib")
181
- svd = joblib.load("truncated_svd.joblib")
182
- print("✅ Vectorizer & SVD loaded")
183
- except Exception as e:
184
- print(f"❌ CRITICAL LOAD ERROR: {e}")
185
- vectorizer = None
186
- svd = None
187
-
188
- models = {}
189
  model_files = {
190
  "XGBoost": "xgboost_model.joblib",
191
  "LightGBM": "lightgbm_model.joblib",
@@ -194,31 +191,49 @@ model_files = {
194
  "Linear SVM": "linear_svm_model.joblib"
195
  }
196
 
197
- for name, filename in model_files.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  try:
199
- models[name] = joblib.load(filename)
200
- print(f"✅ Loaded {name}")
 
 
 
201
  except Exception as e:
202
- print(f"⚠️ Skipping {name}: {e}")
 
203
 
204
  # --- 3. PREDICTION FUNCTION ---
205
  def predict(text, model_name):
206
  if not text:
207
  return "Please enter text", {}, []
208
 
209
- if model_name not in models:
210
- return "Model not found", {}, []
 
211
 
212
- if vectorizer is None or svd is None:
213
- return "Vectorizers not loaded", {}, []
 
214
 
215
  try:
216
  processed = khmer_tokenize(text)
217
  vectors = vectorizer.transform([processed])
218
  vectors_reduced = svd.transform(vectors)
219
 
220
- current_model = models[model_name]
221
-
222
  # --- Keyword Extraction ---
223
  feature_array = np.array(vectorizer.get_feature_names_out())
224
  tfidf_sorting = np.argsort(vectors.toarray()).flatten()[::-1]
@@ -233,21 +248,34 @@ def predict(text, model_name):
233
  confidences = {}
234
  top_label = ""
235
 
 
236
  if hasattr(current_model, "predict_proba"):
237
  probas = current_model.predict_proba(vectors_reduced)[0]
238
  for i in range(len(LABELS)):
239
  if i < len(probas):
240
  confidences[LABELS[i]] = float(probas[i])
241
  top_label = max(confidences, key=confidences.get)
 
 
242
  else:
243
  raw_pred = current_model.predict(vectors_reduced)[0]
244
- if isinstance(raw_pred, (int, np.integer, float, np.floating)):
 
 
 
 
 
 
 
 
245
  pred_idx = int(raw_pred)
246
  top_label = LABELS[pred_idx]
247
- confidences = {LABELS[pred_idx]: 1.0}
 
248
  else:
249
  top_label = str(raw_pred)
250
- confidences = {top_label: 1.0}
 
251
 
252
  return top_label, confidences, keywords
253
 
@@ -260,7 +288,7 @@ demo = gr.Interface(
260
  fn=predict,
261
  inputs=[
262
  gr.Textbox(lines=5, placeholder="Enter Khmer news text here...", label="Input Text"),
263
- gr.Dropdown(choices=list(models.keys()), value="XGBoost", label="Select Model")
264
  ],
265
  outputs=[
266
  gr.Label(label="Top Prediction"),
@@ -272,5 +300,4 @@ demo = gr.Interface(
272
  )
273
 
274
  if __name__ == "__main__":
275
- # ssr_mode=False is CRITICAL for stability on HF Spaces currently
276
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
130
  import nltk
131
  import numpy as np
132
  import traceback
133
+ import warnings
134
 
135
  # --- 1. SETUP ---
136
+ # Filter out the harmless warnings from the logs
137
+ warnings.filterwarnings("ignore")
138
+
139
  from khmernltk import word_tokenize
140
 
141
  # NLTK Setup
 
178
 
179
  return " ".join(processed_tokens)
180
 
181
+ # --- 2. LAZY LOADING MODELS (Prevents Crashing) ---
182
+ vectorizer = None
183
+ svd = None
184
+ models_cache = {}
185
+
 
 
 
 
 
 
 
186
  model_files = {
187
  "XGBoost": "xgboost_model.joblib",
188
  "LightGBM": "lightgbm_model.joblib",
 
191
  "Linear SVM": "linear_svm_model.joblib"
192
  }
193
 
194
+ def load_vectorizers():
195
+ global vectorizer, svd
196
+ if vectorizer is None:
197
+ try:
198
+ vectorizer = joblib.load("tfidf_vectorizer.joblib")
199
+ svd = joblib.load("truncated_svd.joblib")
200
+ except Exception as e:
201
+ print(f"Error loading vectorizers: {e}")
202
+ return False
203
+ return True
204
+
205
+ def get_model(name):
206
+ if name in models_cache:
207
+ return models_cache[name]
208
+
209
  try:
210
+ filename = model_files.get(name)
211
+ if not filename: return None
212
+ loaded_model = joblib.load(filename)
213
+ models_cache[name] = loaded_model
214
+ return loaded_model
215
  except Exception as e:
216
+ print(f"Error loading {name}: {e}")
217
+ return None
218
 
219
  # --- 3. PREDICTION FUNCTION ---
220
  def predict(text, model_name):
221
  if not text:
222
  return "Please enter text", {}, []
223
 
224
+ # Load resources only when needed
225
+ if not load_vectorizers():
226
+ return "System Error: Vectorizers missing", {}, []
227
 
228
+ current_model = get_model(model_name)
229
+ if current_model is None:
230
+ return f"Error: Could not load {model_name}", {}, []
231
 
232
  try:
233
  processed = khmer_tokenize(text)
234
  vectors = vectorizer.transform([processed])
235
  vectors_reduced = svd.transform(vectors)
236
 
 
 
237
  # --- Keyword Extraction ---
238
  feature_array = np.array(vectorizer.get_feature_names_out())
239
  tfidf_sorting = np.argsort(vectors.toarray()).flatten()[::-1]
 
248
  confidences = {}
249
  top_label = ""
250
 
251
+ # CASE A: Models that support probabilities (XGBoost, RF, etc.)
252
  if hasattr(current_model, "predict_proba"):
253
  probas = current_model.predict_proba(vectors_reduced)[0]
254
  for i in range(len(LABELS)):
255
  if i < len(probas):
256
  confidences[LABELS[i]] = float(probas[i])
257
  top_label = max(confidences, key=confidences.get)
258
+
259
+ # CASE B: Models without probabilities (SVM, etc.)
260
  else:
261
  raw_pred = current_model.predict(vectors_reduced)[0]
262
+
263
+ # --- FIX FOR [1 0 0 0] ISSUE ---
264
+ # If the model returns an array/list (One-Hot Encoded), we find the index of the highest value
265
+ if isinstance(raw_pred, (list, np.ndarray)):
266
+ pred_idx = np.argmax(raw_pred)
267
+ top_label = LABELS[pred_idx]
268
+
269
+ # If the model returns a simple integer number
270
+ elif isinstance(raw_pred, (int, np.integer, float, np.floating)):
271
  pred_idx = int(raw_pred)
272
  top_label = LABELS[pred_idx]
273
+
274
+ # Fallback (rare)
275
  else:
276
  top_label = str(raw_pred)
277
+
278
+ confidences = {top_label: 1.0}
279
 
280
  return top_label, confidences, keywords
281
 
 
288
  fn=predict,
289
  inputs=[
290
  gr.Textbox(lines=5, placeholder="Enter Khmer news text here...", label="Input Text"),
291
+ gr.Dropdown(choices=list(model_files.keys()), value="XGBoost", label="Select Model")
292
  ],
293
  outputs=[
294
  gr.Label(label="Top Prediction"),
 
300
  )
301
 
302
  if __name__ == "__main__":
303
+ demo.launch()