Perth0603 commited on
Commit
4e24576
·
verified ·
1 Parent(s): b2fa32d

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +217 -5
inference.py CHANGED
@@ -2,6 +2,7 @@ import re
2
  import joblib
3
  import pandas as pd
4
  import numpy as np
 
5
  from typing import Dict, Any
6
 
7
  _SUSPICIOUS_TOKENS = [
@@ -9,10 +10,70 @@ _SUSPICIOUS_TOKENS = [
9
  ]
10
  _IPV4_PATTERN = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def _engineer_features(url_series: pd.Series) -> pd.DataFrame:
14
  s = url_series.astype(str)
15
  out = pd.DataFrame(index=s.index)
 
 
16
  out["url_len"] = s.str.len().fillna(0)
17
  out["count_dot"] = s.str.count(r"\.")
18
  out["count_hyphen"] = s.str.count("-")
@@ -28,6 +89,118 @@ def _engineer_features(url_series: pd.Series) -> pd.DataFrame:
28
  out["starts_https"] = s.str.startswith("https").astype(int)
29
  out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
30
  out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return out
32
 
33
 
@@ -45,14 +218,21 @@ def load_bundle(path: str) -> Dict[str, Any]:
45
 
46
 
47
  def predict_url(url: str, bundle: Dict[str, Any], threshold: float = 0.5) -> Dict[str, Any]:
48
- """Predict phishing probability for a single URL using the saved bundle."""
 
 
 
 
49
  url_col = bundle["url_col"]
50
  feature_cols = bundle["feature_cols"]
 
51
  model_type = bundle.get("model_type", "xgboost_bst")
52
  model = bundle["model"]
53
 
54
  row = pd.DataFrame({url_col: [url]})
55
- feats = _engineer_features(row[url_col])[feature_cols]
 
 
56
 
57
  if model_type == "xgboost_bst":
58
  import xgboost as xgb # local import to keep base env minimal
@@ -68,22 +248,54 @@ def predict_url(url: str, bundle: Dict[str, Any], threshold: float = 0.5) -> Dic
68
  else:
69
  proba = float(model.predict_proba(feats)[:, 1][0])
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  pred = int(proba >= threshold)
72
- return {
 
 
 
 
73
  "url": url,
74
  "phishing_probability": proba,
75
  "predicted_label": pred,
76
  "backend": model_type,
77
  }
 
 
 
78
 
79
 
80
  if __name__ == "__main__":
81
  # Simple manual test (optional)
82
  try:
83
- bundle = load_bundle("rf_url_phishing_xgboost_bst.joblib")
84
  print(
85
  predict_url(
86
- "http://secure-login-account-update.example.com/session?id=123",
87
  bundle=bundle,
88
  )
89
  )
 
2
  import joblib
3
  import pandas as pd
4
  import numpy as np
5
+ from urllib.parse import urlparse
6
  from typing import Dict, Any
7
 
8
  _SUSPICIOUS_TOKENS = [
 
10
  ]
11
  _IPV4_PATTERN = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
12
 
13
+ _BRAND_NAMES = [
14
+ "facebook","paypal","google","amazon","apple","microsoft",
15
+ "instagram","netflix","bank","hsbc","linkedin","yahoo","outlook"
16
+ ]
17
+ _SUSPICIOUS_TLDS = {"zip","xyz","top","ru","kim","support","ltd","work","gq","tk","ml"}
18
+
19
+ try:
20
+ from rapidfuzz import fuzz # type: ignore
21
+ def _sim(a: str, b: str) -> float:
22
+ return fuzz.ratio(a, b) / 100.0
23
+ except Exception: # pragma: no cover
24
+ import difflib
25
+ def _sim(a: str, b: str) -> float: # type: ignore
26
+ return difflib.SequenceMatcher(None, a, b).ratio()
27
+
28
+
29
+ def _ensure_scheme(u: str) -> str:
30
+ return u if re.match(r'^[a-zA-Z][a-zA-Z0-9+.\-]*://', u) else 'http://' + u
31
+
32
+
33
+ def _get_hostname(u: str) -> str:
34
+ try:
35
+ host = urlparse(_ensure_scheme(u)).hostname or ''
36
+ try:
37
+ host = host.encode('ascii').decode('idna')
38
+ except Exception:
39
+ pass
40
+ return host.lower()
41
+ except Exception:
42
+ return ''
43
+
44
+
45
+ def _get_sld(host: str) -> str:
46
+ parts = host.split('.')
47
+ if len(parts) >= 2:
48
+ return parts[-2]
49
+ return host
50
+
51
+
52
+ def _get_tld(host: str) -> str:
53
+ parts = host.split('.')
54
+ return parts[-1] if len(parts) >= 2 else ''
55
+
56
+
57
+ def _shannon_entropy(s: str) -> float:
58
+ if not s:
59
+ return 0.0
60
+ counts = {}
61
+ for ch in s:
62
+ counts[ch] = counts.get(ch, 0) + 1
63
+ probs = np.array(list(counts.values()), dtype=float)
64
+ probs /= probs.sum()
65
+ return float(-(probs * np.log2(probs)).sum())
66
+
67
+
68
+ def _clean_for_brand(s: str) -> str:
69
+ return re.sub(r'[^a-z]', '', re.sub(r'\d+', '', s.lower()))
70
+
71
 
72
  def _engineer_features(url_series: pd.Series) -> pd.DataFrame:
73
  s = url_series.astype(str)
74
  out = pd.DataFrame(index=s.index)
75
+
76
+ # Lexical features
77
  out["url_len"] = s.str.len().fillna(0)
78
  out["count_dot"] = s.str.count(r"\.")
79
  out["count_hyphen"] = s.str.count("-")
 
89
  out["starts_https"] = s.str.startswith("https").astype(int)
90
  out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
91
  out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
92
+
93
+ # Host-derived
94
+ host = s.apply(_get_hostname)
95
+ sld = host.apply(_get_sld)
96
+ tld = host.apply(_get_tld)
97
+
98
+ out['host_len'] = host.str.len().fillna(0)
99
+ sub_count = host.str.count(r'\.') - 1
100
+ out['subdomain_count'] = sub_count.fillna(0).clip(lower=0).astype(int)
101
+ out['tld_suspicious'] = tld.isin(list(_SUSPICIOUS_TLDS)).astype(int)
102
+ out['has_punycode'] = host.str.contains('xn--', na=False).astype(int)
103
+
104
+ out['sld_len'] = sld.str.len().fillna(0)
105
+ sld_digit_count = sld.str.count(r'\d')
106
+ out['sld_digit_ratio'] = (sld_digit_count / out['sld_len'].replace(0, np.nan)).fillna(0)
107
+ out['sld_entropy'] = sld.apply(_shannon_entropy).astype(float)
108
+
109
+ # Brand similarity features
110
+ sld_clean = sld.apply(_clean_for_brand)
111
+
112
+ def _max_brand_sim(name: str) -> float:
113
+ if not isinstance(name, str) or not name:
114
+ return 0.0
115
+ best = 0.0
116
+ for b in _BRAND_NAMES:
117
+ sc = _sim(name, b)
118
+ if sc > best:
119
+ best = sc
120
+ return float(best)
121
+
122
+ out['max_brand_sim'] = sld_clean.apply(_max_brand_sim).astype(float)
123
+ out['like_facebook'] = sld_clean.apply(lambda x: 1 if _sim(x, 'facebook') >= 0.82 else 0).astype(int)
124
+
125
+ OFFICIAL_DOMAINS = {
126
+ 'facebook': ['facebook.com'],
127
+ 'paypal': ['paypal.com'],
128
+ 'google': ['google.com'],
129
+ 'amazon': ['amazon.com'],
130
+ 'apple': ['apple.com'],
131
+ 'microsoft': ['microsoft.com'],
132
+ 'instagram': ['instagram.com'],
133
+ 'netflix': ['netflix.com'],
134
+ 'hsbc': ['hsbc.com'],
135
+ 'linkedin': ['linkedin.com'],
136
+ 'yahoo': ['yahoo.com'],
137
+ 'outlook': ['outlook.com']
138
+ }
139
+
140
+ def _normalize_leet(name: str) -> str:
141
+ if not isinstance(name, str):
142
+ return ''
143
+ table = str.maketrans({'0':'o','1':'l','3':'e','4':'a','5':'s','7':'t','2':'z','8':'b'})
144
+ return name.translate(table)
145
+
146
+ def _best_brand(name: str):
147
+ if not isinstance(name, str) or not name:
148
+ return '', 0.0
149
+ best_b, best_s = '', 0.0
150
+ for b in _BRAND_NAMES:
151
+ sc = _sim(name, b)
152
+ if sc > best_s:
153
+ best_b, best_s = b, sc
154
+ return best_b, float(best_s)
155
+
156
+ def _get_etld1(h: str) -> str:
157
+ parts = h.split('.') if isinstance(h, str) else []
158
+ if len(parts) >= 2:
159
+ return parts[-2] + '.' + parts[-1]
160
+ return h
161
+
162
+ etld1 = host.apply(_get_etld1)
163
+ brand_best_and_sim = sld_clean.apply(_best_brand)
164
+ brand_best = brand_best_and_sim.apply(lambda x: x[0])
165
+ brand_best_sim = brand_best_and_sim.apply(lambda x: x[1])
166
+
167
+ out['is_official_brand_domain'] = [
168
+ 1 if bb and et in OFFICIAL_DOMAINS.get(bb, []) else 0
169
+ for bb, et in zip(brand_best, etld1)
170
+ ]
171
+
172
+ out['brand_digit_insertion'] = ((sld_clean == brand_best) & (sld.str.contains(r'\d'))).astype(int)
173
+
174
+ sld_leet_norm = sld.apply(_normalize_leet).apply(_clean_for_brand)
175
+ def _max_brand_sim_leet(name: str) -> float:
176
+ if not isinstance(name, str) or not name:
177
+ return 0.0
178
+ best = 0.0
179
+ for b in _BRAND_NAMES:
180
+ sc = _sim(name, b)
181
+ if sc > best:
182
+ best = sc
183
+ return float(best)
184
+ out['max_brand_sim_leet'] = sld_leet_norm.apply(_max_brand_sim_leet).astype(float)
185
+ out['like_brand_leet'] = (out['max_brand_sim_leet'] >= 0.88).astype(int)
186
+
187
+ def _contains_brand_extra(name: str) -> int:
188
+ if not isinstance(name, str) or not name:
189
+ return 0
190
+ for b in _BRAND_NAMES:
191
+ if name != b and b in name:
192
+ return 1
193
+ return 0
194
+ out['sld_contains_brand_extra'] = sld_clean.apply(_contains_brand_extra).astype(int)
195
+
196
+ out['brand_impersonation'] = (
197
+ ((brand_best_sim >= 0.88) | (out['like_brand_leet'] == 1) | (out['sld_contains_brand_extra'] == 1))
198
+ & (out['is_official_brand_domain'] == 0)
199
+ ).astype(int)
200
+
201
+ out['sld_has_hyphen'] = sld.str.contains('-', na=False).astype(int)
202
+ out['sld_has_digits'] = (sld.str.count(r'\d') > 0).astype(int)
203
+
204
  return out
205
 
206
 
 
218
 
219
 
220
  def predict_url(url: str, bundle: Dict[str, Any], threshold: float = 0.5) -> Dict[str, Any]:
221
+ """Predict phishing probability for a single URL using the saved bundle.
222
+
223
+ Applies a rule-based typosquatting guard to catch cases like face123book.com
224
+ even if the model probability is low.
225
+ """
226
  url_col = bundle["url_col"]
227
  feature_cols = bundle["feature_cols"]
228
+ trained_feature_cols = bundle.get("trained_feature_cols")
229
  model_type = bundle.get("model_type", "xgboost_bst")
230
  model = bundle["model"]
231
 
232
  row = pd.DataFrame({url_col: [url]})
233
+ feats_full = _engineer_features(row[url_col])
234
+ desired_cols = list(trained_feature_cols) if trained_feature_cols is not None else list(feature_cols)
235
+ feats = feats_full.reindex(columns=desired_cols, fill_value=0)
236
 
237
  if model_type == "xgboost_bst":
238
  import xgboost as xgb # local import to keep base env minimal
 
248
  else:
249
  proba = float(model.predict_proba(feats)[:, 1][0])
250
 
251
+ # Rule-based typosquatting guard using enriched features (computed regardless of model schema)
252
+ def _bool(feature: str, default: int = 0) -> int:
253
+ return int(feature in feats_full.columns and bool(feats_full.iloc[0].get(feature, default)))
254
+
255
+ def _float(feature: str, default: float = 0.0) -> float:
256
+ return float(feats_full.iloc[0].get(feature, default)) if feature in feats_full.columns else default
257
+
258
+ like_brand = (
259
+ _bool('brand_impersonation') == 1 or
260
+ _bool('like_brand_leet') == 1 or
261
+ _float('max_brand_sim_leet') >= 0.90 or
262
+ _float('max_brand_sim') >= 0.90 or
263
+ _bool('sld_contains_brand_extra') == 1
264
+ )
265
+ risky_host = (
266
+ _bool('is_official_brand_domain') == 0 and
267
+ (
268
+ _bool('sld_has_digits') == 1 or
269
+ _bool('sld_has_hyphen') == 1 or
270
+ _bool('tld_suspicious') == 1 or
271
+ _bool('has_punycode') == 1
272
+ )
273
+ )
274
+ rule_triggered = bool(like_brand and risky_host)
275
+
276
  pred = int(proba >= threshold)
277
+ if rule_triggered and pred == 0:
278
+ pred = 1
279
+ proba = max(proba, 0.9)
280
+
281
+ result = {
282
  "url": url,
283
  "phishing_probability": proba,
284
  "predicted_label": pred,
285
  "backend": model_type,
286
  }
287
+ if rule_triggered:
288
+ result["rule"] = "typosquat_guard"
289
+ return result
290
 
291
 
292
  if __name__ == "__main__":
293
  # Simple manual test (optional)
294
  try:
295
+ bundle = load_bundle("models/rf_url_phishing_xgboost_bst.joblib")
296
  print(
297
  predict_url(
298
+ "www.face123book.com",
299
  bundle=bundle,
300
  )
301
  )