Perth0603 commited on
Commit
113b42d
·
verified ·
1 Parent(s): 9e472d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -104
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import os
3
  from typing import List, Optional, Dict
4
 
@@ -7,7 +6,6 @@ from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
 
10
-
11
  # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
12
  MODEL_ID = (
13
  os.environ.get("MODEL_ID")
@@ -15,7 +13,7 @@ MODEL_ID = (
15
  or "Perth0603/phishing-email-mobilebert"
16
  )
17
 
18
- app = FastAPI(title="Phishing Text Classifier", version="1.1.0")
19
 
20
 
21
  class PredictPayload(BaseModel):
@@ -28,7 +26,7 @@ class BatchPredictPayload(BaseModel):
28
 
29
  class LabeledText(BaseModel):
30
  text: str
31
- label: Optional[str] = None # optional ground truth for quick eval
32
 
33
 
34
  class EvalPayload(BaseModel):
@@ -39,19 +37,25 @@ _tokenizer = None
39
  _model = None
40
  _device = "cpu"
41
 
 
 
 
 
 
42
 
43
  def _normalize_label(txt: str) -> str:
44
- # Optional: normalize common variants for simpler downstream use
45
- t = (txt or "").strip().upper()
46
- if t in ("PHISHING", "PHISH", "SPAM"):
47
  return "PHISH"
48
- if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
49
  return "LEGIT"
50
  return t
51
 
52
 
53
  def _load_model():
54
- global _tokenizer, _model, _device
 
55
  if _tokenizer is None or _model is None:
56
  _device = "cuda" if torch.cuda.is_available() else "cpu"
57
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -66,6 +70,26 @@ def _load_model():
66
  .to(_device)
67
  ).logits
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def _predict_texts(texts: List[str]) -> List[Dict]:
71
  _load_model()
@@ -98,98 +122,5 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
98
  raw_label = labels_by_idx[idx]
99
  norm_label = _normalize_label(raw_label)
100
 
101
- # Also expose per-label probabilities
102
- prob_map = { _normalize_label(labels_by_idx[j]): float(p[j].item()) for j in range(len(labels_by_idx)) }
103
-
104
- outputs.append(
105
- {
106
- "label": norm_label, # normalized (e.g., PHISH/LEGIT)
107
- "raw_label": raw_label, # from model.config.id2label
108
- "score": float(p[idx].item()), # max class probability
109
- "probs": prob_map, # dict of label -> probability
110
- "predicted_index": idx,
111
- }
112
- )
113
- return outputs
114
-
115
-
116
- @app.get("/")
117
- def root():
118
- return {"status": "ok", "model": MODEL_ID}
119
-
120
-
121
- @app.get("/debug/labels")
122
- def debug_labels():
123
- _load_model()
124
- return {
125
- "id2label": getattr(_model.config, "id2label", {}),
126
- "label2id": getattr(_model.config, "label2id", {}),
127
- "num_labels": int(getattr(_model.config, "num_labels", 0)),
128
- "device": _device,
129
- }
130
-
131
-
132
- @app.post("/predict")
133
- def predict(payload: PredictPayload):
134
- try:
135
- res = _predict_texts([payload.inputs])
136
- return res[0]
137
- except Exception as e:
138
- raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
139
-
140
-
141
- @app.post("/predict-batch")
142
- def predict_batch(payload: BatchPredictPayload):
143
- try:
144
- return _predict_texts(payload.inputs)
145
- except Exception as e:
146
- raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
147
-
148
-
149
- @app.post("/evaluate")
150
- def evaluate(payload: EvalPayload):
151
- """
152
- Quick on-the-spot test with provided labeled samples.
153
- Request body:
154
- {
155
- "samples": [
156
- {"text": "Your parcel is held...", "label": "PHISH"},
157
- {"text": "Lunch at 12?", "label": "LEGIT"}
158
- ]
159
- }
160
- Returns accuracy and per-class counts.
161
- """
162
- try:
163
- texts = [s.text for s in payload.samples]
164
- gts = [(_normalize_label(s.label) if s.label else None) for s in payload.samples]
165
- preds = _predict_texts(texts)
166
-
167
- total = len(preds)
168
- correct = 0
169
- per_class = {}
170
-
171
- for gt, pr in zip(gts, preds):
172
- pred_label = pr["label"]
173
- if gt is not None:
174
- correct += int(gt == pred_label)
175
- per_class.setdefault(gt, {"tp": 0, "count": 0})
176
- per_class[gt]["count"] += 1
177
- if gt == pred_label:
178
- per_class[gt]["tp"] += 1
179
-
180
- acc = (correct / total) if total and any(gt is not None for gt in gts) else None
181
-
182
- return {
183
- "accuracy": acc, # None if no ground truths provided
184
- "total": total,
185
- "predictions": preds,
186
- "per_class": per_class,
187
- }
188
- except Exception as e:
189
- raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
190
-
191
-
192
- if __name__ == "__main__":
193
- # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
194
- import uvicorn
195
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
  import os
2
  from typing import List, Optional, Dict
3
 
 
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
 
9
  # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
10
  MODEL_ID = (
11
  os.environ.get("MODEL_ID")
 
13
  or "Perth0603/phishing-email-mobilebert"
14
  )
15
 
16
+ app = FastAPI(title="Phishing Text Classifier", version="1.2.0")
17
 
18
 
19
  class PredictPayload(BaseModel):
 
26
 
27
  class LabeledText(BaseModel):
28
  text: str
29
+ label: Optional[str] = None # optional ground truth for quick eval (accepts "0"/"1" or text)
30
 
31
 
32
  class EvalPayload(BaseModel):
 
37
  _model = None
38
  _device = "cpu"
39
 
40
+ # Cached normalized mapping/meta
41
+ _IDX_PHISH = None # model output index that corresponds to PHISH
42
+ _IDX_LEGIT = None # model output index that corresponds to LEGIT
43
+ _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
44
+
45
 
46
  def _normalize_label(txt: str) -> str:
47
+ # Normalize common variants and accept "0"/"1" from CSVs
48
+ t = (str(txt) if txt is not None else "").strip().upper()
49
+ if t in ("1", "PHISHING", "PHISH", "SPAM"):
50
  return "PHISH"
51
+ if t in ("0", "LEGIT", "LEGITIMATE", "SAFE", "HAM"):
52
  return "LEGIT"
53
  return t
54
 
55
 
56
  def _load_model():
57
+ global _tokenizer, _model, _device, _IDX_PHISH, _IDX_LEGIT, _NORM_LABELS_BY_IDX
58
+
59
  if _tokenizer is None or _model is None:
60
  _device = "cuda" if torch.cuda.is_available() else "cpu"
61
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
70
  .to(_device)
71
  ).logits
72
 
73
+ # Derive normalized labels per index and cache PHISH/LEGIT indices
74
+ id2label = getattr(_model.config, "id2label", {}) or {}
75
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
76
+ _NORM_LABELS_BY_IDX = [_normalize_label(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
77
+
78
+ # Try to locate PHISH/LEGIT indices explicitly
79
+ try:
80
+ _IDX_PHISH = _NORM_LABELS_BY_IDX.index("PHISH")
81
+ except ValueError:
82
+ _IDX_PHISH = None
83
+ try:
84
+ _IDX_LEGIT = _NORM_LABELS_BY_IDX.index("LEGIT")
85
+ except ValueError:
86
+ _IDX_LEGIT = None
87
+
88
+ # If labels are unknown but binary, you can optionally set a default mapping.
89
+ # Commented out by default to avoid wrong assumptions:
90
+ # if _IDX_PHISH is None and _IDX_LEGIT is None and num_labels == 2:
91
+ # _IDX_LEGIT, _IDX_PHISH = 0, 1 # assumes index 1 = PHISH, index 0 = LEGIT
92
+
93
 
94
  def _predict_texts(texts: List[str]) -> List[Dict]:
95
  _load_model()
 
122
  raw_label = labels_by_idx[idx]
123
  norm_label = _normalize_label(raw_label)
124
 
125
+ # Also expose per-label probabilities (normalized names where possible)
126
+ prob_map = {_normalize_label(labels_by