Spaces:
Configuration error
Configuration error
| import os | |
| import gc | |
| import re | |
| import torch | |
| import numpy as np | |
| import firebase_admin | |
| from firebase_admin import credentials, db | |
| from fastapi import FastAPI, Form, Request | |
| from fastapi.responses import HTMLResponse | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoConfig, | |
| DistilBertForSequenceClassification, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer | |
| ) | |
| from lime.lime_text import LimeTextExplainer | |
| import uuid | |
| from datetime import datetime | |
| # FIREBASE INITIALIZATION | |
| # Ensure your serviceAccountKey.json is in the same directory | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate("serviceAccountKey.json") | |
| firebase_admin.initialize_app(cred, { | |
| 'databaseURL': 'https://your-project-id.firebaseio.com/' # <--- Add Firebase URL here | |
| }) | |
| app = FastAPI() | |
| # Global variables | |
| tokenizer = None | |
| model = None | |
| explainer = LimeTextExplainer(class_names=["LEGITIMATE", "PHISHING"]) | |
| gen_tokenizer = None | |
| gen_model = None | |
| MODEL_DIR = "phishing_model_optimized" | |
| WEIGHTS_NAME = "quantized_model.pt" | |
| # PHISHING INDICATOR PATTERNS (Original Rule-Based) | |
| PHISHING_PATTERNS = { | |
| "urgency": [ | |
| r"\bimmediate(ly)?\b", r"\burgent(ly)?\b", r"\bact now\b", | |
| r"\baction required\b", r"\bwithin \d+ (hour|minute|day)s?\b", | |
| r"\bexpir(e|es|ing|ed)\b", r"\bsuspended\b", r"\bcompromised\b", | |
| r"\bverify (now|immediately|your)\b", r"\bfailure to\b", | |
| r"\bpermanent(ly)?\b", r"\bdelete(d)?\b", r"\block(ed)?\b" | |
| ], | |
| "threat": [ | |
| r"\baccount.*(suspend|terminat|delet|lock|compromis)\w*", | |
| r"\b(suspend|terminat|delet|lock|compromis)\w*.*account\b", | |
| r"\blegal action\b", r"\bpenalt(y|ies)\b", r"\bconsequences\b" | |
| ], | |
| "credential_request": [ | |
| r"\bpassword\b", r"\bverify your (identity|account|email)\b", | |
| r"\bconfirm your\b", r"\bupdate.*(payment|billing|account)\b", | |
| r"\bssn\b", r"\bcredit card\b", r"\bbank account\b" | |
| ], | |
| "suspicious_links": [ | |
| r"https?://[^\s]*\.(xyz|tk|ml|ga|cf|gq|top|club|online)/", | |
| r"https?://[^\s]*-[^\s]*\.(com|net|org)/", | |
| r"https?://\d+\.\d+\.\d+\.\d+", | |
| r"bit\.ly|tinyurl|short\.link|t\.co", | |
| r"click.*here|click.*below|click.*link" | |
| ], | |
| "impersonation": [ | |
| r"\b(paypal|amazon|netflix|apple|microsoft|google|bank)\b", | |
| r"\bcustomer (service|support)\b", r"\bsecurity (team|department)\b" | |
| ] | |
| } | |
| def detect_phishing_indicators(text: str) -> dict: | |
| text_lower = text.lower() | |
| detected = {cat: [] for cat in PHISHING_PATTERNS} | |
| for category, patterns in PHISHING_PATTERNS.items(): | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text_lower, re.IGNORECASE) | |
| if matches: | |
| detected[category].extend(matches if isinstance(matches[0], str) else [m[0] for m in matches]) | |
| for category in detected: | |
| detected[category] = list(set(detected[category])) | |
| return detected | |
| def calculate_phishing_score(indicators: dict) -> float: | |
| weights = {"urgency": 0.25, "threat": 0.25, "credential_request": 0.20, "suspicious_links": 0.20, "impersonation": 0.10} | |
| score = 0.0 | |
| for category, weight in weights.items(): | |
| if indicators[category]: | |
| category_score = min(len(indicators[category]) * 0.4, 1.0) | |
| score += weight * category_score | |
| return min(score, 1.0) | |
| def get_confidence_label(confidence: float) -> str: | |
| """Determine confidence level descriptor based on score""" | |
| if confidence >= 0.90: | |
| return "high confidence" | |
| elif confidence >= 0.70: | |
| return "moderate confidence" | |
| else: | |
| return "low confidence" | |
| def generate_explanation_with_flan(indicators: dict, label: str, confidence: float, email_text: str) -> str: | |
| """Generate high-quality natural language explanation using FLAN-T5""" | |
| # Get confidence level descriptor | |
| confidence_level = get_confidence_label(confidence) | |
| # Build detailed indicator analysis | |
| indicator_details = [] | |
| urgency_count = len(indicators["urgency"]) | |
| threat_count = len(indicators["threat"]) | |
| cred_count = len(indicators["credential_request"]) | |
| link_count = len(indicators["suspicious_links"]) | |
| imperson_count = len(indicators["impersonation"]) | |
| if urgency_count > 0: | |
| urgency_examples = ', '.join([f'"{item}"' for item in indicators['urgency'][:2]]) | |
| indicator_details.append(f"high urgency language ({urgency_count} instances: {urgency_examples})") | |
| if threat_count > 0: | |
| threat_examples = ', '.join([f'"{item}"' for item in indicators['threat'][:2]]) | |
| indicator_details.append(f"threatening tone ({threat_count} instances: {threat_examples})") | |
| if cred_count > 0: | |
| cred_examples = ', '.join([f'"{item}"' for item in indicators['credential_request'][:2]]) | |
| indicator_details.append(f"credential requests ({cred_count} instances: {cred_examples})") | |
| if link_count > 0: | |
| indicator_details.append(f"suspicious links ({link_count} detected)") | |
| if imperson_count > 0: | |
| brands = ', '.join(indicators['impersonation'][:2]) | |
| indicator_details.append(f"brand impersonation attempts ({brands})") | |
| # Create sophisticated prompts based on label | |
| if label == "PHISHING": | |
| indicators_summary = "; ".join(indicator_details) if indicator_details else "general phishing patterns" | |
| prompt = f"""You are a cybersecurity expert. Explain why this email is phishing: | |
| Email sample: "{email_text[:250]}" | |
| Detected threats: {indicators_summary} | |
| Write a clear 2-sentence explanation that: | |
| 1. States the classification with confidence level | |
| 2. Describes specific malicious tactics (urgency, social engineering, credential harvesting, clickbait) | |
| 3. Uses professional security terminology | |
| Explanation:""" | |
| else: # LEGITIMATE | |
| safe_indicators = [] | |
| if not indicators["urgency"] and not indicators["threat"]: | |
| safe_indicators.append("no urgency or threat language") | |
| if not indicators["credential_request"]: | |
| safe_indicators.append("no credential requests") | |
| if not indicators["suspicious_links"]: | |
| safe_indicators.append("no suspicious links") | |
| if not indicators["impersonation"]: | |
| safe_indicators.append("no brand impersonation") | |
| safety_summary = ", ".join(safe_indicators) if safe_indicators else "standard communication patterns" | |
| prompt = f"""You are a cybersecurity expert. Explain why this email is legitimate: | |
| Email sample: "{email_text[:250]}" | |
| Safety indicators: {safety_summary} | |
| Write a clear 2-sentence explanation that: | |
| 1. States the classification with confidence level | |
| 2. Notes the absence of social-engineering cues, suspicious tokens, or phishing tactics | |
| 3. Uses professional security terminology | |
| Explanation:""" | |
| try: | |
| # Tokenize the prompt | |
| inputs = gen_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True | |
| ) | |
| # Generate explanation with optimized parameters | |
| with torch.no_grad(): | |
| outputs = gen_model.generate( | |
| inputs.input_ids, | |
| max_length=180, | |
| min_length=40, | |
| num_beams=5, | |
| length_penalty=1.2, | |
| early_stopping=True, | |
| temperature=0.8, | |
| top_p=0.92, | |
| do_sample=False, | |
| no_repeat_ngram_size=3 | |
| ) | |
| # Decode the generated text | |
| explanation = gen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Post-process: add confidence level prefix | |
| explanation = f"The email was classified as {label} with {confidence_level} ({confidence:.2f}). {explanation}" | |
| return explanation | |
| except Exception as e: | |
| print(f"FLAN-T5 generation error: {e}") | |
| # Enhanced fallback explanations | |
| if label == "PHISHING": | |
| reasons = [] | |
| if indicators["urgency"]: reasons.append("uses high urgency tactics") | |
| if indicators["threat"]: reasons.append("contains threatening language") | |
| if indicators["credential_request"]: reasons.append("attempts credential harvesting") | |
| if indicators["suspicious_links"]: reasons.append("includes clickbait keywords") | |
| reason_text = " and ".join(reasons) if reasons else "exhibits fraudulent patterns" | |
| return f"The email was classified as PHISHING with {confidence_level} ({confidence:.2f}). The email {reason_text} suggesting a social-engineering attempt to capture sensitive information." | |
| else: | |
| return f"The email was classified as LEGITIMATE with {confidence_level} ({confidence:.2f}). The message appears routine and contains no social-engineering cues or suspicious tokens." | |
| def load_models(): | |
| global tokenizer, model, gen_tokenizer, gen_model | |
| base_path = os.path.dirname(os.path.abspath(__file__)) | |
| model_path = os.path.join(base_path, MODEL_DIR) | |
| weights_path = os.path.join(model_path, WEIGHTS_NAME) | |
| gc.collect() | |
| try: | |
| # Load DistilBERT for classification | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) | |
| config = AutoConfig.from_pretrained(model_path, local_files_only=True) | |
| base_model = DistilBertForSequenceClassification(config) | |
| model = torch.quantization.quantize_dynamic(base_model, {torch.nn.Linear}, dtype=torch.qint8) | |
| if os.path.exists(weights_path): | |
| model.load_state_dict(torch.load(weights_path, map_location='cpu')) | |
| model.eval() | |
| # Load FLAN-T5 for explanation generation | |
| print("Loading FLAN-T5 for explanation generation...") | |
| gen_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False) | |
| gen_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") | |
| gen_model.eval() | |
| print("FLAN-T5 loaded successfully!") | |
| except Exception as e: | |
| print(f"Load Error: {e}") | |
| def predictor(texts): | |
| probs_list = [] | |
| for text in texts: | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1).numpy() | |
| probs_list.append(probs[0]) | |
| return np.array(probs_list) | |
| def hybrid_predict(email_text: str) -> tuple: | |
| ml_probs = predictor([email_text])[0] | |
| indicators = detect_phishing_indicators(email_text) | |
| rule_score = calculate_phishing_score(indicators) | |
| # Combine scores based on your logic | |
| if rule_score > 0.5: combined_prob = (ml_probs[1] * 0.3) + (rule_score * 0.7) | |
| elif rule_score > 0.2: combined_prob = (ml_probs[1] * 0.5) + (rule_score * 0.5) | |
| else: combined_prob = (ml_probs[1] * 0.7) + (rule_score * 0.3) | |
| label = "PHISHING" if combined_prob >= 0.5 else "LEGITIMATE" | |
| confidence = combined_prob if label == "PHISHING" else 1 - combined_prob | |
| return label, confidence, indicators | |
| async def index(): | |
| return """ | |
| <html> | |
| <head><title>Robust Explainable Phishing Classification</title></head> | |
| <body style="font-family: sans-serif; background: #f4f7f6; display: flex; justify-content: center; padding-top: 50px;"> | |
| <div style="background: white; padding: 40px; border-radius: 15px; box-shadow: 0 10px 30px rgba(0,0,0,0.1); width: 800px;"> | |
| <form action="/predict" method="post"> | |
| <textarea name="email_text" rows="10" style="width: 100%; padding: 10px;" placeholder="Paste email here..." required></textarea> | |
| <button type="submit" style="width: 100%; background: #007bff; color: white; padding: 15px; margin-top: 10px; border: none; cursor: pointer; border-radius: 8px;">Analyze & Explain</button> | |
| </form> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| async def predict(email_text: str = Form(...)): | |
| label, confidence, indicators = hybrid_predict(email_text) | |
| # Generate unique token for this analysis | |
| unique_token = str(uuid.uuid4())[:8].upper() | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Store analysis data in Firebase for reference | |
| analysis_ref = db.reference('/analysis_records') | |
| analysis_ref.child(unique_token).set({ | |
| 'timestamp': timestamp, | |
| 'label': label, | |
| 'confidence': float(confidence), | |
| 'email_length': len(email_text) | |
| }) | |
| # Get LIME explanation for technical keywords | |
| try: | |
| exp = explainer.explain_instance(email_text, predictor, num_features=6, num_samples=100) | |
| keyword_str = ", ".join([word for word, weight in exp.as_list() if abs(weight) > 0.01]) | |
| except: | |
| keyword_str = "analysis unavailable" | |
| # Generate natural language explanation using FLAN-T5 | |
| clean_explanation = generate_explanation_with_flan(indicators, label, confidence, email_text) | |
| color = "#dc3545" if label == "PHISHING" else "#28a745" | |
| # HTML Result with Enhanced Feedback Form | |
| return f""" | |
| <div style="font-family: sans-serif; max-width: 900px; margin: auto; padding: 20px;"> | |
| <div style="background: {color}; color: white; padding: 20px; border-radius: 20px; text-align: center;"> | |
| <h1 style="margin: 0;">{label}</h1> | |
| <p>Confidence: {confidence:.2%}</p> | |
| </div> | |
| <div style="background: white; padding: 30px; border-radius: 15px; margin-top: 30px; box-shadow: 0 5px 15px rgba(0,0,0,0.05);"> | |
| <h3>Security Analysis (FLAN-T5 Generated)</h3> | |
| <p style="font-size: 1.1em;">{clean_explanation}</p> | |
| <p><b>Technical Triggers (LIME):</b> {keyword_str}</p> | |
| <hr style="margin: 40px 0;"> | |
| <h3>Help Us Improve (XAI Evaluation)</h3> | |
| <div style="background: #f8f9fa; padding: 15px; border-radius: 8px; margin-bottom: 20px;"> | |
| <p style="margin: 5px 0;"><b>Analysis Token:</b> <span style="font-family: monospace; font-size: 1.2em; color: {color}; font-weight: bold;">{unique_token}</span></p> | |
| <p style="margin: 5px 0; font-size: 0.9em; color: #666;">Please save this token for your records</p> | |
| </div> | |
| <form action="/feedback" method="post"> | |
| <input type="hidden" name="token" value="{unique_token}"> | |
| <table style="width: 100%; border-collapse: collapse; margin: 20px 0;"> | |
| <thead> | |
| <tr style="background: #f1f3f5;"> | |
| <th style="padding: 15px; text-align: left; border: 1px solid #dee2e6; width: 50%;">Evaluation Criteria</th> | |
| <th style="padding: 10px; text-align: center; border: 1px solid #dee2e6; width: 10%;">1</th> | |
| <th style="padding: 10px; text-align: center; border: 1px solid #dee2e6; width: 10%;">2</th> | |
| <th style="padding: 10px; text-align: center; border: 1px solid #dee2e6; width: 10%;">3</th> | |
| <th style="padding: 10px; text-align: center; border: 1px solid #dee2e6; width: 10%;">4</th> | |
| <th style="padding: 10px; text-align: center; border: 1px solid #dee2e6; width: 10%;">5</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| <tr> | |
| <td style="padding: 15px; border: 1px solid #dee2e6; background: #fff;"> | |
| <b>1. Decision Clarity:</b><br> | |
| <span style="font-size: 0.9em; color: #666;">The explanation helped me understand the result.</span> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6;"> | |
| <input type="radio" name="understanding" value="1" required style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6;"> | |
| <input type="radio" name="understanding" value="2" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6;"> | |
| <input type="radio" name="understanding" value="3" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6;"> | |
| <input type="radio" name="understanding" value="4" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6;"> | |
| <input type="radio" name="understanding" value="5" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| </tr> | |
| <tr> | |
| <td style="padding: 15px; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <b>2. Information Focus:</b><br> | |
| <span style="font-size: 0.9em; color: #666;">The explanation was concise and essential.</span> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <input type="radio" name="clarity" value="1" required style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <input type="radio" name="clarity" value="2" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <input type="radio" name="clarity" value="3" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <input type="radio" name="clarity" value="4" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #dee2e6; background: #f8f9fa;"> | |
| <input type="radio" name="clarity" value="5" style="width: 20px; height: 20px; cursor: pointer;"> | |
| </td> | |
| </tr> | |
| </tbody> | |
| </table> | |
| <p style="font-size: 0.85em; color: #666; text-align: center; margin-top: 10px;"> | |
| Rating Scale: 1 = Strongly Disagree | 5 = Strongly Agree | |
| </p> | |
| <button type="submit" style="width: 100%; background: #28a745; color: white; padding: 12px; border: none; border-radius: 5px; cursor: pointer; font-size: 1.05em; margin-top: 15px;">Submit Feedback to Firebase</button> | |
| </form> | |
| </div> | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <a href="/" style="color: #007bff; text-decoration: none;">← Analyze Another Email</a> | |
| </div> | |
| </div> | |
| """ | |
| async def save_feedback(token: str = Form(...), understanding: int = Form(...), clarity: int = Form(...)): | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Save feedback with token reference | |
| feedback_ref = db.reference('/xai_feedback') | |
| feedback_ref.push({ | |
| 'token': token, | |
| 'understanding': understanding, | |
| 'clarity': clarity, | |
| 'timestamp': timestamp | |
| }) | |
| return HTMLResponse(f""" | |
| <div style="font-family: sans-serif; max-width: 800px; margin: 100px auto; padding: 40px; background: white; border-radius: 15px; box-shadow: 0 10px 30px rgba(0,0,0,0.1); text-align: center;"> | |
| <h2 style="color: #28a745;">Feedback Received!</h2> | |
| <p>Thank you for contributing to our research.</p> | |
| <p><b>Your Token:</b> <span style="font-family: monospace; font-size: 1.2em; color: #007bff;">{token}</span></p> | |
| <a href="/" style="display: inline-block; margin-top: 20px; padding: 12px 30px; background: #007bff; color: white; text-decoration: none; border-radius: 5px;">Go back to Home</a> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |