Perth0603 commited on
Commit
c4d989f
·
verified ·
1 Parent(s): 5bdaec2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -96
app.py CHANGED
@@ -6,55 +6,42 @@ from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
9
- # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
10
- MODEL_ID = (
11
- os.environ.get("MODEL_ID")
12
- or os.environ.get("HF_MODEL_ID")
13
- or "Perth0603/phishing-email-mobilebert"
14
- )
15
-
16
- # Optional: force mapping when model labels are unclear (binary only).
17
- # Example values:
18
- # FORCE_BINARY_MAPPING="LEGIT,PHISH" (index0=LEGIT, index1=PHISH)
19
- # FORCE_BINARY_MAPPING="PHISH,LEGIT" (index0=PHISH, index1=LEGIT)
20
- FORCE_BINARY_MAPPING = os.environ.get("FORCE_BINARY_MAPPING", "").strip().upper()
21
 
22
- app = FastAPI(title="Phishing Text Classifier", version="1.3.0")
 
 
23
 
 
24
 
25
- # ---------- Schemas ----------
26
  class PredictPayload(BaseModel):
27
  inputs: str
28
 
29
-
30
  class BatchPredictPayload(BaseModel):
31
  inputs: List[str]
32
 
33
-
34
  class LabeledText(BaseModel):
35
  text: str
36
- label: Optional[str] = None # optional ground truth ("0"/"1" or text)
37
-
38
 
39
  class EvalPayload(BaseModel):
40
  samples: List[LabeledText]
41
 
42
-
43
- # ---------- Globals / cache ----------
44
  _tokenizer = None
45
  _model = None
46
  _device = "cpu"
47
 
48
- # Cached normalized mapping/meta
49
- _IDX_PHISH: Optional[int] = None # model output index that corresponds to PHISH
50
- _IDX_LEGIT: Optional[int] = None # model output index that corresponds to LEGIT
51
- _NORM_LABELS_BY_IDX: Optional[List[str]] = None # normalized labels ordered by model indices
52
- _USED_FORCED_MAPPING: bool = False # whether FORCE_BINARY_MAPPING took effect
53
 
54
-
55
- # ---------- Helpers ----------
56
  def _normalize_label(txt: str) -> str:
57
- """Normalize common variants and accept "0"/"1" from CSVs."""
58
  t = (str(txt) if txt is not None else "").strip().upper()
59
  if t in ("1", "PHISHING", "PHISH", "SPAM"):
60
  return "PHISH"
@@ -62,48 +49,38 @@ def _normalize_label(txt: str) -> str:
62
  return "LEGIT"
63
  return t
64
 
65
-
66
  def _try_force_binary_mapping(num_labels: int) -> Tuple[Optional[int], Optional[int], bool]:
67
- """Apply FORCE_BINARY_MAPPING env var if provided and binary."""
68
  if num_labels != 2 or not FORCE_BINARY_MAPPING:
69
  return None, None, False
70
  parts = [p.strip() for p in FORCE_BINARY_MAPPING.split(",") if p.strip()]
71
  if len(parts) != 2 or any(p not in ("PHISH", "LEGIT") for p in parts):
72
  return None, None, False
73
- # parts[0] is index 0, parts[1] is index 1
74
  idx_legit = 0 if parts[0] == "LEGIT" else 1 if parts[1] == "LEGIT" else None
75
  idx_phish = 0 if parts[0] == "PHISH" else 1 if parts[1] == "PHISH" else None
76
  if idx_legit is None or idx_phish is None:
77
  return None, None, False
78
  return idx_phish, idx_legit, True
79
 
80
-
81
  def _load_model():
82
- """Load model/tokenizer and derive stable label mapping."""
83
  global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX, _USED_FORCED_MAPPING
84
-
85
  if _tokenizer is not None and _model is not None:
86
  return
87
 
88
  _device = "cuda" if torch.cuda.is_available() else "cpu"
89
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
90
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
91
  _model.to(_device)
92
  _model.eval()
93
 
94
- # Warm-up
95
  with torch.no_grad():
96
  _ = _model(
97
- **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
98
- .to(_device)
99
  ).logits
100
 
101
- # Derive normalized labels per index
102
  id2label = getattr(_model.config, "id2label", {}) or {}
103
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
104
  _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
105
 
106
- # 1) Try explicit indices from normalized labels
107
  _IDX_PHISH = None
108
  _IDX_LEGIT = None
109
  try:
@@ -115,94 +92,73 @@ def _load_model():
115
  except ValueError:
116
  pass
117
 
118
- # 2) If still unknown and binary, allow forced mapping
119
  _USED_FORCED_MAPPING = False
120
  if (_IDX_PHISH is None or _IDX_LEGIT is None) and num_labels == 2:
121
  fp, fl, used = _try_force_binary_mapping(num_labels)
122
  if used:
123
  _IDX_PHISH, _IDX_LEGIT = fp, fl
124
  _USED_FORCED_MAPPING = True
 
125
 
126
- # 3) If still unknown, we keep them None and ONLY return model-native labels.
127
- # (不进行臆测,避免再次搞反)
128
-
129
-
130
- def _postprocess_batch_logits(texts: List[str]) -> List[Dict]:
131
- """Compute predictions + provide robust, unambiguous fields for UI."""
132
  _load_model()
133
  if not texts:
134
  return []
135
 
136
- enc = _tokenizer(
137
- texts,
138
- return_tensors="pt",
139
- padding=True,
140
- truncation=True,
141
- max_length=512,
142
- )
143
  enc = {k: v.to(_device) for k, v in enc.items()}
144
 
145
  with torch.no_grad():
146
  logits = _model(**enc).logits
147
- probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
148
 
149
  id2label = getattr(_model.config, "id2label", {}) or {}
150
  labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
151
  labels_by_idx_norm = [_normalize_label(x) for x in labels_by_idx_raw]
152
 
153
- outputs: List[Dict] = []
154
  for i in range(probs.shape[0]):
155
  p = probs[i]
156
  idx = int(torch.argmax(p).item())
157
 
158
  raw_label = labels_by_idx_raw[idx]
159
  norm_label = labels_by_idx_norm[idx]
160
-
161
- # normalized probs dict
162
  prob_map = {labels_by_idx_norm[j]: float(p[j].item()) for j in range(len(labels_by_idx_norm))}
163
 
164
- # Default display (robust): if我们能确定 PHISH/LEGIT 下标,就用它;否则用norm_label回退
165
- can_map_dataset = (_IDX_PHISH is not None and _IDX_LEGIT is not None)
166
- if can_map_dataset:
167
  phish_prob = float(p[_IDX_PHISH].item())
168
  legit_prob = float(p[_IDX_LEGIT].item())
169
  is_phish = phish_prob >= legit_prob
170
- dataset_label = "1" if is_phish else "0" # 按你的数据集约定:1=PHISH, 0=LEGIT
171
  display_label = "phishing" if is_phish else "legitimate"
172
  probs_by_dataset = {"1": phish_prob, "0": legit_prob}
173
  else:
174
- # 回退策略:用当前pred的规范化标签
175
  is_phish = (norm_label == "PHISH")
176
  dataset_label = "1" if is_phish else "0"
177
  display_label = "phishing" if is_phish else "legitimate"
178
- probs_by_dataset = None # unknown mapping
179
-
180
- outputs.append(
181
- {
182
- # —— 建议前端优先使用这三个字段,不会搞反 ——
183
- "is_phish": is_phish,
184
- "dataset_label": dataset_label, # "1"=PHISH, "0"=LEGIT
185
- "display_label": display_label, # "phishing"/"legitimate"
186
-
187
- # —— 诊断/兼容字段 ——
188
- "label": norm_label, # 规范化后的(PHISH/LEGIT/未知)
189
- "raw_label": raw_label, # 来自 model.config.id2label
190
- "score": float(p[idx].item()), # argmax 概率
191
- "probs": prob_map, # 规范化名 -> 概率
192
- "predicted_index": idx, # 模型 argmax 下标
193
- "predicted_dataset_label": (1 if is_phish else 0), # int,等价于上面的字符串
194
- "probs_by_dataset_label": probs_by_dataset,
195
- }
196
- )
197
-
198
- return outputs
199
-
200
-
201
- # ---------- Routes ----------
202
  @app.get("/")
203
  def root():
204
- return {"status": "ok", "model": MODEL_ID}
205
-
206
 
207
  @app.get("/debug/labels")
208
  def debug_labels():
@@ -217,7 +173,6 @@ def debug_labels():
217
  "idx_legit": _IDX_LEGIT,
218
  }
219
 
220
-
221
  @app.get("/debug/mapping")
222
  def debug_mapping():
223
  _load_model()
@@ -231,17 +186,49 @@ def debug_mapping():
231
  "idx_legit": _IDX_LEGIT,
232
  }
233
 
234
-
235
  @app.post("/predict")
236
  def predict(payload: PredictPayload):
237
  try:
238
- res = _postprocess_batch_logits([payload.inputs])
239
- return res[0]
240
  except Exception as e:
241
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
242
 
243
-
244
  @app.post("/predict-batch")
245
  def predict_batch(payload: BatchPredictPayload):
246
  try:
247
- return _postprocess_batch_logits(payload.inp_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
9
+ # ====== 模型来源 ======
10
+ # 默认从本地目录加载(你上传的文件在 /mnt/data)
11
+ MODEL_DIR = os.environ.get("MODEL_DIR", "/mnt/data")
 
 
 
 
 
 
 
 
 
12
 
13
+ # 可选:当模型没写清标签且为二分类时,强制指定顺序(这里通常不需要)
14
+ # 例:FORCE_BINARY_MAPPING="LEGIT,PHISH" 或 "PHISH,LEGIT"
15
+ FORCE_BINARY_MAPPING = os.environ.get("FORCE_BINARY_MAPPING", "").strip().upper()
16
 
17
+ app = FastAPI(title="Phishing Text Classifier", version="1.3.1")
18
 
19
+ # ====== Schemas ======
20
  class PredictPayload(BaseModel):
21
  inputs: str
22
 
 
23
  class BatchPredictPayload(BaseModel):
24
  inputs: List[str]
25
 
 
26
  class LabeledText(BaseModel):
27
  text: str
28
+ label: Optional[str] = None # "0"/"1" 或文本
 
29
 
30
  class EvalPayload(BaseModel):
31
  samples: List[LabeledText]
32
 
33
+ # ====== Globals ======
 
34
  _tokenizer = None
35
  _model = None
36
  _device = "cpu"
37
 
38
+ _IDX_PHISH: Optional[int] = None
39
+ _IDX_LEGIT: Optional[int] = None
40
+ _NORM_LABELS_BY_IDX: Optional[List[str]] = None
41
+ _USED_FORCED_MAPPING: bool = False
 
42
 
43
+ # ====== Helpers ======
 
44
  def _normalize_label(txt: str) -> str:
 
45
  t = (str(txt) if txt is not None else "").strip().upper()
46
  if t in ("1", "PHISHING", "PHISH", "SPAM"):
47
  return "PHISH"
 
49
  return "LEGIT"
50
  return t
51
 
 
52
  def _try_force_binary_mapping(num_labels: int) -> Tuple[Optional[int], Optional[int], bool]:
 
53
  if num_labels != 2 or not FORCE_BINARY_MAPPING:
54
  return None, None, False
55
  parts = [p.strip() for p in FORCE_BINARY_MAPPING.split(",") if p.strip()]
56
  if len(parts) != 2 or any(p not in ("PHISH", "LEGIT") for p in parts):
57
  return None, None, False
 
58
  idx_legit = 0 if parts[0] == "LEGIT" else 1 if parts[1] == "LEGIT" else None
59
  idx_phish = 0 if parts[0] == "PHISH" else 1 if parts[1] == "PHISH" else None
60
  if idx_legit is None or idx_phish is None:
61
  return None, None, False
62
  return idx_phish, idx_legit, True
63
 
 
64
  def _load_model():
 
65
  global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX, _USED_FORCED_MAPPING
 
66
  if _tokenizer is not None and _model is not None:
67
  return
68
 
69
  _device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
71
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
72
  _model.to(_device)
73
  _model.eval()
74
 
 
75
  with torch.no_grad():
76
  _ = _model(
77
+ **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512).to(_device)
 
78
  ).logits
79
 
 
80
  id2label = getattr(_model.config, "id2label", {}) or {}
81
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
82
  _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
83
 
 
84
  _IDX_PHISH = None
85
  _IDX_LEGIT = None
86
  try:
 
92
  except ValueError:
93
  pass
94
 
 
95
  _USED_FORCED_MAPPING = False
96
  if (_IDX_PHISH is None or _IDX_LEGIT is None) and num_labels == 2:
97
  fp, fl, used = _try_force_binary_mapping(num_labels)
98
  if used:
99
  _IDX_PHISH, _IDX_LEGIT = fp, fl
100
  _USED_FORCED_MAPPING = True
101
+ # 你的模型文件已经写明:0=LEGIT, 1=PHISH,通常这里会自动识别出来。
102
 
103
+ def _postprocess(texts: List[str]) -> List[Dict]:
 
 
 
 
 
104
  _load_model()
105
  if not texts:
106
  return []
107
 
108
+ enc = _tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
 
109
  enc = {k: v.to(_device) for k, v in enc.items()}
110
 
111
  with torch.no_grad():
112
  logits = _model(**enc).logits
113
+ probs = torch.softmax(logits, dim=-1)
114
 
115
  id2label = getattr(_model.config, "id2label", {}) or {}
116
  labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
117
  labels_by_idx_norm = [_normalize_label(x) for x in labels_by_idx_raw]
118
 
119
+ outs: List[Dict] = []
120
  for i in range(probs.shape[0]):
121
  p = probs[i]
122
  idx = int(torch.argmax(p).item())
123
 
124
  raw_label = labels_by_idx_raw[idx]
125
  norm_label = labels_by_idx_norm[idx]
 
 
126
  prob_map = {labels_by_idx_norm[j]: float(p[j].item()) for j in range(len(labels_by_idx_norm))}
127
 
128
+ # —— 核心:用明确的下标来给出“数据集标签”和“UI标签”
129
+ can_map = (_IDX_PHISH is not None and _IDX_LEGIT is not None)
130
+ if can_map:
131
  phish_prob = float(p[_IDX_PHISH].item())
132
  legit_prob = float(p[_IDX_LEGIT].item())
133
  is_phish = phish_prob >= legit_prob
134
+ dataset_label = "1" if is_phish else "0" # 1=PHISH, 0=LEGIT
135
  display_label = "phishing" if is_phish else "legitimate"
136
  probs_by_dataset = {"1": phish_prob, "0": legit_prob}
137
  else:
138
+ # 回退:用规范化标签
139
  is_phish = (norm_label == "PHISH")
140
  dataset_label = "1" if is_phish else "0"
141
  display_label = "phishing" if is_phish else "legitimate"
142
+ probs_by_dataset = None
143
+
144
+ outs.append({
145
+ "is_phish": is_phish, # 前端用它来显示
146
+ "dataset_label": dataset_label, # "1"=PHISH, "0"=LEGIT
147
+ "display_label": display_label, # "phishing"/"legitimate"
148
+ "label": norm_label, # 规范化(兼容/排错)
149
+ "raw_label": raw_label, # 原始模型标签
150
+ "score": float(p[idx].item()),
151
+ "probs": prob_map,
152
+ "predicted_index": idx,
153
+ "predicted_dataset_label": 1 if is_phish else 0,
154
+ "probs_by_dataset_label": probs_by_dataset,
155
+ })
156
+ return outs
157
+
158
+ # ====== Routes ======
 
 
 
 
 
 
 
159
  @app.get("/")
160
  def root():
161
+ return {"status": "ok", "model_dir": MODEL_DIR}
 
162
 
163
  @app.get("/debug/labels")
164
  def debug_labels():
 
173
  "idx_legit": _IDX_LEGIT,
174
  }
175
 
 
176
  @app.get("/debug/mapping")
177
  def debug_mapping():
178
  _load_model()
 
186
  "idx_legit": _IDX_LEGIT,
187
  }
188
 
 
189
  @app.post("/predict")
190
  def predict(payload: PredictPayload):
191
  try:
192
+ return _postprocess([payload.inputs])[0]
 
193
  except Exception as e:
194
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
195
 
 
196
  @app.post("/predict-batch")
197
  def predict_batch(payload: BatchPredictPayload):
198
  try:
199
+ return _postprocess(payload.inputs)
200
+ except Exception as e:
201
+ raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
202
+
203
+ @app.post("/evaluate")
204
+ def evaluate(payload: EvalPayload):
205
+ try:
206
+ texts = [s.text for s in payload.samples]
207
+ gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
208
+ preds = _postprocess(texts)
209
+
210
+ total = len(preds)
211
+ correct = 0
212
+ per_class: Dict[str, Dict[str, int]] = {}
213
+
214
+ for gt, pr in zip(gts, preds):
215
+ pred_norm = "PHISH" if pr["is_phish"] else "LEGIT"
216
+ if gt is not None:
217
+ correct += int(gt == pred_norm)
218
+ per_class.setdefault(gt, {"tp": 0, "count": 0})
219
+ per_class[gt]["count"] += 1
220
+ if gt == pred_norm:
221
+ per_class[gt]["tp"] += 1
222
+
223
+ has_gts = any(gt is not None for gt in gts)
224
+ denom = sum(1 for gt in gts if gt is not None)
225
+ acc = (correct / denom) if (has_gts and denom > 0) else None
226
+
227
+ return {"accuracy": acc, "total": total, "predictions": preds, "per_class": per_class}
228
+ except Exception as e:
229
+ raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
230
+
231
+ if __name__ == "__main__":
232
+ # 启动:uvicorn app:app --host 0.0.0.0 --port 8000 --reload
233
+ import uvicorn
234
+ uvicorn.run(app, host="0.0.0.0", port=8000)