Perth0603's picture
Update app.py
8cfc19f verified
import os
from typing import List, Optional, Dict
import re
import torch
import torch.nn.functional as F
import nltk
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
from textblob import TextBlob
# Download NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
MODEL_ID = "Perth0603/phishing-email-mobilebert"
app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
# Confidence adjustment settings
BASE_CONFIDENCE_MIN = 0.55 # Minimum confidence (55%)
BASE_CONFIDENCE_MAX = 0.85 # Maximum confidence (85%)
# ============================================================================
# TEXT PREPROCESSING CLASS
# ============================================================================
class TextPreprocessor:
"""NLP preprocessing for analysis and feature extraction"""
def __init__(self):
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
self.stop_words = set(stopwords.words('english'))
def tokenize(self, text: str) -> List[str]:
"""Break text into tokens"""
return word_tokenize(text.lower())
def remove_stopwords(self, tokens: List[str]) -> List[str]:
"""Remove common stop words"""
return [token for token in tokens if token.isalnum() and token not in self.stop_words]
def stem(self, tokens: List[str]) -> List[str]:
"""Reduce tokens to stems"""
return [self.stemmer.stem(token) for token in tokens]
def lemmatize(self, tokens: List[str]) -> List[str]:
"""Reduce tokens to lemmas"""
return [self.lemmatizer.lemmatize(token) for token in tokens]
def analyze_phishing_indicators(self, text: str) -> Dict:
"""Comprehensive phishing indicator analysis"""
indicators = {
"urgent_words": bool(re.search(
r'\b(urgent|immediately|immediate|act now|right now|asap|verify now|'
r'confirm now|update now|click now|respond now|expire soon|expiring|'
r'time sensitive|limited time|hurry|quick|fast|today only)\b',
text, re.IGNORECASE
)),
"threat_words": bool(re.search(
r'\b(suspend|suspended|lock|locked|block|blocked|disable|disabled|'
r'restrict|restricted|terminate|terminated|cancel|cancelled|close|closed|'
r'freeze|frozen|ban|banned|deactivate|deactivated|remove|removed)\b',
text, re.IGNORECASE
)),
"action_words": bool(re.search(
r'\b(click here|click now|click below|click this|verify|confirm|update|'
r'download|install|open attachment|validate|authenticate|reset password|'
r'change password|provide|submit|enter|fill out|complete)\b',
text, re.IGNORECASE
)),
"financial_words": bool(re.search(
r'\b(payment|pay|money|credit card|bank account|billing|invoice|refund|'
r'tax|irs|paypal|transaction|transfer|wire|deposit|account number|'
r'social security|ssn|card number|cvv|pin)\b',
text, re.IGNORECASE
)),
"authority_impersonation": bool(re.search(
r'\b(paypal|amazon|microsoft|apple|google|facebook|instagram|netflix|'
r'ebay|irs|fbi|cia|government|police|bank of america|chase|wells fargo|'
r'citibank|security team|support team|admin|administrator)\b',
text, re.IGNORECASE
)),
"suspicious_urls": bool(re.search(r'http[s]?://|www\.', text)),
"suspicious_domain": bool(re.search(
r'\b(bit\.ly|tinyurl|goo\.gl|short|link|redirect|verify-|secure-|account-|'
r'update-|login-|signin-)\w+\.(com|net|org|info|xyz|tk|ml|ga|cf|gq)',
text, re.IGNORECASE
)),
"generic_greeting": bool(re.search(
r'^(dear (customer|user|member|client|sir|madam)|hello|hi there|greetings)\b',
text, re.IGNORECASE
)),
"poor_grammar": self._detect_poor_grammar(text),
"excessive_punctuation": bool(re.search(r'[!?]{2,}', text)),
"all_caps": len(re.findall(r'\b[A-Z]{3,}\b', text)) > 2,
"currency_symbols": bool(re.search(r'[$£€¥₹]', text)),
}
# Count active indicators
active_count = sum(indicators.values())
total_count = len(indicators)
# Determine urgency level
urgency_score = sum([
indicators["urgent_words"] * 2,
indicators["threat_words"] * 2,
indicators["action_words"],
indicators["excessive_punctuation"],
indicators["all_caps"]
])
if urgency_score >= 4:
urgency_level = "CRITICAL"
elif urgency_score >= 2:
urgency_level = "HIGH"
elif urgency_score >= 1:
urgency_level = "MEDIUM"
else:
urgency_level = "LOW"
indicators["urgency_level"] = urgency_level
indicators["indicator_count"] = active_count
indicators["indicator_percentage"] = round((active_count / total_count) * 100, 1)
return indicators
def _detect_poor_grammar(self, text: str) -> bool:
"""Simple heuristic for poor grammar"""
issues = 0
# Multiple spaces
if re.search(r'\s{2,}', text):
issues += 1
# Missing spaces after punctuation
if re.search(r'[.,!?][a-zA-Z]', text):
issues += 1
# Inconsistent capitalization
sentences = re.split(r'[.!?]+', text)
for sent in sentences:
sent = sent.strip()
if sent and len(sent) > 5 and not sent[0].isupper():
issues += 1
break
return issues >= 2
def sentiment_analysis(self, text: str) -> Dict:
"""Analyze sentiment"""
blob = TextBlob(text)
polarity = blob.sentiment.polarity
subjectivity = blob.sentiment.subjectivity
return {
"polarity": round(polarity, 4),
"subjectivity": round(subjectivity, 4),
"sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
"is_persuasive": subjectivity > 0.5,
}
def preprocess(self, text: str) -> Dict:
"""Full preprocessing pipeline"""
tokens = self.tokenize(text)
tokens_no_stop = self.remove_stopwords(tokens)
stemmed = self.stem(tokens_no_stop)
lemmatized = self.lemmatize(tokens_no_stop)
sentiment = self.sentiment_analysis(text)
phishing_indicators = self.analyze_phishing_indicators(text)
return {
"original_text": text,
"tokens": tokens,
"tokens_without_stopwords": tokens_no_stop,
"stemmed_tokens": stemmed,
"lemmatized_tokens": lemmatized,
"sentiment": sentiment,
"phishing_indicators": phishing_indicators,
"token_count": len(tokens_no_stop)
}
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class PredictPayload(BaseModel):
inputs: str
include_preprocessing: bool = True
class BatchPredictPayload(BaseModel):
inputs: List[str]
include_preprocessing: bool = True
class LabeledText(BaseModel):
text: str
label: Optional[str] = None
class EvalPayload(BaseModel):
samples: List[LabeledText]
# ============================================================================
# GLOBAL VARIABLES
# ============================================================================
_tokenizer = None
_model = None
_device = "cpu"
_preprocessor = None
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def _normalize_label(txt: str) -> str:
"""Normalize label text"""
t = (str(txt) if txt is not None else "").strip().upper()
if t in ("PHISHING", "PHISH", "SPAM", "1"):
return "PHISH"
if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"):
return "LEGIT"
return t
def _adjust_confidence_with_indicators(base_prob: float, indicators: Dict, predicted_label: str) -> float:
"""
Adjust confidence based on phishing indicators.
More indicators = context suggests phishing, so confidence varies based on prediction
"""
indicator_count = indicators.get("indicator_count", 0)
indicator_percentage = indicators.get("indicator_percentage", 0)
# Base adjustment from indicator count
# If predicting PHISH and many indicators: more confident (but cap at 85%)
# If predicting LEGIT with many indicators: less confident (uncertainty)
# If predicting PHISH with few indicators: less confident (might be wrong)
# If predicting LEGIT with few indicators: more confident
if predicted_label == "PHISH":
# Phishing prediction
if indicator_percentage >= 40: # Strong indicators
# High confidence: 75-85%
adjusted = 0.75 + (indicator_percentage / 100) * 0.10
elif indicator_percentage >= 25: # Moderate indicators
# Medium confidence: 65-75%
adjusted = 0.65 + (indicator_percentage / 100) * 0.10
else: # Weak indicators
# Lower confidence: 55-65%
adjusted = 0.55 + (indicator_percentage / 100) * 0.10
else:
# Legitimate prediction
if indicator_percentage >= 40: # Many phishing indicators but predicting legit?
# Low confidence: 55-65% (uncertain)
adjusted = 0.65 - (indicator_percentage / 100) * 0.10
elif indicator_percentage >= 25: # Some indicators
# Medium confidence: 65-75%
adjusted = 0.70 - (indicator_percentage / 100) * 0.05
else: # Few indicators
# High confidence: 75-85%
adjusted = 0.75 + ((100 - indicator_percentage) / 100) * 0.10
# Clamp to min/max range
adjusted = max(BASE_CONFIDENCE_MIN, min(BASE_CONFIDENCE_MAX, adjusted))
return adjusted
def _load_model():
"""Load model, tokenizer, and preprocessor"""
global _tokenizer, _model, _device, _preprocessor
if _tokenizer is None or _model is None:
_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n{'='*60}")
print(f"Loading model: {MODEL_ID}")
print(f"Device: {_device}")
print(f"Confidence range: {BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%")
print(f"{'='*60}\n")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
_model.to(_device)
_model.eval()
_preprocessor = TextPreprocessor()
# Warm-up
with torch.no_grad():
_ = _model(
**_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
.to(_device)
).logits
id2label = getattr(_model.config, "id2label", {})
print(f"Model labels: {id2label}")
print(f"{'='*60}\n")
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
"""Predict with indicator-based confidence adjustment"""
_load_model()
if not texts:
return []
# Get preprocessing info (always needed for indicators)
preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
# Tokenize
enc = _tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
enc = {k: v.to(_device) for k, v in enc.items()}
# Predict
with torch.no_grad():
logits = _model(**enc).logits
probs = F.softmax(logits, dim=-1)
# Get labels from model config
id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
outputs: List[Dict] = []
for text_idx in range(probs.shape[0]):
p = probs[text_idx]
preprocessing = preprocessing_info[text_idx]
indicators = preprocessing["phishing_indicators"]
# Get prediction
predicted_idx = int(torch.argmax(p).item())
predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
predicted_label_norm = _normalize_label(predicted_label_raw)
raw_prob = float(p[predicted_idx].item())
# Adjust confidence based on indicators
adjusted_confidence = _adjust_confidence_with_indicators(
raw_prob, indicators, predicted_label_norm
)
# Build probability breakdown (adjusted)
prob_breakdown = {}
for i in range(len(p)):
label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
if i == predicted_idx:
prob_breakdown[label] = round(adjusted_confidence, 4)
else:
prob_breakdown[label] = round(1.0 - adjusted_confidence, 4)
output = {
"text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
"label": predicted_label_norm,
"raw_label": predicted_label_raw,
"is_phish": predicted_label_norm == "PHISH",
"confidence": round(adjusted_confidence * 100, 2),
"score": round(adjusted_confidence, 4),
"probs": prob_breakdown,
"model_raw_confidence": round(raw_prob * 100, 2),
}
if include_preprocessing:
output["preprocessing"] = preprocessing
outputs.append(output)
return outputs
# ============================================================================
# API ENDPOINTS
# ============================================================================
@app.get("/")
def root():
"""Root endpoint"""
_load_model()
return {
"status": "ok",
"model": MODEL_ID,
"device": _device,
"confidence_range": f"{BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%",
"note": "Confidence adjusted based on phishing indicators"
}
@app.get("/debug/labels")
def debug_labels():
"""View model configuration"""
_load_model()
return {
"status": "ok",
"model_id": MODEL_ID,
"id2label": getattr(_model.config, "id2label", {}),
"label2id": getattr(_model.config, "label2id", {}),
"num_labels": int(getattr(_model.config, "num_labels", 0)),
"device": _device,
}
@app.post("/debug/preprocessing")
def debug_preprocessing(payload: PredictPayload):
"""Debug preprocessing"""
try:
_load_model()
preprocessing = _preprocessor.preprocess(payload.inputs)
return preprocessing
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
def predict(payload: PredictPayload):
"""Single prediction"""
try:
res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
return res[0]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict-batch")
def predict_batch(payload: BatchPredictPayload):
"""Batch predictions"""
try:
return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/evaluate")
def evaluate(payload: EvalPayload):
"""Evaluate on labeled samples"""
try:
texts = [s.text for s in payload.samples]
gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
preds = _predict_texts(texts, include_preprocessing=False)
total = len(preds)
correct = 0
per_class: Dict[str, Dict[str, int]] = {}
for gt, pr in zip(gts, preds):
pred_label = pr["label"]
if gt is not None:
correct += int(gt == pred_label)
per_class.setdefault(gt, {"tp": 0, "count": 0})
per_class[gt]["count"] += 1
if gt == pred_label:
per_class[gt]["tp"] += 1
has_gts = any(gt is not None for gt in gts)
acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
return {
"accuracy": round(acc, 4) if acc else None,
"total": total,
"correct": correct,
"predictions": preds,
"per_class": per_class,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)