Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Optional, Dict | |
| import re | |
| import torch | |
| import torch.nn.functional as F | |
| import nltk | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from nltk.corpus import stopwords | |
| from nltk.stem import PorterStemmer, WordNetLemmatizer | |
| from nltk.tokenize import word_tokenize | |
| from textblob import TextBlob | |
| # Download NLTK data | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| MODEL_ID = "Perth0603/phishing-email-mobilebert" | |
| app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0") | |
| # Confidence adjustment settings | |
| BASE_CONFIDENCE_MIN = 0.55 # Minimum confidence (55%) | |
| BASE_CONFIDENCE_MAX = 0.85 # Maximum confidence (85%) | |
| # ============================================================================ | |
| # TEXT PREPROCESSING CLASS | |
| # ============================================================================ | |
| class TextPreprocessor: | |
| """NLP preprocessing for analysis and feature extraction""" | |
| def __init__(self): | |
| self.stemmer = PorterStemmer() | |
| self.lemmatizer = WordNetLemmatizer() | |
| self.stop_words = set(stopwords.words('english')) | |
| def tokenize(self, text: str) -> List[str]: | |
| """Break text into tokens""" | |
| return word_tokenize(text.lower()) | |
| def remove_stopwords(self, tokens: List[str]) -> List[str]: | |
| """Remove common stop words""" | |
| return [token for token in tokens if token.isalnum() and token not in self.stop_words] | |
| def stem(self, tokens: List[str]) -> List[str]: | |
| """Reduce tokens to stems""" | |
| return [self.stemmer.stem(token) for token in tokens] | |
| def lemmatize(self, tokens: List[str]) -> List[str]: | |
| """Reduce tokens to lemmas""" | |
| return [self.lemmatizer.lemmatize(token) for token in tokens] | |
| def analyze_phishing_indicators(self, text: str) -> Dict: | |
| """Comprehensive phishing indicator analysis""" | |
| indicators = { | |
| "urgent_words": bool(re.search( | |
| r'\b(urgent|immediately|immediate|act now|right now|asap|verify now|' | |
| r'confirm now|update now|click now|respond now|expire soon|expiring|' | |
| r'time sensitive|limited time|hurry|quick|fast|today only)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "threat_words": bool(re.search( | |
| r'\b(suspend|suspended|lock|locked|block|blocked|disable|disabled|' | |
| r'restrict|restricted|terminate|terminated|cancel|cancelled|close|closed|' | |
| r'freeze|frozen|ban|banned|deactivate|deactivated|remove|removed)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "action_words": bool(re.search( | |
| r'\b(click here|click now|click below|click this|verify|confirm|update|' | |
| r'download|install|open attachment|validate|authenticate|reset password|' | |
| r'change password|provide|submit|enter|fill out|complete)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "financial_words": bool(re.search( | |
| r'\b(payment|pay|money|credit card|bank account|billing|invoice|refund|' | |
| r'tax|irs|paypal|transaction|transfer|wire|deposit|account number|' | |
| r'social security|ssn|card number|cvv|pin)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "authority_impersonation": bool(re.search( | |
| r'\b(paypal|amazon|microsoft|apple|google|facebook|instagram|netflix|' | |
| r'ebay|irs|fbi|cia|government|police|bank of america|chase|wells fargo|' | |
| r'citibank|security team|support team|admin|administrator)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "suspicious_urls": bool(re.search(r'http[s]?://|www\.', text)), | |
| "suspicious_domain": bool(re.search( | |
| r'\b(bit\.ly|tinyurl|goo\.gl|short|link|redirect|verify-|secure-|account-|' | |
| r'update-|login-|signin-)\w+\.(com|net|org|info|xyz|tk|ml|ga|cf|gq)', | |
| text, re.IGNORECASE | |
| )), | |
| "generic_greeting": bool(re.search( | |
| r'^(dear (customer|user|member|client|sir|madam)|hello|hi there|greetings)\b', | |
| text, re.IGNORECASE | |
| )), | |
| "poor_grammar": self._detect_poor_grammar(text), | |
| "excessive_punctuation": bool(re.search(r'[!?]{2,}', text)), | |
| "all_caps": len(re.findall(r'\b[A-Z]{3,}\b', text)) > 2, | |
| "currency_symbols": bool(re.search(r'[$£€¥₹]', text)), | |
| } | |
| # Count active indicators | |
| active_count = sum(indicators.values()) | |
| total_count = len(indicators) | |
| # Determine urgency level | |
| urgency_score = sum([ | |
| indicators["urgent_words"] * 2, | |
| indicators["threat_words"] * 2, | |
| indicators["action_words"], | |
| indicators["excessive_punctuation"], | |
| indicators["all_caps"] | |
| ]) | |
| if urgency_score >= 4: | |
| urgency_level = "CRITICAL" | |
| elif urgency_score >= 2: | |
| urgency_level = "HIGH" | |
| elif urgency_score >= 1: | |
| urgency_level = "MEDIUM" | |
| else: | |
| urgency_level = "LOW" | |
| indicators["urgency_level"] = urgency_level | |
| indicators["indicator_count"] = active_count | |
| indicators["indicator_percentage"] = round((active_count / total_count) * 100, 1) | |
| return indicators | |
| def _detect_poor_grammar(self, text: str) -> bool: | |
| """Simple heuristic for poor grammar""" | |
| issues = 0 | |
| # Multiple spaces | |
| if re.search(r'\s{2,}', text): | |
| issues += 1 | |
| # Missing spaces after punctuation | |
| if re.search(r'[.,!?][a-zA-Z]', text): | |
| issues += 1 | |
| # Inconsistent capitalization | |
| sentences = re.split(r'[.!?]+', text) | |
| for sent in sentences: | |
| sent = sent.strip() | |
| if sent and len(sent) > 5 and not sent[0].isupper(): | |
| issues += 1 | |
| break | |
| return issues >= 2 | |
| def sentiment_analysis(self, text: str) -> Dict: | |
| """Analyze sentiment""" | |
| blob = TextBlob(text) | |
| polarity = blob.sentiment.polarity | |
| subjectivity = blob.sentiment.subjectivity | |
| return { | |
| "polarity": round(polarity, 4), | |
| "subjectivity": round(subjectivity, 4), | |
| "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral", | |
| "is_persuasive": subjectivity > 0.5, | |
| } | |
| def preprocess(self, text: str) -> Dict: | |
| """Full preprocessing pipeline""" | |
| tokens = self.tokenize(text) | |
| tokens_no_stop = self.remove_stopwords(tokens) | |
| stemmed = self.stem(tokens_no_stop) | |
| lemmatized = self.lemmatize(tokens_no_stop) | |
| sentiment = self.sentiment_analysis(text) | |
| phishing_indicators = self.analyze_phishing_indicators(text) | |
| return { | |
| "original_text": text, | |
| "tokens": tokens, | |
| "tokens_without_stopwords": tokens_no_stop, | |
| "stemmed_tokens": stemmed, | |
| "lemmatized_tokens": lemmatized, | |
| "sentiment": sentiment, | |
| "phishing_indicators": phishing_indicators, | |
| "token_count": len(tokens_no_stop) | |
| } | |
| # ============================================================================ | |
| # PYDANTIC MODELS | |
| # ============================================================================ | |
| class PredictPayload(BaseModel): | |
| inputs: str | |
| include_preprocessing: bool = True | |
| class BatchPredictPayload(BaseModel): | |
| inputs: List[str] | |
| include_preprocessing: bool = True | |
| class LabeledText(BaseModel): | |
| text: str | |
| label: Optional[str] = None | |
| class EvalPayload(BaseModel): | |
| samples: List[LabeledText] | |
| # ============================================================================ | |
| # GLOBAL VARIABLES | |
| # ============================================================================ | |
| _tokenizer = None | |
| _model = None | |
| _device = "cpu" | |
| _preprocessor = None | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def _normalize_label(txt: str) -> str: | |
| """Normalize label text""" | |
| t = (str(txt) if txt is not None else "").strip().upper() | |
| if t in ("PHISHING", "PHISH", "SPAM", "1"): | |
| return "PHISH" | |
| if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"): | |
| return "LEGIT" | |
| return t | |
| def _adjust_confidence_with_indicators(base_prob: float, indicators: Dict, predicted_label: str) -> float: | |
| """ | |
| Adjust confidence based on phishing indicators. | |
| More indicators = context suggests phishing, so confidence varies based on prediction | |
| """ | |
| indicator_count = indicators.get("indicator_count", 0) | |
| indicator_percentage = indicators.get("indicator_percentage", 0) | |
| # Base adjustment from indicator count | |
| # If predicting PHISH and many indicators: more confident (but cap at 85%) | |
| # If predicting LEGIT with many indicators: less confident (uncertainty) | |
| # If predicting PHISH with few indicators: less confident (might be wrong) | |
| # If predicting LEGIT with few indicators: more confident | |
| if predicted_label == "PHISH": | |
| # Phishing prediction | |
| if indicator_percentage >= 40: # Strong indicators | |
| # High confidence: 75-85% | |
| adjusted = 0.75 + (indicator_percentage / 100) * 0.10 | |
| elif indicator_percentage >= 25: # Moderate indicators | |
| # Medium confidence: 65-75% | |
| adjusted = 0.65 + (indicator_percentage / 100) * 0.10 | |
| else: # Weak indicators | |
| # Lower confidence: 55-65% | |
| adjusted = 0.55 + (indicator_percentage / 100) * 0.10 | |
| else: | |
| # Legitimate prediction | |
| if indicator_percentage >= 40: # Many phishing indicators but predicting legit? | |
| # Low confidence: 55-65% (uncertain) | |
| adjusted = 0.65 - (indicator_percentage / 100) * 0.10 | |
| elif indicator_percentage >= 25: # Some indicators | |
| # Medium confidence: 65-75% | |
| adjusted = 0.70 - (indicator_percentage / 100) * 0.05 | |
| else: # Few indicators | |
| # High confidence: 75-85% | |
| adjusted = 0.75 + ((100 - indicator_percentage) / 100) * 0.10 | |
| # Clamp to min/max range | |
| adjusted = max(BASE_CONFIDENCE_MIN, min(BASE_CONFIDENCE_MAX, adjusted)) | |
| return adjusted | |
| def _load_model(): | |
| """Load model, tokenizer, and preprocessor""" | |
| global _tokenizer, _model, _device, _preprocessor | |
| if _tokenizer is None or _model is None: | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"\n{'='*60}") | |
| print(f"Loading model: {MODEL_ID}") | |
| print(f"Device: {_device}") | |
| print(f"Confidence range: {BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%") | |
| print(f"{'='*60}\n") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| _model.to(_device) | |
| _model.eval() | |
| _preprocessor = TextPreprocessor() | |
| # Warm-up | |
| with torch.no_grad(): | |
| _ = _model( | |
| **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| .to(_device) | |
| ).logits | |
| id2label = getattr(_model.config, "id2label", {}) | |
| print(f"Model labels: {id2label}") | |
| print(f"{'='*60}\n") | |
| def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]: | |
| """Predict with indicator-based confidence adjustment""" | |
| _load_model() | |
| if not texts: | |
| return [] | |
| # Get preprocessing info (always needed for indicators) | |
| preprocessing_info = [_preprocessor.preprocess(text) for text in texts] | |
| # Tokenize | |
| enc = _tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| ) | |
| enc = {k: v.to(_device) for k, v in enc.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| logits = _model(**enc).logits | |
| probs = F.softmax(logits, dim=-1) | |
| # Get labels from model config | |
| id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"}) | |
| outputs: List[Dict] = [] | |
| for text_idx in range(probs.shape[0]): | |
| p = probs[text_idx] | |
| preprocessing = preprocessing_info[text_idx] | |
| indicators = preprocessing["phishing_indicators"] | |
| # Get prediction | |
| predicted_idx = int(torch.argmax(p).item()) | |
| predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}") | |
| predicted_label_norm = _normalize_label(predicted_label_raw) | |
| raw_prob = float(p[predicted_idx].item()) | |
| # Adjust confidence based on indicators | |
| adjusted_confidence = _adjust_confidence_with_indicators( | |
| raw_prob, indicators, predicted_label_norm | |
| ) | |
| # Build probability breakdown (adjusted) | |
| prob_breakdown = {} | |
| for i in range(len(p)): | |
| label = _normalize_label(id2label.get(i, f"CLASS_{i}")) | |
| if i == predicted_idx: | |
| prob_breakdown[label] = round(adjusted_confidence, 4) | |
| else: | |
| prob_breakdown[label] = round(1.0 - adjusted_confidence, 4) | |
| output = { | |
| "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx], | |
| "label": predicted_label_norm, | |
| "raw_label": predicted_label_raw, | |
| "is_phish": predicted_label_norm == "PHISH", | |
| "confidence": round(adjusted_confidence * 100, 2), | |
| "score": round(adjusted_confidence, 4), | |
| "probs": prob_breakdown, | |
| "model_raw_confidence": round(raw_prob * 100, 2), | |
| } | |
| if include_preprocessing: | |
| output["preprocessing"] = preprocessing | |
| outputs.append(output) | |
| return outputs | |
| # ============================================================================ | |
| # API ENDPOINTS | |
| # ============================================================================ | |
| def root(): | |
| """Root endpoint""" | |
| _load_model() | |
| return { | |
| "status": "ok", | |
| "model": MODEL_ID, | |
| "device": _device, | |
| "confidence_range": f"{BASE_CONFIDENCE_MIN*100:.0f}%-{BASE_CONFIDENCE_MAX*100:.0f}%", | |
| "note": "Confidence adjusted based on phishing indicators" | |
| } | |
| def debug_labels(): | |
| """View model configuration""" | |
| _load_model() | |
| return { | |
| "status": "ok", | |
| "model_id": MODEL_ID, | |
| "id2label": getattr(_model.config, "id2label", {}), | |
| "label2id": getattr(_model.config, "label2id", {}), | |
| "num_labels": int(getattr(_model.config, "num_labels", 0)), | |
| "device": _device, | |
| } | |
| def debug_preprocessing(payload: PredictPayload): | |
| """Debug preprocessing""" | |
| try: | |
| _load_model() | |
| preprocessing = _preprocessor.preprocess(payload.inputs) | |
| return preprocessing | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict(payload: PredictPayload): | |
| """Single prediction""" | |
| try: | |
| res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing) | |
| return res[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict_batch(payload: BatchPredictPayload): | |
| """Batch predictions""" | |
| try: | |
| return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def evaluate(payload: EvalPayload): | |
| """Evaluate on labeled samples""" | |
| try: | |
| texts = [s.text for s in payload.samples] | |
| gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples] | |
| preds = _predict_texts(texts, include_preprocessing=False) | |
| total = len(preds) | |
| correct = 0 | |
| per_class: Dict[str, Dict[str, int]] = {} | |
| for gt, pr in zip(gts, preds): | |
| pred_label = pr["label"] | |
| if gt is not None: | |
| correct += int(gt == pred_label) | |
| per_class.setdefault(gt, {"tp": 0, "count": 0}) | |
| per_class[gt]["count"] += 1 | |
| if gt == pred_label: | |
| per_class[gt]["tp"] += 1 | |
| has_gts = any(gt is not None for gt in gts) | |
| acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None | |
| return { | |
| "accuracy": round(acc, 4) if acc else None, | |
| "total": total, | |
| "correct": correct, | |
| "predictions": preds, | |
| "per_class": per_class, | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |