Perth0603 commited on
Commit
9b309b6
·
verified ·
1 Parent(s): 2b92082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -110
app.py CHANGED
@@ -13,18 +13,7 @@ MODEL_ID = (
13
  or "Perth0603/phishing-email-mobilebert"
14
  )
15
 
16
- # =========================
17
- # 数据集 0/1 映射的可配置开关
18
- # =========================
19
- # 如果你的 CSV 中 1=PHISH,0=LEGIT(常见约定),保持默认即可
20
- # 如果你的 CSV 中 0=PHISH,1=LEGIT,请把 DATASET_PHISH_VALUE 设为 "0"
21
- DATASET_PHISH_VALUE = (os.environ.get("DATASET_PHISH_VALUE") or "1").strip()
22
- if DATASET_PHISH_VALUE not in {"0", "1"}:
23
- DATASET_PHISH_VALUE = "1" # 容错:非法值时回退到默认
24
-
25
- DATASET_LEGIT_VALUE = "0" if DATASET_PHISH_VALUE == "1" else "1"
26
-
27
- app = FastAPI(title="Phishing Text Classifier", version="1.3.0")
28
 
29
 
30
  class PredictPayload(BaseModel):
@@ -37,7 +26,7 @@ class BatchPredictPayload(BaseModel):
37
 
38
  class LabeledText(BaseModel):
39
  text: str
40
- label: Optional[str] = None # optional ground truth for quick eval (accepts "0"/"1" or text)
41
 
42
 
43
  class EvalPayload(BaseModel):
@@ -49,47 +38,25 @@ _model = None
49
  _device = "cpu"
50
 
51
  # Cached normalized mapping/meta
52
- _IDX_PHISH = None # model output index that corresponds to PHISH
53
- _IDX_LEGIT = None # model output index that corresponds to LEGIT
54
  _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
55
 
56
 
57
  def _normalize_label_text_only(txt: str) -> str:
58
  """
59
- 仅做文字标准化,不解读 "0"/"1"。
60
- 用于模型 id2label -> 统一为 PHISH/LEGIT。
61
  """
62
  t = (str(txt) if txt is not None else "").strip().upper()
63
  if t in ("PHISHING", "PHISH", "SPAM"):
64
  return "PHISH"
65
  if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
66
  return "LEGIT"
 
67
  return t
68
 
69
 
70
- def _normalize_label_from_dataset(txt: str) -> Optional[str]:
71
- """
72
- 把来自 CSV 的 "0"/"1" 或文字标签,统一成 PHISH/LEGIT。
73
- 这里会按 DATASET_PHISH_VALUE/LEGIT_VALUE 来解释 "0"/"1"。
74
- 返回 None 表示无法识别(比如空)。
75
- """
76
- if txt is None:
77
- return None
78
- t = str(txt).strip().upper()
79
- if t in ("0", "1"):
80
- if t == DATASET_PHISH_VALUE:
81
- return "PHISH"
82
- else:
83
- return "LEGIT"
84
- # 文字也支持
85
- t2 = _normalize_label_text_only(t)
86
- if t2 in ("PHISH", "LEGIT"):
87
- return t2
88
- return None
89
-
90
-
91
  def _load_model():
92
- global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX
93
 
94
  if _tokenizer is None or _model is None:
95
  _device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -98,36 +65,29 @@ def _load_model():
98
  _model.to(_device)
99
  _model.eval() # important: disable dropout etc.
100
 
101
- # Warm-up
102
  with torch.no_grad():
103
  _ = _model(
104
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
105
  .to(_device)
106
  ).logits
107
 
108
- # 读取并标准化模型标签(按索引顺序)
109
  id2label = getattr(_model.config, "id2label", {}) or {}
110
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
111
  _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
112
 
