# ============================================================ # PhishGuard AI - tier3_bert_gnn.py # Tier 3: BERT + GNN Parallel Ensemble # # Triggered only when Tier 2 score < 80. # BERT and GNN run in PARALLEL via asyncio.gather + run_in_executor. # # Ensemble formula: # P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·(H_score/100) # # Decision: # P3 >= 0.85 → BLOCK # P3 < 0.40 → SAFE # 0.40 <= P3 < 0.85 → escalate to Tier 4 # ============================================================ from __future__ import annotations import asyncio import logging from typing import Optional logger = logging.getLogger("phishguard.tier3") class Tier3Ensemble: """ Tier 3: BERT + GNN parallel ensemble classifier. Runs BERT and GNN inference in parallel using asyncio.gather with run_in_executor for non-blocking thread pool execution. """ # Ensemble weights W_BERT: float = 0.45 W_GNN: float = 0.35 W_HEURISTIC: float = 0.20 def __init__( self, bert_classifier, gnn_inference, ) -> None: self._bert = bert_classifier self._gnn = gnn_inference async def predict( self, url: str, title: str = "", snippet: str = "", h_score: int = 0, ) -> float: """ Run BERT + GNN in parallel and compute ensemble score. Args: url: The URL to analyze title: Page title (optional) snippet: Page content snippet (optional) h_score: Heuristic score from Tier 2 (0-100, passed through, NOT recomputed) Returns: P3 ∈ [0,1] — ensemble phishing probability """ loop = asyncio.get_event_loop() # Run BERT and GNN in parallel (both are CPU-bound, use thread pool) bert_task = self._bert_predict(url, title, snippet, loop) gnn_task = self._gnn_predict(url, loop) p_bert, p_gnn = await asyncio.gather(bert_task, gnn_task) # Ensemble: P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·H_norm h_norm = h_score / 100.0 p3 = (self.W_BERT * p_bert) + (self.W_GNN * p_gnn) + (self.W_HEURISTIC * h_norm) logger.info( f"Tier3 ensemble | url={url[:60]} | " f"P_bert={p_bert:.4f} P_gnn={p_gnn:.4f} H_norm={h_norm:.4f} → P3={p3:.4f}" ) return round(min(max(p3, 0.0), 1.0), 4) async def _bert_predict( self, url: str, title: str, snippet: str, loop: asyncio.AbstractEventLoop, ) -> float: """ Run BERT inference in thread pool (non-blocking). Returns P_bert ∈ [0,1]. """ try: p_bert = await asyncio.wait_for( loop.run_in_executor( None, # Default thread pool self._bert.predict, url, title, snippet, ), timeout=10.0, ) return float(p_bert) except asyncio.TimeoutError: logger.warning(f"BERT timeout for {url[:50]}") return 0.5 # Neutral on timeout except Exception as e: logger.error(f"BERT predict error: {e}") return 0.5 async def _gnn_predict( self, url: str, loop: asyncio.AbstractEventLoop, ) -> float: """ Run GNN inference in thread pool (non-blocking). Returns P_gnn ∈ [0,1]. """ try: p_gnn = await asyncio.wait_for( loop.run_in_executor( None, self._gnn.predict, url, None, # related_urls ), timeout=5.0, ) return float(p_gnn) except asyncio.TimeoutError: logger.warning(f"GNN timeout for {url[:50]}") return 0.5 except Exception as e: logger.error(f"GNN predict error: {e}") return 0.5 @staticmethod def decide(p3: float) -> str: """ Make decision based on P3 score. Returns: 'block', 'safe', or 'escalate' """ if p3 >= 0.85: return "block" elif p3 < 0.40: return "safe" else: return "escalate"