# ============================================================ # PhishGuard AI - main.py # FastAPI orchestrator — Full 4-tier phishing detection pipeline # with feedback-driven incremental retraining. # # Endpoints: # POST /analyze → 4-tier URL phishing analysis # POST /analyze/email → BERT-only email body analysis # POST /retrain → Incremental model retraining # GET /model_version → Current model version info # GET /health → All model load statuses # # Architecture: # Tier 1: Whitelist O(1) → SAFE exit (~55% traffic) # Tier 2: Heuristic 15 signals → BLOCK if >= 80 (~15% blocked) # Tier 3: BERT+GNN parallel → BLOCK/SAFE/escalate (~15% exits) # Tier 4: CNN visual + brand hash → BLOCK/SAFE (~15% borderline) # ============================================================ from __future__ import annotations import os import sys import asyncio import time import hashlib import logging import logging.handlers from collections import OrderedDict from contextlib import asynccontextmanager from pathlib import Path from typing import List, Optional from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel # ── Path setup ──────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent for sub_dir in ["gnn", "cnn"]: sub_path = BASE_DIR / sub_dir if sub_path.is_dir(): sys.path.insert(0, str(sub_path)) # ── Logging ─────────────────────────────────────────────────────────── log_dir = BASE_DIR / "logs" log_dir.mkdir(exist_ok=True) _handler = logging.handlers.RotatingFileHandler( log_dir / "phishguard.log", maxBytes=5 * 1024 * 1024, backupCount=3, encoding="utf-8", ) _handler.setFormatter(logging.Formatter( "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", )) logger = logging.getLogger("phishguard") logger.setLevel(logging.INFO) logger.addHandler(_handler) logger.addHandler(logging.StreamHandler()) # ── Import project modules ─────────────────────────────────────────── from url_heuristics import HeuristicScorer, HeuristicResult from bert_analyzer import BERTPhishingClassifier # GNN imports GNN_AVAILABLE = False gnn_inference = None try: from gnn.gnn_inference import GNNInference GNN_AVAILABLE = True except ImportError: try: from gnn_inference import GNNInference GNN_AVAILABLE = True except ImportError: logger.warning("GNN module not available") # CNN imports CNN_AVAILABLE = False cnn_inference = None brand_detector = None try: from cnn.cnn_inference import CNNInference from cnn.screenshot_hasher import BrandHashDetector from cnn.cnn_model import preprocess_screenshot CNN_AVAILABLE = True except ImportError: try: from cnn_inference import CNNInference from screenshot_hasher import BrandHashDetector from cnn_model import preprocess_screenshot CNN_AVAILABLE = True except ImportError: logger.warning("CNN module not available") from tier3_bert_gnn import Tier3Ensemble from retraining_service import RetrainingService, FeedbackRecord, RetrainResult # ── Whitelist (Tier 1) ──────────────────────────────────────────────── WHITELIST: set[str] = { "google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org", "twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com", "github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com", "bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com", "outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com", "zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com", "walmart.com", "target.com", "bestbuy.com", "airbnb.com", "x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com", } def get_root_domain(url: str) -> str: """Extract root domain from a URL.""" from urllib.parse import urlparse try: host = urlparse(url).hostname or "" host = host.replace("www.", "") parts = host.split(".") return ".".join(parts[-2:]) if len(parts) >= 2 else host except Exception: return "" # ── URL Cache (LRU, 30-min TTL) ────────────────────────────────────── CACHE_TTL = 30 * 60 CACHE_MAX = 500 class URLCache: def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None: self._cache: OrderedDict = OrderedDict() self._maxsize = maxsize self._ttl = ttl def get(self, url: str) -> Optional[dict]: if url in self._cache: entry = self._cache[url] if time.time() - entry["ts"] < self._ttl: self._cache.move_to_end(url) return entry["result"] else: del self._cache[url] return None def set(self, url: str, result: dict) -> None: self._cache[url] = {"result": result, "ts": time.time()} self._cache.move_to_end(url) if len(self._cache) > self._maxsize: self._cache.popitem(last=False) def clear(self) -> None: self._cache.clear() _url_cache = URLCache() # ── Request/Response Models ─────────────────────────────────────────── class AnalyzeRequest(BaseModel): url: str heuristic_score: float = 0.0 page_title: str = "" page_snippet: str = "" related_urls: list = [] class EmailRequest(BaseModel): sender: str subject: str = "" body: str = "" urls: list = [] timestamp: str = "" class FeedbackSample(BaseModel): url: str verdict: str = "" confidence: float = 0.0 tier_used: int = 0 heuristic_score: int = 0 signals: list = [] user_feedback: Optional[str] = None timestamp: str = "" feedback_ts: Optional[str] = None url_hash: str = "" session_id: str = "" class RetrainRequest(BaseModel): samples: List[FeedbackSample] trigger: str = "count" session_id: str = "" extension_version: str = "" # ── Global state ────────────────────────────────────────────────────── _scorer: Optional[HeuristicScorer] = None _bert: Optional[BERTPhishingClassifier] = None _gnn: Optional[GNNInference] = None _cnn: Optional[CNNInference] = None _brand: Optional[BrandHashDetector] = None _tier3: Optional[Tier3Ensemble] = None _retrain_service: Optional[RetrainingService] = None _retrain_lock = asyncio.Lock() # ── Lifespan (startup/shutdown) ─────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): """Load all models at startup, clean up at shutdown.""" global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service logger.info("=== PhishGuard AI starting up ===") # Tier 2: Heuristic Scorer _scorer = HeuristicScorer() logger.info(" Tier 2: HeuristicScorer initialized") # Tier 3a: BERT _bert = BERTPhishingClassifier() _bert.load_model() logger.info(" Tier 3a: BERT classifier initialized and loaded") # Tier 3b: GNN if GNN_AVAILABLE: _gnn = GNNInference() _gnn.load() logger.info(f" Tier 3b: GNN loaded={_gnn.is_loaded}") else: _gnn = None logger.warning(" Tier 3b: GNN not available") # Tier 3 Ensemble if _gnn: _tier3 = Tier3Ensemble(_bert, _gnn) logger.info(" Tier 3: Ensemble initialized") else: _tier3 = None logger.warning(" Tier 3: Ensemble not available (GNN missing)") # Tier 4: CNN + Brand Detection if CNN_AVAILABLE: _cnn = CNNInference() _cnn.load() _brand = BrandHashDetector() logger.info(f" Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded") else: _cnn = None _brand = None logger.warning(" Tier 4: CNN not available") # Retraining Service _retrain_service = RetrainingService( bert_classifier=_bert, gnn_inference=_gnn or GNNInference(), cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None), ) logger.info(" Retraining service initialized") logger.info("=== PhishGuard AI ready ===") yield logger.info("=== PhishGuard AI shutting down ===") # ── FastAPI App ─────────────────────────────────────────────────────── app = FastAPI( title="PhishGuard AI Backend", version="3.0", description="4-tier ML phishing detection with feedback-driven retraining", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── POST /analyze — Full 4-tier pipeline ────────────────────────────── @app.post("/analyze") async def analyze_endpoint(req: AnalyzeRequest) -> dict: """ Analyze a URL through the 4-tier phishing detection pipeline. Tier 1: Whitelist → SAFE Tier 2: Heuristic → BLOCK if >= 80 Tier 3: BERT+GNN ensemble → BLOCK/SAFE/escalate Tier 4: CNN visual + brand hash → BLOCK/SAFE """ url = req.url details: dict = {} # ── TIER 1: Whitelist ──────────────────────────────────────── root = get_root_domain(url) if root in WHITELIST: return { "url": url, "is_phishing": False, "confidence": 0.0, "method": "whitelist", "status": "safe", "tier": 1, "heuristic_score": 0, "signals": [], "details": {"whitelisted_domain": root}, } # ── Cache check ────────────────────────────────────────────── cached = _url_cache.get(url) if cached is not None: return cached # ── TIER 2: Heuristic scoring ──────────────────────────────── h_result: HeuristicResult = _scorer.score(url) # Use the higher of server-side and browser-side heuristic scores h_score = max(h_result.score, int(req.heuristic_score)) details["heuristic"] = { "score": h_result.score, "raw_score": h_result.raw_score, "signals": h_result.signals, "browser_score": int(req.heuristic_score), "combined_score": h_score, } if h_score >= 80: result = { "url": url, "is_phishing": True, "confidence": h_score / 100.0, "method": "heuristic", "status": "blocked", "tier": 2, "heuristic_score": h_score, "signals": h_result.signals, "details": details, } _url_cache.set(url, result) logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}") return result # ── TIER 3: BERT + GNN Ensemble ────────────────────────────── if _tier3 is not None: try: p3 = await _tier3.predict( url=url, title=req.page_title, snippet=req.page_snippet, h_score=h_score, ) details["tier3_score"] = p3 except Exception as e: logger.error(f"Tier 3 error: {e}") p3 = h_score / 100.0 # fallback to heuristic details["tier3_error"] = str(e) else: # Tier 3 unavailable — use BERT alone + heuristic if _bert is not None: loop = asyncio.get_event_loop() try: p_bert = await loop.run_in_executor( None, _bert.predict, url, req.page_title, req.page_snippet, ) except Exception: p_bert = 0.5 h_norm = h_score / 100.0 p3 = 0.60 * p_bert + 0.40 * h_norm else: p3 = h_score / 100.0 details["tier3_score"] = p3 details["tier3_note"] = "ensemble_unavailable" # Tier 3 decision decision = Tier3Ensemble.decide(p3) if decision == "block": result = { "url": url, "is_phishing": True, "confidence": round(p3, 4), "method": "bert_gnn_ensemble", "status": "blocked", "tier": 3, "heuristic_score": h_score, "signals": h_result.signals, "details": details, } _url_cache.set(url, result) logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}") return result if decision == "safe": result = { "url": url, "is_phishing": False, "confidence": round(p3, 4), "method": "bert_gnn_ensemble", "status": "safe", "tier": 3, "heuristic_score": h_score, "signals": h_result.signals, "details": details, } _url_cache.set(url, result) logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}") return result # ── TIER 4: CNN Visual + Brand Hash (borderline 0.40 ≤ P3 < 0.85) if _cnn is not None and _cnn.is_loaded: try: # Capture screenshot screenshot_bytes = await _capture_screenshot_for_tier4(url) if screenshot_bytes: # CNN prediction p_cnn = _cnn.predict(screenshot_bytes) details["cnn_prob"] = round(p_cnn, 4) # Brand hash check brand_boost = 0.0 if _brand is not None: is_impersonation, brand_name, brand_conf = _brand.detect( screenshot_bytes, url ) details["brand"] = { "impersonation_detected": is_impersonation, "brand": brand_name, "confidence": round(brand_conf, 3), } if is_impersonation: brand_boost = 0.25 # P_final = 0.55·P3 + 0.30·P_cnn + brand_boost p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0) details["tier4_score"] = round(p_final, 4) is_phishing = p_final >= 0.65 result = { "url": url, "is_phishing": is_phishing, "confidence": round(p_final, 4), "method": "full_ensemble_bert_gnn_cnn", "status": "blocked" if is_phishing else "safe", "tier": 4, "heuristic_score": h_score, "signals": h_result.signals, "details": details, } _url_cache.set(url, result) logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}") return result except Exception as e: logger.error(f"Tier 4 error: {e}") details["tier4_error"] = str(e) # Tier 4 unavailable/failed — use Tier 3 score with conservative threshold is_phishing = p3 >= 0.65 result = { "url": url, "is_phishing": is_phishing, "confidence": round(p3, 4), "method": "bert_gnn_ensemble", "status": "blocked" if is_phishing else "safe", "tier": 3, "heuristic_score": h_score, "signals": h_result.signals, "details": details, } _url_cache.set(url, result) logger.info(f"Tier 4 fallback → Tier 3 | url={url[:60]} | P3={p3:.4f}") return result async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]: """Capture screenshot for Tier 4 CNN analysis.""" try: from playwright.async_api import async_playwright async with async_playwright() as p: browser = await p.chromium.launch(headless=True) page = await browser.new_page( viewport={"width": 1280, "height": 800}, user_agent=( "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " "AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36" ), ) # Block heavy resources await page.route( "**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}", lambda route: route.abort(), ) await page.goto(url, wait_until="domcontentloaded", timeout=10000) screenshot = await page.screenshot(type="png") await browser.close() return screenshot except Exception as e: logger.warning(f"Tier 4 screenshot failed: {e}") return None # ── POST /analyze/email ─────────────────────────────────────────────── @app.post("/analyze/email") async def analyze_email_endpoint(req: EmailRequest) -> dict: """BERT-only path for email body text analysis.""" # Sender whitelist check sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else "" if sender_domain in WHITELIST: return { "status": "safe", "analysis": { "isPhishing": False, "probability": 0.0, "reason": "Trusted sender domain", }, } # Analyze embedded URLs MAX_URLS = 3 urls_to_check = req.urls[:MAX_URLS] if not urls_to_check: # Text-only analysis if _bert: combined = f"{req.subject} {req.body}" prob = _bert.predict(combined, req.subject, req.body) is_phishing = prob > 0.6 return { "status": "blocked" if is_phishing else "safe", "analysis": { "isPhishing": is_phishing, "probability": prob, "reason": "BERT text analysis (no URLs)", }, } return { "status": "safe", "analysis": { "isPhishing": False, "probability": 0.1, "reason": "No URLs and no ML model available", }, } # Analyze URLs through the main pipeline tasks = [ analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject)) for u in urls_to_check ] results = await asyncio.gather(*tasks, return_exceptions=True) max_prob = 0.0 phishing_detected = False flagged_urls = [] for idx, r in enumerate(results): if isinstance(r, Exception): continue prob = r.get("confidence", 0.0) max_prob = max(max_prob, prob) if r.get("is_phishing"): phishing_detected = True flagged_urls.append(r.get("url", urls_to_check[idx])) return { "status": "blocked" if phishing_detected else "safe", "analysis": { "isPhishing": phishing_detected, "probability": max_prob, "flagged_urls": flagged_urls, "reason": "URL analysis via ML ensemble", }, } # ── POST /retrain — Incremental retraining ──────────────────────────── @app.post("/retrain") async def retrain_endpoint(req: RetrainRequest) -> dict: """ Receive labeled feedback and incrementally update all models. Uses asyncio.Lock() to prevent concurrent retraining jobs. Timeout: 600s max. """ if _retrain_service is None: return {"status": "error", "message": "Retraining service not initialized"} # Prevent concurrent retraining if _retrain_lock.locked(): return { "status": "skipped", "message": "Retraining already in progress", "models_updated": [], } async with _retrain_lock: # Convert Pydantic models to FeedbackRecord dataclasses records = [ FeedbackRecord( url=s.url, verdict=s.verdict, confidence=s.confidence, tier_used=s.tier_used, heuristic_score=s.heuristic_score, signals=s.signals, user_feedback=s.user_feedback, timestamp=s.timestamp, feedback_ts=s.feedback_ts, url_hash=s.url_hash, session_id=s.session_id, ) for s in req.samples ] try: result = await asyncio.wait_for( _retrain_service.retrain(records), timeout=600, ) # Clear URL cache after retraining (stale results) if result.status == "success": _url_cache.clear() return { "status": result.status, "models_updated": result.models_updated, "samples_used": result.samples_used, "duration_seconds": result.duration_seconds, "accuracy_delta": result.accuracy_delta, "next_retrain_hint": result.next_retrain_hint, } except asyncio.TimeoutError: return { "status": "error", "message": "Retraining timed out (600s limit)", } except Exception as e: logger.error(f"Retrain endpoint error: {e}") return { "status": "error", "message": str(e), } # ── GET /model_version ──────────────────────────────────────────────── @app.get("/model_version") async def model_version_endpoint() -> dict: """Return current model version info for extension polling.""" if _retrain_service: return _retrain_service.get_version_info() return {"version": 0, "updated_at": None, "accuracy": {}} # ── GET /health ─────────────────────────────────────────────────────── @app.get("/health") async def health_endpoint() -> dict: """Liveness probe with per-tier readiness and model statuses.""" return { "status": "ok", "version": "3.0", "tier1": True, "tier2": _scorer is not None, "tier3": _tier3 is not None, "tier4": _cnn is not None and _cnn.is_loaded if _cnn else False, "retraining_in_progress": _retrain_lock.locked(), "model_version": _retrain_service.model_version if _retrain_service else 0, "modules": { "heuristic": _scorer is not None, "bert": _bert is not None and _bert.is_loaded, "bert_lazy": _bert is not None and not _bert.is_loaded, "gnn": _gnn is not None and _gnn.is_loaded if _gnn else False, "cnn": _cnn is not None and _cnn.is_loaded if _cnn else False, "brand_hash": _brand is not None, }, } # ── Legacy feedback endpoint (backward compat) ─────────────────────── @app.post("/feedback") async def legacy_feedback_endpoint(req: dict) -> dict: """Legacy feedback endpoint for backward compatibility.""" return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"} # ── Run directly ────────────────────────────────────────────────────── # uvicorn main:app --reload --port 8000