Perth0603 commited on
Commit
2b92082
·
verified ·
1 Parent(s): 9df8bf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -132
app.py CHANGED
@@ -1,164 +1,203 @@
1
  import os
2
- from typing import List, Optional, Dict, Tuple
3
 
4
  import torch
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
9
- # ====== 模型来源 ======
10
- # 默认从本地目录加载(你上传的文件在 /mnt/data)
11
- MODEL_DIR = os.environ.get("MODEL_DIR", "Perth0603/phishing-email-mobilebert")
 
 
 
12
 
13
- # 可选:当模型没写清标签且为二分类时,强制指定顺序(这里通常不需要)
14
- # 例:FORCE_BINARY_MAPPING="LEGIT,PHISH" "PHISH,LEGIT"
15
- FORCE_BINARY_MAPPING = os.environ.get("FORCE_BINARY_MAPPING", "").strip().upper()
 
 
 
 
 
 
 
 
 
16
 
17
- app = FastAPI(title="Phishing Text Classifier", version="1.3.1")
18
 
19
- # ====== Schemas ======
20
  class PredictPayload(BaseModel):
21
  inputs: str
22
 
 
23
  class BatchPredictPayload(BaseModel):
24
  inputs: List[str]
25
 
 
26
  class LabeledText(BaseModel):
27
  text: str
28
- label: Optional[str] = None # "0"/"1" 或文本
 
29
 
30
  class EvalPayload(BaseModel):
31
  samples: List[LabeledText]
32
 
33
- # ====== Globals ======
34
  _tokenizer = None
35
  _model = None
36
  _device = "cpu"
37
 
38
- _IDX_PHISH: Optional[int] = None
39
- _IDX_LEGIT: Optional[int] = None
40
- _NORM_LABELS_BY_IDX: Optional[List[str]] = None
41
- _USED_FORCED_MAPPING: bool = False
42
 
43
- # ====== Helpers ======
44
- def _normalize_label(txt: str) -> str:
 
 
 
 
45
  t = (str(txt) if txt is not None else "").strip().upper()
46
- if t in ("1", "PHISHING", "PHISH", "SPAM"):
47
  return "PHISH"
48
- if t in ("0", "LEGIT", "LEGITIMATE", "SAFE", "HAM"):
49
  return "LEGIT"
50
  return t
51
 
52
- def _try_force_binary_mapping(num_labels: int) -> Tuple[Optional[int], Optional[int], bool]:
53
- if num_labels != 2 or not FORCE_BINARY_MAPPING:
54
- return None, None, False
55
- parts = [p.strip() for p in FORCE_BINARY_MAPPING.split(",") if p.strip()]
56
- if len(parts) != 2 or any(p not in ("PHISH", "LEGIT") for p in parts):
57
- return None, None, False
58
- idx_legit = 0 if parts[0] == "LEGIT" else 1 if parts[1] == "LEGIT" else None
59
- idx_phish = 0 if parts[0] == "PHISH" else 1 if parts[1] == "PHISH" else None
60
- if idx_legit is None or idx_phish is None:
61
- return None, None, False
62
- return idx_phish, idx_legit, True
63
 
64
- def _load_model():
65
- global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX, _USED_FORCED_MAPPING
66
- if _tokenizer is not None and _model is not None:
67
- return
68
-
69
- _device = "cuda" if torch.cuda.is_available() else "cpu"
70
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
71
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
72
- _model.to(_device)
73
- _model.eval()
74
-
75
- with torch.no_grad():
76
- _ = _model(
77
- **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512).to(_device)
78
- ).logits
 
 
 
 
79
 
80
- id2label = getattr(_model.config, "id2label", {}) or {}
81
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
82
- _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
83
 
84
- _IDX_PHISH = None
85
- _IDX_LEGIT = None
86
- try:
87
- _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
88
- except ValueError:
89
- pass
90
- try:
91
- _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
92
- except ValueError:
93
- pass
94
-
95
- _USED_FORCED_MAPPING = False
96
- if (_IDX_PHISH is None or _IDX_LEGIT is None) and num_labels == 2:
97
- fp, fl, used = _try_force_binary_mapping(num_labels)
98
- if used:
99
- _IDX_PHISH, _IDX_LEGIT = fp, fl
100
- _USED_FORCED_MAPPING = True
101
- # 你的模型文件已经写明:0=LEGIT, 1=PHISH,通常这里会自动识别出来。
102
-
103
- def _postprocess(texts: List[str]) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  _load_model()
105
  if not texts:
