Perth0603 commited on
Commit
bdde6ee
·
verified ·
1 Parent(s): d506ae1

Upload 7 files

Browse files
Files changed (2) hide show
  1. Dockerfile +8 -0
  2. app.py +244 -558
Dockerfile CHANGED
@@ -21,6 +21,14 @@ COPY requirements.txt /app/requirements.txt
21
  RUN pip install -r /app/requirements.txt
22
 
23
  COPY app.py /app/app.py
 
 
 
 
 
 
 
 
24
 
25
  EXPOSE 7860
26
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
21
  RUN pip install -r /app/requirements.txt
22
 
23
  COPY app.py /app/app.py
24
+ COPY autocalib_phishy.csv /app/autocalib_phishy.csv
25
+ COPY autocalib_legit.csv /app/autocalib_legit.csv
26
+ COPY known_hosts.csv /app/known_hosts.csv
27
+
28
+ # Default CSV envs to follow CSVs in image (can be overridden in Space settings)
29
+ ENV AUTOCALIB_PHISHY_CSV=/app/autocalib_phishy.csv \
30
+ AUTOCALIB_LEGIT_CSV=/app/autocalib_legit.csv \
31
+ KNOWN_HOSTS_CSV=/app/known_hosts.csv
32
 
33
  EXPOSE 7860
34
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,558 +1,244 @@
1
- import os
2
- os.environ.setdefault("HOME", "/data")
3
- os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
- os.environ.setdefault("HF_HOME", "/data/.cache")
5
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
- os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
-
8
- from typing import Optional, List, Dict, Any
9
- import csv
10
- from urllib.parse import urlparse
11
- import threading
12
- import re
13
- import numpy as np
14
- import pandas as pd
15
- import joblib
16
- import torch
17
- from fastapi import FastAPI
18
- from fastapi.responses import JSONResponse
19
- from pydantic import BaseModel
20
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
- from huggingface_hub import hf_hub_download
22
-
23
- try:
24
- import xgboost as xgb # type: ignore
25
- except Exception:
26
- xgb = None
27
-
28
- # -------------------------
29
- # Environment / config
30
- # -------------------------
31
- MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
32
- # Support both legacy and HF_* envs
33
- URL_REPO = os.environ.get("HF_URL_MODEL_ID", os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection"))
34
- URL_REPO_TYPE = os.environ.get("HF_URL_REPO_TYPE", os.environ.get("URL_REPO_TYPE", "model")) # model|space|dataset
35
- URL_FILENAME = os.environ.get("HF_URL_FILENAME", os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib"))
36
- CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
37
- os.makedirs(CACHE_DIR, exist_ok=True)
38
-
39
- # Force-thread cap helps tiny Spaces
40
- torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "1")))
41
-
42
- # Optional manual override (beats everything): "PHISH" or "LEGIT"
43
- URL_POSITIVE_CLASS_ENV = os.environ.get("URL_POSITIVE_CLASS", "").strip().upper() # "", "PHISH", "LEGIT"
44
-
45
- app = FastAPI(title="PhishWatch API", version="1.2.0")
46
-
47
- # -------------------------
48
- # Schemas
49
- # -------------------------
50
- class PredictPayload(BaseModel):
51
- inputs: str
52
-
53
- class PredictUrlPayload(BaseModel):
54
- url: str
55
-
56
- # -------------------------
57
- # Lazy singletons
58
- # -------------------------
59
- _tokenizer: Optional[AutoTokenizer] = None
60
- _model: Optional[AutoModelForSequenceClassification] = None
61
- _id2label: Dict[int, str] = {0: "LEGIT", 1: "PHISH"}
62
- _label2id: Dict[str, int] = {"LEGIT": 0, "PHISH": 1}
63
-
64
- _url_bundle: Optional[Dict[str, Any]] = None
65
- _model_lock = threading.Lock()
66
- _url_lock = threading.Lock()
67
-
68
- # Calibrated flag: is XGB class 1 == PHISH?
69
- _url_phish_is_positive: Optional[bool] = None
70
-
71
- # -------------------------
72
- # Autocalibration URL prototypes (CSV-driven)
73
- # -------------------------
74
- # Provide CSV files for calibration lists to avoid code edits:
75
- # - AUTOCALIB_PHISHY_CSV (default hf_space/autocalib_phishy.csv)
76
- # - AUTOCALIB_LEGIT_CSV (default hf_space/autocalib_legit.csv)
77
- # These lists are loaded at startup and before each request (hot-reload safe).
78
- _AUTOCALIB_PHISHY_URLS: List[str] = []
79
- _AUTOCALIB_LEGIT_URLS: List[str] = []
80
-
81
- # Known host overrides via CSV (suffix-matched):
82
- # - KNOWN_HOSTS_CSV (default hf_space/known_hosts.csv) with columns host,label
83
- _KNOWN_LEGIT_HOSTS: List[str] = []
84
- _KNOWN_PHISH_HOSTS: List[str] = []
85
-
86
- # Helpers to normalize and match hosts by suffix (handles subdomains)
87
- def _normalize_host(value: str) -> str:
88
- v = value.strip().lower()
89
- if v.startswith("www."):
90
- v = v[4:]
91
- return v
92
-
93
- def _host_matches_any(host: str, known: List[str]) -> bool:
94
- base = _normalize_host(host)
95
- for item in known:
96
- k = _normalize_host(item)
97
- if base == k or base.endswith("." + k):
98
- return True
99
- return False
100
-
101
- # Optional CSV configuration
102
- def _read_urls_from_csv(path: str) -> List[str]:
103
- urls: List[str] = []
104
- try:
105
- with open(path, newline="", encoding="utf-8") as f:
106
- reader = csv.DictReader(f)
107
- if "url" in (reader.fieldnames or []):
108
- for row in reader:
109
- val = str(row.get("url", "")).strip()
110
- if val:
111
- urls.append(val)
112
- else:
113
- f.seek(0)
114
- f2 = csv.reader(f)
115
- for row in f2:
116
- if not row:
117
- continue
118
- val = str(row[0]).strip()
119
- if val.lower() == "url":
120
- continue
121
- if val:
122
- urls.append(val)
123
- except Exception as e:
124
- print(f"[csv] failed reading URLs from {path}: {e}")
125
- return urls
126
-
127
- def _read_hosts_from_csv(path: str) -> Dict[str, str]:
128
- host_to_label: Dict[str, str] = {}
129
- try:
130
- with open(path, newline="", encoding="utf-8") as f:
131
- reader = csv.DictReader(f)
132
- fields = [x.lower() for x in (reader.fieldnames or [])]
133
- if "host" in fields and "label" in fields:
134
- for row in reader:
135
- host = str(row.get("host", "")).strip().lower()
136
- label = str(row.get("label", "")).strip().upper()
137
- if host and label in ("PHISH", "LEGIT"):
138
- host_to_label[host] = label
139
- else:
140
- f.seek(0)
141
- f2 = csv.reader(f)
142
- for row in f2:
143
- if len(row) < 2:
144
- continue
145
- host = str(row[0]).strip().lower()
146
- label = str(row[1]).strip().upper()
147
- if host.lower() == "host" and label == "LABEL":
148
- continue
149
- if host and label in ("PHISH", "LEGIT"):
150
- host_to_label[host] = label
151
- except Exception as e:
152
- print(f"[csv] failed reading hosts from {path}: {e}")
153
- return host_to_label
154
-
155
- def _load_csv_configs_if_any():
156
- base_dir = os.path.dirname(__file__)
157
- phishy_csv = os.environ.get("AUTOCALIB_PHISHY_CSV", os.path.join(base_dir, "autocalib_phishy.csv"))
158
- legit_csv = os.environ.get("AUTOCALIB_LEGIT_CSV", os.path.join(base_dir, "autocalib_legit.csv"))
159
- hosts_csv = os.environ.get("KNOWN_HOSTS_CSV", os.path.join(base_dir, "known_hosts.csv"))
160
-
161
- if os.path.exists(phishy_csv):
162
- urls = _read_urls_from_csv(phishy_csv)
163
- if urls:
164
- print(f"[csv] loaded phishy URLs: {len(urls)} from {phishy_csv}")
165
- _AUTOCALIB_PHISHY_URLS[:] = urls
166
- if os.path.exists(legit_csv):
167
- urls = _read_urls_from_csv(legit_csv)
168
- if urls:
169
- print(f"[csv] loaded legit URLs: {len(urls)} from {legit_csv}")
170
- _AUTOCALIB_LEGIT_URLS[:] = urls
171
- if os.path.exists(hosts_csv):
172
- mapping = _read_hosts_from_csv(hosts_csv)
173
- if mapping:
174
- print(f"[csv] loaded known hosts: {len(mapping)} from {hosts_csv}")
175
- _KNOWN_LEGIT_HOSTS.clear()
176
- _KNOWN_PHISH_HOSTS.clear()
177
- for host, label in mapping.items():
178
- if label == "LEGIT":
179
- _KNOWN_LEGIT_HOSTS.append(host)
180
- elif label == "PHISH":
181
- _KNOWN_PHISH_HOSTS.append(host)
182
-
183
- # -------------------------
184
- # URL features (must match training)
185
- # -------------------------
186
- _SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]
187
- _ipv4_pattern = re.compile(r"(?:\d{1,3}\.){3}\d{1,3}")
188
-
189
- def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: Optional[List[str]] = None) -> pd.DataFrame:
190
- s = df[url_col].astype(str).fillna("")
191
- out = pd.DataFrame(index=df.index)
192
- out["url_len"] = s.str.len()
193
- out["count_dot"] = s.str.count(r"\.")
194
- out["count_hyphen"] = s.str.count("-")
195
- out["count_digit"] = s.str.count(r"\d")
196
- out["count_at"] = s.str.count("@")
197
- out["count_qmark"] = s.str.count(r"\?")
198
- out["count_eq"] = s.str.count("=")
199
- out["count_slash"] = s.str.count("/")
200
- out["digit_ratio"] = (out["count_digit"] / out["url_len"].replace(0, np.nan)).fillna(0)
201
- out["has_ip"] = s.str.contains(_ipv4_pattern).fillna(False).astype(int)
202
- for tok in _SUSPICIOUS_TOKENS:
203
- out[f"has_{tok}"] = s.str.contains(tok, case=False, regex=False).fillna(False).astype(int)
204
- out["starts_https"] = s.str.startswith("https").astype(int)
205
- out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
206
- out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
207
- return out if not feature_cols else out[feature_cols]
208
-
209
- # -------------------------
210
- # Loaders
211
- # -------------------------
212
- def _load_model():
213
- global _tokenizer, _model, _id2label, _label2id, _text_phish_id
214
- if _tokenizer is None or _model is None:
215
- with _model_lock:
216
- if _tokenizer is None or _model is None:
217
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
218
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
219
- cfg = getattr(_model, "config", None)
220
- if cfg is not None and getattr(cfg, "id2label", None):
221
- _id2label = {int(k): v for k, v in cfg.id2label.items()}
222
- _label2id = {v: int(k) for k, v in _id2label.items()}
223
- # Try to detect which index corresponds to PHISH/SPAM
224
- _text_phish_id = _detect_text_phish_id(_id2label)
225
- with torch.no_grad():
226
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
227
-
228
- # Detect which label id corresponds to phishing for text models based on label strings
229
- _text_phish_id: Optional[int] = None
230
-
231
- def _detect_text_phish_id(id2label: Dict[int, str]) -> Optional[int]:
232
- candidates_phish = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
233
- # Prefer explicit PHISH/SPAM over generic POSITIVE
234
- priority_order = ("PHISH", "SPAM", "MALICIOUS", "POSITIVE")
235
- norm = {k: str(v).strip().upper() for k, v in id2label.items()}
236
- # exact/substring match in priority order
237
- for token in priority_order:
238
- for k, v in norm.items():
239
- if token in v:
240
- return int(k)
241
- return None
242
-
243
- def _load_url_model():
244
- global _url_bundle
245
- if _url_bundle is None:
246
- with _url_lock:
247
- if _url_bundle is None:
248
- local_path = os.path.join(os.getcwd(), URL_FILENAME)
249
- if os.path.exists(local_path):
250
- _url_bundle = joblib.load(local_path)
251
- else:
252
- model_path = hf_hub_download(
253
- repo_id=URL_REPO,
254
- filename=URL_FILENAME,
255
- repo_type=URL_REPO_TYPE,
256
- cache_dir=CACHE_DIR,
257
- )
258
- _url_bundle = joblib.load(model_path)
259
-
260
- def _xgb_predict_class1_prob(booster, feats: pd.DataFrame) -> float:
261
- # predicts P(class==1) under binary:logistic objective
262
- dmat = xgb.DMatrix(feats)
263
- return float(booster.predict(dmat)[0])
264
-
265
- def _auto_calibrate_phish_positive(bundle: Dict[str, Any], feature_cols: List[str], url_col: str) -> bool:
266
- """
267
- Heuristic: probe with 'obviously phishy' and 'obviously legit' URLs.
268
- If mean P(class1) for phishy < legit, then class1 ≈ LEGIT → return False.
269
- Otherwise, class1 ≈ PHISH → return True.
270
- """
271
- # If user forces it via env, honor that first.
272
- if URL_POSITIVE_CLASS_ENV in ("PHISH", "LEGIT"):
273
- return URL_POSITIVE_CLASS_ENV == "PHISH"
274
-
275
- # If bundle has explicit flag, use it.
276
- if "phish_is_positive" in bundle:
277
- return bool(bundle["phish_is_positive"])
278
-
279
- phishy = _AUTOCALIB_PHISHY_URLS
280
- legit = _AUTOCALIB_LEGIT_URLS
281
- # Safe fallback if CSVs are missing/empty
282
- if not phishy:
283
- phishy = [
284
- "http://198.51.100.23/login/update?acc=123",
285
- "http://secure-login-account-update.example.com/session?id=123",
286
- "http://bank.verify-update-security.com/confirm",
287
- "http://paypal.com.account-verify.cn/login",
288
- "http://abc.xyz/downloads/invoice.exe",
289
- ]
290
- if not legit:
291
- legit = [
292
- "https://www.wikipedia.org/",
293
- "https://www.microsoft.com/",
294
- "https://www.python.org/",
295
- "https://www.openai.com/",
296
- ]
297
-
298
- model = bundle.get("model")
299
- model_type: str = str(bundle.get("model_type") or "")
300
-
301
- def _batch_mean(urls: List[str]) -> float:
302
- df = pd.DataFrame({url_col: urls})
303
- feats = _engineer_features(df, url_col, feature_cols)
304
- # XGBoost booster path
305
- if model_type == "xgboost_bst" and xgb is not None:
306
- try:
307
- # Predict row-by-row to be conservative about input formats
308
- return float(np.mean([_xgb_predict_class1_prob(model, pd.DataFrame([feats.iloc[i]])) for i in range(len(feats))]))
309
- except Exception:
310
- pass
311
- # scikit-learn-like path with predict_proba
312
- if hasattr(model, "predict_proba"):
313
- proba = model.predict_proba(feats)
314
- classes = bundle.get("classes", getattr(model, "classes_", None))
315
- class1_idx = 1
316
- if classes is not None:
317
- try:
318
- classes_list = list(classes)
319
- if 1 in classes_list:
320
- class1_idx = classes_list.index(1)
321
- else:
322
- class1_idx = 1 if len(classes_list) > 1 else 0
323
- except Exception:
324
- class1_idx = 1 if proba.shape[1] > 1 else 0
325
- return float(np.mean(proba[:, class1_idx]))
326
- # Fallback: use hard predictions and treat label==1 as prob 1
327
- try:
328
- preds = model.predict(feats)
329
- vals: List[float] = []
330
- for p in preds:
331
- if isinstance(p, (int, float, np.integer, np.floating)):
332
- vals.append(1.0 if int(p) == 1 else 0.0)
333
- else:
334
- up = str(p).strip().upper()
335
- vals.append(1.0 if up.startswith("PHISH") or up == "1" else 0.0)
336
- return float(np.mean(vals)) if vals else 0.0
337
- except Exception:
338
- return 0.0
339
-
340
- try:
341
- phishy_mean = _batch_mean(phishy)
342
- legit_mean = _batch_mean(legit)
343
- except Exception as e:
344
- # If anything goes wrong, default to class1=PHISH to mimic common convention
345
- print(f"[autocalib] failed: {e}")
346
- return True
347
-
348
- # If phishy scores LOWER than legit for class1, then class1 is likely LEGIT
349
- class1_is_phish = phishy_mean > legit_mean
350
- print(f"[autocalib] phishy_mean={phishy_mean:.6f} legit_mean={legit_mean:.6f} -> class1_is_phish={class1_is_phish}")
351
- return class1_is_phish
352
-
353
- # Optional: pre-load on startup
354
- @app.on_event("startup")
355
- def _startup():
356
- try:
357
- _load_model()
358
- except Exception as e:
359
- print(f"[startup] text model load failed: {e}")
360
- try:
361
- _load_url_model()
362
- # Load CSV-driven config if present
363
- _load_csv_configs_if_any()
364
- global _url_phish_is_positive
365
- b = _url_bundle
366
- if isinstance(b, dict) and _url_phish_is_positive is None:
367
- try:
368
- feature_cols: List[str] = b.get("feature_cols") or []
369
- url_col: str = b.get("url_col") or "url"
370
- _url_phish_is_positive = _auto_calibrate_phish_positive(b, feature_cols, url_col)
371
- except Exception as ce:
372
- print(f"[startup] url model calibration failed: {ce}")
373
- except Exception as e:
374
- print(f"[startup] url model load failed: {e}")
375
-
376
- # -------------------------
377
- # Routes
378
- # -------------------------
379
- @app.get("/")
380
- def root():
381
- return {"status": "ok", "model": MODEL_ID}
382
-
383
- @app.post("/predict")
384
- def predict(payload: PredictPayload):
385
- try:
386
- _load_model()
387
- text = (payload.inputs or "").strip()
388
- if not text:
389
- return JSONResponse(status_code=400, content={"error": "Empty input"})
390
- with torch.no_grad():
391
- inputs = _tokenizer([text], return_tensors="pt", truncation=True, max_length=512)
392
- logits = _model(**inputs).logits
393
- probs = torch.softmax(logits, dim=-1)[0]
394
- score, idx = torch.max(probs, dim=0)
395
-
396
- # Normalize label to PHISH/LEGIT if we could detect PHISH id
397
- if _text_phish_id is not None and 0 <= _text_phish_id < probs.shape[0]:
398
- phish_prob = float(probs[_text_phish_id])
399
- norm_label = "PHISH" if phish_prob >= 0.5 else "LEGIT"
400
- norm_score = phish_prob if norm_label == "PHISH" else (1.0 - phish_prob)
401
- return {"label": norm_label, "score": float(norm_score), "raw_index": int(idx)}
402
- else:
403
- # Fallback to model's provided labels
404
- label = _id2label.get(int(idx), str(int(idx)))
405
- return {"label": label, "score": float(score), "raw_index": int(idx)}
406
- except Exception as e:
407
- return JSONResponse(status_code=500, content={"error": str(e)})
408
-
409
- @app.post("/predict-url")
410
- def predict_url(payload: PredictUrlPayload):
411
- try:
412
- _load_url_model()
413
- # Load CSV-based config if present (hot-reload safe)
414
- _load_csv_configs_if_any()
415
- bundle = _url_bundle
416
- if not isinstance(bundle, dict) or "model" not in bundle:
417
- raise RuntimeError("Loaded URL artifact is not a bundle dict with 'model'.")
418
-
419
- model = bundle["model"]
420
- feature_cols: List[str] = bundle.get("feature_cols") or []
421
- url_col: str = bundle.get("url_col") or "url"
422
- model_type: str = bundle.get("model_type") or ""
423
-
424
- url_str = (payload.url or "").strip()
425
- if not url_str:
426
- return JSONResponse(status_code=400, content={"error": "Empty url"})
427
-
428
- row = pd.DataFrame({url_col: [url_str]})
429
- feats = _engineer_features(row, url_col, feature_cols)
430
-
431
- # ----- compute P(PHISH) -----
432
- phish_proba: float = 0.0
433
- meta_phish_is_positive: Optional[bool] = bundle.get("phish_is_positive", None)
434
-
435
- # Resolve polarity precedence: ENV > bundle flag > auto-calibration > default True
436
- if URL_POSITIVE_CLASS_ENV in ("PHISH", "LEGIT"):
437
- phish_is_positive = (URL_POSITIVE_CLASS_ENV == "PHISH")
438
- elif meta_phish_is_positive is not None:
439
- phish_is_positive = bool(meta_phish_is_positive)
440
- else:
441
- global _url_phish_is_positive
442
- if _url_phish_is_positive is None:
443
- try:
444
- _url_phish_is_positive = _auto_calibrate_phish_positive(bundle, feature_cols, url_col)
445
- except Exception as ce:
446
- print(f"[predict-url] auto-calibration failed: {ce}")
447
- phish_is_positive = _url_phish_is_positive if _url_phish_is_positive is not None else True
448
-
449
- backend_debug = {
450
- "phish_is_positive_resolved": phish_is_positive,
451
- "phish_is_positive_bundle": meta_phish_is_positive,
452
- "phish_is_positive_env": URL_POSITIVE_CLASS_ENV if URL_POSITIVE_CLASS_ENV else None,
453
- }
454
-
455
- # Known-domain override after polarity is resolved
456
- host = (urlparse(url_str).hostname or "").lower()
457
- if host:
458
- override_label: Optional[str] = None
459
- if _host_matches_any(host, _KNOWN_LEGIT_HOSTS):
460
- override_label = "LEGIT"
461
- elif _host_matches_any(host, _KNOWN_PHISH_HOSTS):
462
- override_label = "PHISH"
463
- if override_label is not None:
464
- # Map numeric label according to resolved polarity
465
- predicted_label_numeric = 1 if ((override_label == "PHISH") == bool(phish_is_positive)) else 0
466
- phish_proba_override = 0.99 if override_label == "PHISH" else 0.01
467
- score_override = phish_proba_override if override_label == "PHISH" else (1.0 - phish_proba_override)
468
- return {
469
- "label": override_label,
470
- "predicted_label": int(predicted_label_numeric),
471
- "score": float(score_override),
472
- "phishing_probability": float(phish_proba_override),
473
- "backend": str(model_type),
474
- "threshold": 0.5,
475
- "override": {
476
- "reason": "known_host",
477
- "host": host,
478
- },
479
- "phish_is_positive": bool(phish_is_positive),
480
- "phish_is_positive_bundle": meta_phish_is_positive,
481
- "phish_is_positive_env": URL_POSITIVE_CLASS_ENV if URL_POSITIVE_CLASS_ENV else None,
482
- "feature_cols": feature_cols,
483
- "url_col": url_col,
484
- }
485
-
486
- raw_p_class1_debug: Optional[float] = None
487
-
488
- if isinstance(model_type, str) and model_type == "xgboost_bst":
489
- if xgb is None:
490
- raise RuntimeError("xgboost is not installed but required for this model bundle.")
491
- dmat = xgb.DMatrix(feats)
492
- raw_p_class1 = float(model.predict(dmat)[0]) # P(class == 1)
493
- raw_p_class1_debug = raw_p_class1
494
- phish_proba = raw_p_class1 if phish_is_positive else (1.0 - raw_p_class1)
495
-
496
- elif hasattr(model, "predict_proba"):
497
- proba = model.predict_proba(feats)[0]
498
- classes = bundle.get("classes", getattr(model, "classes_", None))
499
- label_map = bundle.get("label_map")
500
- if classes is not None and len(proba) == 2:
501
- classes_list = list(classes)
502
- phish_idx = None
503
- if isinstance(label_map, dict):
504
- for i, c in enumerate(classes_list):
505
- mapped = str(label_map.get(int(c), "")).upper()
506
- if mapped.startswith("PHISH"):
507
- phish_idx = i
508
- break
509
- if phish_idx is None:
510
- # fall back to whichever index matches current polarity
511
- # if phish_is_positive → column for class 1, else column for class 0
512
- target_class = 1 if phish_is_positive else 0
513
- if target_class in classes_list:
514
- phish_idx = classes_list.index(target_class)
515
- else:
516
- phish_idx = 1 if phish_is_positive else 0
517
- phish_proba = float(proba[phish_idx])
518
- else:
519
- phish_proba = float(proba[1]) if len(proba) > 1 else float(np.max(proba))
520
-
521
- else:
522
- pred = model.predict(feats)[0]
523
- if isinstance(pred, (int, float, np.integer, np.floating)):
524
- label_numeric = int(pred)
525
- # interpret through polarity
526
- if label_numeric in (0, 1):
527
- phish_proba = 1.0 if ((label_numeric == 1) == phish_is_positive) else 0.0
528
- else:
529
- phish_proba = float(label_numeric) # best-effort
530
- else:
531
- up = str(pred).strip().upper()
532
- phish_proba = 1.0 if up.startswith("PHISH") else 0.0
533
-
534
- phish_proba = float(phish_proba)
535
- label = "PHISH" if phish_proba >= 0.5 else "LEGIT"
536
- score = phish_proba if label == "PHISH" else (1.0 - phish_proba)
537
- # Map to numeric dataset-style label using resolved polarity
538
- # If PHISH is the positive (class 1), PHISH -> 1 else 0; if not, invert
539
- predicted_label_numeric = 1 if ((label == "PHISH") == bool(phish_is_positive)) else 0
540
-
541
- return {
542
- "label": label,
543
- "predicted_label": int(predicted_label_numeric),
544
- "score": float(score),
545
- "phishing_probability": float(phish_proba),
546
- "backend": str(model_type),
547
- "threshold": 0.5,
548
- # Debug/trace so you can see exactly what was used
549
- "phish_is_positive": bool(phish_is_positive),
550
- "phish_is_positive_bundle": meta_phish_is_positive,
551
- "phish_is_positive_env": URL_POSITIVE_CLASS_ENV if URL_POSITIVE_CLASS_ENV else None,
552
- "raw_proba_class1": float(raw_p_class1_debug) if raw_p_class1_debug is not None else None,
553
- "feature_cols": feature_cols,
554
- "url_col": url_col,
555
- }
556
-
557
- except Exception as e:
558
- return JSONResponse(status_code=500, content={"error": str(e)})
 