113
- # 找出 PHISH/LEGIT 在 logits 中的索引
114
- try:
115
- _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
116
- except ValueError:
117
- _IDX_PHISH = None
118
- try:
119
- _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
120
- except ValueError:
121
- _IDX_LEGIT = None
122
-
123
- # 若模型没提供可识别的标签,但只有 2 类,给出安全的保守默认(不强行假设)
124
- # 这里不自动假设 0/1 的含义,避免再次反转;保留 None,让下游概率照常返回。
125
- # 你也可以按需启用:
126
- # if _IDX_PHISH is None and _IDX_LEGIT is None and num_labels == 2:
127
- # _IDX_LEGIT, _IDX_PHISH = 0, 1
128
-
129
 
130
  def _predict_texts(texts: List[str]) -> List[Dict]:
 
 
 
 
 
 
 
 
 
131
  _load_model()
132
  if not texts:
133
  return []
@@ -148,39 +108,33 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
148
 
149
  # Use the model’s own mapping
150
  id2label = getattr(_model.config, "id2label", None) or {}
151
- labels_by_idx = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(probs.shape[-1])]
 
 
152
 
153
  outputs: List[Dict] = []
154
  for i in range(probs.shape[0]):
155
  p = probs[i]
156
  idx = int(torch.argmax(p).item())
157
- norm_label = labels_by_idx[idx] # 已标准化为 PHISH/LEGIT 或原样回传
158
 
159
- # 构建(标准化后的)各类概率映射
 
 
 
 
160
  prob_map: Dict[str, float] = {}
161
- for j, lbl in enumerate(labels_by_idx):
162
- key = lbl if lbl in ("PHISH", "LEGIT") else f"CLASS_{j}"
163
  prob_map[key] = float(p[j].item())
164
 
165
- # ——把预测映射回你的 CSV 0/1——
166
- # 只有在我们确实知道哪个 index 是 PHISH / LEGIT 时才赋值;否则返回 None,避免误导
167
- ds_label: Optional[int] = None
168
- probs_by_dataset: Optional[Dict[str, float]] = None
169
- if _IDX_PHISH is not None and _IDX_LEGIT is not None:
170
- ds_label = int(DATASET_PHISH_VALUE) if idx == _IDX_PHISH else int(DATASET_LEGIT_VALUE)
171
- probs_by_dataset = {
172
- DATASET_PHISH_VALUE: float(p[_IDX_PHISH].item()), # 数据集里代表 PHISH 的数值("0" 或 "1")
173
- DATASET_LEGIT_VALUE: float(p[_IDX_LEGIT].item()), # 数据集里代表 LEGIT 的数值
174
- }
175
-
176
  outputs.append(
177
  {
178
- "label": norm_label if norm_label in ("PHISH", "LEGIT") else norm_label, # 文字结果
179
- "score": float(p[idx].item()), # max class probability
180
- "probs": prob_map, # 每类概率(键为 PHISH/LEGIT CLASS_k)
181
- "predicted_index": idx, # 模型 argmax 索引
182
- "predicted_dataset_label": ds_label, # 用你的数据集 0/1 表示的预测(对齐到 DATASET_*_VALUE)
183
- "probs_by_dataset_label": probs_by_dataset,
184
  }
185
  )
186
 
@@ -189,13 +143,11 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
189
 
190
  @app.get("/")
191
  def root():
 
192
  return {
193
  "status": "ok",
194
  "model": MODEL_ID,
195
- "dataset_mapping": {
196
- "PHISH_VALUE": DATASET_PHISH_VALUE,
197
- "LEGIT_VALUE": DATASET_LEGIT_VALUE,
198
- },
199
  }
200
 
201
 
@@ -208,12 +160,6 @@ def debug_labels():
208
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
209
  "device": _device,
210
  "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
211
- "idx_phish": _IDX_PHISH,
212
- "idx_legit": _IDX_LEGIT,
213
- "dataset_mapping": {
214
- "PHISH_VALUE": DATASET_PHISH_VALUE,
215
- "LEGIT_VALUE": DATASET_LEGIT_VALUE,
216
- },
217
  }
