Perth0603 commited on
Commit
8cfc19f
·
verified ·
1 Parent(s): dfa9403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -95
app.py CHANGED
@@ -25,9 +25,9 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
- # Uncertainty estimation settings
29
- MC_SAMPLES = 10 # Number of forward passes with dropout (more = smoother, slower)
30
- DROPOUT_RATE = 0.1 # Dropout rate for uncertainty (0.05-0.15)
31
 
32
 
33
  # ============================================================================
@@ -57,34 +57,122 @@ class TextPreprocessor:
57
  """Reduce tokens to lemmas"""
58
  return [self.lemmatizer.lemmatize(token) for token in tokens]
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def sentiment_analysis(self, text: str) -> Dict:
61
- """Analyze sentiment and phishing indicators"""
62
  blob = TextBlob(text)
63
  polarity = blob.sentiment.polarity
64
  subjectivity = blob.sentiment.subjectivity
65
 
66
- phishing_indicators = {
67
- "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
68
- "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
69
- "suspicious_urls": bool(re.search(r'http\S+|www\S+', text)),
70
- "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
71
- }
72
-
73
  return {
74
  "polarity": round(polarity, 4),
75
  "subjectivity": round(subjectivity, 4),
76
  "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
77
  "is_persuasive": subjectivity > 0.5,
78
- "phishing_indicators": phishing_indicators
79
  }
80
 
81
  def preprocess(self, text: str) -> Dict:
82
- """Preprocessing for analysis"""
83
  tokens = self.tokenize(text)
84
  tokens_no_stop = self.remove_stopwords(tokens)
85
  stemmed = self.stem(tokens_no_stop)
86
  lemmatized = self.lemmatize(tokens_no_stop)
87
  sentiment = self.sentiment_analysis(text)
 
88
 
89
  return {
90
  "original_text": text,
@@ -93,6 +181,7 @@ class TextPreprocessor:
93
  "stemmed_tokens": stemmed,
94
  "lemmatized_tokens": lemmatized,
95
  "sentiment": sentiment,
 
96
  "token_count": len(tokens_no_stop)
97
  }
98
 
@@ -103,13 +192,11 @@ class TextPreprocessor:
103
  class PredictPayload(BaseModel):
104
  inputs: str
105
  include_preprocessing: bool = True
106
- use_uncertainty: bool = True # Enable uncertainty estimation
107
 
108
 
109
  class BatchPredictPayload(BaseModel):
110
  inputs: List[str]
111
  include_preprocessing: bool = True
112
- use_uncertainty: bool = True
113
 
114
 
115
  class LabeledText(BaseModel):
@@ -143,11 +230,47 @@ def _normalize_label(txt: str) -> str:
143
  return t
144
 
145
 
146
- def _enable_dropout(model):
147
- """Enable dropout layers during inference for uncertainty estimation"""
148
- for module in model.modules():
149
- if module.__class__.__name__.startswith('Dropout'):
150
- module.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  def _load_model():
@@ -159,7 +282,7 @@ def _load_model():
159
  print(f"\n{'='*60}")
160
  print(f"Loading model: {MODEL_ID}")
161
  print(f"Device: {_device}")
162
- print(f"MC Dropout samples: {MC_SAMPLES}")
163
  print(f"{'='*60}\n")
164
 
165
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -180,49 +303,14 @@ def _load_model():
180
  print(f"{'='*60}\n")
181
 
182
 
183
- def _predict_with_uncertainty(enc: Dict, use_uncertainty: bool = True) -> tuple:
184
- """
185
- Predict with uncertainty estimation using MC Dropout.
186
- Returns: (mean_probs, std_probs)
187
- """
188
- if not use_uncertainty:
189
- # Standard prediction
190
- with torch.no_grad():
191
- logits = _model(**enc).logits
192
- probs = F.softmax(logits, dim=-1)
193
- return probs, torch.zeros_like(probs)
194
-
195
- # Monte Carlo Dropout: multiple forward passes with dropout enabled
196
- prob_samples = []
197
-
198
- _enable_dropout(_model) # Enable dropout during inference
199
-
200
- with torch.no_grad():
201
- for _ in range(MC_SAMPLES):
202
- logits = _model(**enc).logits
203
- probs = F.softmax(logits, dim=-1)
204
- prob_samples.append(probs)
205
-
206
- _model.eval() # Restore eval mode
207
-
208
- # Stack all samples and compute mean and std
209
- prob_samples = torch.stack(prob_samples) # [MC_SAMPLES, batch_size, num_classes]
210
- mean_probs = prob_samples.mean(dim=0) # Average predictions
211
- std_probs = prob_samples.std(dim=0) # Uncertainty (variance)
212
-
213
- return mean_probs, std_probs
214
-
215
-
216
- def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_uncertainty: bool = True) -> List[Dict]:
217
- """Predict with uncertainty-aware confidence scores"""
218
  _load_model()
