Perth0603 commited on
Commit
48a94e0
·
verified ·
1 Parent(s): 081ac88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -18
app.py CHANGED
@@ -25,8 +25,10 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
- # Temperature for softening predictions (1.0 = normal, >1.0 = softer, <1.0 = sharper)
29
- TEMPERATURE = 3.0 # Adjust this value (try 1.5 to 3.0)
 
 
30
 
31
 
32
  # ============================================================================
@@ -140,6 +142,35 @@ def _normalize_label(txt: str) -> str:
140
  return t
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def _load_model():
144
  """Load model, tokenizer, and preprocessor"""
145
  global _tokenizer, _model, _device, _preprocessor
@@ -149,7 +180,8 @@ def _load_model():
149
  print(f"\n{'='*60}")
150
  print(f"Loading model: {MODEL_ID}")
151
  print(f"Device: {_device}")
152
- print(f"Temperature scaling: {TEMPERATURE}")
 
153
  print(f"{'='*60}\n")
154
 
155
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -171,7 +203,7 @@ def _load_model():
171
 
172
 
173
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
174
- """Predict with temperature-scaled probabilities"""
175
  _load_model()
176
  if not texts:
177
  return []
@@ -191,31 +223,33 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
191
  )
192
  enc = {k: v.to(_device) for k, v in enc.items()}
193
 
194
- # Predict with temperature scaling
195
  with torch.no_grad():
196
  logits = _model(**enc).logits
197
- # Apply temperature scaling to soften probabilities
198
- scaled_logits = logits / TEMPERATURE
199
- probs = F.softmax(scaled_logits, dim=-1)
 
200
 
201
  # Get labels from model config
202
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
203
 
204
  outputs: List[Dict] = []
205
- for text_idx in range(probs.shape[0]):
206
- p = probs[text_idx]
 
207
 
208
- # Get prediction
209
- predicted_idx = int(torch.argmax(p).item())
210
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
211
  predicted_label_norm = _normalize_label(predicted_label_raw)
212
- predicted_prob = float(p[predicted_idx].item())
213
 
214
  # Build probability breakdown
215
  prob_breakdown = {}
216
- for i in range(len(p)):
217
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
218
- prob_breakdown[label] = round(float(p[i].item()), 4)
219
 
220
  output = {
221
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
@@ -225,6 +259,7 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
225
  "confidence": round(predicted_prob * 100, 2),
226
  "score": round(predicted_prob, 4),
227
  "probs": prob_breakdown,
 
228
  }
229
 
230
  if include_preprocessing and preprocessing_info:
@@ -247,8 +282,11 @@ def root():
247
  "status": "ok",
248
  "model": MODEL_ID,
249
  "device": _device,
250
- "temperature": TEMPERATURE,
251
- "note": "Using temperature scaling to calibrate probabilities"
 
 
 
252
  }
253
 
254
 
@@ -264,7 +302,6 @@ def debug_labels():
264
  "label2id": getattr(_model.config, "label2id", {}),
265
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
266
  "device": _device,
267
- "temperature": TEMPERATURE,
268
  }
269
 
270
 
 
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
28
+ # Confidence calibration settings
29
+ MIN_CONFIDENCE = 0.70 # Minimum confidence to report (70%)
30
+ MAX_CONFIDENCE = 0.95 # Maximum confidence to report (95%)
31
+ SMOOTHING_FACTOR = 0.15 # How much to smooth (0.1-0.3)
32
 
33
 
34
  # ============================================================================
 
142
  return t
143
 
144
 
145
+ def _calibrate_probabilities(probs: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Calibrate overconfident probabilities to more realistic range.
148
+ Uses label smoothing to reduce extreme confidence.
149
+ """
150
+ num_classes = probs.shape[-1]
151
+
152
+ # Apply label smoothing
153
+ smoothed_probs = probs * (1 - SMOOTHING_FACTOR) + (SMOOTHING_FACTOR / num_classes)
154
+
155
+ # Clip to min/max confidence range
156
+ max_prob, max_idx = torch.max(smoothed_probs, dim=-1, keepdim=True)
157
+
158
+ # Scale to desired range
159
+ if max_prob > MAX_CONFIDENCE:
160
+ scale_factor = MAX_CONFIDENCE / max_prob
161
+ smoothed_probs = smoothed_probs * scale_factor
162
+
163
+ # Ensure minimum confidence for winner
164
+ if max_prob < MIN_CONFIDENCE:
165
+ scale_factor = MIN_CONFIDENCE / max_prob
166
+ smoothed_probs = smoothed_probs * scale_factor
167
+
168
+ # Renormalize to sum to 1
169
+ smoothed_probs = smoothed_probs / smoothed_probs.sum(dim=-1, keepdim=True)
170
+
171
+ return smoothed_probs
172
+
173
+
174
  def _load_model():
175
  """Load model, tokenizer, and preprocessor"""
176
  global _tokenizer, _model, _device, _preprocessor
 
180
  print(f"\n{'='*60}")
181
  print(f"Loading model: {MODEL_ID}")
182
  print(f"Device: {_device}")
183
+ print(f"Confidence calibration: {MIN_CONFIDENCE*100:.0f}%-{MAX_CONFIDENCE*100:.0f}%")
184
+ print(f"Smoothing factor: {SMOOTHING_FACTOR}")
185
  print(f"{'='*60}\n")
186
 
187
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
203
 
204
 
205
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
206
+ """Predict with calibrated probabilities"""
207
  _load_model()
208
  if not texts:
209
  return []
 
223
  )
224
  enc = {k: v.to(_device) for k, v in enc.items()}
225
 
226
+ # Predict
227
  with torch.no_grad():
228
  logits = _model(**enc).logits
229
+ probs = F.softmax(logits, dim=-1)
230
+
231
+ # Apply calibration to reduce overconfidence
232
+ calibrated_probs = _calibrate_probabilities(probs)
233
 
234
  # Get labels from model config
235
  id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
236
 
237
  outputs: List[Dict] = []
238
+ for text_idx in range(calibrated_probs.shape[0]):
239
+ p_original = probs[text_idx]
240
+ p_calibrated = calibrated_probs[text_idx]
241
 
242
+ # Get prediction from calibrated probs
243
+ predicted_idx = int(torch.argmax(p_calibrated).item())
244
  predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
245
  predicted_label_norm = _normalize_label(predicted_label_raw)
246
+ predicted_prob = float(p_calibrated[predicted_idx].item())
247
 
248
  # Build probability breakdown
249
  prob_breakdown = {}
250
+ for i in range(len(p_calibrated)):
251
  label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
252
+ prob_breakdown[label] = round(float(p_calibrated[i].item()), 4)
253
 
254
  output = {
255
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
 
259
  "confidence": round(predicted_prob * 100, 2),
260
  "score": round(predicted_prob, 4),
261
  "probs": prob_breakdown,
262
+ "original_confidence": round(float(p_original[predicted_idx].item()) * 100, 2), # Show original for comparison
263
  }
264
 
265
  if include_preprocessing and preprocessing_info:
 
282
  "status": "ok",
283
  "model": MODEL_ID,
284
  "device": _device,
285
+ "calibration": {
286
+ "min_confidence": MIN_CONFIDENCE,
287
+ "max_confidence": MAX_CONFIDENCE,
288
+ "smoothing_factor": SMOOTHING_FACTOR
289
+ }
290
  }
291
 
292
 
 
302
  "label2id": getattr(_model.config, "label2id", {}),
303
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
304
  "device": _device,
 
305
  }
306
 
307