Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,9 +25,9 @@ MODEL_ID = "Perth0603/phishing-email-mobilebert"
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
|
| 33 |
# ============================================================================
|
|
@@ -57,34 +57,122 @@ class TextPreprocessor:
|
|
| 57 |
"""Reduce tokens to lemmas"""
|
| 58 |
return [self.lemmatizer.lemmatize(token) for token in tokens]
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def sentiment_analysis(self, text: str) -> Dict:
|
| 61 |
-
"""Analyze sentiment
|
| 62 |
blob = TextBlob(text)
|
| 63 |
polarity = blob.sentiment.polarity
|
| 64 |
subjectivity = blob.sentiment.subjectivity
|
| 65 |
|
| 66 |
-
phishing_indicators = {
|
| 67 |
-
"urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
|
| 68 |
-
"threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
|
| 69 |
-
"suspicious_urls": bool(re.search(r'http\S+|www\S+', text)),
|
| 70 |
-
"urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
return {
|
| 74 |
"polarity": round(polarity, 4),
|
| 75 |
"subjectivity": round(subjectivity, 4),
|
| 76 |
"sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
|
| 77 |
"is_persuasive": subjectivity > 0.5,
|
| 78 |
-
"phishing_indicators": phishing_indicators
|
| 79 |
}
|
| 80 |
|
| 81 |
def preprocess(self, text: str) -> Dict:
|
| 82 |
-
"""
|
| 83 |
tokens = self.tokenize(text)
|
| 84 |
tokens_no_stop = self.remove_stopwords(tokens)
|
| 85 |
stemmed = self.stem(tokens_no_stop)
|
| 86 |
lemmatized = self.lemmatize(tokens_no_stop)
|
| 87 |
sentiment = self.sentiment_analysis(text)
|
|
|
|
| 88 |
|
| 89 |
return {
|
| 90 |
"original_text": text,
|
|
@@ -93,6 +181,7 @@ class TextPreprocessor:
|
|
| 93 |
"stemmed_tokens": stemmed,
|
| 94 |
"lemmatized_tokens": lemmatized,
|
| 95 |
"sentiment": sentiment,
|
|
|
|
| 96 |
"token_count": len(tokens_no_stop)
|
| 97 |
}
|
| 98 |
|
|
@@ -103,13 +192,11 @@ class TextPreprocessor:
|
|
| 103 |
class PredictPayload(BaseModel):
|
| 104 |
inputs: str
|
| 105 |
include_preprocessing: bool = True
|
| 106 |
-
use_uncertainty: bool = True # Enable uncertainty estimation
|
| 107 |
|
| 108 |
|
| 109 |
class BatchPredictPayload(BaseModel):
|
| 110 |
inputs: List[str]
|
| 111 |
include_preprocessing: bool = True
|
| 112 |
-
use_uncertainty: bool = True
|
| 113 |
|
| 114 |
|
| 115 |
class LabeledText(BaseModel):
|
|
@@ -143,11 +230,47 @@ def _normalize_label(txt: str) -> str:
|
|
| 143 |
return t
|
| 144 |
|
| 145 |
|
| 146 |
-
def
|
| 147 |
-
"""
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
def _load_model():
|
|
@@ -159,7 +282,7 @@ def _load_model():
|
|
| 159 |
print(f"\n{'='*60}")
|
| 160 |
print(f"Loading model: {MODEL_ID}")
|
| 161 |
print(f"Device: {_device}")
|
| 162 |
-
print(f"
|
| 163 |
print(f"{'='*60}\n")
|
| 164 |
|
| 165 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
@@ -180,49 +303,14 @@ def _load_model():
|
|
| 180 |
print(f"{'='*60}\n")
|
| 181 |
|
| 182 |
|
| 183 |
-
def
|
| 184 |
-
"""
|
| 185 |
-
Predict with uncertainty estimation using MC Dropout.
|
| 186 |
-
Returns: (mean_probs, std_probs)
|
| 187 |
-
"""
|
| 188 |
-
if not use_uncertainty:
|
| 189 |
-
# Standard prediction
|
| 190 |
-
with torch.no_grad():
|
| 191 |
-
logits = _model(**enc).logits
|
| 192 |
-
probs = F.softmax(logits, dim=-1)
|
| 193 |
-
return probs, torch.zeros_like(probs)
|
| 194 |
-
|
| 195 |
-
# Monte Carlo Dropout: multiple forward passes with dropout enabled
|
| 196 |
-
prob_samples = []
|
| 197 |
-
|
| 198 |
-
_enable_dropout(_model) # Enable dropout during inference
|
| 199 |
-
|
| 200 |
-
with torch.no_grad():
|
| 201 |
-
for _ in range(MC_SAMPLES):
|
| 202 |
-
logits = _model(**enc).logits
|
| 203 |
-
probs = F.softmax(logits, dim=-1)
|
| 204 |
-
prob_samples.append(probs)
|
| 205 |
-
|
| 206 |
-
_model.eval() # Restore eval mode
|
| 207 |
-
|
| 208 |
-
# Stack all samples and compute mean and std
|
| 209 |
-
prob_samples = torch.stack(prob_samples) # [MC_SAMPLES, batch_size, num_classes]
|
| 210 |
-
mean_probs = prob_samples.mean(dim=0) # Average predictions
|
| 211 |
-
std_probs = prob_samples.std(dim=0) # Uncertainty (variance)
|
| 212 |
-
|
| 213 |
-
return mean_probs, std_probs
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_uncertainty: bool = True) -> List[Dict]:
|
| 217 |
-
"""Predict with uncertainty-aware confidence scores"""
|
| 218 |
_load_model()
|
| 219 |
if not texts:
|
| 220 |
return []
|
| 221 |
|
| 222 |
-
# Get preprocessing info
|
| 223 |
-
preprocessing_info =
|
| 224 |
-
if include_preprocessing:
|
| 225 |
-
preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
|
| 226 |
|
| 227 |
# Tokenize
|
| 228 |
enc = _tokenizer(
|
|
@@ -234,36 +322,39 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_unc
|
|
| 234 |
)
|
| 235 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 236 |
|
| 237 |
-
# Predict
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# Get labels from model config
|
| 241 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 242 |
|
| 243 |
outputs: List[Dict] = []
|
| 244 |
-
for text_idx in range(
|
| 245 |
-
|
| 246 |
-
|
|
|
|
| 247 |
|
| 248 |
# Get prediction
|
| 249 |
-
predicted_idx = int(torch.argmax(
|
| 250 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 251 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
adjusted_confidence = max(0.5, predicted_prob - uncertainty_penalty) # Don't go below 50%
|
| 259 |
|
| 260 |
-
# Build probability breakdown
|
| 261 |
prob_breakdown = {}
|
| 262 |
-
|
| 263 |
-
for i in range(len(p_mean)):
|
| 264 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
| 267 |
|
| 268 |
output = {
|
| 269 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
|
@@ -272,14 +363,12 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True, use_unc
|
|
| 272 |
"is_phish": predicted_label_norm == "PHISH",
|
| 273 |
"confidence": round(adjusted_confidence * 100, 2),
|
| 274 |
"score": round(adjusted_confidence, 4),
|
| 275 |
-
"uncertainty": round(predicted_std * 100, 2), # Uncertainty as percentage
|
| 276 |
"probs": prob_breakdown,
|
| 277 |
-
"
|
| 278 |
-
"raw_confidence": round(predicted_prob * 100, 2), # Original model confidence
|
| 279 |
}
|
| 280 |
|
| 281 |
-
if include_preprocessing
|
| 282 |
-
output["preprocessing"] =
|
| 283 |
|
| 284 |
outputs.append(output)
|
| 285 |
|
|
@@ -298,11 +387,8 @@ def root():
|
|
| 298 |
"status": "ok",
|
| 299 |
"model": MODEL_ID,
|
| 300 |
"device": _device,
|
| 301 |
-
"
|
| 302 |
-
|
| 303 |
-
"mc_samples": MC_SAMPLES,
|
| 304 |
-
"dropout_rate": DROPOUT_RATE
|
| 305 |
-
}
|
| 306 |
}
|
| 307 |
|
| 308 |
|
|
@@ -336,9 +422,7 @@ def debug_preprocessing(payload: PredictPayload):
|
|
| 336 |
def predict(payload: PredictPayload):
|
| 337 |
"""Single prediction"""
|
| 338 |
try:
|
| 339 |
-
res = _predict_texts([payload.inputs],
|
| 340 |
-
include_preprocessing=payload.include_preprocessing,
|
| 341 |
-
use_uncertainty=payload.use_uncertainty)
|
| 342 |
return res[0]
|
| 343 |
except Exception as e:
|
| 344 |
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -348,9 +432,7 @@ def predict(payload: PredictPayload):
|
|
| 348 |
def predict_batch(payload: BatchPredictPayload):
|
| 349 |
"""Batch predictions"""
|
| 350 |
try:
|
| 351 |
-
return _predict_texts(payload.inputs,
|
| 352 |
-
include_preprocessing=payload.include_preprocessing,
|
| 353 |
-
use_uncertainty=payload.use_uncertainty)
|
| 354 |
except Exception as e:
|
| 355 |
raise HTTPException(status_code=500, detail=str(e))
|
| 356 |
|
|
@@ -361,7 +443,7 @@ def evaluate(payload: EvalPayload):
|
|
| 361 |
try:
|
| 362 |
texts = [s.text for s in payload.samples]
|
| 363 |
gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
|
| 364 |
-
preds = _predict_texts(texts, include_preprocessing=False
|
| 365 |
|
| 366 |
total = len(preds)
|
| 367 |
correct = 0
|
|
|
|
| 25 |
|
| 26 |
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
|
| 27 |
|
| 28 |
+
# Confidence adjustment settings
|
| 29 |
+
BASE_CONFIDENCE_MIN = 0.55 # Minimum confidence (55%)
|
| 30 |
+
BASE_CONFIDENCE_MAX = 0.85 # Maximum confidence (85%)
|
| 31 |
|
| 32 |
|
| 33 |
# ============================================================================
|
|
|
|
| 57 |
"""Reduce tokens to lemmas"""
|
| 58 |
return [self.lemmatizer.lemmatize(token) for token in tokens]
|
| 59 |
|
| 60 |
+
def analyze_phishing_indicators(self, text: str) -> Dict:
|
| 61 |
+
"""Comprehensive phishing indicator analysis"""
|
| 62 |
+
indicators = {
|
| 63 |
+
"urgent_words": bool(re.search(
|
| 64 |
+
r'\b(urgent|immediately|immediate|act now|right now|asap|verify now|'
|
| 65 |
+
r'confirm now|update now|click now|respond now|expire soon|expiring|'
|
| 66 |
+
r'time sensitive|limited time|hurry|quick|fast|today only)\b',
|
| 67 |
+
text, re.IGNORECASE
|
| 68 |
+
)),
|
| 69 |
+
"threat_words": bool(re.search(
|
| 70 |
+
r'\b(suspend|suspended|lock|locked|block|blocked|disable|disabled|'
|
| 71 |
+
r'restrict|restricted|terminate|terminated|cancel|cancelled|close|closed|'
|
| 72 |
+
r'freeze|frozen|ban|banned|deactivate|deactivated|remove|removed)\b',
|
| 73 |
+
text, re.IGNORECASE
|
| 74 |
+
)),
|
| 75 |
+
"action_words": bool(re.search(
|
| 76 |
+
r'\b(click here|click now|click below|click this|verify|confirm|update|'
|
| 77 |
+
r'download|install|open attachment|validate|authenticate|reset password|'
|
| 78 |
+
r'change password|provide|submit|enter|fill out|complete)\b',
|
| 79 |
+
text, re.IGNORECASE
|
| 80 |
+
)),
|
| 81 |
+
"financial_words": bool(re.search(
|
| 82 |
+
r'\b(payment|pay|money|credit card|bank account|billing|invoice|refund|'
|
| 83 |
+
r'tax|irs|paypal|transaction|transfer|wire|deposit|account number|'
|
| 84 |
+
r'social security|ssn|card number|cvv|pin)\b',
|
| 85 |
+
text, re.IGNORECASE
|
| 86 |
+
)),
|
| 87 |
+
"authority_impersonation": bool(re.search(
|
| 88 |
+
r'\b(paypal|amazon|microsoft|apple|google|facebook|instagram|netflix|'
|
| 89 |
+
r'ebay|irs|fbi|cia|government|police|bank of america|chase|wells fargo|'
|
| 90 |
+
r'citibank|security team|support team|admin|administrator)\b',
|
| 91 |
+
text, re.IGNORECASE
|
| 92 |
+
)),
|
| 93 |
+
"suspicious_urls": bool(re.search(r'http[s]?://|www\.', text)),
|
| 94 |
+
"suspicious_domain": bool(re.search(
|
| 95 |
+
r'\b(bit\.ly|tinyurl|goo\.gl|short|link|redirect|verify-|secure-|account-|'
|
| 96 |
+
r'update-|login-|signin-)\w+\.(com|net|org|info|xyz|tk|ml|ga|cf|gq)',
|
| 97 |
+
text, re.IGNORECASE
|
| 98 |
+
)),
|
| 99 |
+
"generic_greeting": bool(re.search(
|
| 100 |
+
r'^(dear (customer|user|member|client|sir|madam)|hello|hi there|greetings)\b',
|
| 101 |
+
text, re.IGNORECASE
|
| 102 |
+
)),
|
| 103 |
+
"poor_grammar": self._detect_poor_grammar(text),
|
| 104 |
+
"excessive_punctuation": bool(re.search(r'[!?]{2,}', text)),
|
| 105 |
+
"all_caps": len(re.findall(r'\b[A-Z]{3,}\b', text)) > 2,
|
| 106 |
+
"currency_symbols": bool(re.search(r'[$£€¥₹]', text)),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Count active indicators
|
| 110 |
+
active_count = sum(indicators.values())
|
| 111 |
+
total_count = len(indicators)
|
| 112 |
+
|
| 113 |
+
# Determine urgency level
|
| 114 |
+
urgency_score = sum([
|
| 115 |
+
indicators["urgent_words"] * 2,
|
| 116 |
+
indicators["threat_words"] * 2,
|
| 117 |
+
indicators["action_words"],
|
| 118 |
+
indicators["excessive_punctuation"],
|
| 119 |
+
indicators["all_caps"]
|
| 120 |
+
])
|
| 121 |
+
|
| 122 |
+
if urgency_score >= 4:
|
| 123 |
+
urgency_level = "CRITICAL"
|
| 124 |
+
elif urgency_score >= 2:
|
| 125 |
+
urgency_level = "HIGH"
|
| 126 |
+
elif urgency_score >= 1:
|
| 127 |
+
urgency_level = "MEDIUM"
|
| 128 |
+
else:
|
| 129 |
+
urgency_level = "LOW"
|
| 130 |
+
|
| 131 |
+
indicators["urgency_level"] = urgency_level
|
| 132 |
+
indicators["indicator_count"] = active_count
|
| 133 |
+
indicators["indicator_percentage"] = round((active_count / total_count) * 100, 1)
|
| 134 |
+
|
| 135 |
+
return indicators
|
| 136 |
+
|
| 137 |
+
def _detect_poor_grammar(self, text: str) -> bool:
|
| 138 |
+
"""Simple heuristic for poor grammar"""
|
| 139 |
+
issues = 0
|
| 140 |
+
# Multiple spaces
|
| 141 |
+
if re.search(r'\s{2,}', text):
|
| 142 |
+
issues += 1
|
| 143 |
+
# Missing spaces after punctuation
|
| 144 |
+
if re.search(r'[.,!?][a-zA-Z]', text):
|
| 145 |
+
issues += 1
|
| 146 |
+
# Inconsistent capitalization
|
| 147 |
+
sentences = re.split(r'[.!?]+', text)
|
| 148 |
+
for sent in sentences:
|
| 149 |
+
sent = sent.strip()
|
| 150 |
+
if sent and len(sent) > 5 and not sent[0].isupper():
|
| 151 |
+
issues += 1
|
| 152 |
+
break
|
| 153 |
+
return issues >= 2
|
| 154 |
+
|
| 155 |
def sentiment_analysis(self, text: str) -> Dict:
|
| 156 |
+
"""Analyze sentiment"""
|
| 157 |
blob = TextBlob(text)
|
| 158 |
polarity = blob.sentiment.polarity
|
| 159 |
subjectivity = blob.sentiment.subjectivity
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return {
|
| 162 |
"polarity": round(polarity, 4),
|
| 163 |
"subjectivity": round(subjectivity, 4),
|
| 164 |
"sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
|
| 165 |
"is_persuasive": subjectivity > 0.5,
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
def preprocess(self, text: str) -> Dict:
|
| 169 |
+
"""Full preprocessing pipeline"""
|
| 170 |
tokens = self.tokenize(text)
|
| 171 |
tokens_no_stop = self.remove_stopwords(tokens)
|
| 172 |
stemmed = self.stem(tokens_no_stop)
|
| 173 |
lemmatized = self.lemmatize(tokens_no_stop)
|
| 174 |
sentiment = self.sentiment_analysis(text)
|
| 175 |
+
phishing_indicators = self.analyze_phishing_indicators(text)
|
| 176 |
|
| 177 |
return {
|
| 178 |
"original_text": text,
|
|
|
|
| 181 |
"stemmed_tokens": stemmed,
|
| 182 |
"lemmatized_tokens": lemmatized,
|
| 183 |
"sentiment": sentiment,
|
| 184 |
+
"phishing_indicators": phishing_indicators,
|
| 185 |
"token_count": len(tokens_no_stop)
|
| 186 |
}
|
| 187 |
|
|
|
|
| 192 |
class PredictPayload(BaseModel):
|
| 193 |
inputs: str
|
| 194 |
include_preprocessing: bool = True
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
class BatchPredictPayload(BaseModel):
|
| 198 |
inputs: List[str]
|
| 199 |
include_preprocessing: bool = True
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
class LabeledText(BaseModel):
|
|
|
|
| 230 |
return t
|
| 231 |
|
| 232 |
|
| 233 |
+
def _adjust_confidence_with_indicators(base_prob: float, indicators: Dict, predicted_label: str) -> float:
|
| 234 |
+
"""
|
| 235 |
+
Adjust confidence based on phishing indicators.
|
| 236 |
+
More indicators = context suggests phishing, so confidence varies based on prediction
|
| 237 |
+
"""
|
| 238 |
+
indicator_count = indicators.get("indicator_count", 0)
|
| 239 |
+
indicator_percentage = indicators.get("indicator_percentage", 0)
|
| 240 |
+
|
| 241 |
+
# Base adjustment from indicator count
|
| 242 |
+
# If predicting PHISH and many indicators: more confident (but cap at 85%)
|
| 243 |
+
# If predicting LEGIT with many indicators: less confident (uncertainty)
|
| 244 |
+
# If predicting PHISH with few indicators: less confident (might be wrong)
|
| 245 |
+
# If predicting LEGIT with few indicators: more confident
|
| 246 |
+
|
| 247 |
+
if predicted_label == "PHISH":
|
| 248 |
+
# Phishing prediction
|
| 249 |
+
if indicator_percentage >= 40: # Strong indicators
|
| 250 |
+
# High confidence: 75-85%
|
| 251 |
+
adjusted = 0.75 + (indicator_percentage / 100) * 0.10
|
| 252 |
+
elif indicator_percentage >= 25: # Moderate indicators
|
| 253 |
+
# Medium confidence: 65-75%
|
| 254 |
+
adjusted = 0.65 + (indicator_percentage / 100) * 0.10
|
| 255 |
+
else: # Weak indicators
|
| 256 |
+
# Lower confidence: 55-65%
|
| 257 |
+
adjusted = 0.55 + (indicator_percentage / 100) * 0.10
|
| 258 |
+
else:
|
| 259 |
+
# Legitimate prediction
|
| 260 |
+
if indicator_percentage >= 40: # Many phishing indicators but predicting legit?
|
| 261 |
+
# Low confidence: 55-65% (uncertain)
|
| 262 |
+
adjusted = 0.65 - (indicator_percentage / 100) * 0.10
|
| 263 |
+
elif indicator_percentage >= 25: # Some indicators
|
| 264 |
+
# Medium confidence: 65-75%
|
| 265 |
+
adjusted = 0.70 - (indicator_percentage / 100) * 0.05
|
| 266 |
+
else: # Few indicators
|
| 267 |
+
# High confidence: 75-85%
|
| 268 |
+
adjusted = 0.75 + ((100 - indicator_percentage) / 100) * 0.10
|
| 269 |
+
|
| 270 |
+
# Clamp to min/max range
|
| 271 |
+
adjusted = max(BASE_CONFIDENCE_MIN, min(BASE_CONFIDENCE_MAX, adjusted))
|
| 272 |
+
|
| 273 |
+
return adjusted
|
| 274 |
|
| 275 |
|
| 276 |
def _load_model():
|
|
|
|
| 282 |
print(f"\n{'='*60}")
|
| 283 |
print(f"Loading model: {MODEL_ID}")
|
| 284 |
print(f"Device: {_device}")
|
| 285 |
+
print(f"Confidence range: {BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%")
|
| 286 |
print(f"{'='*60}\n")
|
| 287 |
|
| 288 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
| 303 |
print(f"{'='*60}\n")
|
| 304 |
|
| 305 |
|
| 306 |
+
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
|
| 307 |
+
"""Predict with indicator-based confidence adjustment"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
_load_model()
|
| 309 |
if not texts:
|
| 310 |
return []
|
| 311 |
|
| 312 |
+
# Get preprocessing info (always needed for indicators)
|
| 313 |
+
preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
|
|
|
|
|
|
|
| 314 |
|
| 315 |
# Tokenize
|
| 316 |
enc = _tokenizer(
|
|
|
|
| 322 |
)
|
| 323 |
enc = {k: v.to(_device) for k, v in enc.items()}
|
| 324 |
|
| 325 |
+
# Predict
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
logits = _model(**enc).logits
|
| 328 |
+
probs = F.softmax(logits, dim=-1)
|
| 329 |
|
| 330 |
# Get labels from model config
|
| 331 |
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
|
| 332 |
|
| 333 |
outputs: List[Dict] = []
|
| 334 |
+
for text_idx in range(probs.shape[0]):
|
| 335 |
+
p = probs[text_idx]
|
| 336 |
+
preprocessing = preprocessing_info[text_idx]
|
| 337 |
+
indicators = preprocessing["phishing_indicators"]
|
| 338 |
|
| 339 |
# Get prediction
|
| 340 |
+
predicted_idx = int(torch.argmax(p).item())
|
| 341 |
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
|
| 342 |
predicted_label_norm = _normalize_label(predicted_label_raw)
|
| 343 |
+
raw_prob = float(p[predicted_idx].item())
|
| 344 |
+
|
| 345 |
+
# Adjust confidence based on indicators
|
| 346 |
+
adjusted_confidence = _adjust_confidence_with_indicators(
|
| 347 |
+
raw_prob, indicators, predicted_label_norm
|
| 348 |
+
)
|
|
|
|
| 349 |
|
| 350 |
+
# Build probability breakdown (adjusted)
|
| 351 |
prob_breakdown = {}
|
| 352 |
+
for i in range(len(p)):
|
|
|
|
| 353 |
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
|
| 354 |
+
if i == predicted_idx:
|
| 355 |
+
prob_breakdown[label] = round(adjusted_confidence, 4)
|
| 356 |
+
else:
|
| 357 |
+
prob_breakdown[label] = round(1.0 - adjusted_confidence, 4)
|
| 358 |
|
| 359 |
output = {
|
| 360 |
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
|
|
|
|
| 363 |
"is_phish": predicted_label_norm == "PHISH",
|
| 364 |
"confidence": round(adjusted_confidence * 100, 2),
|
| 365 |
"score": round(adjusted_confidence, 4),
|
|
|
|
| 366 |
"probs": prob_breakdown,
|
| 367 |
+
"model_raw_confidence": round(raw_prob * 100, 2),
|
|
|
|
| 368 |
}
|
| 369 |
|
| 370 |
+
if include_preprocessing:
|
| 371 |
+
output["preprocessing"] = preprocessing
|
| 372 |
|
| 373 |
outputs.append(output)
|
| 374 |
|
|
|
|
| 387 |
"status": "ok",
|
| 388 |
"model": MODEL_ID,
|
| 389 |
"device": _device,
|
| 390 |
+
"confidence_range": f"{BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%",
|
| 391 |
+
"note": "Confidence adjusted based on phishing indicators"
|
|
|
|
|
|
|
|
|
|
| 392 |
}
|
| 393 |
|
| 394 |
|
|
|
|
| 422 |
def predict(payload: PredictPayload):
|
| 423 |
"""Single prediction"""
|
| 424 |
try:
|
| 425 |
+
res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
|
|
|
|
|
|
|
| 426 |
return res[0]
|
| 427 |
except Exception as e:
|
| 428 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 432 |
def predict_batch(payload: BatchPredictPayload):
|
| 433 |
"""Batch predictions"""
|
| 434 |
try:
|
| 435 |
+
return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
|
|
|
|
|
|
|
| 436 |
except Exception as e:
|
| 437 |
raise HTTPException(status_code=500, detail=str(e))
|
| 438 |
|
|
|
|
| 443 |
try:
|
| 444 |
texts = [s.text for s in payload.samples]
|
| 445 |
gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
|
| 446 |
+
preds = _predict_texts(texts, include_preprocessing=False)
|
| 447 |
|
| 448 |
total = len(preds)
|
| 449 |
correct = 0
|