Perth0603 commited on
Commit
d2fbc7d
·
verified ·
1 Parent(s): ad3f1d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -88
app.py CHANGED
@@ -1,98 +1,187 @@
1
  import os
2
  from typing import List, Optional, Dict
 
3
 
4
  import torch
 
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
-
9
- # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
10
- MODEL_ID = (
11
- os.environ.get("MODEL_ID")
12
- or os.environ.get("HF_MODEL_ID")
13
- or "Perth0603/phishing-email-mobilebert"
14
- )
15
-
16
- app = FastAPI(title="Phishing Text Classifier (Model-Authoritative)", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
 
 
 
19
  class PredictPayload(BaseModel):
20
  inputs: str
 
21
 
22
 
23
  class BatchPredictPayload(BaseModel):
24
  inputs: List[str]
 
25
 
26
 
27
  class LabeledText(BaseModel):
28
  text: str
29
- label: Optional[str] = None # optional ground truth for quick eval (accepts text)
30
 
31
 
32
  class EvalPayload(BaseModel):
33
  samples: List[LabeledText]
34
 
35
 
 
 
 
36
  _tokenizer = None
37
  _model = None
38
  _device = "cpu"
39
-
40
- # Cached normalized mapping/meta
41
- _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
42
 
43
 
44
- def _normalize_label_text_only(txt: str) -> str:
45
- """
46
- Normalize model label text to PHISH/LEGIT when possible.
47
- If unfamiliar, return the uppercased original token.
48
- """
49
  t = (str(txt) if txt is not None else "").strip().upper()
50
- if t in ("PHISHING", "PHISH", "SPAM"):
51
  return "PHISH"
52
- if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
53
  return "LEGIT"
54
- # keep other label names as-is (uppercased) so we don't force an incorrect mapping
55
  return t
56
 
57
 
58
  def _load_model():
59
- global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX
 
60
 
61
  if _tokenizer is None or _model is None:
62
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
63
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
65
  _model.to(_device)
66
- _model.eval() # important: disable dropout etc.
 
67
 
68
- # Warm-up (silent)
69
  with torch.no_grad():
70
  _ = _model(
71
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
72
  .to(_device)
73
  ).logits
74
 
75
- # Read and normalize model labels (by index)
76
- id2label = getattr(_model.config, "id2label", {}) or {}
77
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
78
- _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
79
-
80
-
81
- def _predict_texts(texts: List[str]) -> List[Dict]:
82
- """
83
- Predict and return strictly model-authoritative outputs:
84
- - label: normalized model label (PHISH/LEGIT or other model label uppercased)
85
- - raw_label: original id2label string from model.config
86
- - is_phish: boolean derived from normalized label (True if normalized == "PHISH")
87
- - score: probability of predicted class
88
- - probs: dict of normalized label -> probability (or CLASS_i keys if unknown)
89
- - predicted_index: argmax index
90
- """
91
  _load_model()
92
  if not texts:
93
  return []
94
 
95
- # Tokenize batch
 
 
 
 
 
96
  enc = _tokenizer(
97
  texts,
98
  return_tensors="pt",
@@ -102,95 +191,120 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
102
  )
103
  enc = {k: v.to(_device) for k, v in enc.items()}
104
 
 
105
  with torch.no_grad():
106
  logits = _model(**enc).logits
107
- probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
 
 
108
 
109
- # Use the model’s own mapping
110
- id2label = getattr(_model.config, "id2label", None) or {}
111
- labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
112
- # normalized labels where possible
113
- labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
114
 
115
  outputs: List[Dict] = []
116
- for i in range(probs.shape[0]):
117
- p = probs[i]
118
- idx = int(torch.argmax(p).item())
119
-
120
- raw_label = labels_by_idx_raw[idx]
121
- norm_label = labels_by_idx_norm[idx] # normalized where possible
122
-
123
- # Build probability map keyed by normalized labels when available,
124
- # otherwise fallback to CLASS_i keys to avoid collision
125
- prob_map: Dict[str, float] = {}
126
- for j, lbl_norm in enumerate(labels_by_idx_norm):
127
- key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
128
- prob_map[key] = float(p[j].item())
129
-
130
- outputs.append(
131
- {
132
- "label": norm_label, # authoritative label (model-driven, normalized)
133
- "raw_label": raw_label, # original model id2label value
134
- "is_phish": True if norm_label == "PHISH" else False,
135
- "score": float(p[idx].item()), # probability of predicted class
136
- "probs": prob_map, # per-class probabilities (keys normalized or CLASS_i)
137
- "predicted_index": idx,
138
- }
139
- )
 
 
 
 
 
140
 
141
  return outputs
142
 
143
 
 
 
 
 
144
  @app.get("/")
145
  def root():
 
146
  _load_model()
147
  return {
148
  "status": "ok",
149
  "model": MODEL_ID,
150
- "note": "This service returns predictions exactly as the model decides (label derived from model.config.id2label). Frontend should use `label` or `is_phish` as authority."
 
 
151
  }
152
 
153
 
154
  @app.get("/debug/labels")
155
  def debug_labels():
 
156
  _load_model()
 
157
  return {
 
 
158
  "id2label": getattr(_model.config, "id2label", {}),
159
  "label2id": getattr(_model.config, "label2id", {}),
160
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
161
  "device": _device,
162
- "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
163
  }
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
166
  @app.post("/predict")
167
  def predict(payload: PredictPayload):
 
168
  try:
169
- res = _predict_texts([payload.inputs])
170
  return res[0]
171
  except Exception as e:
172
- raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
173
 
174
 
175
  @app.post("/predict-batch")
176
  def predict_batch(payload: BatchPredictPayload):
 
177
  try:
178
- return _predict_texts(payload.inputs)
179
  except Exception as e:
180
- raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
181
 
182
 
183
  @app.post("/evaluate")
184
  def evaluate(payload: EvalPayload):
185
- """
186
- Quick on-the-spot test with provided labeled samples.
187
- The provided labels are interpreted as text labels (PHISH/LEGIT/etc.) — evaluation is done
188
- by comparing normalized GT text to model's normalized prediction (no 0/1 dataset mapping applied).
189
- """
190
  try:
191
  texts = [s.text for s in payload.samples]
192
- gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
193
- preds = _predict_texts(texts)
194
 
195
  total = len(preds)
196
  correct = 0
@@ -209,16 +323,16 @@ def evaluate(payload: EvalPayload):
209
  acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
210
 
211
  return {
212
- "accuracy": acc,
213
  "total": total,
 
214
  "predictions": preds,
215
  "per_class": per_class,
216
  }
217
  except Exception as e:
218
- raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
219
 
220
 
221
  if __name__ == "__main__":
222
- # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
223
  import uvicorn
224
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  from typing import List, Optional, Dict
3
+ import re
4
 
5
  import torch
6
+ import torch.nn.functional as F
7
+ import nltk
8
  from fastapi import FastAPI, HTTPException
9
  from pydantic import BaseModel
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+ from nltk.corpus import stopwords
12
+ from nltk.stem import PorterStemmer, WordNetLemmatizer
13
+ from nltk.tokenize import word_tokenize
14
+ from textblob import TextBlob
15
+
16
+ # Download NLTK data
17
+ try:
18
+ nltk.data.find('tokenizers/punkt')
19
+ except LookupError:
20
+ nltk.download('punkt')
21
+ nltk.download('stopwords')
22
+ nltk.download('wordnet')
23
+
24
+ MODEL_ID = "Perth0603/phishing-email-mobilebert"
25
+
26
+ app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
+
28
+ # Temperature for softening predictions (1.0 = normal, >1.0 = softer, <1.0 = sharper)
29
+ TEMPERATURE = 2.5 # Adjust this value (try 1.5 to 3.0)
30
+
31
+
32
+ # ============================================================================
33
+ # TEXT PREPROCESSING CLASS
34
+ # ============================================================================
35
+ class TextPreprocessor:
36
+ """NLP preprocessing for analysis and feature extraction"""
37
+
38
+ def __init__(self):
39
+ self.stemmer = PorterStemmer()
40
+ self.lemmatizer = WordNetLemmatizer()
41
+ self.stop_words = set(stopwords.words('english'))
42
+
43
+ def tokenize(self, text: str) -> List[str]:
44
+ """Break text into tokens"""
45
+ return word_tokenize(text.lower())
46
+
47
+ def remove_stopwords(self, tokens: List[str]) -> List[str]:
48
+ """Remove common stop words"""
49
+ return [token for token in tokens if token.isalnum() and token not in self.stop_words]
50
+
51
+ def stem(self, tokens: List[str]) -> List[str]:
52
+ """Reduce tokens to stems"""
53
+ return [self.stemmer.stem(token) for token in tokens]
54
+
55
+ def lemmatize(self, tokens: List[str]) -> List[str]:
56
+ """Reduce tokens to lemmas"""
57
+ return [self.lemmatizer.lemmatize(token) for token in tokens]
58
+
59
+ def sentiment_analysis(self, text: str) -> Dict:
60
+ """Analyze sentiment and phishing indicators"""
61
+ blob = TextBlob(text)
62
+ polarity = blob.sentiment.polarity
63
+ subjectivity = blob.sentiment.subjectivity
64
+
65
+ phishing_indicators = {
66
+ "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
67
+ "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
68
+ "suspicious_urls": bool(re.search(r'http\S+|www\S+', text)),
69
+ "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
70
+ }
71
+
72
+ return {
73
+ "polarity": round(polarity, 4),
74
+ "subjectivity": round(subjectivity, 4),
75
+ "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
76
+ "is_persuasive": subjectivity > 0.5,
77
+ "phishing_indicators": phishing_indicators
78
+ }
79
+
80
+ def preprocess(self, text: str) -> Dict:
81
+ """Preprocessing for analysis"""
82
+ tokens = self.tokenize(text)
83
+ tokens_no_stop = self.remove_stopwords(tokens)
84
+ stemmed = self.stem(tokens_no_stop)
85
+ lemmatized = self.lemmatize(tokens_no_stop)
86
+ sentiment = self.sentiment_analysis(text)
87
+
88
+ return {
89
+ "original_text": text,
90
+ "tokens": tokens,
91
+ "tokens_without_stopwords": tokens_no_stop,
92
+ "stemmed_tokens": stemmed,
93
+ "lemmatized_tokens": lemmatized,
94
+ "sentiment": sentiment,
95
+ "token_count": len(tokens_no_stop)
96
+ }
97
 
98
 
99
+ # ============================================================================
100
+ # PYDANTIC MODELS
101
+ # ============================================================================
102
  class PredictPayload(BaseModel):
103
  inputs: str
104
+ include_preprocessing: bool = True
105
 
106
 
107
  class BatchPredictPayload(BaseModel):
108
  inputs: List[str]
109
+ include_preprocessing: bool = True
110
 
111
 
112
  class LabeledText(BaseModel):
113
  text: str
114
+ label: Optional[str] = None
115
 
116
 
117
  class EvalPayload(BaseModel):
118
  samples: List[LabeledText]
119
 
120
 
121
+ # ============================================================================
122
+ # GLOBAL VARIABLES
123
+ # ============================================================================
124
  _tokenizer = None
125
  _model = None
126
  _device = "cpu"
127
+ _preprocessor = None
 
 
128
 
129
 
130
+ # ============================================================================
131
+ # HELPER FUNCTIONS
132
+ # ============================================================================
133
+ def _normalize_label(txt: str) -> str:
134
+ """Normalize label text"""
135
  t = (str(txt) if txt is not None else "").strip().upper()
136
+ if t in ("PHISHING", "PHISH", "SPAM", "1"):
137
  return "PHISH"
138
+ if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"):
139
  return "LEGIT"
 
140
  return t
141
 
142
 
143
  def _load_model():
144
+ """Load model, tokenizer, and preprocessor"""
145
+ global _tokenizer, _model, _device, _preprocessor
146
 
147
  if _tokenizer is None or _model is None:
148
  _device = "cuda" if torch.cuda.is_available() else "cpu"
149
+ print(f"\n{'='*60}")
150
+ print(f"Loading model: {MODEL_ID}")
151
+ print(f"Device: {_device}")
152
+ print(f"Temperature scaling: {TEMPERATURE}")
153
+ print(f"{'='*60}\n")
154
+
155
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
156
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
157
  _model.to(_device)
158
+ _model.eval()
159
+ _preprocessor = TextPreprocessor()
160
 
161
+ # Warm-up
162
  with torch.no_grad():
163
  _ = _model(
164
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
165
  .to(_device)
166
  ).logits
167
 
168
+ id2label = getattr(_model.config, "id2label", {})
169
+ print(f"Model labels: {id2label}")
170
+ print(f"{'='*60}\n")
171
+
172
+
173
+ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
174
+ """Predict with temperature-scaled probabilities"""
 
 
 
 
 
 
 
 
 
175
  _load_model()
176
  if not texts:
177
  return []
178
 
179
+ # Get preprocessing info
180
+ preprocessing_info = None
181
+ if include_preprocessing:
182
+ preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
183
+
184
+ # Tokenize
185
  enc = _tokenizer(
186
  texts,
187
  return_tensors="pt",
 
191
  )
192
  enc = {k: v.to(_device) for k, v in enc.items()}
193
 
194
+ # Predict with temperature scaling
195
  with torch.no_grad():
196
  logits = _model(**enc).logits
197
+ # Apply temperature scaling to soften probabilities
198
+ scaled_logits = logits / TEMPERATURE
199
+ probs = F.softmax(scaled_logits, dim=-1)
200
 
201
+ # Get labels from model config
202
+ id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
 
 
 
203
 
204
  outputs: List[Dict] = []
205
+ for text_idx in range(probs.shape[0]):
206
+ p = probs[text_idx]
207
+
208
+ # Get prediction
209
+ predicted_idx = int(torch.argmax(p).item())
210
+ predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
211
+ predicted_label_norm = _normalize_label(predicted_label_raw)
212
+ predicted_prob = float(p[predicted_idx].item())
213
+
214
+ # Build probability breakdown
215
+ prob_breakdown = {}
216
+ for i in range(len(p)):
217
+ label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
218
+ prob_breakdown[label] = round(float(p[i].item()), 4)
219
+
220
+ output = {
221
+ "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
222
+ "label": predicted_label_norm,
223
+ "raw_label": predicted_label_raw,
224
+ "is_phish": predicted_label_norm == "PHISH",
225
+ "confidence": round(predicted_prob * 100, 2),
226
+ "score": round(predicted_prob, 4),
227
+ "probs": prob_breakdown,
228
+ }
229
+
230
+ if include_preprocessing and preprocessing_info:
231
+ output["preprocessing"] = preprocessing_info[text_idx]
232
+
233
+ outputs.append(output)
234
 
235
  return outputs
236
 
237
 
238
+ # ============================================================================
239
+ # API ENDPOINTS
240
+ # ============================================================================
241
+
242
  @app.get("/")
243
  def root():
244
+ """Root endpoint"""
245
  _load_model()
246
  return {
247
  "status": "ok",
248
  "model": MODEL_ID,
249
+ "device": _device,
250
+ "temperature": TEMPERATURE,
251
+ "note": "Using temperature scaling to calibrate probabilities"
252
  }
253
 
254
 
255
  @app.get("/debug/labels")
256
  def debug_labels():
257
+ """View model configuration"""
258
  _load_model()
259
+
260
  return {
261
+ "status": "ok",
262
+ "model_id": MODEL_ID,
263
  "id2label": getattr(_model.config, "id2label", {}),
264
  "label2id": getattr(_model.config, "label2id", {}),
265
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
266
  "device": _device,
267
+ "temperature": TEMPERATURE,
268
  }
269
 
270
 
271
+ @app.post("/debug/preprocessing")
272
+ def debug_preprocessing(payload: PredictPayload):
273
+ """Debug preprocessing"""
274
+ try:
275
+ _load_model()
276
+ preprocessing = _preprocessor.preprocess(payload.inputs)
277
+ return preprocessing
278
+ except Exception as e:
279
+ raise HTTPException(status_code=500, detail=str(e))
280
+
281
+
282
  @app.post("/predict")
283
  def predict(payload: PredictPayload):
284
+ """Single prediction"""
285
  try:
286
+ res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
287
  return res[0]
288
  except Exception as e:
289
+ raise HTTPException(status_code=500, detail=str(e))
290
 
291
 
292
  @app.post("/predict-batch")
293
  def predict_batch(payload: BatchPredictPayload):
294
+ """Batch predictions"""
295
  try:
296
+ return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
297
  except Exception as e:
298
+ raise HTTPException(status_code=500, detail=str(e))
299
 
300
 
301
  @app.post("/evaluate")
302
  def evaluate(payload: EvalPayload):
303
+ """Evaluate on labeled samples"""
 
 
 
 
304
  try:
305
  texts = [s.text for s in payload.samples]
306
+ gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
307
+ preds = _predict_texts(texts, include_preprocessing=False)
308
 
309
  total = len(preds)
310
  correct = 0
 
323
  acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
324
 
325
  return {
326
+ "accuracy": round(acc, 4) if acc else None,
327
  "total": total,
328
+ "correct": correct,
329
  "predictions": preds,
330
  "per_class": per_class,
331
  }
332
  except Exception as e:
333
+ raise HTTPException(status_code=500, detail=str(e))
334
 
335
 
336
  if __name__ == "__main__":
 
337
  import uvicorn
338
+ uvicorn.run(app, host="0.0.0.0", port=8000)