106
  return []
107
 
108
- enc = _tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
 
 
109
  enc = {k: v.to(_device) for k, v in enc.items()}
110
 
111
  with torch.no_grad():
112
  logits = _model(**enc).logits
113
- probs = torch.softmax(logits, dim=-1)
114
 
115
- id2label = getattr(_model.config, "id2label", {}) or {}
116
- labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
117
- labels_by_idx_norm = [_normalize_label(x) for x in labels_by_idx_raw]
118
 
119
- outs: List[Dict] = []
120
  for i in range(probs.shape[0]):
121
  p = probs[i]
122
  idx = int(torch.argmax(p).item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- raw_label = labels_by_idx_raw[idx]
125
- norm_label = labels_by_idx_norm[idx]
126
- prob_map = {labels_by_idx_norm[j]: float(p[j].item()) for j in range(len(labels_by_idx_norm))}
127
-
128
- # —— 核心:用明确的下标来给出“数据集标签”和“UI标签”
129
- can_map = (_IDX_PHISH is not None and _IDX_LEGIT is not None)
130
- if can_map:
131
- phish_prob = float(p[_IDX_PHISH].item())
132
- legit_prob = float(p[_IDX_LEGIT].item())
133
- is_phish = phish_prob >= legit_prob
134
- dataset_label = "1" if is_phish else "0" # 1=PHISH, 0=LEGIT
135
- display_label = "phishing" if is_phish else "legitimate"
136
- probs_by_dataset = {"1": phish_prob, "0": legit_prob}
137
- else:
138
- # 回退:用规范化标签
139
- is_phish = (norm_label == "PHISH")
140
- dataset_label = "1" if is_phish else "0"
141
- display_label = "phishing" if is_phish else "legitimate"
142
- probs_by_dataset = None
143
-
144
- outs.append({
145
- "is_phish": is_phish, # 前端用它来显示
146
- "dataset_label": dataset_label, # "1"=PHISH, "0"=LEGIT
147
- "display_label": display_label, # "phishing"/"legitimate"
148
- "label": norm_label, # 规范化(兼容/排错)
149
- "raw_label": raw_label, # 原始模型标签
150
- "score": float(p[idx].item()),
151
- "probs": prob_map,
152
- "predicted_index": idx,
153
- "predicted_dataset_label": 1 if is_phish else 0,
154
- "probs_by_dataset_label": probs_by_dataset,
155
- })
156
- return outs
157
-
158
- # ====== Routes ======
159
  @app.get("/")
160
  def root():
161
- return {"status": "ok", "model_dir": MODEL_DIR}
 
 
 
 
 
 
 
 
162
 
163
  @app.get("/debug/labels")
164
  def debug_labels():
@@ -171,64 +210,83 @@ def debug_labels():
171
  "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
172
  "idx_phish": _IDX_PHISH,
173
  "idx_legit": _IDX_LEGIT,
 
 
 
 
174
  }
175
 
176
- @app.get("/debug/mapping")
177
- def debug_mapping():
178
- _load_model()
179
- num_labels = int(getattr(_model.config, "num_labels", 0))
180
- return {
181
- "forced_mapping_env": FORCE_BINARY_MAPPING or None,
182
- "used_forced_mapping": _USED_FORCED_MAPPING,
183
- "num_labels": num_labels,
184
- "can_map_dataset": (_IDX_PHISH is not None and _IDX_LEGIT is not None),
185
- "idx_phish": _IDX_PHISH,
186
- "idx_legit": _IDX_LEGIT,
187
- }
188
 
189
  @app.post("/predict")
190
  def predict(payload: PredictPayload):
191
  try:
192
- return _postprocess([payload.inputs])[0]
 
193
  except Exception as e:
194
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
195
 
 
196
  @app.post("/predict-batch")
197
  def predict_batch(payload: BatchPredictPayload):
198
  try:
199
- return _postprocess(payload.inputs)
200
  except Exception as e:
201
  raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