219
  if not texts:
220
  return []
221
 
222
- # Get preprocessing info
223
- preprocessing_info = None
224
- if include_preprocessing:
225
- preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
226
 
227
  # Tokenize
228
  enc = _tokenizer(
@@ -234,36 +322,39 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_unc
234
  )
235
  enc = {k: v.to(_device) for k, v in enc.items()}
236
 
237
- # Predict with uncertainty
238
- mean_probs, std_probs = _predict_with_uncertainty(enc, use_uncertainty)
 
 
239
 
240
  # Get labels from model config
241
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
242
 
243
  outputs: List[Dict] = []
244
- for text_idx in range(mean_probs.shape[0]):
245
- p_mean = mean_probs[text_idx]
246
- p_std = std_probs[text_idx]
 
247
 
248
  # Get prediction
249
- predicted_idx = int(torch.argmax(p_mean).item())
250
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
251
  predicted_label_norm = _normalize_label(predicted_label_raw)
252
- predicted_prob = float(p_mean[predicted_idx].item())
253
- predicted_std = float(p_std[predicted_idx].item())
254
-
255
- # Calculate uncertainty-adjusted confidence
256
- # Higher uncertainty = lower reported confidence
257
- uncertainty_penalty = predicted_std * 2.0 # Amplify uncertainty effect
258
- adjusted_confidence = max(0.5, predicted_prob - uncertainty_penalty) # Don't go below 50%
259
 
260
- # Build probability breakdown
261
  prob_breakdown = {}
262
- uncertainty_breakdown = {}
263
- for i in range(len(p_mean)):
264
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
265
- prob_breakdown[label] = round(float(p_mean[i].item()), 4)
266
- uncertainty_breakdown[label] = round(float(p_std[i].item()), 4)
 
 
267
 
268
  output = {
269
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
@@ -272,14 +363,12 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_unc
272
  "is_phish": predicted_label_norm == "PHISH",
273
  "confidence": round(adjusted_confidence * 100, 2),
274
  "score": round(adjusted_confidence, 4),
275
- "uncertainty": round(predicted_std * 100, 2), # Uncertainty as percentage
276
  "probs": prob_breakdown,
277
- "uncertainty_scores": uncertainty_breakdown,
278
- "raw_confidence": round(predicted_prob * 100, 2), # Original model confidence
279
  }
280
 
281
- if include_preprocessing and preprocessing_info:
282
- output["preprocessing"] = preprocessing_info[text_idx]
283
 
284
  outputs.append(output)
285
 
@@ -298,11 +387,8 @@ def root():
298
  "status": "ok",
299
  "model": MODEL_ID,
300
  "device": _device,
301
- "uncertainty_estimation": {
302
- "enabled": True,
303
- "mc_samples": MC_SAMPLES,
304
- "dropout_rate": DROPOUT_RATE
305
- }
306
  }
307
 
308
 
@@ -336,9 +422,7 @@ def debug_preprocessing(payload: PredictPayload):
336
  def predict(payload: PredictPayload):
337
  """Single prediction"""
338
  try:
339
- res = _predict_texts([payload.inputs],
340
- include_preprocessing=payload.include_preprocessing,
341
- use_uncertainty=payload.use_uncertainty)
342
  return res[0]
343
  except Exception as e:
344
  raise HTTPException(status_code=500, detail=str(e))
@@ -348,9 +432,7 @@ def predict(payload: PredictPayload):
348
  def predict_batch(payload: BatchPredictPayload):
349
  """Batch predictions"""
350
  try:
351
- return _predict_texts(payload.inputs,
352
- include_preprocessing=payload.include_preprocessing,
353
- use_uncertainty=payload.use_uncertainty)
354
  except Exception as e:
355
  raise HTTPException(status_code=500, detail=str(e))
356
 
@@ -361,7 +443,7 @@ def evaluate(payload: EvalPayload):
361
  try:
362
  texts = [s.text for s in payload.samples]
363
  gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
364
- preds = _predict_texts(texts, include_preprocessing=False, use_uncertainty=True)
365
 
366
  total = len(preds)
367
  correct = 0
 
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
+ # Confidence adjustment settings
29
+ BASE_CONFIDENCE_MIN = 0.55 # Minimum confidence (55%)
30
+ BASE_CONFIDENCE_MAX = 0.85 # Maximum confidence (85%)
31
 
32
 
33
  # ============================================================================
 
57
  """Reduce tokens to lemmas"""
58
  return [self.lemmatizer.lemmatize(token) for token in tokens]
59
 
60
+ def analyze_phishing_indicators(self, text: str) -> Dict:
61
+ """Comprehensive phishing indicator analysis"""
62
+ indicators = {
63
+ "urgent_words": bool(re.search(
64
+ r'\b(urgent|immediately|immediate|act now|right now|asap|verify now|'
65
+ r'confirm now|update now|click now|respond now|expire soon|expiring|'
66
+ r'time sensitive|limited time|hurry|quick|fast|today only)\b',
67
+ text, re.IGNORECASE
68
+ )),
69
+ "threat_words": bool(re.search(
70
+ r'\b(suspend|suspended|lock|locked|block|blocked|disable|disabled|'
71
+ r'restrict|restricted|terminate|terminated|cancel|cancelled|close|closed|'
72
+ r'freeze|frozen|ban|banned|deactivate|deactivated|remove|removed)\b',
73
+ text, re.IGNORECASE
74
+ )),
75
+ "action_words": bool(re.search(
76
+ r'\b(click here|click now|click below|click this|verify|confirm|update|'
77
+ r'download|install|open attachment|validate|authenticate|reset password|'
78
+ r'change password|provide|submit|enter|fill out|complete)\b',
79
+ text, re.IGNORECASE
80
+ )),
81
+ "financial_words": bool(re.search(
82
+ r'\b(payment|pay|money|credit card|bank account|billing|invoice|refund|'
83
+ r'tax|irs|paypal|transaction|transfer|wire|deposit|account number|'
84
+ r'social security|ssn|card number|cvv|pin)\b',
85
+ text, re.IGNORECASE
86
+ )),
87
+ "authority_impersonation": bool(re.search(
88
+ r'\b(paypal|amazon|microsoft|apple|google|facebook|instagram|netflix|'
89
+ r'ebay|irs|fbi|cia|government|police|bank of america|chase|wells fargo|'
90
+ r'citibank|security team|support team|admin|administrator)\b',
91
+ text, re.IGNORECASE
92
+ )),
93
+ "suspicious_urls": bool(re.search(r'http[s]?://|www\.', text)),
94
+ "suspicious_domain": bool(re.search(
95
+ r'\b(bit\.ly|tinyurl|goo\.gl|short|link|redirect|verify-|secure-|account-|'
96
+ r'update-|login-|signin-)\w+\.(com|net|org|info|xyz|tk|ml|ga|cf|gq)',
97
+ text, re.IGNORECASE
98
+ )),
99
+ "generic_greeting": bool(re.search(
100
+ r'^(dear (customer|user|member|client|sir|madam)|hello|hi there|greetings)\b',
101
+ text, re.IGNORECASE
102
+ )),
103
+ "poor_grammar": self._detect_poor_grammar(text),
104
+ "excessive_punctuation": bool(re.search(r'[!?]{2,}', text)),
105
+ "all_caps": len(re.findall(r'\b[A-Z]{3,}\b', text)) > 2,
106
+ "currency_symbols": bool(re.search(r'[$£€¥₹]', text)),
107
+ }
108
+
109
+ # Count active indicators
110
+ active_count = sum(indicators.values())
111
+ total_count = len(indicators)
112
+
113
+ # Determine urgency level
114
+ urgency_score = sum([
115
+ indicators["urgent_words"] * 2,
116
+ indicators["threat_words"] * 2,
117
+ indicators["action_words"],
118
+ indicators["excessive_punctuation"],
119
+ indicators["all_caps"]
120
+ ])
121
+
122
+ if urgency_score >= 4:
123
+ urgency_level = "CRITICAL"
124
+ elif urgency_score >= 2:
125
+ urgency_level = "HIGH"
126
+ elif urgency_score >= 1:
127
+ urgency_level = "MEDIUM"
128
+ else:
129
+ urgency_level = "LOW"
130
+
131
+ indicators["urgency_level"] = urgency_level
132
+ indicators["indicator_count"] = active_count
133
+ indicators["indicator_percentage"] = round((active_count / total_count) * 100, 1)
134
+
135
+ return indicators
136
+
137
+ def _detect_poor_grammar(self, text: str) -> bool:
138
+ """Simple heuristic for poor grammar"""
139
+ issues = 0
140
+ # Multiple spaces
141
+ if re.search(r'\s{2,}', text):
142
+ issues += 1
143
+ # Missing spaces after punctuation
144
+ if re.search(r'[.,!?][a-zA-Z]', text):
145
+ issues += 1
146
+ # Inconsistent capitalization
147
+ sentences = re.split(r'[.!?]+', text)
148
+ for sent in sentences:
149
+ sent = sent.strip()
150
+ if sent and len(sent) > 5 and not sent[0].isupper():
151
+ issues += 1
152
+ break
153
+ return issues >= 2
154
+
155
  def sentiment_analysis(self, text: str) -> Dict:
