Perth0603 commited on
Commit
dfa9403
·
verified ·
1 Parent(s): 48a94e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -60
app.py CHANGED
@@ -25,10 +25,9 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
- # Confidence calibration settings
29
- MIN_CONFIDENCE = 0.70 # Minimum confidence to report (70%)
30
- MAX_CONFIDENCE = 0.95 # Maximum confidence to report (95%)
31
- SMOOTHING_FACTOR = 0.15 # How much to smooth (0.1-0.3)
32
 
33
 
34
  # ============================================================================
@@ -104,11 +103,13 @@ class TextPreprocessor:
104
  class PredictPayload(BaseModel):
105
  inputs: str
106
  include_preprocessing: bool = True
 
107
 
108
 
109
  class BatchPredictPayload(BaseModel):
110
  inputs: List[str]
111
  include_preprocessing: bool = True
 
112
 
113
 
114
  class LabeledText(BaseModel):
@@ -142,33 +143,11 @@ def _normalize_label(txt: str) -> str:
142
  return t
143
 
144
 
145
- def _calibrate_probabilities(probs: torch.Tensor) -> torch.Tensor:
146
- """
147
- Calibrate overconfident probabilities to more realistic range.
148
- Uses label smoothing to reduce extreme confidence.
149
- """
150
- num_classes = probs.shape[-1]
151
-
152
- # Apply label smoothing
153
- smoothed_probs = probs * (1 - SMOOTHING_FACTOR) + (SMOOTHING_FACTOR / num_classes)
154
-
155
- # Clip to min/max confidence range
156
- max_prob, max_idx = torch.max(smoothed_probs, dim=-1, keepdim=True)
157
-
158
- # Scale to desired range
159
- if max_prob > MAX_CONFIDENCE:
160
- scale_factor = MAX_CONFIDENCE / max_prob
161
- smoothed_probs = smoothed_probs * scale_factor
162
-
163
- # Ensure minimum confidence for winner
164
- if max_prob < MIN_CONFIDENCE:
165
- scale_factor = MIN_CONFIDENCE / max_prob
166
- smoothed_probs = smoothed_probs * scale_factor
167
-
168
- # Renormalize to sum to 1
169
- smoothed_probs = smoothed_probs / smoothed_probs.sum(dim=-1, keepdim=True)
170
-
171
- return smoothed_probs
172
 
173
 
174
  def _load_model():
@@ -180,8 +159,7 @@ def _load_model():
180
  print(f"\n{'='*60}")
181
  print(f"Loading model: {MODEL_ID}")
182
  print(f"Device: {_device}")
183
- print(f"Confidence calibration: {MIN_CONFIDENCE*100:.0f}%-{MAX_CONFIDENCE*100:.0f}%")
184
- print(f"Smoothing factor: {SMOOTHING_FACTOR}")
185
  print(f"{'='*60}\n")
186
 
187
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -202,8 +180,41 @@ def _load_model():
202
  print(f"{'='*60}\n")
203
 
204
 
