Perth0603 commited on
Commit
f4317f9
·
verified ·
1 Parent(s): 311de59

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -17
app.py CHANGED
@@ -65,6 +65,27 @@ _url_lock = threading.Lock()
65
  # Calibrated flag: is XGB class 1 == PHISH?
66
  _url_phish_is_positive: Optional[bool] = None
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # -------------------------
69
  # URL features (must match training)
70
  # -------------------------
@@ -95,7 +116,7 @@ def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[Li
95
  # Loaders
96
  # -------------------------
97
  def _load_model():
98
- global _tokenizer, _model, _id2label, _label2id
99
  if _tokenizer is None or _model is None:
100
  with _model_lock:
101
  if _tokenizer is None or _model is None:
@@ -105,9 +126,26 @@ def _load_model():
105
  if cfg is not None and getattr(cfg, "id2label", None):
106
  _id2label = {int(k): v for k, v in cfg.id2label.items()}
107
  _label2id = {v: int(k) for k, v in _id2label.items()}
 
 
108
  with torch.no_grad():
109
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def _load_url_model():
112
  global _url_bundle
113
  if _url_bundle is None:
@@ -144,20 +182,8 @@ def _auto_calibrate_phish_positive(bundle: Dict[str, Any], feature_cols: List[st
144
  if "phish_is_positive" in bundle:
145
  return bool(bundle["phish_is_positive"])
146
 
147
- phishy = [
148
- "http://198.51.100.23/login/update?acc=123",
149
- "http://secure-login-account-update.example.com/session?id=123",
150
- "http://bank.verify-update-security.com/confirm",
151
- "http://paypal.com.account-verify.cn/login",
152
- "http://abc.xyz/downloads/invoice.exe"
153
- ]
154
- legit = [
155
- "https://www.wikipedia.org/",
156
- "https://www.microsoft.com/",
157
- "https://www.openai.com/",
158
- "https://www.python.org/",
159
- "https://www.gov.uk/"
160
- ]
161
 
162
  def _batch_mean(urls: List[str]) -> float:
163
  df = pd.DataFrame({url_col: urls})
@@ -218,8 +244,17 @@ def predict(payload: PredictPayload):
218
  logits = _model(**inputs).logits
219
  probs = torch.softmax(logits, dim=-1)[0]
220
  score, idx = torch.max(probs, dim=0)
221
- label = _id2label.get(int(idx), str(int(idx)))
222
- return {"label": label, "score": float(score), "raw_index": int(idx)}
 
 
 
 
 
 
 
 
 
223
  except Exception as e:
224
  return JSONResponse(status_code=500, content={"error": str(e)})
225
 
 
65
  # Calibrated flag: is XGB class 1 == PHISH?
66
  _url_phish_is_positive: Optional[bool] = None
67
 
68
+ # -------------------------
69
+ # Autocalibration URL prototypes (editable)
70
+ # -------------------------
71
+ # You can edit these lists to define which URLs are considered obviously phishy/legit
72
+ # for polarity auto-calibration of classical URL models (e.g., XGBoost, scikit-learn).
73
+ _AUTOCALIB_PHISHY_URLS: List[str] = [
74
+ "http://198.51.100.23/login/update?acc=123",
75
+ "http://secure-login-account-update.example.com/session?id=123",
76
+ "http://bank.verify-update-security.com/confirm",
77
+ "http://paypal.com.account-verify.cn/login",
78
+ "http://abc.xyz/downloads/invoice.exe",
79
+ ]
80
+
81
+ _AUTOCALIB_LEGIT_URLS: List[str] = [
82
+ "https://www.wikipedia.org/",
83
+ "https://www.microsoft.com/",
84
+ "https://www.openai.com/",
85
+ "https://www.python.org/",
86
+ "https://www.gov.uk/",
87
+ ]
88
+
89
  # -------------------------
90
  # URL features (must match training)
91
  # -------------------------
 
116
  # Loaders
117
  # -------------------------
118
  def _load_model():
119
+ global _tokenizer, _model, _id2label, _label2id, _text_phish_id
120
  if _tokenizer is None or _model is None:
121
  with _model_lock:
122
  if _tokenizer is None or _model is None:
 
126
  if cfg is not None and getattr(cfg, "id2label", None):
127
  _id2label = {int(k): v for k, v in cfg.id2label.items()}
128
  _label2id = {v: int(k) for k, v in _id2label.items()}
129
+ # Try to detect which index corresponds to PHISH/SPAM
130
+ _text_phish_id = _detect_text_phish_id(_id2label)
131
  with torch.no_grad():
132
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
133
 
134
+ # Detect which label id corresponds to phishing for text models based on label strings
135
+ _text_phish_id: Optional[int] = None
136
+
137
+ def _detect_text_phish_id(id2label: Dict[int, str]) -> Optional[int]:
138
+ candidates_phish = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
139
+ # Prefer explicit PHISH/SPAM over generic POSITIVE
140
+ priority_order = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
141
+ norm = {k: str(v).strip().upper() for k, v in id2label.items()}
142
+ # exact/substring match in priority order
143
+ for token in priority_order:
144
+ for k, v in norm.items():
145
+ if token in v:
146
+ return int(k)
147
+ return None
148
+
149
  def _load_url_model():
150
  global _url_bundle
151
  if _url_bundle is None:
 
182
  if "phish_is_positive" in bundle:
183
  return bool(bundle["phish_is_positive"])
184
 
185
+ phishy = _AUTOCALIB_PHISHY_URLS
186
+ legit = _AUTOCALIB_LEGIT_URLS
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def _batch_mean(urls: List[str]) -> float:
189
  df = pd.DataFrame({url_col: urls})
 
244
  logits = _model(**inputs).logits
245
  probs = torch.softmax(logits, dim=-1)[0]
246
  score, idx = torch.max(probs, dim=0)
247
+
248
+ # Normalize label to PHISH/LEGIT if we could detect PHISH id
249
+ if _text_phish_id is not None and 0 <= _text_phish_id < probs.shape[0]:
250
+ phish_prob = float(probs[_text_phish_id])
251
+ norm_label = "PHISH" if phish_prob >= 0.5 else "LEGIT"
252
+ norm_score = phish_prob if norm_label == "PHISH" else (1.0 - phish_prob)
253
+ return {"label": norm_label, "score": float(norm_score), "raw_index": int(idx)}
254
+ else:
255
+ # Fallback to model's provided labels
256
+ label = _id2label.get(int(idx), str(int(idx)))
257
+ return {"label": label, "score": float(score), "raw_index": int(idx)}
258
  except Exception as e:
259
  return JSONResponse(status_code=500, content={"error": str(e)})
260