Perth0603 commited on
Commit
311de59
·
verified ·
1 Parent(s): 6d68a85

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -48
app.py CHANGED
@@ -21,7 +21,7 @@ from huggingface_hub import hf_hub_download
21
  try:
22
  import xgboost as xgb # type: ignore
23
  except Exception:
24
- xgb = None # optional; required if bundle uses xgboost
25
 
26
  # -------------------------
27
  # Environment / config
@@ -30,14 +30,16 @@ MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
30
  URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection")
31
  URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") # model|space|dataset
32
  URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib")
33
-
34
  CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
35
  os.makedirs(CACHE_DIR, exist_ok=True)
36
 
37
- # You can control torch threads for tiny machines
38
  torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "1")))
39
 
40
- app = FastAPI(title="Phishing Text & URL Classifier", version="1.1.0")
 
 
 
41
 
42
  # -------------------------
43
  # Schemas
@@ -56,20 +58,20 @@ _model: Optional[AutoModelForSequenceClassification] = None
56
  _id2label: Dict[int, str] = {0: "LEGIT", 1: "PHISH"}
57
  _label2id: Dict[str, int] = {"LEGIT": 0, "PHISH": 1}
58
 
59
- _url_bundle: Optional[Dict[str, Any]] = None # {model, feature_cols, url_col, model_type, ...}
60
-
61
  _model_lock = threading.Lock()
62
  _url_lock = threading.Lock()
63
 
 
 
 
64
  # -------------------------
65
- # URL feature engineering
66
- # (must match training)
67
  # -------------------------
68
  _SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]
69
  _ipv4_pattern = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
70
 
71
  def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[List[str]] = None) -> pd.DataFrame:
72
- # Be robust to NaNs and non-strings
73
  s = df[url_col].astype(str).fillna("")
74
  out = pd.DataFrame(index=df.index)
75
  out["url_len"] = s.str.len()
@@ -87,7 +89,7 @@ def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[Li
87
  out["starts_https"] = s.str.startswith("https").astype(int)
88
  out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
89
  out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
90
- return out if feature_cols is None or len(feature_cols) == 0 else out[feature_cols]
91
 
92
  # -------------------------
93
  # Loaders
@@ -99,12 +101,10 @@ def _load_model():
99
  if _tokenizer is None or _model is None:
100
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
101
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
102
- # Use model-config labels if present
103
  cfg = getattr(_model, "config", None)
104
  if cfg is not None and getattr(cfg, "id2label", None):
105
  _id2label = {int(k): v for k, v in cfg.id2label.items()}
106
  _label2id = {v: int(k) for k, v in _id2label.items()}
107
- # Warm-up
108
  with torch.no_grad():
109
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
110
 
@@ -113,7 +113,6 @@ def _load_url_model():
113
  if _url_bundle is None:
114
  with _url_lock:
115
  if _url_bundle is None:
116
- # Prefer local artifact if present (e.g., committed into the Space)
117
  local_path = os.path.join(os.getcwd(), URL_FILENAME)
118
  if os.path.exists(local_path):
119
  _url_bundle = joblib.load(local_path)
@@ -126,7 +125,59 @@ def _load_url_model():
126
  )
127
  _url_bundle = joblib.load(model_path)
128
 
129
- # Optional: try warm loading on startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  @app.on_event("startup")
131
  def _startup():
132
  try:
@@ -135,6 +186,16 @@ def _startup():
135
  print(f"[startup] text model load failed: {e}")
136
  try:
137
  _load_url_model()
 
 
 
 
 
 
 
 
 
 
138
  except Exception as e:
139
  print(f"[startup] url model load failed: {e}")
140
 
@@ -182,77 +243,87 @@ def predict_url(payload: PredictUrlPayload):
182
  row = pd.DataFrame({url_col: [url_str]})
183
  feats = _engineer_features(row, url_col, feature_cols)
184
 
185
- # Standardize on producing P(PHISH) first
186
- phish_proba: Optional[float] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  if isinstance(model_type, str) and model_type == "xgboost_bst":
189
  if xgb is None:
190
  raise RuntimeError("xgboost is not installed but required for this model bundle.")
191
  dmat = xgb.DMatrix(feats)
192
- raw_p_class1 = float(model.predict(dmat)[0]) # probability of class "1" under binary:logistic
193
- # Interpret polarity using bundle metadata
194
- phish_is_positive = bool(bundle.get("phish_is_positive", True)) # default: class 1 = PHISH
195
- if phish_is_positive:
196
- phish_proba = raw_p_class1
197
- else:
198
- phish_proba = 1.0 - raw_p_class1
199
 
