| | import os |
| | os.environ.setdefault("HOME", "/data") |
| | os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache") |
| | os.environ.setdefault("HF_HOME", "/data/.cache") |
| | os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache") |
| | os.environ.setdefault("TORCH_HOME", "/data/.cache") |
| |
|
| | from fastapi import FastAPI |
| | from fastapi.responses import JSONResponse |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | from huggingface_hub import hf_hub_download |
| | import joblib |
| | import torch |
| | import re |
| | import numpy as np |
| | import pandas as pd |
| | try: |
| | import xgboost as xgb |
| | except Exception: |
| | xgb = None |
| |
|
| |
|
| | MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert") |
| | URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection") |
| | URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") |
| | |
| | URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib") |
| |
|
| | |
| | CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache") |
| | os.makedirs(CACHE_DIR, exist_ok=True) |
| |
|
| | app = FastAPI(title="Phishing Text Classifier", version="1.0.0") |
| |
|
| |
|
| | class PredictPayload(BaseModel): |
| | inputs: str |
| |
|
| |
|
| | |
| | _tokenizer = None |
| | _model = None |
| | _url_bundle = None |
| |
|
| |
|
| | def _load_url_model(): |
| | global _url_bundle |
| | if _url_bundle is None: |
| | |
| | local_path = os.path.join(os.getcwd(), URL_FILENAME) |
| | if os.path.exists(local_path): |
| | _url_bundle = joblib.load(local_path) |
| | return |
| | |
| | model_path = hf_hub_download( |
| | repo_id=URL_REPO, |
| | filename=URL_FILENAME, |
| | repo_type=URL_REPO_TYPE, |
| | cache_dir=CACHE_DIR, |
| | ) |
| | _url_bundle = joblib.load(model_path) |
| |
|
| |
|
| | |
| | _SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"] |
| | _ipv4_pattern = re.compile(r'(?:\d{1,3}\.){3}\d{1,3}') |
| |
|
| | def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: list[str] | None = None) -> pd.DataFrame: |
| | s = df[url_col].astype(str) |
| | out = pd.DataFrame(index=df.index) |
| | out['url_len'] = s.str.len().fillna(0) |
| | out['count_dot'] = s.str.count(r'\.') |
| | out['count_hyphen'] = s.str.count('-') |
| | out['count_digit'] = s.str.count(r'\d') |
| | out['count_at'] = s.str.count('@') |
| | out['count_qmark'] = s.str.count('\?') |
| | out['count_eq'] = s.str.count('=') |
| | out['count_slash'] = s.str.count('/') |
| | out['digit_ratio'] = (out['count_digit'] / out['url_len'].replace(0, np.nan)).fillna(0) |
| | out['has_ip'] = s.str.contains(_ipv4_pattern).astype(int) |
| | for tok in _SUSPICIOUS_TOKENS: |
| | out[f'has_{tok}'] = s.str.contains(tok, case=False, regex=False).astype(int) |
| | out['starts_https'] = s.str.startswith('https').astype(int) |
| | out['ends_with_exe'] = s.str.endswith('.exe').astype(int) |
| | out['ends_with_zip'] = s.str.endswith('.zip').astype(int) |
| | return out if feature_cols is None else out[feature_cols] |
| |
|
| |
|
| | def _load_model(): |
| | global _tokenizer, _model |
| | if _tokenizer is None or _model is None: |
| | _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR) |
| | _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR) |
| | |
| | with torch.no_grad(): |
| | _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits |
| |
|
| |
|
| | @app.get("/") |
| | def root(): |
| | return {"status": "ok", "model": MODEL_ID} |
| |
|
| |
|
| | @app.post("/predict") |
| | def predict(payload: PredictPayload): |
| | try: |
| | _load_model() |
| | with torch.no_grad(): |
| | inputs = _tokenizer([payload.inputs], return_tensors="pt", truncation=True, max_length=512) |
| | logits = _model(**inputs).logits |
| | probs = torch.softmax(logits, dim=-1)[0] |
| | score, idx = torch.max(probs, dim=0) |
| | except Exception as e: |
| | return JSONResponse(status_code=500, content={"error": str(e)}) |
| |
|
| | |
| | id2label = {0: "LEGIT", 1: "PHISH"} |
| | label = id2label.get(int(idx), str(int(idx))) |
| | return {"label": label, "score": float(score)} |
| |
|
| |
|
| | class PredictUrlPayload(BaseModel): |
| | url: str |
| |
|
| |
|
| | @app.post("/predict-url") |
| | def predict_url(payload: PredictUrlPayload): |
| | try: |
| | _load_url_model() |
| | bundle = _url_bundle |
| | if not isinstance(bundle, dict) or 'model' not in bundle: |
| | raise RuntimeError("Loaded URL artifact is not a bundle dict with 'model'.") |
| | model = bundle['model'] |
| | feature_cols = bundle.get('feature_cols') or [] |
| | url_col = bundle.get('url_col') or 'url' |
| | model_type = bundle.get('model_type') or '' |
| |
|
| | row = pd.DataFrame({url_col: [payload.url]}) |
| | feats = _engineer_features(row, url_col, feature_cols) |
| |
|
| | score = None |
| | label = None |
| |
|
| | if isinstance(model_type, str) and model_type == 'xgboost_bst': |
| | if xgb is None: |
| | raise RuntimeError("xgboost is not installed but required for this model bundle.") |
| | dmat = xgb.DMatrix(feats) |
| | proba = float(model.predict(dmat)[0]) |
| | score = proba |
| | label = "PHISH" if score >= 0.5 else "LEGIT" |
| | elif hasattr(model, "predict_proba"): |
| | proba = model.predict_proba(feats)[0] |
| | if len(proba) == 2: |
| | score = float(proba[1]) |
| | label = "PHISH" if score >= 0.5 else "LEGIT" |
| | else: |
| | max_idx = int(np.argmax(proba)) |
| | score = float(proba[max_idx]) |
| | label = "PHISH" if max_idx == 1 else "LEGIT" |
| | else: |
| | pred = model.predict(feats)[0] |
| | if isinstance(pred, (int, float, np.integer, np.floating)): |
| | label = "PHISH" if int(pred) == 1 else "LEGIT" |
| | score = 1.0 if label == "PHISH" else 0.0 |
| | else: |
| | up = str(pred).strip().upper() |
| | if up in ("PHISH", "PHISHING", "MALICIOUS"): |
| | label, score = "PHISH", 1.0 |
| | else: |
| | label, score = "LEGIT", 0.0 |
| | except Exception as e: |
| | return JSONResponse(status_code=500, content={"error": str(e)}) |
| |
|
| | return {"label": label, "score": float(score)} |
| |
|
| |
|
| |
|