Perth0603 commited on
Commit
9e472d3
·
verified ·
1 Parent(s): 7690e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -18
app.py CHANGED
@@ -1,31 +1,116 @@
1
- from fastapi import FastAPI
 
 
 
 
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import torch
5
- import os
6
 
7
 
8
- MODEL_ID = os.environ.get("MODEL_ID", "dima806/phishing-email-detection")
9
- app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
 
 
 
 
 
 
10
 
11
 
12
  class PredictPayload(BaseModel):
13
  inputs: str
14
 
15
 
16
- # Lazy singletons for model/tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
17
  _tokenizer = None
18
  _model = None
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def _load_model():
22
- global _tokenizer, _model
23
  if _tokenizer is None or _model is None:
 
24
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
25
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
 
 
 
26
  # Warm-up
27
  with torch.no_grad():
28
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  @app.get("/")
@@ -33,15 +118,78 @@ def root():
33
  return {"status": "ok", "model": MODEL_ID}
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  @app.post("/predict")
37
  def predict(payload: PredictPayload):
38
- _load_model()
39
- with torch.no_grad():
40
- logits = _model(**_tokenizer([payload.inputs], return_tensors="pt")).logits
41
- probs = torch.softmax(logits, dim=-1)[0]
42
- score, idx = torch.max(probs, dim=0)
43
-
44
- # Map common ids to labels (kept generic; your config also has these)
45
- id2label = {0: "LEGIT", 1: "PHISH"}
46
- label = id2label.get(int(idx), str(int(idx)))
47
- return {"label": label, "score": float(score)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ from typing import List, Optional, Dict
4
+
5
+ import torch
6
+ from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
9
 
10
 
11
+ # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
12
+ MODEL_ID = (
13
+ os.environ.get("MODEL_ID")
14
+ or os.environ.get("HF_MODEL_ID")
15
+ or "Perth0603/phishing-email-mobilebert"
16
+ )
17
+
18
+ app = FastAPI(title="Phishing Text Classifier", version="1.1.0")
19
 
20
 
21
  class PredictPayload(BaseModel):
22
  inputs: str
23
 
24
 
25
+ class BatchPredictPayload(BaseModel):
26
+ inputs: List[str]
27
+
28
+
29
+ class LabeledText(BaseModel):
30
+ text: str
31
+ label: Optional[str] = None # optional ground truth for quick eval
32
+
33
+
34
+ class EvalPayload(BaseModel):
35
+ samples: List[LabeledText]
36
+
37
+
38
  _tokenizer = None
39
  _model = None
40
+ _device = "cpu"
41
+
42
+
43
+ def _normalize_label(txt: str) -> str:
44
+ # Optional: normalize common variants for simpler downstream use
45
+ t = (txt or "").strip().upper()
46
+ if t in ("PHISHING", "PHISH", "SPAM"):
47
+ return "PHISH"
48
+ if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
49
+ return "LEGIT"
50
+ return t
51
 
52
 
53
  def _load_model():
54
+ global _tokenizer, _model, _device
55
  if _tokenizer is None or _model is None:
56
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
57
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
58
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
59
+ _model.to(_device)
60
+ _model.eval() # important: disable dropout etc.
61
+
62
  # Warm-up
63
  with torch.no_grad():
64
+ _ = _model(
65
+ **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
66
+ .to(_device)
67
+ ).logits
68
+
69
+
70
+ def _predict_texts(texts: List[str]) -> List[Dict]:
71
+ _load_model()
72
+ if not texts:
73
+ return []
74
+
75
+ # Tokenize batch
76
+ enc = _tokenizer(
77
+ texts,
78
+ return_tensors="pt",
79
+ padding=True,
80
+ truncation=True,
81
+ max_length=512,
82
+ )
83
+ enc = {k: v.to(_device) for k, v in enc.items()}
84
+
85
+ with torch.no_grad():
86
+ logits = _model(**enc).logits
87
+ probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
88
+
89
+ # Use the model’s own mapping
90
+ id2label = getattr(_model.config, "id2label", None) or {}
91
+ # Build a stable label list by index
92
+ labels_by_idx = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
93
+
94
+ outputs: List[Dict] = []
95
+ for i in range(probs.shape[0]):
96
+ p = probs[i]
97
+ idx = int(torch.argmax(p).item())
98
+ raw_label = labels_by_idx[idx]
99
+ norm_label = _normalize_label(raw_label)
100
+
101
+ # Also expose per-label probabilities
102
+ prob_map = { _normalize_label(labels_by_idx[j]): float(p[j].item()) for j in range(len(labels_by_idx)) }
103
+
104
+ outputs.append(
105
+ {
106
+ "label": norm_label, # normalized (e.g., PHISH/LEGIT)
107
+ "raw_label": raw_label, # from model.config.id2label
108
+ "score": float(p[idx].item()), # max class probability
109
+ "probs": prob_map, # dict of label -> probability
110
+ "predicted_index": idx,
111
+ }
112
+ )
113
+ return outputs
114
 
115
 
116
  @app.get("/")
 
118
  return {"status": "ok", "model": MODEL_ID}
119
 
120
 
121
+ @app.get("/debug/labels")
122
+ def debug_labels():
123
+ _load_model()
124
+ return {
125
+ "id2label": getattr(_model.config, "id2label", {}),
126
+ "label2id": getattr(_model.config, "label2id", {}),
127
+ "num_labels": int(getattr(_model.config, "num_labels", 0)),
128
+ "device": _device,
129
+ }
130
+
131
+
132
  @app.post("/predict")
133
  def predict(payload: PredictPayload):
134
+ try:
135
+ res = _predict_texts([payload.inputs])
136
+ return res[0]
137
+ except Exception as e:
138
+ raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
139
+
140
+
141
+ @app.post("/predict-batch")
142
+ def predict_batch(payload: BatchPredictPayload):
143
+ try:
144
+ return _predict_texts(payload.inputs)
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
147
+
148
+
149
+ @app.post("/evaluate")
150
+ def evaluate(payload: EvalPayload):
151
+ """
152
+ Quick on-the-spot test with provided labeled samples.
153
+ Request body:
154
+ {
155
+ "samples": [
156
+ {"text": "Your parcel is held...", "label": "PHISH"},
157
+ {"text": "Lunch at 12?", "label": "LEGIT"}
158
+ ]
159
+ }
160
+ Returns accuracy and per-class counts.
161
+ """
162
+ try:
163
+ texts = [s.text for s in payload.samples]
164
+ gts = [(_normalize_label(s.label) if s.label else None) for s in payload.samples]
165
+ preds = _predict_texts(texts)
166
+
167
+ total = len(preds)
168
+ correct = 0
169
+ per_class = {}
170
+
171
+ for gt, pr in zip(gts, preds):
172
+ pred_label = pr["label"]
173
+ if gt is not None:
174
+ correct += int(gt == pred_label)
175
+ per_class.setdefault(gt, {"tp": 0, "count": 0})
176
+ per_class[gt]["count"] += 1
177
+ if gt == pred_label:
178
+ per_class[gt]["tp"] += 1
179
+
180
+ acc = (correct / total) if total and any(gt is not None for gt in gts) else None
181
+
182
+ return {
183
+ "accuracy": acc, # None if no ground truths provided
184
+ "total": total,
185
+ "predictions": preds,
186
+ "per_class": per_class,
187
+ }
188
+ except Exception as e:
189
+ raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
194
+ import uvicorn
195
+ uvicorn.run(app, host="0.0.0.0", port=8000)