200
  elif hasattr(model, "predict_proba"):
201
  proba = model.predict_proba(feats)[0]
202
- # Try to discover which column corresponds to PHISH
203
  classes = bundle.get("classes", getattr(model, "classes_", None))
204
- label_map = bundle.get("label_map") # e.g., {0:"LEGIT", 1:"PHISH"} or reversed
205
  if classes is not None and len(proba) == 2:
206
  classes_list = list(classes)
207
  phish_idx = None
208
  if isinstance(label_map, dict):
209
- # Find class whose mapped label starts with "PHISH"
210
  for i, c in enumerate(classes_list):
211
  mapped = str(label_map.get(int(c), "")).upper()
212
  if mapped.startswith("PHISH"):
213
  phish_idx = i
214
  break
215
  if phish_idx is None:
216
- # fallback: assume class '1' is PHISH if present; else index 1
217
- phish_idx = classes_list.index(1) if 1 in classes_list else 1
 
 
 
 
 
218
  phish_proba = float(proba[phish_idx])
219
  else:
220
- # Unknown multi-class; best-effort: use index 1 if exists, else argmax
221
  phish_proba = float(proba[1]) if len(proba) > 1 else float(np.max(proba))
222
 
223
  else:
224
- # Plain predict interface: try to coerce to PHISH/LEGIT
225
  pred = model.predict(feats)[0]
226
  if isinstance(pred, (int, float, np.integer, np.floating)):
227
- # Assume numeric label space where 1 = PHISH by default
228
  label_numeric = int(pred)
229
- phish_proba = 1.0 if label_numeric == 1 else 0.0
 
 
 
 
230
  else:
231
  up = str(pred).strip().upper()
232
- if up in ("PHISH", "PHISHING", "MALICIOUS"):
233
- phish_proba = 1.0
234
- else:
235
- phish_proba = 0.0
236
-
237
- # Safety: ensure probability is a float
238
- phish_proba = float(phish_proba or 0.0)
239
 
240
- # Derive human label and a display confidence:
241
- # - label: PHISH if P>=0.5 else LEGIT
242
- # - score: confidence for the chosen label (mirrors text endpoint behavior)
243
  label = "PHISH" if phish_proba >= 0.5 else "LEGIT"
244
- display_score = phish_proba if label == "PHISH" else (1.0 - phish_proba)
245
 
246
- # Helpful metadata for debugging polarity mismatches
247
  return {
248
  "label": label,
249
- "score": float(display_score),
250
  "phishing_probability": float(phish_proba),
251
  "backend": str(model_type),
252
  "threshold": 0.5,
253
- "phish_is_positive": bool(bundle.get("phish_is_positive", True)),
254
- "classes": bundle.get("classes"),
255
- "label_map": bundle.get("label_map"),
 
256
  "feature_cols": feature_cols,
257
  "url_col": url_col,
258
  }
 
21
  try:
22
  import xgboost as xgb # type: ignore
23
  except Exception:
24
+ xgb = None
25
 
26
  # -------------------------
27
  # Environment / config
 
30
  URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection")
31
  URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") # model|space|dataset
32
  URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib")
 
33
  CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
34
  os.makedirs(CACHE_DIR, exist_ok=True)
35
 
36
+ # Force-thread cap helps tiny Spaces
37
  torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "1")))
38
 
39
+ # Optional manual override (beats everything): "PHISH" or "LEGIT"
40
+ URL_POSITIVE_CLASS_ENV = os.environ.get("URL_POSITIVE_CLASS", "").strip().upper() # "", "PHISH", "LEGIT"
41
+
42
+ app = FastAPI(title="PhishWatch API", version="1.2.0")
43
 
44
  # -------------------------
45
  # Schemas
 
58
  _id2label: Dict[int, str] = {0: "LEGIT", 1: "PHISH"}
59
  _label2id: Dict[str, int] = {"LEGIT": 0, "PHISH": 1}
60
 
61
+ _url_bundle: Optional[Dict[str, Any]] = None
 
62
  _model_lock = threading.Lock()
63
  _url_lock = threading.Lock()
64
 
65
+ # Calibrated flag: is XGB class 1 == PHISH?
66
+ _url_phish_is_positive: Optional[bool] = None
67
+
68
  # -------------------------
69
+ # URL features (must match training)
 
70
  # -------------------------
71
  _SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]
72
  _ipv4_pattern = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
73
 
74
  def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[List[str]] = None) -> pd.DataFrame:
 
75
  s = df[url_col].astype(str).fillna("")
76
  out = pd.DataFrame(index=df.index)
77
  out["url_len"] = s.str.len()
 
