phishguard-api / main.py
prashanth135's picture
Upload 2 files
6663b5f verified
# ============================================================
# PhishGuard AI - main.py
# FastAPI orchestrator β€” Full 4-tier phishing detection pipeline
# with feedback-driven incremental retraining.
#
# Endpoints:
# POST /analyze β†’ 4-tier URL phishing analysis
# POST /analyze/email β†’ BERT-only email body analysis
# POST /retrain β†’ Incremental model retraining
# GET /model_version β†’ Current model version info
# GET /health β†’ All model load statuses
#
# Architecture:
# Tier 1: Whitelist O(1) β†’ SAFE exit (~55% traffic)
# Tier 2: Heuristic 15 signals β†’ BLOCK if >= 80 (~15% blocked)
# Tier 3: BERT+GNN parallel β†’ BLOCK/SAFE/escalate (~15% exits)
# Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE (~15% borderline)
# ============================================================
from __future__ import annotations
import os
import sys
import asyncio
import time
import hashlib
import logging
import logging.handlers
from collections import OrderedDict
from contextlib import asynccontextmanager
from pathlib import Path
from typing import List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# ── Path setup ────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent
for sub_dir in ["gnn", "cnn"]:
sub_path = BASE_DIR / sub_dir
if sub_path.is_dir():
sys.path.insert(0, str(sub_path))
# ── Logging ───────────────────────────────────────────────────────────
log_dir = BASE_DIR / "logs"
log_dir.mkdir(exist_ok=True)
_handler = logging.handlers.RotatingFileHandler(
log_dir / "phishguard.log",
maxBytes=5 * 1024 * 1024,
backupCount=3,
encoding="utf-8",
)
_handler.setFormatter(logging.Formatter(
"%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
))
logger = logging.getLogger("phishguard")
logger.setLevel(logging.INFO)
logger.addHandler(_handler)
logger.addHandler(logging.StreamHandler())
# ── Import project modules ───────────────────────────────────────────
from url_heuristics import HeuristicScorer, HeuristicResult
from bert_analyzer import BERTPhishingClassifier
# GNN imports
GNN_AVAILABLE = False
gnn_inference = None
try:
from gnn.gnn_inference import GNNInference
GNN_AVAILABLE = True
except ImportError:
try:
from gnn_inference import GNNInference
GNN_AVAILABLE = True
except ImportError:
logger.warning("GNN module not available")
# CNN imports
CNN_AVAILABLE = False
cnn_inference = None
brand_detector = None
try:
from cnn.cnn_inference import CNNInference
from cnn.screenshot_hasher import BrandHashDetector
from cnn.cnn_model import preprocess_screenshot
CNN_AVAILABLE = True
except ImportError:
try:
from cnn_inference import CNNInference
from screenshot_hasher import BrandHashDetector
from cnn_model import preprocess_screenshot
CNN_AVAILABLE = True
except ImportError:
logger.warning("CNN module not available")
from tier3_bert_gnn import Tier3Ensemble
from retraining_service import RetrainingService, FeedbackRecord, RetrainResult
# ── Whitelist (Tier 1) ────────────────────────────────────────────────
WHITELIST: set[str] = {
"google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org",
"twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com",
"github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com",
"bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com",
"outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com",
"zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com",
"walmart.com", "target.com", "bestbuy.com", "airbnb.com",
"x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com",
}
def get_root_domain(url: str) -> str:
"""Extract root domain from a URL."""
from urllib.parse import urlparse
try:
host = urlparse(url).hostname or ""
host = host.replace("www.", "")
parts = host.split(".")
return ".".join(parts[-2:]) if len(parts) >= 2 else host
except Exception:
return ""
# ── URL Cache (LRU, 30-min TTL) ──────────────────────────────────────
CACHE_TTL = 30 * 60
CACHE_MAX = 500
class URLCache:
def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None:
self._cache: OrderedDict = OrderedDict()
self._maxsize = maxsize
self._ttl = ttl
def get(self, url: str) -> Optional[dict]:
if url in self._cache:
entry = self._cache[url]
if time.time() - entry["ts"] < self._ttl:
self._cache.move_to_end(url)
return entry["result"]
else:
del self._cache[url]
return None
def set(self, url: str, result: dict) -> None:
self._cache[url] = {"result": result, "ts": time.time()}
self._cache.move_to_end(url)
if len(self._cache) > self._maxsize:
self._cache.popitem(last=False)
def clear(self) -> None:
self._cache.clear()
_url_cache = URLCache()
# ── Request/Response Models ───────────────────────────────────────────
class AnalyzeRequest(BaseModel):
url: str
heuristic_score: float = 0.0
page_title: str = ""
page_snippet: str = ""
related_urls: list = []
class EmailRequest(BaseModel):
sender: str
subject: str = ""
body: str = ""
urls: list = []
timestamp: str = ""
class FeedbackSample(BaseModel):
url: str
verdict: str = ""
confidence: float = 0.0
tier_used: int = 0
heuristic_score: int = 0
signals: list = []
user_feedback: Optional[str] = None
timestamp: str = ""
feedback_ts: Optional[str] = None
url_hash: str = ""
session_id: str = ""
class RetrainRequest(BaseModel):
samples: List[FeedbackSample]
trigger: str = "count"
session_id: str = ""
extension_version: str = ""
# ── Global state ──────────────────────────────────────────────────────
_scorer: Optional[HeuristicScorer] = None
_bert: Optional[BERTPhishingClassifier] = None
_gnn: Optional[GNNInference] = None
_cnn: Optional[CNNInference] = None
_brand: Optional[BrandHashDetector] = None
_tier3: Optional[Tier3Ensemble] = None
_retrain_service: Optional[RetrainingService] = None
_retrain_lock = asyncio.Lock()
# ── Lifespan (startup/shutdown) ───────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load all models at startup, clean up at shutdown."""
global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service
logger.info("=== PhishGuard AI starting up ===")
# Tier 2: Heuristic Scorer
_scorer = HeuristicScorer()
logger.info(" Tier 2: HeuristicScorer initialized")
# Tier 3a: BERT
_bert = BERTPhishingClassifier()
_bert.load_model()
logger.info(" Tier 3a: BERT classifier initialized and loaded")
# Tier 3b: GNN
if GNN_AVAILABLE:
_gnn = GNNInference()
_gnn.load()
logger.info(f" Tier 3b: GNN loaded={_gnn.is_loaded}")
else:
_gnn = None
logger.warning(" Tier 3b: GNN not available")
# Tier 3 Ensemble
if _gnn:
_tier3 = Tier3Ensemble(_bert, _gnn)
logger.info(" Tier 3: Ensemble initialized")
else:
_tier3 = None
logger.warning(" Tier 3: Ensemble not available (GNN missing)")
# Tier 4: CNN + Brand Detection
if CNN_AVAILABLE:
_cnn = CNNInference()
_cnn.load()
_brand = BrandHashDetector()
logger.info(f" Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded")
else:
_cnn = None
_brand = None
logger.warning(" Tier 4: CNN not available")
# Retraining Service
_retrain_service = RetrainingService(
bert_classifier=_bert,
gnn_inference=_gnn or GNNInference(),
cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None),
)
logger.info(" Retraining service initialized")
logger.info("=== PhishGuard AI ready ===")
yield
logger.info("=== PhishGuard AI shutting down ===")
# ── FastAPI App ───────────────────────────────────────────────────────
app = FastAPI(
title="PhishGuard AI Backend",
version="3.0",
description="4-tier ML phishing detection with feedback-driven retraining",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── POST /analyze β€” Full 4-tier pipeline ──────────────────────────────
@app.post("/analyze")
async def analyze_endpoint(req: AnalyzeRequest) -> dict:
"""
Analyze a URL through the 4-tier phishing detection pipeline.
Tier 1: Whitelist β†’ SAFE
Tier 2: Heuristic β†’ BLOCK if >= 80
Tier 3: BERT+GNN ensemble β†’ BLOCK/SAFE/escalate
Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE
"""
url = req.url
details: dict = {}
# ── TIER 1: Whitelist ────────────────────────────────────────
root = get_root_domain(url)
if root in WHITELIST:
return {
"url": url,
"is_phishing": False,
"confidence": 0.0,
"method": "whitelist",
"status": "safe",
"tier": 1,
"heuristic_score": 0,
"signals": [],
"details": {"whitelisted_domain": root},
}
# ── Cache check ──────────────────────────────────────────────
cached = _url_cache.get(url)
if cached is not None:
return cached
# ── TIER 2: Heuristic scoring ────────────────────────────────
h_result: HeuristicResult = _scorer.score(url)
# Use the higher of server-side and browser-side heuristic scores
h_score = max(h_result.score, int(req.heuristic_score))
details["heuristic"] = {
"score": h_result.score,
"raw_score": h_result.raw_score,
"signals": h_result.signals,
"browser_score": int(req.heuristic_score),
"combined_score": h_score,
}
if h_score >= 80:
result = {
"url": url,
"is_phishing": True,
"confidence": h_score / 100.0,
"method": "heuristic",
"status": "blocked",
"tier": 2,
"heuristic_score": h_score,
"signals": h_result.signals,
"details": details,
}
_url_cache.set(url, result)
logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}")
return result
# ── TIER 3: BERT + GNN Ensemble ──────────────────────────────
if _tier3 is not None:
try:
p3 = await _tier3.predict(
url=url,
title=req.page_title,
snippet=req.page_snippet,
h_score=h_score,
)
details["tier3_score"] = p3
except Exception as e:
logger.error(f"Tier 3 error: {e}")
p3 = h_score / 100.0 # fallback to heuristic
details["tier3_error"] = str(e)
else:
# Tier 3 unavailable β€” use BERT alone + heuristic
if _bert is not None:
loop = asyncio.get_event_loop()
try:
p_bert = await loop.run_in_executor(
None, _bert.predict, url, req.page_title, req.page_snippet,
)
except Exception:
p_bert = 0.5
h_norm = h_score / 100.0
p3 = 0.60 * p_bert + 0.40 * h_norm
else:
p3 = h_score / 100.0
details["tier3_score"] = p3
details["tier3_note"] = "ensemble_unavailable"
# Tier 3 decision
decision = Tier3Ensemble.decide(p3)
if decision == "block":
result = {
"url": url,
"is_phishing": True,
"confidence": round(p3, 4),
"method": "bert_gnn_ensemble",
"status": "blocked",
"tier": 3,
"heuristic_score": h_score,
"signals": h_result.signals,
"details": details,
}
_url_cache.set(url, result)
logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}")
return result
if decision == "safe":
result = {
"url": url,
"is_phishing": False,
"confidence": round(p3, 4),
"method": "bert_gnn_ensemble",
"status": "safe",
"tier": 3,
"heuristic_score": h_score,
"signals": h_result.signals,
"details": details,
}
_url_cache.set(url, result)
logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}")
return result
# ── TIER 4: CNN Visual + Brand Hash (borderline 0.40 ≀ P3 < 0.85)
if _cnn is not None and _cnn.is_loaded:
try:
# Capture screenshot
screenshot_bytes = await _capture_screenshot_for_tier4(url)
if screenshot_bytes:
# CNN prediction
p_cnn = _cnn.predict(screenshot_bytes)
details["cnn_prob"] = round(p_cnn, 4)
# Brand hash check
brand_boost = 0.0
if _brand is not None:
is_impersonation, brand_name, brand_conf = _brand.detect(
screenshot_bytes, url
)
details["brand"] = {
"impersonation_detected": is_impersonation,
"brand": brand_name,
"confidence": round(brand_conf, 3),
}
if is_impersonation:
brand_boost = 0.25
# P_final = 0.55Β·P3 + 0.30Β·P_cnn + brand_boost
p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0)
details["tier4_score"] = round(p_final, 4)
is_phishing = p_final >= 0.65
result = {
"url": url,
"is_phishing": is_phishing,
"confidence": round(p_final, 4),
"method": "full_ensemble_bert_gnn_cnn",
"status": "blocked" if is_phishing else "safe",
"tier": 4,
"heuristic_score": h_score,
"signals": h_result.signals,
"details": details,
}
_url_cache.set(url, result)
logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}")
return result
except Exception as e:
logger.error(f"Tier 4 error: {e}")
details["tier4_error"] = str(e)
# Tier 4 unavailable/failed β€” use Tier 3 score with conservative threshold
is_phishing = p3 >= 0.65
result = {
"url": url,
"is_phishing": is_phishing,
"confidence": round(p3, 4),
"method": "bert_gnn_ensemble",
"status": "blocked" if is_phishing else "safe",
"tier": 3,
"heuristic_score": h_score,
"signals": h_result.signals,
"details": details,
}
_url_cache.set(url, result)
logger.info(f"Tier 4 fallback β†’ Tier 3 | url={url[:60]} | P3={p3:.4f}")
return result
async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]:
"""Capture screenshot for Tier 4 CNN analysis."""
try:
from playwright.async_api import async_playwright
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page(
viewport={"width": 1280, "height": 800},
user_agent=(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36"
),
)
# Block heavy resources
await page.route(
"**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}",
lambda route: route.abort(),
)
await page.goto(url, wait_until="domcontentloaded", timeout=10000)
screenshot = await page.screenshot(type="png")
await browser.close()
return screenshot
except Exception as e:
logger.warning(f"Tier 4 screenshot failed: {e}")
return None
# ── POST /analyze/email ───────────────────────────────────────────────
@app.post("/analyze/email")
async def analyze_email_endpoint(req: EmailRequest) -> dict:
"""BERT-only path for email body text analysis."""
# Sender whitelist check
sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else ""
if sender_domain in WHITELIST:
return {
"status": "safe",
"analysis": {
"isPhishing": False,
"probability": 0.0,
"reason": "Trusted sender domain",
},
}
# Analyze embedded URLs
MAX_URLS = 3
urls_to_check = req.urls[:MAX_URLS]
if not urls_to_check:
# Text-only analysis
if _bert:
combined = f"{req.subject} {req.body}"
prob = _bert.predict(combined, req.subject, req.body)
is_phishing = prob > 0.6
return {
"status": "blocked" if is_phishing else "safe",
"analysis": {
"isPhishing": is_phishing,
"probability": prob,
"reason": "BERT text analysis (no URLs)",
},
}
return {
"status": "safe",
"analysis": {
"isPhishing": False,
"probability": 0.1,
"reason": "No URLs and no ML model available",
},
}
# Analyze URLs through the main pipeline
tasks = [
analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject))
for u in urls_to_check
]
results = await asyncio.gather(*tasks, return_exceptions=True)
max_prob = 0.0
phishing_detected = False
flagged_urls = []
for idx, r in enumerate(results):
if isinstance(r, Exception):
continue
prob = r.get("confidence", 0.0)
max_prob = max(max_prob, prob)
if r.get("is_phishing"):
phishing_detected = True
flagged_urls.append(r.get("url", urls_to_check[idx]))
return {
"status": "blocked" if phishing_detected else "safe",
"analysis": {
"isPhishing": phishing_detected,
"probability": max_prob,
"flagged_urls": flagged_urls,
"reason": "URL analysis via ML ensemble",
},
}
# ── POST /retrain β€” Incremental retraining ────────────────────────────
@app.post("/retrain")
async def retrain_endpoint(req: RetrainRequest) -> dict:
"""
Receive labeled feedback and incrementally update all models.
Uses asyncio.Lock() to prevent concurrent retraining jobs.
Timeout: 600s max.
"""
if _retrain_service is None:
return {"status": "error", "message": "Retraining service not initialized"}
# Prevent concurrent retraining
if _retrain_lock.locked():
return {
"status": "skipped",
"message": "Retraining already in progress",
"models_updated": [],
}
async with _retrain_lock:
# Convert Pydantic models to FeedbackRecord dataclasses
records = [
FeedbackRecord(
url=s.url,
verdict=s.verdict,
confidence=s.confidence,
tier_used=s.tier_used,
heuristic_score=s.heuristic_score,
signals=s.signals,
user_feedback=s.user_feedback,
timestamp=s.timestamp,
feedback_ts=s.feedback_ts,
url_hash=s.url_hash,
session_id=s.session_id,
)
for s in req.samples
]
try:
result = await asyncio.wait_for(
_retrain_service.retrain(records),
timeout=600,
)
# Clear URL cache after retraining (stale results)
if result.status == "success":
_url_cache.clear()
return {
"status": result.status,
"models_updated": result.models_updated,
"samples_used": result.samples_used,
"duration_seconds": result.duration_seconds,
"accuracy_delta": result.accuracy_delta,
"next_retrain_hint": result.next_retrain_hint,
}
except asyncio.TimeoutError:
return {
"status": "error",
"message": "Retraining timed out (600s limit)",
}
except Exception as e:
logger.error(f"Retrain endpoint error: {e}")
return {
"status": "error",
"message": str(e),
}
# ── GET /model_version ────────────────────────────────────────────────
@app.get("/model_version")
async def model_version_endpoint() -> dict:
"""Return current model version info for extension polling."""
if _retrain_service:
return _retrain_service.get_version_info()
return {"version": 0, "updated_at": None, "accuracy": {}}
# ── GET /health ───────────────────────────────────────────────────────
@app.get("/health")
async def health_endpoint() -> dict:
"""Liveness probe with per-tier readiness and model statuses."""
return {
"status": "ok",
"version": "3.0",
"tier1": True,
"tier2": _scorer is not None,
"tier3": _tier3 is not None,
"tier4": _cnn is not None and _cnn.is_loaded if _cnn else False,
"retraining_in_progress": _retrain_lock.locked(),
"model_version": _retrain_service.model_version if _retrain_service else 0,
"modules": {
"heuristic": _scorer is not None,
"bert": _bert is not None and _bert.is_loaded,
"bert_lazy": _bert is not None and not _bert.is_loaded,
"gnn": _gnn is not None and _gnn.is_loaded if _gnn else False,
"cnn": _cnn is not None and _cnn.is_loaded if _cnn else False,
"brand_hash": _brand is not None,
},
}
# ── Legacy feedback endpoint (backward compat) ───────────────────────
@app.post("/feedback")
async def legacy_feedback_endpoint(req: dict) -> dict:
"""Legacy feedback endpoint for backward compatibility."""
return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"}
# ── Run directly ──────────────────────────────────────────────────────
# uvicorn main:app --reload --port 8000