205
- def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
206
- """Predict with calibrated probabilities"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  _load_model()
208
  if not texts:
209
  return []
@@ -223,43 +234,48 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
223
  )
224
  enc = {k: v.to(_device) for k, v in enc.items()}
225
 
226
- # Predict
227
- with torch.no_grad():
228
- logits = _model(**enc).logits
229
- probs = F.softmax(logits, dim=-1)
230
-
231
- # Apply calibration to reduce overconfidence
232
- calibrated_probs = _calibrate_probabilities(probs)
233
 
234
  # Get labels from model config
235
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
236
 
237
  outputs: List[Dict] = []
238
- for text_idx in range(calibrated_probs.shape[0]):
239
- p_original = probs[text_idx]
240
- p_calibrated = calibrated_probs[text_idx]
241
 
242
- # Get prediction from calibrated probs
243
- predicted_idx = int(torch.argmax(p_calibrated).item())
244
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
245
  predicted_label_norm = _normalize_label(predicted_label_raw)
246
- predicted_prob = float(p_calibrated[predicted_idx].item())
 
 
 
 
 
 
247
 
248
  # Build probability breakdown
249
  prob_breakdown = {}
250
- for i in range(len(p_calibrated)):
 
251
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
252
- prob_breakdown[label] = round(float(p_calibrated[i].item()), 4)
 
253
 
254
  output = {
255
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
256
  "label": predicted_label_norm,
257
  "raw_label": predicted_label_raw,
258
  "is_phish": predicted_label_norm == "PHISH",
259
- "confidence": round(predicted_prob * 100, 2),
260
- "score": round(predicted_prob, 4),
 
261
  "probs": prob_breakdown,
262
- "original_confidence": round(float(p_original[predicted_idx].item()) * 100, 2), # Show original for comparison
 
263
  }
264
 
265
  if include_preprocessing and preprocessing_info:
@@ -282,10 +298,10 @@ def root():
282
  "status": "ok",
283
  "model": MODEL_ID,
284
  "device": _device,
285
- "calibration": {
286
- "min_confidence": MIN_CONFIDENCE,
287
- "max_confidence": MAX_CONFIDENCE,
288
- "smoothing_factor": SMOOTHING_FACTOR
289
  }
290
  }
291
 
@@ -320,7 +336,9 @@ def debug_preprocessing(payload: PredictPayload):
320
  def predict(payload: PredictPayload):
321
  """Single prediction"""
322
  try:
323
- res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
 
 
324
  return res[0]
325
  except Exception as e:
326
  raise HTTPException(status_code=500, detail=str(e))
@@ -330,7 +348,9 @@ def predict(payload: PredictPayload):
330
  def predict_batch(payload: BatchPredictPayload):
331
  """Batch predictions"""
332
  try:
333
- return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
 
 
334
  except Exception as e:
335
  raise HTTPException(status_code=500, detail=str(e))
336
 
@@ -341,7 +361,7 @@ def evaluate(payload: EvalPayload):
341
  try:
342
  texts = [s.text for s in payload.samples]
343
  gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
344
- preds = _predict_texts(texts, include_preprocessing=False)
345
 
346
  total = len(preds)
347
  correct = 0
 
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
+ # Uncertainty estimation settings
29
+ MC_SAMPLES = 10 # Number of forward passes with dropout (more = smoother, slower)
30
+ DROPOUT_RATE = 0.1 # Dropout rate for uncertainty (0.05-0.15)
 
31
 
32
 
33
  # ============================================================================
 
103
  class PredictPayload(BaseModel):
104
  inputs: str
105
  include_preprocessing: bool = True
106
+ use_uncertainty: bool = True # Enable uncertainty estimation
107
 
108
 
109
  class BatchPredictPayload(BaseModel):
110
  inputs: List[str]
111
  include_preprocessing: bool = True
112
+ use_uncertainty: bool = True
113
 
114
 
115
  class LabeledText(BaseModel):
 
143
  return t
144
 
145
 
146
+ def _enable_dropout(model):
147
+ """Enable dropout layers during inference for uncertainty estimation"""
148
+ for module in model.modules():
149
+ if module.__class__.__name__.startswith('Dropout'):
150
+ module.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  def _load_model():
 
159
  print(f"\n{'='*60}")
160
  print(f"Loading model: {MODEL_ID}")
161
  print(f"Device: {_device}")
162
+ print(f"MC Dropout samples: {MC_SAMPLES}")
 
163
  print(f"{'='*60}\n")
164
 
165
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
180
  print(f"{'='*60}\n")
181
 
182
 
183
+ def _predict_with_uncertainty(enc: Dict, use_uncertainty: bool = True) -> tuple:
184
+ """
185
+ Predict with uncertainty estimation using MC Dropout.
186
+ Returns: (mean_probs, std_probs)
187
+ """
188
+ if not use_uncertainty:
189
+ # Standard prediction
190
+ with torch.no_grad():
191
+ logits = _model(**enc).logits
192
+ probs = F.softmax(logits, dim=-1)
193
+ return probs, torch.zeros_like(probs)
194
+
195
+ # Monte Carlo Dropout: multiple forward passes with dropout enabled
196
+ prob_samples = []
197
+
198
+ _enable_dropout(_model) # Enable dropout during inference
199
+
200
+ with torch.no_grad():
201
+ for _ in range(MC_SAMPLES):
202
+ logits = _model(**enc).logits
203
+ probs = F.softmax(logits, dim=-1)
204
+ prob_samples.append(probs)
205
+
206
+ _model.eval() # Restore eval mode
207
+
208
+ # Stack all samples and compute mean and std
209
+ prob_samples = torch.stack(prob_samples) # [MC_SAMPLES, batch_size, num_classes]
210
+ mean_probs = prob_samples.mean(dim=0) # Average predictions
211
+ std_probs = prob_samples.std(dim=0) # Uncertainty (variance)
212
+
213
+ return mean_probs, std_probs
214
+
215
+
216
+ def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_uncertainty: bool = True) -> List[Dict]:
217
+ """Predict with uncertainty-aware confidence scores"""
218
  _load_model()
219
  if not texts:
220
  return []
 
234
  )
235
  enc = {k: v.to(_device) for k, v in enc.items()}
236
 
237
+ # Predict with uncertainty
238
+ mean_probs, std_probs = _predict_with_uncertainty(enc, use_uncertainty)
 
 
 
 
 
239
 
240
  # Get labels from model config
241
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
242
 
243
  outputs: List[Dict] = []
244
+ for text_idx in range(mean_probs.shape[0]):
245
+ p_mean = mean_probs[text_idx]
246
+ p_std = std_probs[text_idx]
247
 
248
+ # Get prediction
249
+ predicted_idx = int(torch.argmax(p_mean).item())
250
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
251
  predicted_label_norm = _normalize_label(predicted_label_raw)
252
+ predicted_prob = float(p_mean[predicted_idx].item())
253
+ predicted_std = float(p_std[predicted_idx].item())
254
+
255
+ # Calculate uncertainty-adjusted confidence
256
+ # Higher uncertainty = lower reported confidence
257
+ uncertainty_penalty = predicted_std * 2.0 # Amplify uncertainty effect
258
+ adjusted_confidence = max(0.5, predicted_prob - uncertainty_penalty) # Don't go below 50%
259
 
260
  # Build probability breakdown
261
  prob_breakdown = {}
262
+ uncertainty_breakdown = {}
263
+ for i in range(len(p_mean)):
264
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
265
+ prob_breakdown[label] = round(float(p_mean[i].item()), 4)
266
+ uncertainty_breakdown[label] = round(float(p_std[i].item()), 4)
267
 
268
  output = {
269
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
270
  "label": predicted_label_norm,
271
  "raw_label": predicted_label_raw,
272
  "is_phish": predicted_label_norm == "PHISH",
273
+ "confidence": round(adjusted_confidence * 100, 2),
274
+ "score": round(adjusted_confidence, 4),
275
+ "uncertainty": round(predicted_std * 100, 2), # Uncertainty as percentage
276
  "probs": prob_breakdown,
277
+ "uncertainty_scores": uncertainty_breakdown,
278
+ "raw_confidence": round(predicted_prob * 100, 2), # Original model confidence
279
  }
280
 
281
  if include_preprocessing and preprocessing_info:
 
298
  "status": "ok",
299
  "model": MODEL_ID,
300
  "device": _device,
301
+ "uncertainty_estimation": {
302
+ "enabled": True,
303
+ "mc_samples": MC_SAMPLES,
304
+ "dropout_rate": DROPOUT_RATE
305
  }
306
  }
307
 
 
336
  def predict(payload: PredictPayload):
337
  """Single prediction"""
338
  try:
339
+ res = _predict_texts([payload.inputs],
340
+ include_preprocessing=payload.include_preprocessing,
341
+ use_uncertainty=payload.use_uncertainty)
342
  return res[0]
343
  except Exception as e:
344
  raise HTTPException(status_code=500, detail=str(e))
 
348
  def predict_batch(payload: BatchPredictPayload):
349
  """Batch predictions"""
350
  try:
351
+ return _predict_texts(payload.inputs,
352
+ include_preprocessing=payload.include_preprocessing,
353
+ use_uncertainty=payload.use_uncertainty)
354
  except Exception as e:
355
  raise HTTPException(status_code=500, detail=str(e))
356
 
 
361
  try:
362
  texts = [s.text for s in payload.samples]
363
  gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
364
+ preds = _predict_texts(texts, include_preprocessing=False, use_uncertainty=True)
365
 
366
  total = len(preds)
367
  correct = 0