Upload app.py
Browse files
app.py
CHANGED
|
@@ -65,6 +65,27 @@ _url_lock = threading.Lock()
|
|
| 65 |
# Calibrated flag: is XGB class 1 == PHISH?
|
| 66 |
_url_phish_is_positive: Optional[bool] = None
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# -------------------------
|
| 69 |
# URL features (must match training)
|
| 70 |
# -------------------------
|
|
@@ -95,7 +116,7 @@ def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[Li
|
|
| 95 |
# Loaders
|
| 96 |
# -------------------------
|
| 97 |
def _load_model():
|
| 98 |
-
global _tokenizer, _model, _id2label, _label2id
|
| 99 |
if _tokenizer is None or _model is None:
|
| 100 |
with _model_lock:
|
| 101 |
if _tokenizer is None or _model is None:
|
|
@@ -105,9 +126,26 @@ def _load_model():
|
|
| 105 |
if cfg is not None and getattr(cfg, "id2label", None):
|
| 106 |
_id2label = {int(k): v for k, v in cfg.id2label.items()}
|
| 107 |
_label2id = {v: int(k) for k, v in _id2label.items()}
|
|
|
|
|
|
|
| 108 |
with torch.no_grad():
|
| 109 |
_ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def _load_url_model():
|
| 112 |
global _url_bundle
|
| 113 |
if _url_bundle is None:
|
|
@@ -144,20 +182,8 @@ def _auto_calibrate_phish_positive(bundle: Dict[str, Any], feature_cols: List[st
|
|
| 144 |
if "phish_is_positive" in bundle:
|
| 145 |
return bool(bundle["phish_is_positive"])
|
| 146 |
|
| 147 |
-
phishy =
|
| 148 |
-
|
| 149 |
-
"http://secure-login-account-update.example.com/session?id=123",
|
| 150 |
-
"http://bank.verify-update-security.com/confirm",
|
| 151 |
-
"http://paypal.com.account-verify.cn/login",
|
| 152 |
-
"http://abc.xyz/downloads/invoice.exe"
|
| 153 |
-
]
|
| 154 |
-
legit = [
|
| 155 |
-
"https://www.wikipedia.org/",
|
| 156 |
-
"https://www.microsoft.com/",
|
| 157 |
-
"https://www.openai.com/",
|
| 158 |
-
"https://www.python.org/",
|
| 159 |
-
"https://www.gov.uk/"
|
| 160 |
-
]
|
| 161 |
|
| 162 |
def _batch_mean(urls: List[str]) -> float:
|
| 163 |
df = pd.DataFrame({url_col: urls})
|
|
@@ -218,8 +244,17 @@ def predict(payload: PredictPayload):
|
|
| 218 |
logits = _model(**inputs).logits
|
| 219 |
probs = torch.softmax(logits, dim=-1)[0]
|
| 220 |
score, idx = torch.max(probs, dim=0)
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
except Exception as e:
|
| 224 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 225 |
|
|
|
|
| 65 |
# Calibrated flag: is XGB class 1 == PHISH?
|
| 66 |
_url_phish_is_positive: Optional[bool] = None
|
| 67 |
|
| 68 |
+
# -------------------------
|
| 69 |
+
# Autocalibration URL prototypes (editable)
|
| 70 |
+
# -------------------------
|
| 71 |
+
# You can edit these lists to define which URLs are considered obviously phishy/legit
|
| 72 |
+
# for polarity auto-calibration of classical URL models (e.g., XGBoost, scikit-learn).
|
| 73 |
+
_AUTOCALIB_PHISHY_URLS: List[str] = [
|
| 74 |
+
"http://198.51.100.23/login/update?acc=123",
|
| 75 |
+
"http://secure-login-account-update.example.com/session?id=123",
|
| 76 |
+
"http://bank.verify-update-security.com/confirm",
|
| 77 |
+
"http://paypal.com.account-verify.cn/login",
|
| 78 |
+
"http://abc.xyz/downloads/invoice.exe",
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
_AUTOCALIB_LEGIT_URLS: List[str] = [
|
| 82 |
+
"https://www.wikipedia.org/",
|
| 83 |
+
"https://www.microsoft.com/",
|
| 84 |
+
"https://www.openai.com/",
|
| 85 |
+
"https://www.python.org/",
|
| 86 |
+
"https://www.gov.uk/",
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
# -------------------------
|
| 90 |
# URL features (must match training)
|
| 91 |
# -------------------------
|
|
|
|
| 116 |
# Loaders
|
| 117 |
# -------------------------
|
| 118 |
def _load_model():
|
| 119 |
+
global _tokenizer, _model, _id2label, _label2id, _text_phish_id
|
| 120 |
if _tokenizer is None or _model is None:
|
| 121 |
with _model_lock:
|
| 122 |
if _tokenizer is None or _model is None:
|
|
|
|
| 126 |
if cfg is not None and getattr(cfg, "id2label", None):
|
| 127 |
_id2label = {int(k): v for k, v in cfg.id2label.items()}
|
| 128 |
_label2id = {v: int(k) for k, v in _id2label.items()}
|
| 129 |
+
# Try to detect which index corresponds to PHISH/SPAM
|
| 130 |
+
_text_phish_id = _detect_text_phish_id(_id2label)
|
| 131 |
with torch.no_grad():
|
| 132 |
_ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
|
| 133 |
|
| 134 |
+
# Detect which label id corresponds to phishing for text models based on label strings
|
| 135 |
+
_text_phish_id: Optional[int] = None
|
| 136 |
+
|
| 137 |
+
def _detect_text_phish_id(id2label: Dict[int, str]) -> Optional[int]:
|
| 138 |
+
candidates_phish = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
|
| 139 |
+
# Prefer explicit PHISH/SPAM over generic POSITIVE
|
| 140 |
+
priority_order = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
|
| 141 |
+
norm = {k: str(v).strip().upper() for k, v in id2label.items()}
|
| 142 |
+
# exact/substring match in priority order
|
| 143 |
+
for token in priority_order:
|
| 144 |
+
for k, v in norm.items():
|
| 145 |
+
if token in v:
|
| 146 |
+
return int(k)
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
def _load_url_model():
|
| 150 |
global _url_bundle
|
| 151 |
if _url_bundle is None:
|
|
|
|
| 182 |
if "phish_is_positive" in bundle:
|
| 183 |
return bool(bundle["phish_is_positive"])
|
| 184 |
|
| 185 |
+
phishy = _AUTOCALIB_PHISHY_URLS
|
| 186 |
+
legit = _AUTOCALIB_LEGIT_URLS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def _batch_mean(urls: List[str]) -> float:
|
| 189 |
df = pd.DataFrame({url_col: urls})
|
|
|
|
| 244 |
logits = _model(**inputs).logits
|
| 245 |
probs = torch.softmax(logits, dim=-1)[0]
|
| 246 |
score, idx = torch.max(probs, dim=0)
|
| 247 |
+
|
| 248 |
+
# Normalize label to PHISH/LEGIT if we could detect PHISH id
|
| 249 |
+
if _text_phish_id is not None and 0 <= _text_phish_id < probs.shape[0]:
|
| 250 |
+
phish_prob = float(probs[_text_phish_id])
|
| 251 |
+
norm_label = "PHISH" if phish_prob >= 0.5 else "LEGIT"
|
| 252 |
+
norm_score = phish_prob if norm_label == "PHISH" else (1.0 - phish_prob)
|
| 253 |
+
return {"label": norm_label, "score": float(norm_score), "raw_index": int(idx)}
|
| 254 |
+
else:
|
| 255 |
+
# Fallback to model's provided labels
|
| 256 |
+
label = _id2label.get(int(idx), str(int(idx)))
|
| 257 |
+
return {"label": label, "score": float(score), "raw_index": int(idx)}
|
| 258 |
except Exception as e:
|
| 259 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 260 |
|