Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,8 +25,10 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
# ============================================================================
|
|
@@ -140,6 +142,35 @@ def _normalize_label(txt: str) -> str:
|
|
| 140 |
return t
|
| 141 |
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def _load_model():
|
| 144 |
"""Load model, tokenizer, and preprocessor"""
|
| 145 |
global _tokenizer, _model, _device, _preprocessor
|
|
@@ -149,7 +180,8 @@ def _load_model():
|
|
| 149 |
print(f"\n{'='*60}")
|
| 150 |
print(f"Loading model: {MODEL_ID}")
|
| 151 |
print(f"Device: {_device}")
|
| 152 |
-
print(f"
|
|
|
|
| 153 |
print(f"{'='*60}\n")
|
| 154 |
|
| 155 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
@@ -171,7 +203,7 @@ def _load_model():
|
|
| 171 |
|
| 172 |
|
| 173 |
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
|
| 174 |
-
"""Predict with
|
| 175 |
_load_model()
|
| 176 |
if not texts:
|
| 177 |
return []
|
|
@@ -191,31 +223,33 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
|
|
| 191 |
)
|
| 192 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 193 |
|
| 194 |
-
# Predict
|
| 195 |
with torch.no_grad():
|
| 196 |
logits = _model(**enc).logits
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
| 200 |
|
| 201 |
# Get labels from model config
|
| 202 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 203 |
|
| 204 |
outputs: List[Dict] = []
|
| 205 |
-
for text_idx in range(
|
| 206 |
-
|
|
|
|
| 207 |
|
| 208 |
-
# Get prediction
|
| 209 |
-
predicted_idx = int(torch.argmax(
|
| 210 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 211 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 212 |
-
predicted_prob = float(
|
| 213 |
|
| 214 |
# Build probability breakdown
|
| 215 |
prob_breakdown = {}
|
| 216 |
-
for i in range(len(
|
| 217 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 218 |
-
prob_breakdown[label] = round(float(
|
| 219 |
|
| 220 |
output = {
|
| 221 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
|
@@ -225,6 +259,7 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
|
|
| 225 |
"confidence": round(predicted_prob * 100, 2),
|
| 226 |
"score": round(predicted_prob, 4),
|
| 227 |
"probs": prob_breakdown,
|
|
|
|
| 228 |
}
|
| 229 |
|
| 230 |
if include_preprocessing and preprocessing_info:
|
|
@@ -247,8 +282,11 @@ def root():
|
|
| 247 |
"status": "ok",
|
| 248 |
"model": MODEL_ID,
|
| 249 |
"device": _device,
|
| 250 |
-
"
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
| 252 |
}
|
| 253 |
|
| 254 |
|
|
@@ -264,7 +302,6 @@ def debug_labels():
|
|
| 264 |
"label2id": getattr(_model.config, "label2id", {}),
|
| 265 |
"num_labels": int(getattr(_model.config, "num_labels", 0)),
|
| 266 |
"device": _device,
|
| 267 |
-
"temperature": TEMPERATURE,
|
| 268 |
}
|
| 269 |
|
| 270 |
|
|
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
+
# Confidence calibration settings
|
| 29 |
+
MIN_CONFIDENCE = 0.70 # Minimum confidence to report (70%)
|
| 30 |
+
MAX_CONFIDENCE = 0.95 # Maximum confidence to report (95%)
|
| 31 |
+
SMOOTHING_FACTOR = 0.15 # How much to smooth (0.1-0.3)
|
| 32 |
|
| 33 |
|
| 34 |
# ============================================================================
|
|
|
|
| 142 |
return t
|
| 143 |
|
| 144 |
|
| 145 |
+
def _calibrate_probabilities(probs: torch.Tensor) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Calibrate overconfident probabilities to more realistic range.
|
| 148 |
+
Uses label smoothing to reduce extreme confidence.
|
| 149 |
+
"""
|
| 150 |
+
num_classes = probs.shape[-1]
|
| 151 |
+
|
| 152 |
+
# Apply label smoothing
|
| 153 |
+
smoothed_probs = probs * (1 - SMOOTHING_FACTOR) + (SMOOTHING_FACTOR / num_classes)
|
| 154 |
+
|
| 155 |
+
# Clip to min/max confidence range
|
| 156 |
+
max_prob, max_idx = torch.max(smoothed_probs, dim=-1, keepdim=True)
|
| 157 |
+
|
| 158 |
+
# Scale to desired range
|
| 159 |
+
if max_prob > MAX_CONFIDENCE:
|
| 160 |
+
scale_factor = MAX_CONFIDENCE / max_prob
|
| 161 |
+
smoothed_probs = smoothed_probs * scale_factor
|
| 162 |
+
|
| 163 |
+
# Ensure minimum confidence for winner
|
| 164 |
+
if max_prob < MIN_CONFIDENCE:
|
| 165 |
+
scale_factor = MIN_CONFIDENCE / max_prob
|
| 166 |
+
smoothed_probs = smoothed_probs * scale_factor
|
| 167 |
+
|
| 168 |
+
# Renormalize to sum to 1
|
| 169 |
+
smoothed_probs = smoothed_probs / smoothed_probs.sum(dim=-1, keepdim=True)
|
| 170 |
+
|
| 171 |
+
return smoothed_probs
|
| 172 |
+
|
| 173 |
+
|
| 174 |
def _load_model():
|
| 175 |
"""Load model, tokenizer, and preprocessor"""
|
| 176 |
global _tokenizer, _model, _device, _preprocessor
|
|
|
|
| 180 |
print(f"\n{'='*60}")
|
| 181 |
print(f"Loading model: {MODEL_ID}")
|
| 182 |
print(f"Device: {_device}")
|
| 183 |
+
print(f"Confidence calibration: {MIN_CONFIDENCE*100:.0f}%-{MAX_CONFIDENCE*100:.0f}%")
|
| 184 |
+
print(f"Smoothing factor: {SMOOTHING_FACTOR}")
|
| 185 |
print(f"{'='*60}\n")
|
| 186 |
|
| 187 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
|
| 206 |
+
"""Predict with calibrated probabilities"""
|
| 207 |
_load_model()
|
| 208 |
if not texts:
|
| 209 |
return []
|
|
|
|
| 223 |
)
|
| 224 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 225 |
|
| 226 |
+
# Predict
|
| 227 |
with torch.no_grad():
|
| 228 |
logits = _model(**enc).logits
|
| 229 |
+
probs = F.softmax(logits, dim=-1)
|
| 230 |
+
|
| 231 |
+
# Apply calibration to reduce overconfidence
|
| 232 |
+
calibrated_probs = _calibrate_probabilities(probs)
|
| 233 |
|
| 234 |
# Get labels from model config
|
| 235 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 236 |
|
| 237 |
outputs: List[Dict] = []
|
| 238 |
+
for text_idx in range(calibrated_probs.shape[0]):
|
| 239 |
+
p_original = probs[text_idx]
|
| 240 |
+
p_calibrated = calibrated_probs[text_idx]
|
| 241 |
|
| 242 |
+
# Get prediction from calibrated probs
|
| 243 |
+
predicted_idx = int(torch.argmax(p_calibrated).item())
|
| 244 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 245 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 246 |
+
predicted_prob = float(p_calibrated[predicted_idx].item())
|
| 247 |
|
| 248 |
# Build probability breakdown
|
| 249 |
prob_breakdown = {}
|
| 250 |
+
for i in range(len(p_calibrated)):
|
| 251 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 252 |
+
prob_breakdown[label] = round(float(p_calibrated[i].item()), 4)
|
| 253 |
|
| 254 |
output = {
|
| 255 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
|
|
|
| 259 |
"confidence": round(predicted_prob * 100, 2),
|
| 260 |
"score": round(predicted_prob, 4),
|
| 261 |
"probs": prob_breakdown,
|
| 262 |
+
"original_confidence": round(float(p_original[predicted_idx].item()) * 100, 2), # Show original for comparison
|
| 263 |
}
|
| 264 |
|
| 265 |
if include_preprocessing and preprocessing_info:
|
|
|
|
| 282 |
"status": "ok",
|
| 283 |
"model": MODEL_ID,
|
| 284 |
"device": _device,
|
| 285 |
+
"calibration": {
|
| 286 |
+
"min_confidence": MIN_CONFIDENCE,
|
| 287 |
+
"max_confidence": MAX_CONFIDENCE,
|
| 288 |
+
"smoothing_factor": SMOOTHING_FACTOR
|
| 289 |
+
}
|
| 290 |
}
|
| 291 |
|
| 292 |
|
|
|
|
| 302 |
"label2id": getattr(_model.config, "label2id", {}),
|
| 303 |
"num_labels": int(getattr(_model.config, "num_labels", 0)),
|
| 304 |
"device": _device,
|
|
|
|
| 305 |
}
|
| 306 |
|
| 307 |
|