89
  out["starts_https"] = s.str.startswith("https").astype(int)
90
  out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
91
  out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
92
+ return out if not feature_cols else out[feature_cols]
93
 
94
  # -------------------------
95
  # Loaders
 
101
  if _tokenizer is None or _model is None:
102
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
103
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
 
104
  cfg = getattr(_model, "config", None)
105
  if cfg is not None and getattr(cfg, "id2label", None):
106
  _id2label = {int(k): v for k, v in cfg.id2label.items()}
107
  _label2id = {v: int(k) for k, v in _id2label.items()}
 
108
  with torch.no_grad():
109
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
110
 
 
113
  if _url_bundle is None:
114
  with _url_lock:
115
  if _url_bundle is None:
 
116
  local_path = os.path.join(os.getcwd(), URL_FILENAME)
117
  if os.path.exists(local_path):
118
  _url_bundle = joblib.load(local_path)
 
125
  )
126
  _url_bundle = joblib.load(model_path)
127
 
128
+ def _xgb_predict_class1_prob(booster, feats: pd.DataFrame) -> float:
129
+ # predicts P(class==1) under binary:logistic objective
130
+ dmat = xgb.DMatrix(feats)
131
+ return float(booster.predict(dmat)[0])
132
+
133
+ def _auto_calibrate_phish_positive(bundle: Dict[str, Any], feature_cols: List[str], url_col: str) -> bool:
134
+ """
135
+ Heuristic: probe with 'obviously phishy' and 'obviously legit' URLs.
136
+ If mean P(class1) for phishy < legit, then class1 ≈ LEGIT → return False.
137
+ Otherwise, class1 ≈ PHISH → return True.
138
+ """
139
+ # If user forces it via env, honor that first.
140
+ if URL_POSITIVE_CLASS_ENV in ("PHISH", "LEGIT"):
141
+ return URL_POSITIVE_CLASS_ENV == "PHISH"
142
+
143
+ # If bundle has explicit flag, use it.
144
+ if "phish_is_positive" in bundle:
145
+ return bool(bundle["phish_is_positive"])
146
+
147
+ phishy = [
148
+ "http://198.51.100.23/login/update?acc=123",
149
+ "http://secure-login-account-update.example.com/session?id=123",
150
+ "http://bank.verify-update-security.com/confirm",
151
+ "http://paypal.com.account-verify.cn/login",
152
+ "http://abc.xyz/downloads/invoice.exe"
153
+ ]
154
+ legit = [
155
+ "https://www.wikipedia.org/",
156
+ "https://www.microsoft.com/",
157
+ "https://www.openai.com/",
158
+ "https://www.python.org/",
159
+ "https://www.gov.uk/"
160
+ ]
161
+
162
+ def _batch_mean(urls: List[str]) -> float:
163
+ df = pd.DataFrame({url_col: urls})
164
+ f = _engineer_features(df, url_col, feature_cols)
165
+ return float(np.mean([_xgb_predict_class1_prob(bundle["model"], pd.DataFrame([f.iloc[i]])) for i in range(len(f))]))
166
+
167
+ try:
168
+ phishy_mean = _batch_mean(phishy)
169
+ legit_mean = _batch_mean(legit)
170
+ except Exception as e:
171
+ # If anything goes wrong, default to class1=PHISH to mimic common convention
172
+ print(f"[autocalib] failed: {e}")
173
+ return True
174
+
175
+ # If phishy scores LOWER than legit for class1, then class1 is likely LEGIT
176
+ class1_is_phish = phishy_mean > legit_mean
177
+ print(f"[autocalib] phishy_mean={phishy_mean:.6f} legit_mean={legit_mean:.6f} -> class1_is_phish={class1_is_phish}")
178
+ return class1_is_phish
179
+
180
+ # Optional: pre-load on startup
181
  @app.on_event("startup")
182
  def _startup():
183
  try:
 
186
  print(f"[startup] text model load failed: {e}")
187
  try:
188
  _load_url_model()
189
+ # Calibrate for XGB if needed
190
+ global _url_phish_is_positive
191
+ b = _url_bundle
192
+ if isinstance(b, dict) and b.get("model_type") == "xgboost_bst" and _url_phish_is_positive is None:
193
+ if xgb is None:
194
+ print("[startup] xgboost not installed; cannot calibrate URL model.")
195
+ else:
196
+ feature_cols: List[str] = b.get("feature_cols") or []
197
+ url_col: str = b.get("url_col") or "url"
198
+ _url_phish_is_positive = _auto_calibrate_phish_positive(b, feature_cols, url_col)
199
  except Exception as e:
200
  print(f"[startup] url model load failed: {e}")