156
+ """Analyze sentiment"""
157
  blob = TextBlob(text)
158
  polarity = blob.sentiment.polarity
159
  subjectivity = blob.sentiment.subjectivity
160
 
 
 
 
 
 
 
 
161
  return {
162
  "polarity": round(polarity, 4),
163
  "subjectivity": round(subjectivity, 4),
164
  "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
165
  "is_persuasive": subjectivity > 0.5,
 
166
  }
167
 
168
  def preprocess(self, text: str) -> Dict:
169
+ """Full preprocessing pipeline"""
170
  tokens = self.tokenize(text)
171
  tokens_no_stop = self.remove_stopwords(tokens)
172
  stemmed = self.stem(tokens_no_stop)
173
  lemmatized = self.lemmatize(tokens_no_stop)
174
  sentiment = self.sentiment_analysis(text)
175
+ phishing_indicators = self.analyze_phishing_indicators(text)
176
 
177
  return {
178
  "original_text": text,
 
181
  "stemmed_tokens": stemmed,
182
  "lemmatized_tokens": lemmatized,
183
  "sentiment": sentiment,
184
+ "phishing_indicators": phishing_indicators,
185
  "token_count": len(tokens_no_stop)
186
  }
187
 
 
192
  class PredictPayload(BaseModel):
193
  inputs: str
194
  include_preprocessing: bool = True
 
195
 
196
 
197
  class BatchPredictPayload(BaseModel):
198
  inputs: List[str]
199
  include_preprocessing: bool = True
 
200
 
201
 
202
  class LabeledText(BaseModel):
 
230
  return t
231
 
232
 
233
+ def _adjust_confidence_with_indicators(base_prob: float, indicators: Dict, predicted_label: str) -> float:
234
+ """
235
+ Adjust confidence based on phishing indicators.
236
+ More indicators = context suggests phishing, so confidence varies based on prediction
237
+ """
238
+ indicator_count = indicators.get("indicator_count", 0)
239
+ indicator_percentage = indicators.get("indicator_percentage", 0)
240
+
241
+ # Base adjustment from indicator count
242
+ # If predicting PHISH and many indicators: more confident (but cap at 85%)
243
+ # If predicting LEGIT with many indicators: less confident (uncertainty)
244
+ # If predicting PHISH with few indicators: less confident (might be wrong)
245
+ # If predicting LEGIT with few indicators: more confident
246
+
247
+ if predicted_label == "PHISH":
248
+ # Phishing prediction
249
+ if indicator_percentage >= 40: # Strong indicators
250
+ # High confidence: 75-85%
251
+ adjusted = 0.75 + (indicator_percentage / 100) * 0.10
252
+ elif indicator_percentage >= 25: # Moderate indicators
253
+ # Medium confidence: 65-75%
254
+ adjusted = 0.65 + (indicator_percentage / 100) * 0.10
255
+ else: # Weak indicators
256
+ # Lower confidence: 55-65%
257
+ adjusted = 0.55 + (indicator_percentage / 100) * 0.10
258
+ else:
259
+ # Legitimate prediction
260
+ if indicator_percentage >= 40: # Many phishing indicators but predicting legit?
261
+ # Low confidence: 55-65% (uncertain)
262
+ adjusted = 0.65 - (indicator_percentage / 100) * 0.10
263
+ elif indicator_percentage >= 25: # Some indicators
264
+ # Medium confidence: 65-75%
265
+ adjusted = 0.70 - (indicator_percentage / 100) * 0.05
266
+ else: # Few indicators
267
+ # High confidence: 75-85%
268
+ adjusted = 0.75 + ((100 - indicator_percentage) / 100) * 0.10
269
+
270
+ # Clamp to min/max range
271
+ adjusted = max(BASE_CONFIDENCE_MIN, min(BASE_CONFIDENCE_MAX, adjusted))
272
+
273
+ return adjusted
274
 
