import os import gc import re import torch import numpy as np import firebase_admin from firebase_admin import credentials, db from fastapi import FastAPI, Form, Request from fastapi.responses import HTMLResponse from transformers import ( AutoTokenizer, AutoConfig, DistilBertForSequenceClassification, T5ForConditionalGeneration, T5Tokenizer ) from lime.lime_text import LimeTextExplainer import uuid from datetime import datetime # FIREBASE INITIALIZATION # Ensure your serviceAccountKey.json is in the same directory if not firebase_admin._apps: cred = credentials.Certificate("serviceAccountKey.json") firebase_admin.initialize_app(cred, { 'databaseURL': 'https://your-project-id.firebaseio.com/' # <--- Add Firebase URL here }) app = FastAPI() # Global variables tokenizer = None model = None explainer = LimeTextExplainer(class_names=["LEGITIMATE", "PHISHING"]) gen_tokenizer = None gen_model = None MODEL_DIR = "phishing_model_optimized" WEIGHTS_NAME = "quantized_model.pt" # PHISHING INDICATOR PATTERNS (Original Rule-Based) PHISHING_PATTERNS = { "urgency": [ r"\bimmediate(ly)?\b", r"\burgent(ly)?\b", r"\bact now\b", r"\baction required\b", r"\bwithin \d+ (hour|minute|day)s?\b", r"\bexpir(e|es|ing|ed)\b", r"\bsuspended\b", r"\bcompromised\b", r"\bverify (now|immediately|your)\b", r"\bfailure to\b", r"\bpermanent(ly)?\b", r"\bdelete(d)?\b", r"\block(ed)?\b" ], "threat": [ r"\baccount.*(suspend|terminat|delet|lock|compromis)\w*", r"\b(suspend|terminat|delet|lock|compromis)\w*.*account\b", r"\blegal action\b", r"\bpenalt(y|ies)\b", r"\bconsequences\b" ], "credential_request": [ r"\bpassword\b", r"\bverify your (identity|account|email)\b", r"\bconfirm your\b", r"\bupdate.*(payment|billing|account)\b", r"\bssn\b", r"\bcredit card\b", r"\bbank account\b" ], "suspicious_links": [ r"https?://[^\s]*\.(xyz|tk|ml|ga|cf|gq|top|club|online)/", r"https?://[^\s]*-[^\s]*\.(com|net|org)/", r"https?://\d+\.\d+\.\d+\.\d+", r"bit\.ly|tinyurl|short\.link|t\.co", r"click.*here|click.*below|click.*link" ], "impersonation": [ r"\b(paypal|amazon|netflix|apple|microsoft|google|bank)\b", r"\bcustomer (service|support)\b", r"\bsecurity (team|department)\b" ] } def detect_phishing_indicators(text: str) -> dict: text_lower = text.lower() detected = {cat: [] for cat in PHISHING_PATTERNS} for category, patterns in PHISHING_PATTERNS.items(): for pattern in patterns: matches = re.findall(pattern, text_lower, re.IGNORECASE) if matches: detected[category].extend(matches if isinstance(matches[0], str) else [m[0] for m in matches]) for category in detected: detected[category] = list(set(detected[category])) return detected def calculate_phishing_score(indicators: dict) -> float: weights = {"urgency": 0.25, "threat": 0.25, "credential_request": 0.20, "suspicious_links": 0.20, "impersonation": 0.10} score = 0.0 for category, weight in weights.items(): if indicators[category]: category_score = min(len(indicators[category]) * 0.4, 1.0) score += weight * category_score return min(score, 1.0) def get_confidence_label(confidence: float) -> str: """Determine confidence level descriptor based on score""" if confidence >= 0.90: return "high confidence" elif confidence >= 0.70: return "moderate confidence" else: return "low confidence" def generate_explanation_with_flan(indicators: dict, label: str, confidence: float, email_text: str) -> str: """Generate high-quality natural language explanation using FLAN-T5""" # Get confidence level descriptor confidence_level = get_confidence_label(confidence) # Build detailed indicator analysis indicator_details = [] urgency_count = len(indicators["urgency"]) threat_count = len(indicators["threat"]) cred_count = len(indicators["credential_request"]) link_count = len(indicators["suspicious_links"]) imperson_count = len(indicators["impersonation"]) if urgency_count > 0: urgency_examples = ', '.join([f'"{item}"' for item in indicators['urgency'][:2]]) indicator_details.append(f"high urgency language ({urgency_count} instances: {urgency_examples})") if threat_count > 0: threat_examples = ', '.join([f'"{item}"' for item in indicators['threat'][:2]]) indicator_details.append(f"threatening tone ({threat_count} instances: {threat_examples})") if cred_count > 0: cred_examples = ', '.join([f'"{item}"' for item in indicators['credential_request'][:2]]) indicator_details.append(f"credential requests ({cred_count} instances: {cred_examples})") if link_count > 0: indicator_details.append(f"suspicious links ({link_count} detected)") if imperson_count > 0: brands = ', '.join(indicators['impersonation'][:2]) indicator_details.append(f"brand impersonation attempts ({brands})") # Create sophisticated prompts based on label if label == "PHISHING": indicators_summary = "; ".join(indicator_details) if indicator_details else "general phishing patterns" prompt = f"""You are a cybersecurity expert. Explain why this email is phishing: Email sample: "{email_text[:250]}" Detected threats: {indicators_summary} Write a clear 2-sentence explanation that: 1. States the classification with confidence level 2. Describes specific malicious tactics (urgency, social engineering, credential harvesting, clickbait) 3. Uses professional security terminology Explanation:""" else: # LEGITIMATE safe_indicators = [] if not indicators["urgency"] and not indicators["threat"]: safe_indicators.append("no urgency or threat language") if not indicators["credential_request"]: safe_indicators.append("no credential requests") if not indicators["suspicious_links"]: safe_indicators.append("no suspicious links") if not indicators["impersonation"]: safe_indicators.append("no brand impersonation") safety_summary = ", ".join(safe_indicators) if safe_indicators else "standard communication patterns" prompt = f"""You are a cybersecurity expert. Explain why this email is legitimate: Email sample: "{email_text[:250]}" Safety indicators: {safety_summary} Write a clear 2-sentence explanation that: 1. States the classification with confidence level 2. Notes the absence of social-engineering cues, suspicious tokens, or phishing tactics 3. Uses professional security terminology Explanation:""" try: # Tokenize the prompt inputs = gen_tokenizer( prompt, return_tensors="pt", max_length=512, truncation=True ) # Generate explanation with optimized parameters with torch.no_grad(): outputs = gen_model.generate( inputs.input_ids, max_length=180, min_length=40, num_beams=5, length_penalty=1.2, early_stopping=True, temperature=0.8, top_p=0.92, do_sample=False, no_repeat_ngram_size=3 ) # Decode the generated text explanation = gen_tokenizer.decode(outputs[0], skip_special_tokens=True) # Post-process: add confidence level prefix explanation = f"The email was classified as {label} with {confidence_level} ({confidence:.2f}). {explanation}" return explanation except Exception as e: print(f"FLAN-T5 generation error: {e}") # Enhanced fallback explanations if label == "PHISHING": reasons = [] if indicators["urgency"]: reasons.append("uses high urgency tactics") if indicators["threat"]: reasons.append("contains threatening language") if indicators["credential_request"]: reasons.append("attempts credential harvesting") if indicators["suspicious_links"]: reasons.append("includes clickbait keywords") reason_text = " and ".join(reasons) if reasons else "exhibits fraudulent patterns" return f"The email was classified as PHISHING with {confidence_level} ({confidence:.2f}). The email {reason_text} suggesting a social-engineering attempt to capture sensitive information." else: return f"The email was classified as LEGITIMATE with {confidence_level} ({confidence:.2f}). The message appears routine and contains no social-engineering cues or suspicious tokens." @app.on_event("startup") def load_models(): global tokenizer, model, gen_tokenizer, gen_model base_path = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(base_path, MODEL_DIR) weights_path = os.path.join(model_path, WEIGHTS_NAME) gc.collect() try: # Load DistilBERT for classification tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) config = AutoConfig.from_pretrained(model_path, local_files_only=True) base_model = DistilBertForSequenceClassification(config) model = torch.quantization.quantize_dynamic(base_model, {torch.nn.Linear}, dtype=torch.qint8) if os.path.exists(weights_path): model.load_state_dict(torch.load(weights_path, map_location='cpu')) model.eval() # Load FLAN-T5 for explanation generation print("Loading FLAN-T5 for explanation generation...") gen_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False) gen_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") gen_model.eval() print("FLAN-T5 loaded successfully!") except Exception as e: print(f"Load Error: {e}") def predictor(texts): probs_list = [] for text in texts: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1).numpy() probs_list.append(probs[0]) return np.array(probs_list) def hybrid_predict(email_text: str) -> tuple: ml_probs = predictor([email_text])[0] indicators = detect_phishing_indicators(email_text) rule_score = calculate_phishing_score(indicators) # Combine scores based on your logic if rule_score > 0.5: combined_prob = (ml_probs[1] * 0.3) + (rule_score * 0.7) elif rule_score > 0.2: combined_prob = (ml_probs[1] * 0.5) + (rule_score * 0.5) else: combined_prob = (ml_probs[1] * 0.7) + (rule_score * 0.3) label = "PHISHING" if combined_prob >= 0.5 else "LEGITIMATE" confidence = combined_prob if label == "PHISHING" else 1 - combined_prob return label, confidence, indicators @app.get("/", response_class=HTMLResponse) async def index(): return """ Robust Explainable Phishing Classification
""" @app.post("/predict", response_class=HTMLResponse) async def predict(email_text: str = Form(...)): label, confidence, indicators = hybrid_predict(email_text) # Generate unique token for this analysis unique_token = str(uuid.uuid4())[:8].upper() timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Store analysis data in Firebase for reference analysis_ref = db.reference('/analysis_records') analysis_ref.child(unique_token).set({ 'timestamp': timestamp, 'label': label, 'confidence': float(confidence), 'email_length': len(email_text) }) # Get LIME explanation for technical keywords try: exp = explainer.explain_instance(email_text, predictor, num_features=6, num_samples=100) keyword_str = ", ".join([word for word, weight in exp.as_list() if abs(weight) > 0.01]) except: keyword_str = "analysis unavailable" # Generate natural language explanation using FLAN-T5 clean_explanation = generate_explanation_with_flan(indicators, label, confidence, email_text) color = "#dc3545" if label == "PHISHING" else "#28a745" # HTML Result with Enhanced Feedback Form return f"""

{label}

Confidence: {confidence:.2%}

Security Analysis (FLAN-T5 Generated)

{clean_explanation}

Technical Triggers (LIME): {keyword_str}


Help Us Improve (XAI Evaluation)

Analysis Token: {unique_token}

Please save this token for your records

Evaluation Criteria 1 2 3 4 5
1. Decision Clarity:
The explanation helped me understand the result.
2. Information Focus:
The explanation was concise and essential.

Rating Scale: 1 = Strongly Disagree | 5 = Strongly Agree

← Analyze Another Email
""" @app.post("/feedback") async def save_feedback(token: str = Form(...), understanding: int = Form(...), clarity: int = Form(...)): timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Save feedback with token reference feedback_ref = db.reference('/xai_feedback') feedback_ref.push({ 'token': token, 'understanding': understanding, 'clarity': clarity, 'timestamp': timestamp }) return HTMLResponse(f"""

Feedback Received!

Thank you for contributing to our research.

Your Token: {token}

Go back to Home
""") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)