Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,10 +25,9 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
SMOOTHING_FACTOR = 0.15 # How much to smooth (0.1-0.3)
|
| 32 |
|
| 33 |
|
| 34 |
# ============================================================================
|
|
@@ -104,11 +103,13 @@ class TextPreprocessor:
|
|
| 104 |
class PredictPayload(BaseModel):
|
| 105 |
inputs: str
|
| 106 |
include_preprocessing: bool = True
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class BatchPredictPayload(BaseModel):
|
| 110 |
inputs: List[str]
|
| 111 |
include_preprocessing: bool = True
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class LabeledText(BaseModel):
|
|
@@ -142,33 +143,11 @@ def _normalize_label(txt: str) -> str:
|
|
| 142 |
return t
|
| 143 |
|
| 144 |
|
| 145 |
-
def
|
| 146 |
-
"""
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
num_classes = probs.shape[-1]
|
| 151 |
-
|
| 152 |
-
# Apply label smoothing
|
| 153 |
-
smoothed_probs = probs * (1 - SMOOTHING_FACTOR) + (SMOOTHING_FACTOR / num_classes)
|
| 154 |
-
|
| 155 |
-
# Clip to min/max confidence range
|
| 156 |
-
max_prob, max_idx = torch.max(smoothed_probs, dim=-1, keepdim=True)
|
| 157 |
-
|
| 158 |
-
# Scale to desired range
|
| 159 |
-
if max_prob > MAX_CONFIDENCE:
|
| 160 |
-
scale_factor = MAX_CONFIDENCE / max_prob
|
| 161 |
-
smoothed_probs = smoothed_probs * scale_factor
|
| 162 |
-
|
| 163 |
-
# Ensure minimum confidence for winner
|
| 164 |
-
if max_prob < MIN_CONFIDENCE:
|
| 165 |
-
scale_factor = MIN_CONFIDENCE / max_prob
|
| 166 |
-
smoothed_probs = smoothed_probs * scale_factor
|
| 167 |
-
|
| 168 |
-
# Renormalize to sum to 1
|
| 169 |
-
smoothed_probs = smoothed_probs / smoothed_probs.sum(dim=-1, keepdim=True)
|
| 170 |
-
|
| 171 |
-
return smoothed_probs
|
| 172 |
|
| 173 |
|
| 174 |
def _load_model():
|
|
@@ -180,8 +159,7 @@ def _load_model():
|
|
| 180 |
print(f"\n{'='*60}")
|
| 181 |
print(f"Loading model: {MODEL_ID}")
|
| 182 |
print(f"Device: {_device}")
|
| 183 |
-
print(f"
|
| 184 |
-
print(f"Smoothing factor: {SMOOTHING_FACTOR}")
|
| 185 |
print(f"{'='*60}\n")
|
| 186 |
|
| 187 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
@@ -202,8 +180,41 @@ def _load_model():
|
|
| 202 |
print(f"{'='*60}\n")
|
| 203 |
|
| 204 |
|
| 205 |
-
def
|
| 206 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
_load_model()
|
| 208 |
if not texts:
|
| 209 |
return []
|
|
@@ -223,43 +234,48 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
|
|
| 223 |
)
|
| 224 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 225 |
|
| 226 |
-
# Predict
|
| 227 |
-
|
| 228 |
-
logits = _model(**enc).logits
|
| 229 |
-
probs = F.softmax(logits, dim=-1)
|
| 230 |
-
|
| 231 |
-
# Apply calibration to reduce overconfidence
|
| 232 |
-
calibrated_probs = _calibrate_probabilities(probs)
|
| 233 |
|
| 234 |
# Get labels from model config
|
| 235 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 236 |
|
| 237 |
outputs: List[Dict] = []
|
| 238 |
-
for text_idx in range(
|
| 239 |
-
|
| 240 |
-
|
| 241 |
|
| 242 |
-
# Get prediction
|
| 243 |
-
predicted_idx = int(torch.argmax(
|
| 244 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 245 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 246 |
-
predicted_prob = float(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Build probability breakdown
|
| 249 |
prob_breakdown = {}
|
| 250 |
-
|
|
|
|
| 251 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 252 |
-
prob_breakdown[label] = round(float(
|
|
|
|
| 253 |
|
| 254 |
output = {
|
| 255 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
| 256 |
"label": predicted_label_norm,
|
| 257 |
"raw_label": predicted_label_raw,
|
| 258 |
"is_phish": predicted_label_norm == "PHISH",
|
| 259 |
-
"confidence": round(
|
| 260 |
-
"score": round(
|
|
|
|
| 261 |
"probs": prob_breakdown,
|
| 262 |
-
"
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
if include_preprocessing and preprocessing_info:
|
|
@@ -282,10 +298,10 @@ def root():
|
|
| 282 |
"status": "ok",
|
| 283 |
"model": MODEL_ID,
|
| 284 |
"device": _device,
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
}
|
| 290 |
}
|
| 291 |
|
|
@@ -320,7 +336,9 @@ def debug_preprocessing(payload: PredictPayload):
|
|
| 320 |
def predict(payload: PredictPayload):
|
| 321 |
"""Single prediction"""
|
| 322 |
try:
|
| 323 |
-
res = _predict_texts([payload.inputs],
|
|
|
|
|
|
|
| 324 |
return res[0]
|
| 325 |
except Exception as e:
|
| 326 |
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -330,7 +348,9 @@ def predict(payload: PredictPayload):
|
|
| 330 |
def predict_batch(payload: BatchPredictPayload):
|
| 331 |
"""Batch predictions"""
|
| 332 |
try:
|
| 333 |
-
return _predict_texts(payload.inputs,
|
|
|
|
|
|
|
| 334 |
except Exception as e:
|
| 335 |
raise HTTPException(status_code=500, detail=str(e))
|
| 336 |
|
|
@@ -341,7 +361,7 @@ def evaluate(payload: EvalPayload):
|
|
| 341 |
try:
|
| 342 |
texts = [s.text for s in payload.samples]
|
| 343 |
gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
|
| 344 |
-
preds = _predict_texts(texts, include_preprocessing=False)
|
| 345 |
|
| 346 |
total = len(preds)
|
| 347 |
correct = 0
|
|
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
+
# Uncertainty estimation settings
|
| 29 |
+
MC_SAMPLES = 10 # Number of forward passes with dropout (more = smoother, slower)
|
| 30 |
+
DROPOUT_RATE = 0.1 # Dropout rate for uncertainty (0.05-0.15)
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# ============================================================================
|
|
|
|
| 103 |
class PredictPayload(BaseModel):
|
| 104 |
inputs: str
|
| 105 |
include_preprocessing: bool = True
|
| 106 |
+
use_uncertainty: bool = True # Enable uncertainty estimation
|
| 107 |
|
| 108 |
|
| 109 |
class BatchPredictPayload(BaseModel):
|
| 110 |
inputs: List[str]
|
| 111 |
include_preprocessing: bool = True
|
| 112 |
+
use_uncertainty: bool = True
|
| 113 |
|
| 114 |
|
| 115 |
class LabeledText(BaseModel):
|
|
|
|
| 143 |
return t
|
| 144 |
|
| 145 |
|
| 146 |
+
def _enable_dropout(model):
|
| 147 |
+
"""Enable dropout layers during inference for uncertainty estimation"""
|
| 148 |
+
for module in model.modules():
|
| 149 |
+
if module.__class__.__name__.startswith('Dropout'):
|
| 150 |
+
module.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
def _load_model():
|
|
|
|
| 159 |
print(f"\n{'='*60}")
|
| 160 |
print(f"Loading model: {MODEL_ID}")
|
| 161 |
print(f"Device: {_device}")
|
| 162 |
+
print(f"MC Dropout samples: {MC_SAMPLES}")
|
|
|
|
| 163 |
print(f"{'='*60}\n")
|
| 164 |
|
| 165 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
| 180 |
print(f"{'='*60}\n")
|
| 181 |
|
| 182 |
|
| 183 |
+
def _predict_with_uncertainty(enc: Dict, use_uncertainty: bool = True) -> tuple:
|
| 184 |
+
"""
|
| 185 |
+
Predict with uncertainty estimation using MC Dropout.
|
| 186 |
+
Returns: (mean_probs, std_probs)
|
| 187 |
+
"""
|
| 188 |
+
if not use_uncertainty:
|
| 189 |
+
# Standard prediction
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
logits = _model(**enc).logits
|
| 192 |
+
probs = F.softmax(logits, dim=-1)
|
| 193 |
+
return probs, torch.zeros_like(probs)
|
| 194 |
+
|
| 195 |
+
# Monte Carlo Dropout: multiple forward passes with dropout enabled
|
| 196 |
+
prob_samples = []
|
| 197 |
+
|
| 198 |
+
_enable_dropout(_model) # Enable dropout during inference
|
| 199 |
+
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
for _ in range(MC_SAMPLES):
|
| 202 |
+
logits = _model(**enc).logits
|
| 203 |
+
probs = F.softmax(logits, dim=-1)
|
| 204 |
+
prob_samples.append(probs)
|
| 205 |
+
|
| 206 |
+
_model.eval() # Restore eval mode
|
| 207 |
+
|
| 208 |
+
# Stack all samples and compute mean and std
|
| 209 |
+
prob_samples = torch.stack(prob_samples) # [MC_SAMPLES, batch_size, num_classes]
|
| 210 |
+
mean_probs = prob_samples.mean(dim=0) # Average predictions
|
| 211 |
+
std_probs = prob_samples.std(dim=0) # Uncertainty (variance)
|
| 212 |
+
|
| 213 |
+
return mean_probs, std_probs
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_uncertainty: bool = True) -> List[Dict]:
|
| 217 |
+
"""Predict with uncertainty-aware confidence scores"""
|
| 218 |
_load_model()
|
| 219 |
if not texts:
|
| 220 |
return []
|
|
|
|
| 234 |
)
|
| 235 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 236 |
|
| 237 |
+
# Predict with uncertainty
|
| 238 |
+
mean_probs, std_probs = _predict_with_uncertainty(enc, use_uncertainty)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# Get labels from model config
|
| 241 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 242 |
|
| 243 |
outputs: List[Dict] = []
|
| 244 |
+
for text_idx in range(mean_probs.shape[0]):
|
| 245 |
+
p_mean = mean_probs[text_idx]
|
| 246 |
+
p_std = std_probs[text_idx]
|
| 247 |
|
| 248 |
+
# Get prediction
|
| 249 |
+
predicted_idx = int(torch.argmax(p_mean).item())
|
| 250 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 251 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 252 |
+
predicted_prob = float(p_mean[predicted_idx].item())
|
| 253 |
+
predicted_std = float(p_std[predicted_idx].item())
|
| 254 |
+
|
| 255 |
+
# Calculate uncertainty-adjusted confidence
|
| 256 |
+
# Higher uncertainty = lower reported confidence
|
| 257 |
+
uncertainty_penalty = predicted_std * 2.0 # Amplify uncertainty effect
|
| 258 |
+
adjusted_confidence = max(0.5, predicted_prob - uncertainty_penalty) # Don't go below 50%
|
| 259 |
|
| 260 |
# Build probability breakdown
|
| 261 |
prob_breakdown = {}
|
| 262 |
+
uncertainty_breakdown = {}
|
| 263 |
+
for i in range(len(p_mean)):
|
| 264 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 265 |
+
prob_breakdown[label] = round(float(p_mean[i].item()), 4)
|
| 266 |
+
uncertainty_breakdown[label] = round(float(p_std[i].item()), 4)
|
| 267 |
|
| 268 |
output = {
|
| 269 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
| 270 |
"label": predicted_label_norm,
|
| 271 |
"raw_label": predicted_label_raw,
|
| 272 |
"is_phish": predicted_label_norm == "PHISH",
|
| 273 |
+
"confidence": round(adjusted_confidence * 100, 2),
|
| 274 |
+
"score": round(adjusted_confidence, 4),
|
| 275 |
+
"uncertainty": round(predicted_std * 100, 2), # Uncertainty as percentage
|
| 276 |
"probs": prob_breakdown,
|
| 277 |
+
"uncertainty_scores": uncertainty_breakdown,
|
| 278 |
+
"raw_confidence": round(predicted_prob * 100, 2), # Original model confidence
|
| 279 |
}
|
| 280 |
|
| 281 |
if include_preprocessing and preprocessing_info:
|
|
|
|
| 298 |
"status": "ok",
|
| 299 |
"model": MODEL_ID,
|
| 300 |
"device": _device,
|
| 301 |
+
"uncertainty_estimation": {
|
| 302 |
+
"enabled": True,
|
| 303 |
+
"mc_samples": MC_SAMPLES,
|
| 304 |
+
"dropout_rate": DROPOUT_RATE
|
| 305 |
}
|
| 306 |
}
|
| 307 |
|
|
|
|
| 336 |
def predict(payload: PredictPayload):
|
| 337 |
"""Single prediction"""
|
| 338 |
try:
|
| 339 |
+
res = _predict_texts([payload.inputs],
|
| 340 |
+
include_preprocessing=payload.include_preprocessing,
|
| 341 |
+
use_uncertainty=payload.use_uncertainty)
|
| 342 |
return res[0]
|
| 343 |
except Exception as e:
|
| 344 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 348 |
def predict_batch(payload: BatchPredictPayload):
|
| 349 |
"""Batch predictions"""
|
| 350 |
try:
|
| 351 |
+
return _predict_texts(payload.inputs,
|
| 352 |
+
include_preprocessing=payload.include_preprocessing,
|
| 353 |
+
use_uncertainty=payload.use_uncertainty)
|
| 354 |
except Exception as e:
|
| 355 |
raise HTTPException(status_code=500, detail=str(e))
|
| 356 |
|
|
|
|
| 361 |
try:
|
| 362 |
texts = [s.text for s in payload.samples]
|
| 363 |
gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
|
| 364 |
+
preds = _predict_texts(texts, include_preprocessing=False, use_uncertainty=True)
|
| 365 |
|
| 366 |
total = len(preds)
|
| 367 |
correct = 0
|