202
 
 
203
  @app.post("/evaluate")
204
  def evaluate(payload: EvalPayload):
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  try:
206
  texts = [s.text for s in payload.samples]
207
- gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
208
- preds = _postprocess(texts)
 
209
 
210
  total = len(preds)
211
  correct = 0
212
  per_class: Dict[str, Dict[str, int]] = {}
213
 
214
  for gt, pr in zip(gts, preds):
215
- pred_norm = "PHISH" if pr["is_phish"] else "LEGIT"
216
- if gt is not None:
217
- correct += int(gt == pred_norm)
218
  per_class.setdefault(gt, {"tp": 0, "count": 0})
219
  per_class[gt]["count"] += 1
220
- if gt == pred_norm:
221
  per_class[gt]["tp"] += 1
222
 
223
  has_gts = any(gt is not None for gt in gts)
224
  denom = sum(1 for gt in gts if gt is not None)
225
  acc = (correct / denom) if (has_gts and denom > 0) else None
226
 
227
- return {"accuracy": acc, "total": total, "predictions": preds, "per_class": per_class}
 
 
 
 
 
 
 
 
 
228
  except Exception as e:
229
  raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
230
 
 
231
  if __name__ == "__main__":
232
- # 启动:uvicorn app:app --host 0.0.0.0 --port 8000 --reload
233
  import uvicorn
234
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
+ from typing import List, Optional, Dict
3
 
4
  import torch
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
9
+ # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
10
+ MODEL_ID = (
11
+ os.environ.get("MODEL_ID")
12
+ or os.environ.get("HF_MODEL_ID")
13
+ or "Perth0603/phishing-email-mobilebert"
14
+ )
15
 
16
+ # =========================
17
+ # 数据集 0/1 映射的可配置开关
18
+ # =========================
19
+ # 如果你的 CSV 中 1=PHISH,0=LEGIT(常见约定),保持默认即可
20
+ # 如果你的 CSV 中 0=PHISH,1=LEGIT,请把 DATASET_PHISH_VALUE 设为 "0"
21
+ DATASET_PHISH_VALUE = (os.environ.get("DATASET_PHISH_VALUE") or "1").strip()
22
+ if DATASET_PHISH_VALUE not in {"0", "1"}:
23
+ DATASET_PHISH_VALUE = "1" # 容错:非法值时回退到默认
24
+
25
+ DATASET_LEGIT_VALUE = "0" if DATASET_PHISH_VALUE == "1" else "1"
26
+
27
+ app = FastAPI(title="Phishing Text Classifier", version="1.3.0")
28
 
 
29
 
 
30
  class PredictPayload(BaseModel):
31
  inputs: str
32
 
33
+
34
  class BatchPredictPayload(BaseModel):
35
  inputs: List[str]
36
 
37
+
38
  class LabeledText(BaseModel):
39
  text: str
40
+ label: Optional[str] = None # optional ground truth for quick eval (accepts "0"/"1" or text)
41
+
42
 
43
  class EvalPayload(BaseModel):
44
  samples: List[LabeledText]
45
 
46
+
47
  _tokenizer = None
48
  _model = None
49
  _device = "cpu"
50
 
51
+ # Cached normalized mapping/meta
52
+ _IDX_PHISH = None # model output index that corresponds to PHISH
53
+ _IDX_LEGIT = None # model output index that corresponds to LEGIT
54
+ _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
55
 
56
+
57
+ def _normalize_label_text_only(txt: str) -> str:
58
+ """
59
+ 仅做文字标准化,不解读 "0"/"1"。
60
+ 用于模型 id2label -> 统一为 PHISH/LEGIT。
61
+ """
62
  t = (str(txt) if txt is not None else "").strip().upper()
63
+ if t in ("PHISHING", "PHISH", "SPAM"):
64
  return "PHISH"
65
+ if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
66
  return "LEGIT"
67
  return t
