Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - tier3_bert_gnn.py | |
| # Tier 3: BERT + GNN Parallel Ensemble | |
| # | |
| # Triggered only when Tier 2 score < 80. | |
| # BERT and GNN run in PARALLEL via asyncio.gather + run_in_executor. | |
| # | |
| # Ensemble formula: | |
| # P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·(H_score/100) | |
| # | |
| # Decision: | |
| # P3 >= 0.85 → BLOCK | |
| # P3 < 0.40 → SAFE | |
| # 0.40 <= P3 < 0.85 → escalate to Tier 4 | |
| # ============================================================ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from typing import Optional | |
| logger = logging.getLogger("phishguard.tier3") | |
| class Tier3Ensemble: | |
| """ | |
| Tier 3: BERT + GNN parallel ensemble classifier. | |
| Runs BERT and GNN inference in parallel using asyncio.gather | |
| with run_in_executor for non-blocking thread pool execution. | |
| """ | |
| # Ensemble weights | |
| W_BERT: float = 0.45 | |
| W_GNN: float = 0.35 | |
| W_HEURISTIC: float = 0.20 | |
| def __init__( | |
| self, | |
| bert_classifier, | |
| gnn_inference, | |
| ) -> None: | |
| self._bert = bert_classifier | |
| self._gnn = gnn_inference | |
| async def predict( | |
| self, | |
| url: str, | |
| title: str = "", | |
| snippet: str = "", | |
| h_score: int = 0, | |
| ) -> float: | |
| """ | |
| Run BERT + GNN in parallel and compute ensemble score. | |
| Args: | |
| url: The URL to analyze | |
| title: Page title (optional) | |
| snippet: Page content snippet (optional) | |
| h_score: Heuristic score from Tier 2 (0-100, passed through, NOT recomputed) | |
| Returns: | |
| P3 ∈ [0,1] — ensemble phishing probability | |
| """ | |
| loop = asyncio.get_event_loop() | |
| # Run BERT and GNN in parallel (both are CPU-bound, use thread pool) | |
| bert_task = self._bert_predict(url, title, snippet, loop) | |
| gnn_task = self._gnn_predict(url, loop) | |
| p_bert, p_gnn = await asyncio.gather(bert_task, gnn_task) | |
| # Ensemble: P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·H_norm | |
| h_norm = h_score / 100.0 | |
| p3 = (self.W_BERT * p_bert) + (self.W_GNN * p_gnn) + (self.W_HEURISTIC * h_norm) | |
| logger.info( | |
| f"Tier3 ensemble | url={url[:60]} | " | |
| f"P_bert={p_bert:.4f} P_gnn={p_gnn:.4f} H_norm={h_norm:.4f} → P3={p3:.4f}" | |
| ) | |
| return round(min(max(p3, 0.0), 1.0), 4) | |
| async def _bert_predict( | |
| self, | |
| url: str, | |
| title: str, | |
| snippet: str, | |
| loop: asyncio.AbstractEventLoop, | |
| ) -> float: | |
| """ | |
| Run BERT inference in thread pool (non-blocking). | |
| Returns P_bert ∈ [0,1]. | |
| """ | |
| try: | |
| p_bert = await asyncio.wait_for( | |
| loop.run_in_executor( | |
| None, # Default thread pool | |
| self._bert.predict, | |
| url, | |
| title, | |
| snippet, | |
| ), | |
| timeout=10.0, | |
| ) | |
| return float(p_bert) | |
| except asyncio.TimeoutError: | |
| logger.warning(f"BERT timeout for {url[:50]}") | |
| return 0.5 # Neutral on timeout | |
| except Exception as e: | |
| logger.error(f"BERT predict error: {e}") | |
| return 0.5 | |
| async def _gnn_predict( | |
| self, | |
| url: str, | |
| loop: asyncio.AbstractEventLoop, | |
| ) -> float: | |
| """ | |
| Run GNN inference in thread pool (non-blocking). | |
| Returns P_gnn ∈ [0,1]. | |
| """ | |
| try: | |
| p_gnn = await asyncio.wait_for( | |
| loop.run_in_executor( | |
| None, | |
| self._gnn.predict, | |
| url, | |
| None, # related_urls | |
| ), | |
| timeout=5.0, | |
| ) | |
| return float(p_gnn) | |
| except asyncio.TimeoutError: | |
| logger.warning(f"GNN timeout for {url[:50]}") | |
| return 0.5 | |
| except Exception as e: | |
| logger.error(f"GNN predict error: {e}") | |
| return 0.5 | |
| def decide(p3: float) -> str: | |
| """ | |
| Make decision based on P3 score. | |
| Returns: 'block', 'safe', or 'escalate' | |
| """ | |
| if p3 >= 0.85: | |
| return "block" | |
| elif p3 < 0.40: | |
| return "safe" | |
| else: | |
| return "escalate" | |