Perth0603 commited on
Commit
54fa158
·
verified ·
1 Parent(s): 6a642c0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -13
app.py CHANGED
@@ -27,9 +27,10 @@ except Exception:
27
  # Environment / config
28
  # -------------------------
29
  MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
30
- URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection")
31
- URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") # model|space|dataset
32
- URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib")
 
33
  CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
34
  os.makedirs(CACHE_DIR, exist_ok=True)
35
 
@@ -185,10 +186,47 @@ def _auto_calibrate_phish_positive(bundle: Dict[str, Any], feature_cols: List[st
185
  phishy = _AUTOCALIB_PHISHY_URLS
186
  legit = _AUTOCALIB_LEGIT_URLS
187
 
 
 
 
188
  def _batch_mean(urls: List[str]) -> float:
189
  df = pd.DataFrame({url_col: urls})
190
- f = _engineer_features(df, url_col, feature_cols)
191
- return float(np.mean([_xgb_predict_class1_prob(bundle["model"], pd.DataFrame([f.iloc[i]])) for i in range(len(f))]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  try:
194
  phishy_mean = _batch_mean(phishy)
@@ -212,16 +250,15 @@ def _startup():
212
  print(f"[startup] text model load failed: {e}")
213
  try:
214
  _load_url_model()
215
- # Calibrate for XGB if needed
216
  global _url_phish_is_positive
217
  b = _url_bundle
218
- if isinstance(b, dict) and b.get("model_type") == "xgboost_bst" and _url_phish_is_positive is None:
219
- if xgb is None:
220
- print("[startup] xgboost not installed; cannot calibrate URL model.")
221
- else:
222
  feature_cols: List[str] = b.get("feature_cols") or []
223
  url_col: str = b.get("url_col") or "url"
224
  _url_phish_is_positive = _auto_calibrate_phish_positive(b, feature_cols, url_col)
 
 
225
  except Exception as e:
226
  print(f"[startup] url model load failed: {e}")
227
 
@@ -288,10 +325,12 @@ def predict_url(payload: PredictUrlPayload):
288
  elif meta_phish_is_positive is not None:
289
  phish_is_positive = bool(meta_phish_is_positive)
290
  else:
291
- # If not yet calibrated, do it now for xgb
292
  global _url_phish_is_positive
293
- if _url_phish_is_positive is None and model_type == "xgboost_bst" and xgb is not None:
294
- _url_phish_is_positive = _auto_calibrate_phish_positive(bundle, feature_cols, url_col)
 
 
 
295
  phish_is_positive = _url_phish_is_positive if _url_phish_is_positive is not None else True
296
 
297
  backend_debug = {
 
27
  # Environment / config
28
  # -------------------------
29
  MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
30
+ # Support both legacy and HF_* envs
31
+ URL_REPO = os.environ.get("HF_URL_MODEL_ID", os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection"))
32
+ URL_REPO_TYPE = os.environ.get("HF_URL_REPO_TYPE", os.environ.get("URL_REPO_TYPE", "model")) # model|space|dataset
33
+ URL_FILENAME = os.environ.get("HF_URL_FILENAME", os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib"))
34
  CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
35
  os.makedirs(CACHE_DIR, exist_ok=True)
36
 
 
186
  phishy = _AUTOCALIB_PHISHY_URLS
187
  legit = _AUTOCALIB_LEGIT_URLS
188
 
189
+ model = bundle.get("model")
190
+ model_type: str = str(bundle.get("model_type") or "")
191
+
192
  def _batch_mean(urls: List[str]) -> float:
193
  df = pd.DataFrame({url_col: urls})
194
+ feats = _engineer_features(df, url_col, feature_cols)
195
+ # XGBoost booster path
196
+ if model_type == "xgboost_bst" and xgb is not None:
197
+ try:
198
+ # Predict row-by-row to be conservative about input formats
199
+ return float(np.mean([_xgb_predict_class1_prob(model, pd.DataFrame([feats.iloc[i]])) for i in range(len(feats))]))
200
+ except Exception:
201
+ pass
202
+ # scikit-learn-like path with predict_proba
203
+ if hasattr(model, "predict_proba"):
204
+ proba = model.predict_proba(feats)
205
+ classes = bundle.get("classes", getattr(model, "classes_", None))
206
+ class1_idx = 1
207
+ if classes is not None:
208
+ try:
209
+ classes_list = list(classes)
210
+ if 1 in classes_list:
211
+ class1_idx = classes_list.index(1)
212
+ else:
213
+ class1_idx = 1 if len(classes_list) > 1 else 0
214
+ except Exception:
215
+ class1_idx = 1 if proba.shape[1] > 1 else 0
216
+ return float(np.mean(proba[:, class1_idx]))
217
+ # Fallback: use hard predictions and treat label==1 as prob 1
218
+ try:
219
+ preds = model.predict(feats)
220
+ vals: List[float] = []
221
+ for p in preds:
222
+ if isinstance(p, (int, float, np.integer, np.floating)):
223
+ vals.append(1.0 if int(p) == 1 else 0.0)
224
+ else:
225
+ up = str(p).strip().upper()
226
+ vals.append(1.0 if up.startswith("PHISH") or up == "1" else 0.0)
227
+ return float(np.mean(vals)) if vals else 0.0
228
+ except Exception:
229
+ return 0.0
230
 
231
  try:
232
  phishy_mean = _batch_mean(phishy)
 
250
  print(f"[startup] text model load failed: {e}")
251
  try:
252
  _load_url_model()
 
253
  global _url_phish_is_positive
254
  b = _url_bundle
255
+ if isinstance(b, dict) and _url_phish_is_positive is None:
256
+ try:
 
 
257
  feature_cols: List[str] = b.get("feature_cols") or []
258
  url_col: str = b.get("url_col") or "url"
259
  _url_phish_is_positive = _auto_calibrate_phish_positive(b, feature_cols, url_col)
260
+ except Exception as ce:
261
+ print(f"[startup] url model calibration failed: {ce}")
262
  except Exception as e:
263
  print(f"[startup] url model load failed: {e}")
264
 
 
325
  elif meta_phish_is_positive is not None:
326
  phish_is_positive = bool(meta_phish_is_positive)
327
  else:
 
328
  global _url_phish_is_positive
329
+ if _url_phish_is_positive is None:
330
+ try:
331
+ _url_phish_is_positive = _auto_calibrate_phish_positive(bundle, feature_cols, url_col)
332
+ except Exception as ce:
333
+ print(f"[predict-url] auto-calibration failed: {ce}")
334
  phish_is_positive = _url_phish_is_positive if _url_phish_is_positive is not None else True
335
 
336
  backend_debug = {