68
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def _normalize_label_from_dataset(txt: str) -> Optional[str]:
71
+ """
72
+ 把来自 CSV "0"/"1" 或文字标签,统一成 PHISH/LEGIT。
73
+ 这里会按 DATASET_PHISH_VALUE/LEGIT_VALUE 来解释 "0"/"1"。
74
+ 返回 None 表示无法识别(比如空)。
75
+ """
76
+ if txt is None:
77
+ return None
78
+ t = str(txt).strip().upper()
79
+ if t in ("0", "1"):
80
+ if t == DATASET_PHISH_VALUE:
81
+ return "PHISH"
82
+ else:
83
+ return "LEGIT"
84
+ # 文字也支持
85
+ t2 = _normalize_label_text_only(t)
86
+ if t2 in ("PHISH", "LEGIT"):
87
+ return t2
88
+ return None
89
 
 
 
 
90
 
91
+ def _load_model():
92
+ global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX
93
+
94
+ if _tokenizer is None or _model is None:
95
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
97
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
98
+ _model.to(_device)
99
+ _model.eval() # important: disable dropout etc.
100
+
101
+ # Warm-up
102
+ with torch.no_grad():
103
+ _ = _model(
104
+ **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
105
+ .to(_device)
106
+ ).logits
107
+
108
+ # 读取并标准化模型标签(按索引顺序)
109
+ id2label = getattr(_model.config, "id2label", {}) or {}
110
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
111
+ _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
112
+
113
+ # 找出 PHISH/LEGIT 在 logits 中的索引
114
+ try:
115
+ _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
116
+ except ValueError:
117
+ _IDX_PHISH = None
118
+ try:
119
+ _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
120
+ except ValueError:
121
+ _IDX_LEGIT = None
122
+
123
+ # 若模型没提供可识别的标签,但只有 2 类,给出安全的保守默认(不强行假设)
124
+ # 这里不自动假设 0/1 的含义,避免再次反转;保留 None,让下游概率照常返回。
125
+ # 你也可以按需启用:
126
+ # if _IDX_PHISH is None and _IDX_LEGIT is None and num_labels == 2:
127
+ # _IDX_LEGIT, _IDX_PHISH = 0, 1
128
+
129
+
130
+ def _predict_texts(texts: List[str]) -> List[Dict]:
131
  _load_model()
132
  if not texts:
133
  return []
134
 
135
+ # Tokenize batch
136
+ enc = _tokenizer(
137
+ texts,
138
+ return_tensors="pt",
139
+ padding=True,
140
+ truncation=True,
141
+ max_length=512,
142
+ )
143
  enc = {k: v.to(_device) for k, v in enc.items()}
144
 
145
  with torch.no_grad():
146
  logits = _model(**enc).logits
147
+ probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
148
 
149
+ # Use the model’s own mapping
150
+ id2label = getattr(_model.config, "id2label", None) or {}
151
+ labels_by_idx = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(probs.shape[-1])]
152
 
153
+ outputs: List[Dict] = []
154
  for i in range(probs.shape[0]):
155
  p = probs[i]
156
  idx = int(torch.argmax(p).item())
157
+ norm_label = labels_by_idx[idx] # 已标准化为 PHISH/LEGIT 或原样回传
158
+
159
+ # 构建(标准化后的)各类概率映射
160
+ prob_map: Dict[str, float] = {}
161
+ for j, lbl in enumerate(labels_by_idx):
162
+ key = lbl if lbl in ("PHISH", "LEGIT") else f"CLASS_{j}"
163
+ prob_map[key] = float(p[j].item())
164
+
165
+ # ——把预测映射回你的 CSV 0/1——
166
+ # 只有在我们确实知道哪个 index 是 PHISH / LEGIT 时才赋值;否则返回 None,避免误导
167
+ ds_label: Optional[int] = None
168
+ probs_by_dataset: Optional[Dict[str, float]] = None
169
+ if _IDX_PHISH is not None and _IDX_LEGIT is not None:
170
+ ds_label = int(DATASET_PHISH_VALUE) if idx == _IDX_PHISH else int(DATASET_LEGIT_VALUE)
171
+ probs_by_dataset = {
172
+ DATASET_PHISH_VALUE: float(p[_IDX_PHISH].item()), # 数据集里代表 PHISH 的数值("0" 或 "1")
173
+ DATASET_LEGIT_VALUE: float(p[_IDX_LEGIT].item()), # 数据集里代表 LEGIT 的数值
174
+ }
175
+
176
+ outputs.append(
177
+ {
178
+ "label": norm_label if norm_label in ("PHISH", "LEGIT") else norm_label, # 文字结果
179
+ "score": float(p[idx].item()), # max class probability
180
+ "probs": prob_map, # 每类概率(键为 PHISH/LEGIT 或 CLASS_k)
181
+ "predicted_index": idx, # 模型 argmax 索引
182
+ "predicted_dataset_label": ds_label, # 用你的数据集 0/1 表示的预测(对齐到 DATASET_*_VALUE)
183
+ "probs_by_dataset_label": probs_by_dataset,
184
+ }
185
+ )
186
+
187
+ return outputs
188
+
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  @app.get("/")
191
  def root():
