Perth0603 commited on
Commit
5bdaec2
·
verified ·
1 Parent(s): 72eb3f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -133
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import List, Optional, Dict
3
 
4
  import torch
5
  from fastapi import FastAPI, HTTPException
@@ -13,9 +13,16 @@ MODEL_ID = (
13
  or "Perth0603/phishing-email-mobilebert"
14
  )
15
 
16
- app = FastAPI(title="Phishing Text Classifier", version="1.2.0")
 
 
 
 
17
 
 
18
 
 
 
19
  class PredictPayload(BaseModel):
20
  inputs: str
21
 
@@ -26,25 +33,28 @@ class BatchPredictPayload(BaseModel):
26
 
27
  class LabeledText(BaseModel):
28
  text: str
29
- label: Optional[str] = None # optional ground truth for quick eval (accepts "0"/"1" or text)
30
 
31
 
32
  class EvalPayload(BaseModel):
33
  samples: List[LabeledText]
34
 
35
 
 
36
  _tokenizer = None
37
  _model = None
38
  _device = "cpu"
39
 
40
  # Cached normalized mapping/meta
41
- _IDX_PHISH = None # model output index that corresponds to PHISH
42
- _IDX_LEGIT = None # model output index that corresponds to LEGIT
43
- _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
 
44
 
45
 
 
46
  def _normalize_label(txt: str) -> str:
47
- # Normalize common variants and accept "0"/"1" from CSVs
48
  t = (str(txt) if txt is not None else "").strip().upper()
49
  if t in ("1", "PHISHING", "PHISH", "SPAM"):
50
  return "PHISH"
@@ -53,50 +63,76 @@ def _normalize_label(txt: str) -> str:
53
  return t
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def _load_model():
57
- global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX
58
-
59
- if _tokenizer is None or _model is None:
60
- _device = "cuda" if torch.cuda.is_available() else "cpu"
61
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
62
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
63
- _model.to(_device)
64
- _model.eval() # important: disable dropout etc.
65
-
66
- # Warm-up
67
- with torch.no_grad():
68
- _ = _model(
69
- **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
70
- .to(_device)
71
- ).logits
72
-
73
- # Derive normalized labels per index and cache PHISH/LEGIT indices
74
- id2label = getattr(_model.config, "id2label", {}) or {}
75
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
76
- _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
77
-
78
- # Try to locate PHISH/LEGIT indices explicitly
79
- try:
80
- _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
81
- except ValueError:
82
- _IDX_PHISH = None
83
- try:
84
- _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
85
- except ValueError:
86
- _IDX_LEGIT = None
87
-
88
- # If labels are unknown but binary, you can optionally set a default mapping.
89
- # Commented out by default to avoid wrong assumptions:
90
- # if _IDX_PHISH is None and _IDX_LEGIT is None and num_labels == 2:
91
- # _IDX_LEGIT, _IDX_PHISH = 0, 1 # assumes index 1 = PHISH, index 0 = LEGIT
92
-
93
-
94
- def _predict_texts(texts: List[str]) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
95
  _load_model()
96
  if not texts:
97
  return []
98
 
99
- # Tokenize batch
100
  enc = _tokenizer(
101
  texts,
102
  return_tensors="pt",
@@ -110,45 +146,51 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
110
  logits = _model(**enc).logits
111
  probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
112
 
113
- # Use the model’s own mapping
114
- id2label = getattr(_model.config, "id2label", None) or {}
115
- # Build a stable label list by index
116
- labels_by_idx = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
117
 
118
  outputs: List[Dict] = []
119
  for i in range(probs.shape[0]):
120
  p = probs[i]
121
  idx = int(torch.argmax(p).item())
122
- raw_label = labels_by_idx[idx]
123
- norm_label = _normalize_label(raw_label)
124
-
125
- # Also expose per-label probabilities (normalized names where possible)
126
- prob_map = {_normalize_label(labels_by_idx[j]): float(p[j].item()) for j in range(len(labels_by_idx))}
127
-
128
- # Map to your dataset convention: PHISH=1, LEGIT=0
129
- ds_label = None
130
- if _IDX_PHISH is not None and _IDX_LEGIT is not None:
131
- if idx == _IDX_PHISH:
132
- ds_label = 1
133
- elif idx == _IDX_LEGIT:
134
- ds_label = 0
135
-
136
- # Per-dataset-label probabilities when both indices are known
137
- probs_by_dataset = None
138
- if _IDX_PHISH is not None and _IDX_LEGIT is not None:
139
- probs_by_dataset = {
140
- "1": float(p[_IDX_PHISH].item()), # PHISH
141
- "0": float(p[_IDX_LEGIT].item()), # LEGIT
142
- }
 
143
 
144
  outputs.append(
145
  {
146
- "label": norm_label, # normalized (e.g., PHISH/LEGIT)
147
- "raw_label": raw_label, # from model.config.id2label
148
- "score": float(p[idx].item()), # max class probability
149
- "probs": prob_map, # dict of normalized label -> probability
150
- "predicted_index": idx, # model argmax index
151
- "predicted_dataset_label": ds_label, # 1 for PHISH, 0 for LEGIT (your convention)
 
 
 
 
 
 
152
  "probs_by_dataset_label": probs_by_dataset,
153
  }
154
  )
@@ -156,6 +198,7 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
156
  return outputs
157
 
158
 
 
159
  @app.get("/")
160
  def root():
161
  return {"status": "ok", "model": MODEL_ID}
@@ -175,10 +218,24 @@ def debug_labels():
175
  }
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  @app.post("/predict")
179
  def predict(payload: PredictPayload):
180
  try:
181
- res = _predict_texts([payload.inputs])
182
  return res[0]
183
  except Exception as e:
184
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
@@ -187,58 +244,4 @@ def predict(payload: PredictPayload):
187
  @app.post("/predict-batch")
188
  def predict_batch(payload: BatchPredictPayload):
189
  try:
190
- return _predict_texts(payload.inputs)
191
- except Exception as e:
192
- raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
193
-
194
-
195
- @app.post("/evaluate")
196
- def evaluate(payload: EvalPayload):
197
- """
198
- Quick on-the-spot test with provided labeled samples.
199
-
200
- Request body:
201
- {
202
- "samples": [
203
- {"text": "Your parcel is held...", "label": "PHISH"}, # or "1"
204
- {"text": "Lunch at 12?", "label": "LEGIT"} # or "0"
205
- ]
206
- }
207
-
208
- Returns accuracy and per-class counts.
209
- """
210
- try:
211
- texts = [s.text for s in payload.samples]
212
- gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
213
- preds = _predict_texts(texts)
214
-
215
- total = len(preds)
216
- correct = 0
217
- per_class: Dict[str, Dict[str, int]] = {}
218
-
219
- for gt, pr in zip(gts, preds):
220
- pred_label = pr["label"]
221
- if gt is not None:
222
- correct += int(gt == pred_label)
223
- per_class.setdefault(gt, {"tp": 0, "count": 0})
224
- per_class[gt]["count"] += 1
225
- if gt == pred_label:
226
- per_class[gt]["tp"] += 1
227
-
228
- has_gts = any(gt is not None for gt in gts)
229
- acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
230
-
231
- return {
232
- "accuracy": acc, # None if no ground truths provided
233
- "total": total,
234
- "predictions": preds,
235
- "per_class": per_class,
236
- }
237
- except Exception as e:
238
- raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
239
-
240
-
241
- if __name__ == "__main__":
242
- # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
243
- import uvicorn
244
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
+ from typing import List, Optional, Dict, Tuple
3
 
4
  import torch
5
  from fastapi import FastAPI, HTTPException
 
13
  or "Perth0603/phishing-email-mobilebert"
14
  )
15
 
16
+ # Optional: force mapping when model labels are unclear (binary only).
17
+ # Example values:
18
+ # FORCE_BINARY_MAPPING="LEGIT,PHISH" (index0=LEGIT, index1=PHISH)
19
+ # FORCE_BINARY_MAPPING="PHISH,LEGIT" (index0=PHISH, index1=LEGIT)
20
+ FORCE_BINARY_MAPPING = os.environ.get("FORCE_BINARY_MAPPING", "").strip().upper()
21
 
22
+ app = FastAPI(title="Phishing Text Classifier", version="1.3.0")
23
 
24
+
25
+ # ---------- Schemas ----------
26
  class PredictPayload(BaseModel):
27
  inputs: str
28
 
 
33
 
34
  class LabeledText(BaseModel):
35
  text: str
36
+ label: Optional[str] = None # optional ground truth ("0"/"1" or text)
37
 
38
 
39
  class EvalPayload(BaseModel):
40
  samples: List[LabeledText]
41
 
42
 
43
+ # ---------- Globals / cache ----------
44
  _tokenizer = None
45
  _model = None
46
  _device = "cpu"
47
 
48
  # Cached normalized mapping/meta
49
+ _IDX_PHISH: Optional[int] = None # model output index that corresponds to PHISH
50
+ _IDX_LEGIT: Optional[int] = None # model output index that corresponds to LEGIT
51
+ _NORM_LABELS_BY_IDX: Optional[List[str]] = None # normalized labels ordered by model indices
52
+ _USED_FORCED_MAPPING: bool = False # whether FORCE_BINARY_MAPPING took effect
53
 
54
 
55
+ # ---------- Helpers ----------
56
  def _normalize_label(txt: str) -> str:
57
+ """Normalize common variants and accept "0"/"1" from CSVs."""
58
  t = (str(txt) if txt is not None else "").strip().upper()
59
  if t in ("1", "PHISHING", "PHISH", "SPAM"):
60
  return "PHISH"
 
63
  return t
64
 
65
 
66
+ def _try_force_binary_mapping(num_labels: int) -> Tuple[Optional[int], Optional[int], bool]:
67
+ """Apply FORCE_BINARY_MAPPING env var if provided and binary."""
68
+ if num_labels != 2 or not FORCE_BINARY_MAPPING:
69
+ return None, None, False
70
+ parts = [p.strip() for p in FORCE_BINARY_MAPPING.split(",") if p.strip()]
71
+ if len(parts) != 2 or any(p not in ("PHISH", "LEGIT") for p in parts):
72
+ return None, None, False
73
+ # parts[0] is index 0, parts[1] is index 1
74
+ idx_legit = 0 if parts[0] == "LEGIT" else 1 if parts[1] == "LEGIT" else None
75
+ idx_phish = 0 if parts[0] == "PHISH" else 1 if parts[1] == "PHISH" else None
76
+ if idx_legit is None or idx_phish is None:
77
+ return None, None, False
78
+ return idx_phish, idx_legit, True
79
+
80
+
81
  def _load_model():
82
+ """Load model/tokenizer and derive stable label mapping."""
83
+ global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX, _USED_FORCED_MAPPING
84
+
85
+ if _tokenizer is not None and _model is not None:
86
+ return
87
+
88
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
90
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
91
+ _model.to(_device)
92
+ _model.eval()
93
+
94
+ # Warm-up
95
+ with torch.no_grad():
96
+ _ = _model(
97
+ **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
98
+ .to(_device)
99
+ ).logits
100
+
101
+ # Derive normalized labels per index
102
+ id2label = getattr(_model.config, "id2label", {}) or {}
103
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
104
+ _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
105
+
106
+ # 1) Try explicit indices from normalized labels
107
+ _IDX_PHISH = None
108
+ _IDX_LEGIT = None
109
+ try:
110
+ _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
111
+ except ValueError:
112
+ pass
113
+ try:
114
+ _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
115
+ except ValueError:
116
+ pass
117
+
118
+ # 2) If still unknown and binary, allow forced mapping
119
+ _USED_FORCED_MAPPING = False
120
+ if (_IDX_PHISH is None or _IDX_LEGIT is None) and num_labels == 2:
121
+ fp, fl, used = _try_force_binary_mapping(num_labels)
122
+ if used:
123
+ _IDX_PHISH, _IDX_LEGIT = fp, fl
124
+ _USED_FORCED_MAPPING = True
125
+
126
+ # 3) If still unknown, we keep them None and ONLY return model-native labels.
127
+ # (不进行臆测,避免再次搞反)
128
+
129
+
130
+ def _postprocess_batch_logits(texts: List[str]) -> List[Dict]:
131
+ """Compute predictions + provide robust, unambiguous fields for UI."""
132
  _load_model()
133
  if not texts:
134
  return []
135
 
 
136
  enc = _tokenizer(
137
  texts,
138
  return_tensors="pt",
 
146
  logits = _model(**enc).logits
147
  probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
148
 
149
+ id2label = getattr(_model.config, "id2label", {}) or {}
150
+ labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
151
+ labels_by_idx_norm = [_normalize_label(x) for x in labels_by_idx_raw]
 
152
 
153
  outputs: List[Dict] = []
154
  for i in range(probs.shape[0]):
155
  p = probs[i]
156
  idx = int(torch.argmax(p).item())
157
+
158
+ raw_label = labels_by_idx_raw[idx]
159
+ norm_label = labels_by_idx_norm[idx]
160
+
161
+ # normalized probs dict
162
+ prob_map = {labels_by_idx_norm[j]: float(p[j].item()) for j in range(len(labels_by_idx_norm))}
163
+
164
+ # Default display (robust): if我们能确定 PHISH/LEGIT 下标,就用它;否则用norm_label回退
165
+ can_map_dataset = (_IDX_PHISH is not None and _IDX_LEGIT is not None)
166
+ if can_map_dataset:
167
+ phish_prob = float(p[_IDX_PHISH].item())
168
+ legit_prob = float(p[_IDX_LEGIT].item())
169
+ is_phish = phish_prob >= legit_prob
170
+ dataset_label = "1" if is_phish else "0" # 按你的数据集约定:1=PHISH, 0=LEGIT
171
+ display_label = "phishing" if is_phish else "legitimate"
172
+ probs_by_dataset = {"1": phish_prob, "0": legit_prob}
173
+ else:
174
+ # 回退策略:用当前pred的规范化标签
175
+ is_phish = (norm_label == "PHISH")
176
+ dataset_label = "1" if is_phish else "0"
177
+ display_label = "phishing" if is_phish else "legitimate"
178
+ probs_by_dataset = None # unknown mapping
179
 
180
  outputs.append(
181
  {
182
+ # —— 建议前端优先使用这三个字段,不会搞反 ——
183
+ "is_phish": is_phish,
184
+ "dataset_label": dataset_label, # "1"=PHISH, "0"=LEGIT
185
+ "display_label": display_label, # "phishing"/"legitimate"
186
+
187
+ # —— 诊断/兼容字段 ——
188
+ "label": norm_label, # 规范化后的(PHISH/LEGIT/未知)
189
+ "raw_label": raw_label, # 来自 model.config.id2label
190
+ "score": float(p[idx].item()), # argmax 概率
191
+ "probs": prob_map, # 规范化名 -> 概率
192
+ "predicted_index": idx, # 模型 argmax 下标
193
+ "predicted_dataset_label": (1 if is_phish else 0), # int,等价于上面的字符串
194
  "probs_by_dataset_label": probs_by_dataset,
195
  }
196
  )
 
198
  return outputs
199
 
200
 
201
+ # ---------- Routes ----------
202
  @app.get("/")
203
  def root():
204
  return {"status": "ok", "model": MODEL_ID}
 
218
  }
219
 
220
 
221
+ @app.get("/debug/mapping")
222
+ def debug_mapping():
223
+ _load_model()
224
+ num_labels = int(getattr(_model.config, "num_labels", 0))
225
+ return {
226
+ "forced_mapping_env": FORCE_BINARY_MAPPING or None,
227
+ "used_forced_mapping": _USED_FORCED_MAPPING,
228
+ "num_labels": num_labels,
229
+ "can_map_dataset": (_IDX_PHISH is not None and _IDX_LEGIT is not None),
230
+ "idx_phish": _IDX_PHISH,
231
+ "idx_legit": _IDX_LEGIT,
232
+ }
233
+
234
+
235
  @app.post("/predict")
236
  def predict(payload: PredictPayload):
237
  try:
238
+ res = _postprocess_batch_logits([payload.inputs])
239
  return res[0]
240
  except Exception as e:
241
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
 
244
  @app.post("/predict-batch")
245
  def predict_batch(payload: BatchPredictPayload):
246
  try:
247
+ return _postprocess_batch_logits(payload.inp_