201
 
 
243
  row = pd.DataFrame({url_col: [url_str]})
244
  feats = _engineer_features(row, url_col, feature_cols)
245
 
246
+ # ----- compute P(PHISH) -----
247
+ phish_proba: float = 0.0
248
+ meta_phish_is_positive: Optional[bool] = bundle.get("phish_is_positive", None)
249
+
250
+ # Resolve polarity precedence: ENV > bundle flag > auto-calibration > default True
251
+ if URL_POSITIVE_CLASS_ENV in ("PHISH", "LEGIT"):
252
+ phish_is_positive = (URL_POSITIVE_CLASS_ENV == "PHISH")
253
+ elif meta_phish_is_positive is not None:
254
+ phish_is_positive = bool(meta_phish_is_positive)
255
+ else:
256
+ # If not yet calibrated, do it now for xgb
257
+ global _url_phish_is_positive
258
+ if _url_phish_is_positive is None and model_type == "xgboost_bst" and xgb is not None:
259
+ _url_phish_is_positive = _auto_calibrate_phish_positive(bundle, feature_cols, url_col)
260
+ phish_is_positive = _url_phish_is_positive if _url_phish_is_positive is not None else True
261
+
262
+ backend_debug = {
263
+ "phish_is_positive_resolved": phish_is_positive,
264
+ "phish_is_positive_bundle": meta_phish_is_positive,
265
+ "phish_is_positive_env": URL_POSITIVE_CLASS_ENV if URL_POSITIVE_CLASS_ENV else None,
266
+ }
267
 
268
  if isinstance(model_type, str) and model_type == "xgboost_bst":
269
  if xgb is None:
270
  raise RuntimeError("xgboost is not installed but required for this model bundle.")
271
  dmat = xgb.DMatrix(feats)
272
+ raw_p_class1 = float(model.predict(dmat)[0]) # P(class == 1)
273
+ phish_proba = raw_p_class1 if phish_is_positive else (1.0 - raw_p_class1)
 
 
 
 
 
274
 
275
  elif hasattr(model, "predict_proba"):
276
  proba = model.predict_proba(feats)[0]
 
277
  classes = bundle.get("classes", getattr(model, "classes_", None))
278
+ label_map = bundle.get("label_map")
279
  if classes is not None and len(proba) == 2:
280
  classes_list = list(classes)
281
  phish_idx = None
282
  if isinstance(label_map, dict):
 
283
  for i, c in enumerate(classes_list):
284
  mapped = str(label_map.get(int(c), "")).upper()
285
  if mapped.startswith("PHISH"):
286
  phish_idx = i
287
  break
288
  if phish_idx is None:
289
+ # fall back to whichever index matches current polarity
290
+ # if phish_is_positive column for class 1, else column for class 0
291
+ target_class = 1 if phish_is_positive else 0
292
+ if target_class in classes_list:
293
+ phish_idx = classes_list.index(target_class)
294
+ else:
295
+ phish_idx = 1 if phish_is_positive else 0
296
  phish_proba = float(proba[phish_idx])
297
  else:
 
298
  phish_proba = float(proba[1]) if len(proba) > 1 else float(np.max(proba))
299
 
300
  else:
 
301
  pred = model.predict(feats)[0]
302
  if isinstance(pred, (int, float, np.integer, np.floating)):
 
303
  label_numeric = int(pred)
304
+ # interpret through polarity
305
+ if label_numeric in (0, 1):
306
+ phish_proba = 1.0 if ((label_numeric == 1) == phish_is_positive) else 0.0
307
+ else:
308
+ phish_proba = float(label_numeric) # best-effort
309
  else:
310
  up = str(pred).strip().upper()
311
+ phish_proba = 1.0 if up.startswith("PHISH") else 0.0
 
 
 
 
 
 
312
 
313
+ phish_proba = float(phish_proba)
 
 
314
  label = "PHISH" if phish_proba >= 0.5 else "LEGIT"
315
+ score = phish_proba if label == "PHISH" else (1.0 - phish_proba)
316
 
 
317
  return {
318
  "label": label,
319
+ "score": float(score),
320
  "phishing_probability": float(phish_proba),
321
  "backend": str(model_type),
322
  "threshold": 0.5,
323
+ # Debug/trace so you can see exactly what was used
324
+ "phish_is_positive": bool(phish_is_positive),
325
+ "phish_is_positive_bundle": meta_phish_is_positive,
326
+ "phish_is_positive_env": URL_POSITIVE_CLASS_ENV if URL_POSITIVE_CLASS_ENV else None,
327
  "feature_cols": feature_cols,
328
  "url_col": url_col,
329
  }