192
+ return {
193
+ "status": "ok",
194
+ "model": MODEL_ID,
195
+ "dataset_mapping": {
196
+ "PHISH_VALUE": DATASET_PHISH_VALUE,
197
+ "LEGIT_VALUE": DATASET_LEGIT_VALUE,
198
+ },
199
+ }
200
+
201
 
202
  @app.get("/debug/labels")
203
  def debug_labels():
 
210
  "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
211
  "idx_phish": _IDX_PHISH,
212
  "idx_legit": _IDX_LEGIT,
213
+ "dataset_mapping": {
214
+ "PHISH_VALUE": DATASET_PHISH_VALUE,
215
+ "LEGIT_VALUE": DATASET_LEGIT_VALUE,
216
+ },
217
  }
218
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  @app.post("/predict")
221
  def predict(payload: PredictPayload):
222
  try:
223
+ res = _predict_texts([payload.inputs])
224
+ return res[0]
225
  except Exception as e:
226
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
227
 
228
+
229
  @app.post("/predict-batch")
230
  def predict_batch(payload: BatchPredictPayload):
231
  try:
232
+ return _predict_texts(payload.inputs)
233
  except Exception as e:
234
  raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
235
 
236
+
237
  @app.post("/evaluate")
238
  def evaluate(payload: EvalPayload):
239
+ """
240
+ Quick on-the-spot test with provided labeled samples.
241
+
242
+ Request body:
243
+ {
244
+ "samples": [
245
+ {"text": "Your parcel is held...", "label": "PHISH"}, # or "0"/"1"(按你的数据集约定)
246
+ {"text": "Lunch at 12?", "label": "LEGIT"} # or "0"/"1"
247
+ ]
248
+ }
249
+
250
+ Returns accuracy and per-class counts (labels normalized to PHISH/LEGIT).
251
+ """
252
  try:
253
  texts = [s.text for s in payload.samples]
254
+ # 这里用数据集映射把 "0"/"1" 转成人类可读的 PHISH/LEGIT
255
+ gts = [_normalize_label_from_dataset(s.label) if s.label is not None else None for s in payload.samples]
256
+ preds = _predict_texts(texts)
257
 
258
  total = len(preds)
259
  correct = 0
260
  per_class: Dict[str, Dict[str, int]] = {}
261
 
262
  for gt, pr in zip(gts, preds):
263
+ pred_label = pr["label"] if pr["label"] in ("PHISH", "LEGIT") else None
264
+ if gt is not None and pred_label is not None:
265
+ correct += int(gt == pred_label)
266
  per_class.setdefault(gt, {"tp": 0, "count": 0})
267
  per_class[gt]["count"] += 1
268
+ if gt == pred_label:
269
  per_class[gt]["tp"] += 1
270
 
271
  has_gts = any(gt is not None for gt in gts)
272
  denom = sum(1 for gt in gts if gt is not None)
273
  acc = (correct / denom) if (has_gts and denom > 0) else None
274
 
275
+ return {
276
+ "accuracy": acc, # None if no ground truths provided
277
+ "total": total,
278
+ "predictions": preds,
279
+ "per_class": per_class,
280
+ "dataset_mapping": {
281
+ "PHISH_VALUE": DATASET_PHISH_VALUE,
282
+ "LEGIT_VALUE": DATASET_LEGIT_VALUE,
283
+ },
284
+ }
285
  except Exception as e:
286
  raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
287
 
288
+
289
  if __name__ == "__main__":
290
+ # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
291
  import uvicorn
292
  uvicorn.run(app, host="0.0.0.0", port=8000)