Perth0603's picture
Upload inference.py
4e24576 verified
raw
history blame
10.8 kB
import re
import joblib
import pandas as pd
import numpy as np
from urllib.parse import urlparse
from typing import Dict, Any
_SUSPICIOUS_TOKENS = [
"login", "verify", "secure", "update", "bank", "pay", "account", "webscr"
]
_IPV4_PATTERN = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
_BRAND_NAMES = [
"facebook","paypal","google","amazon","apple","microsoft",
"instagram","netflix","bank","hsbc","linkedin","yahoo","outlook"
]
_SUSPICIOUS_TLDS = {"zip","xyz","top","ru","kim","support","ltd","work","gq","tk","ml"}
try:
from rapidfuzz import fuzz # type: ignore
def _sim(a: str, b: str) -> float:
return fuzz.ratio(a, b) / 100.0
except Exception: # pragma: no cover
import difflib
def _sim(a: str, b: str) -> float: # type: ignore
return difflib.SequenceMatcher(None, a, b).ratio()
def _ensure_scheme(u: str) -> str:
return u if re.match(r'^[a-zA-Z][a-zA-Z0-9+.\-]*://', u) else 'http://' + u
def _get_hostname(u: str) -> str:
try:
host = urlparse(_ensure_scheme(u)).hostname or ''
try:
host = host.encode('ascii').decode('idna')
except Exception:
pass
return host.lower()
except Exception:
return ''
def _get_sld(host: str) -> str:
parts = host.split('.')
if len(parts) >= 2:
return parts[-2]
return host
def _get_tld(host: str) -> str:
parts = host.split('.')
return parts[-1] if len(parts) >= 2 else ''
def _shannon_entropy(s: str) -> float:
if not s:
return 0.0
counts = {}
for ch in s:
counts[ch] = counts.get(ch, 0) + 1
probs = np.array(list(counts.values()), dtype=float)
probs /= probs.sum()
return float(-(probs * np.log2(probs)).sum())
def _clean_for_brand(s: str) -> str:
return re.sub(r'[^a-z]', '', re.sub(r'\d+', '', s.lower()))
def _engineer_features(url_series: pd.Series) -> pd.DataFrame:
s = url_series.astype(str)
out = pd.DataFrame(index=s.index)
# Lexical features
out["url_len"] = s.str.len().fillna(0)
out["count_dot"] = s.str.count(r"\.")
out["count_hyphen"] = s.str.count("-")
out["count_digit"] = s.str.count(r"\d")
out["count_at"] = s.str.count("@")
out["count_qmark"] = s.str.count("\?")
out["count_eq"] = s.str.count("=")
out["count_slash"] = s.str.count("/")
out["digit_ratio"] = (out["count_digit"] / out["url_len"].replace(0, np.nan)).fillna(0)
out["has_ip"] = s.str.contains(_IPV4_PATTERN).astype(int)
for tok in _SUSPICIOUS_TOKENS:
out[f"has_{tok}"] = s.str.contains(tok, case=False, regex=False).astype(int)
out["starts_https"] = s.str.startswith("https").astype(int)
out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
# Host-derived
host = s.apply(_get_hostname)
sld = host.apply(_get_sld)
tld = host.apply(_get_tld)
out['host_len'] = host.str.len().fillna(0)
sub_count = host.str.count(r'\.') - 1
out['subdomain_count'] = sub_count.fillna(0).clip(lower=0).astype(int)
out['tld_suspicious'] = tld.isin(list(_SUSPICIOUS_TLDS)).astype(int)
out['has_punycode'] = host.str.contains('xn--', na=False).astype(int)
out['sld_len'] = sld.str.len().fillna(0)
sld_digit_count = sld.str.count(r'\d')
out['sld_digit_ratio'] = (sld_digit_count / out['sld_len'].replace(0, np.nan)).fillna(0)
out['sld_entropy'] = sld.apply(_shannon_entropy).astype(float)
# Brand similarity features
sld_clean = sld.apply(_clean_for_brand)
def _max_brand_sim(name: str) -> float:
if not isinstance(name, str) or not name:
return 0.0
best = 0.0
for b in _BRAND_NAMES:
sc = _sim(name, b)
if sc > best:
best = sc
return float(best)
out['max_brand_sim'] = sld_clean.apply(_max_brand_sim).astype(float)
out['like_facebook'] = sld_clean.apply(lambda x: 1 if _sim(x, 'facebook') >= 0.82 else 0).astype(int)
OFFICIAL_DOMAINS = {
'facebook': ['facebook.com'],
'paypal': ['paypal.com'],
'google': ['google.com'],
'amazon': ['amazon.com'],
'apple': ['apple.com'],
'microsoft': ['microsoft.com'],
'instagram': ['instagram.com'],
'netflix': ['netflix.com'],
'hsbc': ['hsbc.com'],
'linkedin': ['linkedin.com'],
'yahoo': ['yahoo.com'],
'outlook': ['outlook.com']
}
def _normalize_leet(name: str) -> str:
if not isinstance(name, str):
return ''
table = str.maketrans({'0':'o','1':'l','3':'e','4':'a','5':'s','7':'t','2':'z','8':'b'})
return name.translate(table)
def _best_brand(name: str):
if not isinstance(name, str) or not name:
return '', 0.0
best_b, best_s = '', 0.0
for b in _BRAND_NAMES:
sc = _sim(name, b)
if sc > best_s:
best_b, best_s = b, sc
return best_b, float(best_s)
def _get_etld1(h: str) -> str:
parts = h.split('.') if isinstance(h, str) else []
if len(parts) >= 2:
return parts[-2] + '.' + parts[-1]
return h
etld1 = host.apply(_get_etld1)
brand_best_and_sim = sld_clean.apply(_best_brand)
brand_best = brand_best_and_sim.apply(lambda x: x[0])
brand_best_sim = brand_best_and_sim.apply(lambda x: x[1])
out['is_official_brand_domain'] = [
1 if bb and et in OFFICIAL_DOMAINS.get(bb, []) else 0
for bb, et in zip(brand_best, etld1)
]
out['brand_digit_insertion'] = ((sld_clean == brand_best) & (sld.str.contains(r'\d'))).astype(int)
sld_leet_norm = sld.apply(_normalize_leet).apply(_clean_for_brand)
def _max_brand_sim_leet(name: str) -> float:
if not isinstance(name, str) or not name:
return 0.0
best = 0.0
for b in _BRAND_NAMES:
sc = _sim(name, b)
if sc > best:
best = sc
return float(best)
out['max_brand_sim_leet'] = sld_leet_norm.apply(_max_brand_sim_leet).astype(float)
out['like_brand_leet'] = (out['max_brand_sim_leet'] >= 0.88).astype(int)
def _contains_brand_extra(name: str) -> int:
if not isinstance(name, str) or not name:
return 0
for b in _BRAND_NAMES:
if name != b and b in name:
return 1
return 0
out['sld_contains_brand_extra'] = sld_clean.apply(_contains_brand_extra).astype(int)
out['brand_impersonation'] = (
((brand_best_sim >= 0.88) | (out['like_brand_leet'] == 1) | (out['sld_contains_brand_extra'] == 1))
& (out['is_official_brand_domain'] == 0)
).astype(int)
out['sld_has_hyphen'] = sld.str.contains('-', na=False).astype(int)
out['sld_has_digits'] = (sld.str.count(r'\d') > 0).astype(int)
return out
def load_bundle(path: str) -> Dict[str, Any]:
"""Load the saved joblib bundle produced by the notebook.
Returns a dict with keys: model, feature_cols, url_col, label_col, model_type
"""
bundle = joblib.load(path)
required = {"model", "feature_cols", "url_col", "label_col", "model_type"}
missing = required - set(bundle.keys())
if missing:
raise ValueError(f"Bundle missing keys: {missing}")
return bundle
def predict_url(url: str, bundle: Dict[str, Any], threshold: float = 0.5) -> Dict[str, Any]:
"""Predict phishing probability for a single URL using the saved bundle.
Applies a rule-based typosquatting guard to catch cases like face123book.com
even if the model probability is low.
"""
url_col = bundle["url_col"]
feature_cols = bundle["feature_cols"]
trained_feature_cols = bundle.get("trained_feature_cols")
model_type = bundle.get("model_type", "xgboost_bst")
model = bundle["model"]
row = pd.DataFrame({url_col: [url]})
feats_full = _engineer_features(row[url_col])
desired_cols = list(trained_feature_cols) if trained_feature_cols is not None else list(feature_cols)
feats = feats_full.reindex(columns=desired_cols, fill_value=0)
if model_type == "xgboost_bst":
import xgboost as xgb # local import to keep base env minimal
dmat = xgb.DMatrix(feats)
proba = float(model.predict(dmat)[0])
elif model_type == "cuml_rf":
try:
import cudf # type: ignore
gfeats = cudf.DataFrame.from_pandas(feats)
proba = float(model.predict_proba(gfeats)[:, 1].to_pandas().values[0])
except Exception as e: # pragma: no cover
raise RuntimeError("cudf/cuml required for this bundle but not available") from e
else:
proba = float(model.predict_proba(feats)[:, 1][0])
# Rule-based typosquatting guard using enriched features (computed regardless of model schema)
def _bool(feature: str, default: int = 0) -> int:
return int(feature in feats_full.columns and bool(feats_full.iloc[0].get(feature, default)))
def _float(feature: str, default: float = 0.0) -> float:
return float(feats_full.iloc[0].get(feature, default)) if feature in feats_full.columns else default
like_brand = (
_bool('brand_impersonation') == 1 or
_bool('like_brand_leet') == 1 or
_float('max_brand_sim_leet') >= 0.90 or
_float('max_brand_sim') >= 0.90 or
_bool('sld_contains_brand_extra') == 1
)
risky_host = (
_bool('is_official_brand_domain') == 0 and
(
_bool('sld_has_digits') == 1 or
_bool('sld_has_hyphen') == 1 or
_bool('tld_suspicious') == 1 or
_bool('has_punycode') == 1
)
)
rule_triggered = bool(like_brand and risky_host)
pred = int(proba >= threshold)
if rule_triggered and pred == 0:
pred = 1
proba = max(proba, 0.9)
result = {
"url": url,
"phishing_probability": proba,
"predicted_label": pred,
"backend": model_type,
}
if rule_triggered:
result["rule"] = "typosquat_guard"
return result
if __name__ == "__main__":
# Simple manual test (optional)
try:
bundle = load_bundle("models/rf_url_phishing_xgboost_bst.joblib")
print(
predict_url(
"www.face123book.com",
bundle=bundle,
)
)
except FileNotFoundError:
print("Bundle not found in current directory. This is expected inside the source repo.")