File size: 6,670 Bytes
b057179 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import os
os.environ.setdefault("HOME", "/data")
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
os.environ.setdefault("TORCH_HOME", "/data/.cache")
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
import joblib
import torch
import re
import numpy as np
import pandas as pd
try:
import xgboost as xgb # type: ignore
except Exception:
xgb = None # optional; required if bundle uses xgboost
MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection")
URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") # model|space|dataset
# NOTE: set to your artifact filename, e.g. rf_url_phishing_xgboost_bst.joblib
URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib")
# Ensure writable cache directory for HF/torch inside Spaces Docker
CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
os.makedirs(CACHE_DIR, exist_ok=True)
app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
class PredictPayload(BaseModel):
inputs: str
# Lazy singletons for model/tokenizer
_tokenizer = None
_model = None
_url_bundle = None # holds dict: {model, feature_cols, url_col, label_col, model_type}
def _load_url_model():
global _url_bundle
if _url_bundle is None:
# Prefer local artifact if present (e.g., committed into the Space repo)
local_path = os.path.join(os.getcwd(), URL_FILENAME)
if os.path.exists(local_path):
_url_bundle = joblib.load(local_path)
return
# Download model artifact from HF Hub
model_path = hf_hub_download(
repo_id=URL_REPO,
filename=URL_FILENAME,
repo_type=URL_REPO_TYPE,
cache_dir=CACHE_DIR,
)
_url_bundle = joblib.load(model_path)
# URL feature engineering (must match training)
_SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]
_ipv4_pattern = re.compile(r'(?:\d{1,3}\.){3}\d{1,3}')
def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: list[str] | None = None) -> pd.DataFrame:
s = df[url_col].astype(str)
out = pd.DataFrame(index=df.index)
out['url_len'] = s.str.len().fillna(0)
out['count_dot'] = s.str.count(r'\.')
out['count_hyphen'] = s.str.count('-')
out['count_digit'] = s.str.count(r'\d')
out['count_at'] = s.str.count('@')
out['count_qmark'] = s.str.count('\?')
out['count_eq'] = s.str.count('=')
out['count_slash'] = s.str.count('/')
out['digit_ratio'] = (out['count_digit'] / out['url_len'].replace(0, np.nan)).fillna(0)
out['has_ip'] = s.str.contains(_ipv4_pattern).astype(int)
for tok in _SUSPICIOUS_TOKENS:
out[f'has_{tok}'] = s.str.contains(tok, case=False, regex=False).astype(int)
out['starts_https'] = s.str.startswith('https').astype(int)
out['ends_with_exe'] = s.str.endswith('.exe').astype(int)
out['ends_with_zip'] = s.str.endswith('.zip').astype(int)
return out if feature_cols is None else out[feature_cols]
def _load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
# Warm-up
with torch.no_grad():
_ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_ID}
@app.post("/predict")
def predict(payload: PredictPayload):
try:
_load_model()
with torch.no_grad():
inputs = _tokenizer([payload.inputs], return_tensors="pt", truncation=True, max_length=512)
logits = _model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
score, idx = torch.max(probs, dim=0)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
# Map common ids to labels (kept generic; your config also has these)
id2label = {0: "LEGIT", 1: "PHISH"}
label = id2label.get(int(idx), str(int(idx)))
return {"label": label, "score": float(score)}
class PredictUrlPayload(BaseModel):
url: str
@app.post("/predict-url")
def predict_url(payload: PredictUrlPayload):
try:
_load_url_model()
bundle = _url_bundle
if not isinstance(bundle, dict) or 'model' not in bundle:
raise RuntimeError("Loaded URL artifact is not a bundle dict with 'model'.")
model = bundle['model']
feature_cols = bundle.get('feature_cols') or []
url_col = bundle.get('url_col') or 'url'
model_type = bundle.get('model_type') or ''
row = pd.DataFrame({url_col: [payload.url]})
feats = _engineer_features(row, url_col, feature_cols)
score = None
label = None
if isinstance(model_type, str) and model_type == 'xgboost_bst':
if xgb is None:
raise RuntimeError("xgboost is not installed but required for this model bundle.")
dmat = xgb.DMatrix(feats)
proba = float(model.predict(dmat)[0])
score = proba
label = "PHISH" if score >= 0.5 else "LEGIT"
elif hasattr(model, "predict_proba"):
proba = model.predict_proba(feats)[0]
if len(proba) == 2:
score = float(proba[1])
label = "PHISH" if score >= 0.5 else "LEGIT"
else:
max_idx = int(np.argmax(proba))
score = float(proba[max_idx])
label = "PHISH" if max_idx == 1 else "LEGIT"
else:
pred = model.predict(feats)[0]
if isinstance(pred, (int, float, np.integer, np.floating)):
label = "PHISH" if int(pred) == 1 else "LEGIT"
score = 1.0 if label == "PHISH" else 0.0
else:
up = str(pred).strip().upper()
if up in ("PHISH", "PHISHING", "MALICIOUS"):
label, score = "PHISH", 1.0
else:
label, score = "LEGIT", 0.0
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
return {"label": label, "score": float(score)}
|