218
 
219
 
@@ -238,21 +184,12 @@ def predict_batch(payload: BatchPredictPayload):
238
  def evaluate(payload: EvalPayload):
239
  """
240
  Quick on-the-spot test with provided labeled samples.
241
-
242
- Request body:
243
- {
244
- "samples": [
245
- {"text": "Your parcel is held...", "label": "PHISH"}, # or "0"/"1"(按你的数据集约定)
246
- {"text": "Lunch at 12?", "label": "LEGIT"} # or "0"/"1"
247
- ]
248
- }
249
-
250
- Returns accuracy and per-class counts (labels normalized to PHISH/LEGIT).
251
  """
252
  try:
253
  texts = [s.text for s in payload.samples]
254
- # 这里用数据集映射把 "0"/"1" 转成人类可读的 PHISH/LEGIT
255
- gts = [_normalize_label_from_dataset(s.label) if s.label is not None else None for s in payload.samples]
256
  preds = _predict_texts(texts)
257
 
258
  total = len(preds)
@@ -260,8 +197,8 @@ def evaluate(payload: EvalPayload):
260
  per_class: Dict[str, Dict[str, int]] = {}
261
 
262
  for gt, pr in zip(gts, preds):
263
- pred_label = pr["label"] if pr["label"] in ("PHISH", "LEGIT") else None
264
- if gt is not None and pred_label is not None:
265
  correct += int(gt == pred_label)
266
  per_class.setdefault(gt, {"tp": 0, "count": 0})
267
  per_class[gt]["count"] += 1
@@ -269,18 +206,13 @@ def evaluate(payload: EvalPayload):
269
  per_class[gt]["tp"] += 1
270
 
271
  has_gts = any(gt is not None for gt in gts)
272
- denom = sum(1 for gt in gts if gt is not None)
273
- acc = (correct / denom) if (has_gts and denom > 0) else None
274
 
275
  return {
276
- "accuracy": acc, # None if no ground truths provided
277
  "total": total,
278
  "predictions": preds,
279
  "per_class": per_class,
280
- "dataset_mapping": {
281
- "PHISH_VALUE": DATASET_PHISH_VALUE,
282
- "LEGIT_VALUE": DATASET_LEGIT_VALUE,
283
- },
284
  }
285
  except Exception as e:
286
  raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
 
13
  or "Perth0603/phishing-email-mobilebert"
14
  )
15
 
