Perth0603 commited on
Commit
c44f1d8
·
verified ·
1 Parent(s): 137bfbe

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -177
app.py DELETED
@@ -1,177 +0,0 @@
1
- import os
2
- os.environ.setdefault("HOME", "/data")
3
- os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
- os.environ.setdefault("HF_HOME", "/data/.cache")
5
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
- os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
-
8
- from fastapi import FastAPI
9
- from fastapi.responses import JSONResponse
10
- from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
- from huggingface_hub import hf_hub_download
13
- import joblib
14
- import torch
15
- import re
16
- import numpy as np
17
- import pandas as pd
18
- try:
19
- import xgboost as xgb # type: ignore
20
- except Exception:
21
- xgb = None # optional; required if bundle uses xgboost
22
-
23
-
24
- MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
25
- URL_REPO = os.environ.get("URL_REPO", "Perth0603/Random-Forest-Model-for-PhishingDetection")
26
- URL_REPO_TYPE = os.environ.get("URL_REPO_TYPE", "model") # model|space|dataset
27
- # NOTE: set to your artifact filename, e.g. rf_url_phishing_xgboost_bst.joblib
28
- URL_FILENAME = os.environ.get("URL_FILENAME", "rf_url_phishing_xgboost_bst.joblib")
29
-
30
- # Ensure writable cache directory for HF/torch inside Spaces Docker
31
- CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
32
- os.makedirs(CACHE_DIR, exist_ok=True)
33
-
34
- app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
35
-
36
-
37
- class PredictPayload(BaseModel):
38
- inputs: str
39
-
40
-
41
- # Lazy singletons for model/tokenizer
42
- _tokenizer = None
43
- _model = None
44
- _url_bundle = None # holds dict: {model, feature_cols, url_col, label_col, model_type}
45
-
46
-
47
- def _load_url_model():
48
- global _url_bundle
49
- if _url_bundle is None:
50
- # Prefer local artifact if present (e.g., committed into the Space repo)
51
- local_path = os.path.join(os.getcwd(), URL_FILENAME)
52
- if os.path.exists(local_path):
53
- _url_bundle = joblib.load(local_path)
54
- return
55
- # Download model artifact from HF Hub
56
- model_path = hf_hub_download(
57
- repo_id=URL_REPO,
58
- filename=URL_FILENAME,
59
- repo_type=URL_REPO_TYPE,
60
- cache_dir=CACHE_DIR,
61
- )
62
- _url_bundle = joblib.load(model_path)
63
-
64
-
65
- # URL feature engineering (must match training)
66
- _SUSPICIOUS_TOKENS = ["login", "verify", "secure", "update", "bank", "pay", "account", "webscr"]
67
- _ipv4_pattern = re.compile(r'(?:\d{1,3}\.){3}\d{1,3}')
68
-
69
- def _engineer_features(df: pd.DataFrame, url_col: str, feature_cols: list[str] | None = None) -> pd.DataFrame:
70
- s = df[url_col].astype(str)
71
- out = pd.DataFrame(index=df.index)
72
- out['url_len'] = s.str.len().fillna(0)
73
- out['count_dot'] = s.str.count(r'\.')
74
- out['count_hyphen'] = s.str.count('-')
75
- out['count_digit'] = s.str.count(r'\d')
76
- out['count_at'] = s.str.count('@')
77
- out['count_qmark'] = s.str.count('\?')
78
- out['count_eq'] = s.str.count('=')
79
- out['count_slash'] = s.str.count('/')
80
- out['digit_ratio'] = (out['count_digit'] / out['url_len'].replace(0, np.nan)).fillna(0)
81
- out['has_ip'] = s.str.contains(_ipv4_pattern).astype(int)
82
- for tok in _SUSPICIOUS_TOKENS:
83
- out[f'has_{tok}'] = s.str.contains(tok, case=False, regex=False).astype(int)
84
- out['starts_https'] = s.str.startswith('https').astype(int)
85
- out['ends_with_exe'] = s.str.endswith('.exe').astype(int)
86
- out['ends_with_zip'] = s.str.endswith('.zip').astype(int)
87
- return out if feature_cols is None else out[feature_cols]
88
-
89
-
90
- def _load_model():
91
- global _tokenizer, _model
92
- if _tokenizer is None or _model is None:
93
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
94
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
95
- # Warm-up
96
- with torch.no_grad():
97
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
98
-
99
-
100
- @app.get("/")
101
- def root():
102
- return {"status": "ok", "model": MODEL_ID}
103
-
104
-
105
- @app.post("/predict")
106
- def predict(payload: PredictPayload):
107
- try:
108
- _load_model()
109
- with torch.no_grad():
110
- inputs = _tokenizer([payload.inputs], return_tensors="pt", truncation=True, max_length=512)
111
- logits = _model(**inputs).logits
112
- probs = torch.softmax(logits, dim=-1)[0]
113
- score, idx = torch.max(probs, dim=0)
114
- except Exception as e:
115
- return JSONResponse(status_code=500, content={"error": str(e)})
116
-
117
- # Map common ids to labels (kept generic; your config also has these)
118
- id2label = {0: "LEGIT", 1: "PHISH"}
119
- label = id2label.get(int(idx), str(int(idx)))
120
- return {"label": label, "score": float(score)}
121
-
122
-
123
- class PredictUrlPayload(BaseModel):
124
- url: str
125
-
126
-
127
- @app.post("/predict-url")
128
- def predict_url(payload: PredictUrlPayload):
129
- try:
130
- _load_url_model()
131
- bundle = _url_bundle
132
- if not isinstance(bundle, dict) or 'model' not in bundle:
133
- raise RuntimeError("Loaded URL artifact is not a bundle dict with 'model'.")
134
- model = bundle['model']
135
- feature_cols = bundle.get('feature_cols') or []
136
- url_col = bundle.get('url_col') or 'url'
137
- model_type = bundle.get('model_type') or ''
138
-
139
- row = pd.DataFrame({url_col: [payload.url]})
140
- feats = _engineer_features(row, url_col, feature_cols)
141
-
142
- score = None
143
- label = None
144
-
145
- if isinstance(model_type, str) and model_type == 'xgboost_bst':
146
- if xgb is None:
147
- raise RuntimeError("xgboost is not installed but required for this model bundle.")
148
- dmat = xgb.DMatrix(feats)
149
- proba = float(model.predict(dmat)[0])
150
- score = proba
151
- label = "PHISH" if score >= 0.5 else "LEGIT"
152
- elif hasattr(model, "predict_proba"):
153
- proba = model.predict_proba(feats)[0]
154
- if len(proba) == 2:
155
- score = float(proba[1])
156
- label = "PHISH" if score >= 0.5 else "LEGIT"
157
- else:
158
- max_idx = int(np.argmax(proba))
159
- score = float(proba[max_idx])
160
- label = "PHISH" if max_idx == 1 else "LEGIT"
161
- else:
162
- pred = model.predict(feats)[0]
163
- if isinstance(pred, (int, float, np.integer, np.floating)):
164
- label = "PHISH" if int(pred) == 1 else "LEGIT"
165
- score = 1.0 if label == "PHISH" else 0.0
166
- else:
167
- up = str(pred).strip().upper()
168
- if up in ("PHISH", "PHISHING", "MALICIOUS"):
169
- label, score = "PHISH", 1.0
170
- else:
171
- label, score = "LEGIT", 0.0
172
- except Exception as e:
173
- return JSONResponse(status_code=500, content={"error": str(e)})
174
-
175
- return {"label": label, "score": float(score)}
176
-
177
-