1
+ import os
2
+ import csv
3
+ import re
4
+ import threading
5
+ from typing import Optional, List, Dict, Any
6
+
7
+ import joblib
8
+ import numpy as np
9
+ import pandas as pd
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import JSONResponse
12
+ from huggingface_hub import hf_hub_download
13
+ from pydantic import BaseModel
14
+ from urllib.parse import urlparse
15
+
16
+ try:
17
+ import xgboost as xgb # type: ignore
18
+ except Exception:
19
+ xgb = None
20
+
21
+
22
+ # Environment defaults suitable for HF Spaces
23
+ os.environ.setdefault("HOME", "/data")
24
+ os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
25
+ os.environ.setdefault("HF_HOME", "/data/.cache")
26
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
27
+ os.environ.setdefault("TORCH_HOME", "/data/.cache")
28
+
29
+
30
+ # Config
31
+ URL_REPO = os.environ.get(
32
+ "HF_URL_MODEL_ID",
33
+ os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection"),
34
+ )
35
+ URL_REPO_TYPE = os.environ.get("HF_URL_REPO_TYPE", os.environ.get("URL_REPO_TYPE", "model"))
36
+ URL_FILENAME = os.environ.get("HF_URL_FILENAME", os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib"))
37
+ CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
38
+ os.makedirs(CACHE_DIR, exist_ok=True)
39
+
40
+ # Polarity override: "PHISH" or "LEGIT"; empty means default (class 1 = PHISH)
41
+ URL_POSITIVE_CLASS_ENV = os.environ.get("URL_POSITIVE_CLASS", "").strip().upper()
42
+
43
+ # CSV configuration (defaults to files in same directory)
44
+ BASE_DIR = os.path.dirname(__file__)
45
+ AUTOCALIB_PHISHY_CSV = os.environ.get("AUTOCALIB_PHISHY_CSV", os.path.join(BASE_DIR, "autocalib_phishy.csv"))
46
+ AUTOCALIB_LEGIT_CSV = os.environ.get("AUTOCALIB_LEGIT_CSV", os.path.join(BASE_DIR, "autocalib_legit.csv"))
47
+ KNOWN_HOSTS_CSV = os.environ.get("KNOWN_HOSTS_CSV", os.path.join(BASE_DIR, "known_hosts.csv"))
48
+
49
+
50
+ app = FastAPI(title="PhishWatch URL API", version="2.0.0")
51
+
52
+
53
+ class PredictUrlPayload(BaseModel):
54
+ url: str
55
+
56
+
57
+ _url_bundle: Optional[Dict[str, Any]] = None
58
+ _url_lock = threading.Lock()
59
+
60
+
61
+ def _normalize_host(value: str) -> str:
62
+ v = value.strip().lower()
63
+ if v.startswith("www."):
64
+ v = v[4:]
65
+ return v
66
+
67
+
68
+ def _host_matches_any(host: str, known: List[str]) -> bool:
69
+ base = _normalize_host(host)
70
+ for item in known:
71
+ k = _normalize_host(item)
72
+ if base == k or base.endswith("." + k):
73
+ return True
74
+ return False
75
+
76
+
77
+ def _read_urls_from_csv(path: str) -> List[str]:
78
+ urls: List[str] = []
79
+ try:
80
+ with open(path, newline="", encoding="utf-8") as f:
81
+ reader = csv.DictReader(f)
82
+ if "url" in (reader.fieldnames or []):
83
+ for row in reader:
84
+ val = str(row.get("url", "")).strip()
85
+ if val:
86
+ urls.append(val)
87
+ else:
88
+ f.seek(0)
89
+ f2 = csv.reader(f)
90
+ for row in f2:
91
+ if not row:
92
+ continue
93
+ val = str(row[0]).strip()
94
+ if val.lower() == "url":
95
+ continue
96
+ if val:
97
+ urls.append(val)
98
+ except FileNotFoundError:
99
+ pass
100
+ except Exception as e:
101
+ print(f"[csv] failed reading URLs from {path}: {e}")
102
+ return urls
103
+
104
+
105
+ def _read_hosts_from_csv(path: str) -> Dict[str, str]:
106
+ out: Dict[str, str] = {}
107
+ try:
108
+ with open(path, newline="", encoding="utf-8") as f:
109
+ reader = csv.DictReader(f)
110
+ fields = [x.lower() for x in (reader.fieldnames or [])]
111
+ if "host" in fields and "label" in fields:
112
+ for row in reader:
113
+ host = str(row.get("host", "")).strip()
114
+ label = str(row.get("label", "")).strip().upper()
115
+ if host and label in ("PHISH", "LEGIT"):
116
+ out[host] = label
117
+ except FileNotFoundError:
118
+ pass
119
+ except Exception as e:
120
+ print(f"[csv] failed reading hosts from {path}: {e}")
121
+ return out
122
+
123
+
124
+ def _engineer_features(urls: List[str], feature_cols: List[str]) -> pd.DataFrame:
125
+ s = pd.Series(urls, dtype=str)
126
+ out = pd.DataFrame()
127
+ out["url_len"] = s.str.len().fillna(0)
128
+ out["count_dot"] = s.str.count(r"\.")
129
+ out["count_hyphen"] = s.str.count("-")
130
+ out["count_digit"] = s.str.count(r"\d")
131
+ out["count_at"] = s.str.count("@")
132
+ out["count_qmark"] = s.str.count(r"\?")
133
+ out["count_eq"] = s.str.count("=")
134
+ out["count_slash"] = s.str.count("/")
135
+ out["digit_ratio"] = (out["count_digit"] / out["url_len"].replace(0, np.nan)).fillna(0)
136
+ out["has_ip"] = s.str.contains(r"(?:\d{1,3}\.){3}\d{1,3}").astype(int)
137
+ for tok in ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]:
138
+ out[f"has_{tok}"] = s.str.contains(tok, case=False, regex=False).astype(int)
139
+ out["starts_https"] = s.str.startswith("https").astype(int)
140
+ out["ends_with_exe"] = s.str.endswith(".exe").astype(int)
141
+ out["ends_with_zip"] = s.str.endswith(".zip").astype(int)
142
+ return out[feature_cols]
143
+
144
+
145
+ def _load_url_model():
146
+ global _url_bundle
147
+ if _url_bundle is None:
148
+ with _url_lock:
149
+ if _url_bundle is None:
150
+ local_path = os.path.join(os.getcwd(), URL_FILENAME)
151
+ if os.path.exists(local_path):
152
+ _url_bundle = joblib.load(local_path)
153
+ else:
154
+ model_path = hf_hub_download(
155
+ repo_id=URL_REPO,
156
+ filename=URL_FILENAME,
157
+ repo_type=URL_REPO_TYPE,
158
+ cache_dir=CACHE_DIR,
159
+ )
160
+ _url_bundle = joblib.load(model_path)
161
+
162
+
163
+ @app.get("/")
164
+ def root():
165
+ return {"status": "ok", "backend": "url-only"}
166
+
167
+
168
+ @app.post("/predict-url")
169
+ def predict_url(payload: PredictUrlPayload):
170
+ try:
171
+ _load_url_model()
172
+
173
+ # Load CSVs on every request (keeps behavior in sync without code edits)
174
+ phishy_list = _read_urls_from_csv(AUTOCALIB_PHISHY_CSV)
175
+ legit_list = _read_urls_from_csv(AUTOCALIB_LEGIT_CSV)
176
+ host_map = _read_hosts_from_csv(KNOWN_HOSTS_CSV)
177
+
178
+ bundle = _url_bundle
179
+ if not isinstance(bundle, dict) or "model" not in bundle:
180
+ raise RuntimeError("Loaded URL artifact is not a bundle dict with 'model'.")
181
+
182
+ model = bundle["model"]
183
+ feature_cols: List[str] = bundle.get("feature_cols") or []
184
+ url_col: str = bundle.get("url_col") or "url"
185
+ model_type: str = bundle.get("model_type") or ""
186
+
187
+ url_str = (payload.url or "").strip()
188
+ if not url_str:
189
+ return JSONResponse(status_code=400, content={"error": "Empty url"})
190
+
191
+ # Known-host override (suffix match)
192
+ host = (urlparse(url_str).hostname or "").lower()
193
+ if host and host_map:
194
+ for h, lbl in host_map.items():
195
+ if _host_matches_any(host, [h]):
196
+ phish_is_positive = True if URL_POSITIVE_CLASS_ENV == "" else (URL_POSITIVE_CLASS_ENV == "PHISH")
197
+ label = lbl
198
+ predicted_label = 1 if ((label == "PHISH") == phish_is_positive) else 0
199
+ phish_proba = 0.99 if label == "PHISH" else 0.01
200
+ score = phish_proba if label == "PHISH" else (1.0 - phish_proba)
201
+ return {
202
+ "label": label,
203
+ "predicted_label": int(predicted_label),
204
+ "score": float(score),
205
+ "phishing_probability": float(phish_proba),
206
+ "backend": str(model_type),
207
+ "threshold": 0.5,
208
+ "url_col": url_col,
209
+ }
210
+
211
+ # Mirror inference.py exactly for probability of class 1
212
+ feats = _engineer_features([url_str], feature_cols)
213
+ if model_type == "xgboost_bst":
214
+ if xgb is None:
215
+ raise RuntimeError("xgboost not installed")
216
+ dmat = xgb.DMatrix(feats)
217
+ raw_p_class1 = float(model.predict(dmat)[0])
218
+ elif hasattr(model, "predict_proba"):
219
+ raw_p_class1 = float(model.predict_proba(feats)[:, 1][0])
220
+ else:
221
+ pred = model.predict(feats)[0]
222
+ raw_p_class1 = 1.0 if int(pred) == 1 else 0.0
223
+
224
+ # Polarity: strictly env or default (class1==PHISH)
225
+ phish_is_positive = True if URL_POSITIVE_CLASS_ENV == "" else (URL_POSITIVE_CLASS_ENV == "PHISH")
226
+
227
+ phish_proba = raw_p_class1 if phish_is_positive else (1.0 - raw_p_class1)
228
+ label = "PHISH" if phish_proba >= 0.5 else "LEGIT"
229
+ predicted_label = 1 if ((label == "PHISH") == phish_is_positive) else 0
230
+ score = phish_proba if label == "PHISH" else (1.0 - phish_proba)
231
+
232
+ return {
233
+ "label": label,
234
+ "predicted_label": int(predicted_label),
235
+ "score": float(score),
236
+ "phishing_probability": float(phish_proba),
237
+ "backend": str(model_type),
238
+ "threshold": 0.5,
239
+ "url_col": url_col,
240
+ }
241
+ except Exception as e:
242
+ return JSONResponse(status_code=500, content={"error": str(e)})
243
+
244
+