Perth0603 commited on
Commit
00007cf
·
verified ·
1 Parent(s): 3d48ccd

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +93 -0
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Dict, Any
6
+
7
+ _SUSPICIOUS_TOKENS = [
8
+ "login", "verify", "secure", "update", "bank", "pay", "account", "webscr"
9
+ ]
10
+ _IPV4_PATTERN = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
11
+
12
+
13
+ def _engineer_features(url_series: pd.Series) -> pd.DataFrame:
14
+ s = url_series.astype(str)
15
+ out = pd.DataFrame(index=s.index)
16
+ out["url_len"] = s.str.len().fillna(0)
17
+ out["count_dot"] = s.str.count(r"\.")
18
+ out["count_hyphen"] = s.str.count("-")
19
+ out["count_digit"] = s.str.count(r"\d")
20
+ out["count_at"] = s.str.count("@")
21
+ out["count_qmark"] = s.str.count("\?")
22
+ out["count_eq"] = s.str.count("=")
23
+ out["count_slash"] = s.str.count("/")
24
+ out["digit_ratio"] = (out["count_digit"] / out["url_len"].replace(0, np.nan)).fillna(0)
25
+ out["has_ip"] = s.str.contains(_IPV4_PATTERN).astype(int)
26
+ for tok in _SUSPICIOUS_TOKENS:
27
+ out[f"has_{tok}"] = s.str.contains(tok, case=False, regex=False).astype(int)
28
+ out["starts_https"] = s.str.startswith("https").astype(int)
29
+ out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
30
+ out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
31
+ return out
32
+
33
+
34
+ def load_bundle(path: str) -> Dict[str, Any]:
35
+ """Load the saved joblib bundle produced by the notebook.
36
+
37
+ Returns a dict with keys: model, feature_cols, url_col, label_col, model_type
38
+ """
39
+ bundle = joblib.load(path)
40
+ required = {"model", "feature_cols", "url_col", "label_col", "model_type"}
41
+ missing = required - set(bundle.keys())
42
+ if missing:
43
+ raise ValueError(f"Bundle missing keys: {missing}")
44
+ return bundle
45
+
46
+
47
+ def predict_url(url: str, bundle: Dict[str, Any], threshold: float = 0.5) -> Dict[str, Any]:
48
+ """Predict phishing probability for a single URL using the saved bundle."""
49
+ url_col = bundle["url_col"]
50
+ feature_cols = bundle["feature_cols"]
51
+ model_type = bundle.get("model_type", "xgboost_bst")
52
+ model = bundle["model"]
53
+
54
+ row = pd.DataFrame({url_col: [url]})
55
+ feats = _engineer_features(row[url_col])[feature_cols]
56
+
57
+ if model_type == "xgboost_bst":
58
+ import xgboost as xgb # local import to keep base env minimal
59
+ dmat = xgb.DMatrix(feats)
60
+ proba = float(model.predict(dmat)[0])
61
+ elif model_type == "cuml_rf":
62
+ try:
63
+ import cudf # type: ignore
64
+ gfeats = cudf.DataFrame.from_pandas(feats)
65
+ proba = float(model.predict_proba(gfeats)[:, 1].to_pandas().values[0])
66
+ except Exception as e: # pragma: no cover
67
+ raise RuntimeError("cudf/cuml required for this bundle but not available") from e
68
+ else:
69
+ proba = float(model.predict_proba(feats)[:, 1][0])
70
+
71
+ pred = int(proba >= threshold)
72
+ return {
73
+ "url": url,
74
+ "phishing_probability": proba,
75
+ "predicted_label": pred,
76
+ "backend": model_type,
77
+ }
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # Simple manual test (optional)
82
+ try:
83
+ bundle = load_bundle("rf_url_phishing_xgboost_bst.joblib")
84
+ print(
85
+ predict_url(
86
+ "http://secure-login-account-update.example.com/session?id=123",
87
+ bundle=bundle,
88
+ )
89
+ )
90
+ except FileNotFoundError:
91
+ print("Bundle not found in current directory. This is expected inside the source repo.")
92
+
93
+