275
 
276
  def _load_model():
 
282
  print(f"\n{'='*60}")
283
  print(f"Loading model: {MODEL_ID}")
284
  print(f"Device: {_device}")
285
+ print(f"Confidence range: {BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%")
286
  print(f"{'='*60}\n")
287
 
288
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
303
  print(f"{'='*60}\n")
304
 
305
 
306
+ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
307
+ """Predict with indicator-based confidence adjustment"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  _load_model()
309
  if not texts:
310
  return []
311
 
312
+ # Get preprocessing info (always needed for indicators)
313
+ preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
 
 
314
 
315
  # Tokenize
316
  enc = _tokenizer(
 
322
  )
323
  enc = {k: v.to(_device) for k, v in enc.items()}
324
 
325
+ # Predict
326
+ with torch.no_grad():
327
+ logits = _model(**enc).logits
328
+ probs = F.softmax(logits, dim=-1)
329
 
330
  # Get labels from model config
331
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
332
 
333
  outputs: List[Dict] = []
334
+ for text_idx in range(probs.shape[0]):
335
+ p = probs[text_idx]
336
+ preprocessing = preprocessing_info[text_idx]
337
+ indicators = preprocessing["phishing_indicators"]
338
 
339
  # Get prediction
340
+ predicted_idx = int(torch.argmax(p).item())
341
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
342
  predicted_label_norm = _normalize_label(predicted_label_raw)
343
+ raw_prob = float(p[predicted_idx].item())
344
+
345
+ # Adjust confidence based on indicators
346
+ adjusted_confidence = _adjust_confidence_with_indicators(
347
+ raw_prob, indicators, predicted_label_norm
348
+ )
 
349
 
350
+ # Build probability breakdown (adjusted)
351
  prob_breakdown = {}
352
+ for i in range(len(p)):
 
353
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
354
+ if i == predicted_idx:
355
+ prob_breakdown[label] = round(adjusted_confidence, 4)
356
+ else:
357
+ prob_breakdown[label] = round(1.0 - adjusted_confidence, 4)
358
 
359
  output = {
360
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
 
363
  "is_phish": predicted_label_norm == "PHISH",
364
  "confidence": round(adjusted_confidence * 100, 2),
365
  "score": round(adjusted_confidence, 4),
 
366
  "probs": prob_breakdown,
367
+ "model_raw_confidence": round(raw_prob * 100, 2),
 
368
  }
369
 
370
+ if include_preprocessing:
371
+ output["preprocessing"] = preprocessing
372
 
373
  outputs.append(output)
374
 
 
387
  "status": "ok",
388
  "model": MODEL_ID,
389
  "device": _device,
390
+ "confidence_range": f"{BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%",
391
+ "note": "Confidence adjusted based on phishing indicators"
 
 
 
392
  }
393
 
394
 
 
422
  def predict(payload: PredictPayload):
423
  """Single prediction"""
424
  try:
425
+ res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
 
 
426
  return res[0]
427
  except Exception as e:
428
  raise HTTPException(status_code=500, detail=str(e))
 
432
  def predict_batch(payload: BatchPredictPayload):
433
  """Batch predictions"""
434
  try:
435
+ return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
 
 
436
  except Exception as e:
437
  raise HTTPException(status_code=500, detail=str(e))
438
 
 
443
  try:
444
  texts = [s.text for s in payload.samples]
445
  gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
446
+ preds = _predict_texts(texts, include_preprocessing=False)
447
 
448
  total = len(preds)
449
  correct = 0