16
+ app = FastAPI(title="Phishing Text Classifier (Model-Authoritative)", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  class PredictPayload(BaseModel):
 
26
 
27
  class LabeledText(BaseModel):
28
  text: str
29
+ label: Optional[str] = None # optional ground truth for quick eval (accepts text)
30
 
31
 
32
  class EvalPayload(BaseModel):
 
38
  _device = "cpu"
39
 
40
  # Cached normalized mapping/meta
 
 
41
  _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
42
 
43
 
44
  def _normalize_label_text_only(txt: str) -> str:
45
  """
46
+ Normalize model label text to PHISH/LEGIT when possible.
47
+ If unfamiliar, return the uppercased original token.
48
  """
49
  t = (str(txt) if txt is not None else "").strip().upper()
50
  if t in ("PHISHING", "PHISH", "SPAM"):
51
  return "PHISH"
52
  if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
53
  return "LEGIT"
54
+ # keep other label names as-is (uppercased) so we don't force an incorrect mapping
55
  return t
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def _load_model():
59
+ global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX
60
 
61
  if _tokenizer is None or _model is None:
62
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
65
  _model.to(_device)
66
  _model.eval() # important: disable dropout etc.
67
 
68
+ # Warm-up (silent)
69
  with torch.no_grad():
70
  _ = _model(
71
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
72
  .to(_device)
73
  ).logits
74
 
75
+ # Read and normalize model labels (by index)
76
  id2label = getattr(_model.config, "id2label", {}) or {}
77
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
78
  _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def _predict_texts(texts: List[str]) -> List[Dict]:
82
+ """
83
+ Predict and return strictly model-authoritative outputs:
84
+ - label: normalized model label (PHISH/LEGIT or other model label uppercased)
85
+ - raw_label: original id2label string from model.config
86
+ - is_phish: boolean derived from normalized label (True if normalized == "PHISH")
87
+ - score: probability of predicted class
88
+ - probs: dict of normalized label -> probability (or CLASS_i keys if unknown)
89
+ - predicted_index: argmax index
90
+ """
91
  _load_model()
92
  if not texts:
93
  return []
 
108
 
109
  # Use the model’s own mapping
110
  id2label = getattr(_model.config, "id2label", None) or {}
111
+ labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
112
+ # normalized labels where possible
113
+ labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
114
 
115
  outputs: List[Dict] = []
116
  for i in range(probs.shape[0]):
117
  p = probs[i]
118
  idx = int(torch.argmax(p).item())
 
119
 
120
+ raw_label = labels_by_idx_raw[idx]
121
+ norm_label = labels_by_idx_norm[idx] # normalized where possible
122
+
123
+ # Build probability map keyed by normalized labels when available,
124
+ # otherwise fallback to CLASS_i keys to avoid collision
125
  prob_map: Dict[str, float] = {}
126
+ for j, lbl_norm in enumerate(labels_by_idx_norm):
127
+ key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
128
  prob_map[key] = float(p[j].item())
129
 
 
 
 
 
 
 
 
 
 
 
 
130
  outputs.append(
131
  {
132
+ "label": norm_label, # authoritative label (model-driven, normalized)
133
+ "raw_label": raw_label, # original model id2label value
134
+ "is_phish": True if norm_label == "PHISH" else False,
135
+ "score": float(p[idx].item()), # probability of predicted class
136
+ "probs": prob_map, # per-class probabilities (keys normalized or CLASS_i)
137
+ "predicted_index": idx,
138
  }
139
  )
140
 
 
143
 
144
  @app.get("/")
145
  def root():
146
+ _load_model()
147
  return {
148
  "status": "ok",
149
  "model": MODEL_ID,
150
+ "note": "This service returns predictions exactly as the model decides (label derived from model.config.id2label). Frontend should use `label` or `is_phish` as authority."
 
 
 
151
  }
152
 
153
 
 
160
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
161
  "device": _device,
162
  "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
 
 
 
 
 
 
163
  }
164
 
165
 
 
184
  def evaluate(payload: EvalPayload):
185
  """
186
  Quick on-the-spot test with provided labeled samples.
187
+ The provided labels are interpreted as text labels (PHISH/LEGIT/etc.) — evaluation is done
188
+ by comparing normalized GT text to model's normalized prediction (no 0/1 dataset mapping applied).
 
 
 
 
 
 
 
 
189
  """
190
  try:
191
  texts = [s.text for s in payload.samples]
192
+ gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
 
193
  preds = _predict_texts(texts)
194
 
195
  total = len(preds)
 
197
  per_class: Dict[str, Dict[str, int]] = {}
198
 
199
  for gt, pr in zip(gts, preds):
200
+ pred_label = pr["label"]
201
+ if gt is not None:
202
  correct += int(gt == pred_label)
203
  per_class.setdefault(gt, {"tp": 0, "count": 0})
204
  per_class[gt]["count"] += 1
 
206
  per_class[gt]["tp"] += 1
207
 
208
  has_gts = any(gt is not None for gt in gts)
209
+ acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
 
210
 
211
  return {
212
+ "accuracy": acc,
213
  "total": total,
214
  "predictions": preds,
215
  "per_class": per_class,
 
 
 
 
216
  }
217
  except Exception as e:
218
  raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")