phishguard-api / tier3_bert_gnn.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - tier3_bert_gnn.py
# Tier 3: BERT + GNN Parallel Ensemble
#
# Triggered only when Tier 2 score < 80.
# BERT and GNN run in PARALLEL via asyncio.gather + run_in_executor.
#
# Ensemble formula:
# P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·(H_score/100)
#
# Decision:
# P3 >= 0.85 → BLOCK
# P3 < 0.40 → SAFE
# 0.40 <= P3 < 0.85 → escalate to Tier 4
# ============================================================
from __future__ import annotations
import asyncio
import logging
from typing import Optional
logger = logging.getLogger("phishguard.tier3")
class Tier3Ensemble:
"""
Tier 3: BERT + GNN parallel ensemble classifier.
Runs BERT and GNN inference in parallel using asyncio.gather
with run_in_executor for non-blocking thread pool execution.
"""
# Ensemble weights
W_BERT: float = 0.45
W_GNN: float = 0.35
W_HEURISTIC: float = 0.20
def __init__(
self,
bert_classifier,
gnn_inference,
) -> None:
self._bert = bert_classifier
self._gnn = gnn_inference
async def predict(
self,
url: str,
title: str = "",
snippet: str = "",
h_score: int = 0,
) -> float:
"""
Run BERT + GNN in parallel and compute ensemble score.
Args:
url: The URL to analyze
title: Page title (optional)
snippet: Page content snippet (optional)
h_score: Heuristic score from Tier 2 (0-100, passed through, NOT recomputed)
Returns:
P3 ∈ [0,1] — ensemble phishing probability
"""
loop = asyncio.get_event_loop()
# Run BERT and GNN in parallel (both are CPU-bound, use thread pool)
bert_task = self._bert_predict(url, title, snippet, loop)
gnn_task = self._gnn_predict(url, loop)
p_bert, p_gnn = await asyncio.gather(bert_task, gnn_task)
# Ensemble: P3 = 0.45·P_bert + 0.35·P_gnn + 0.20·H_norm
h_norm = h_score / 100.0
p3 = (self.W_BERT * p_bert) + (self.W_GNN * p_gnn) + (self.W_HEURISTIC * h_norm)
logger.info(
f"Tier3 ensemble | url={url[:60]} | "
f"P_bert={p_bert:.4f} P_gnn={p_gnn:.4f} H_norm={h_norm:.4f} → P3={p3:.4f}"
)
return round(min(max(p3, 0.0), 1.0), 4)
async def _bert_predict(
self,
url: str,
title: str,
snippet: str,
loop: asyncio.AbstractEventLoop,
) -> float:
"""
Run BERT inference in thread pool (non-blocking).
Returns P_bert ∈ [0,1].
"""
try:
p_bert = await asyncio.wait_for(
loop.run_in_executor(
None, # Default thread pool
self._bert.predict,
url,
title,
snippet,
),
timeout=10.0,
)
return float(p_bert)
except asyncio.TimeoutError:
logger.warning(f"BERT timeout for {url[:50]}")
return 0.5 # Neutral on timeout
except Exception as e:
logger.error(f"BERT predict error: {e}")
return 0.5
async def _gnn_predict(
self,
url: str,
loop: asyncio.AbstractEventLoop,
) -> float:
"""
Run GNN inference in thread pool (non-blocking).
Returns P_gnn ∈ [0,1].
"""
try:
p_gnn = await asyncio.wait_for(
loop.run_in_executor(
None,
self._gnn.predict,
url,
None, # related_urls
),
timeout=5.0,
)
return float(p_gnn)
except asyncio.TimeoutError:
logger.warning(f"GNN timeout for {url[:50]}")
return 0.5
except Exception as e:
logger.error(f"GNN predict error: {e}")
return 0.5
@staticmethod
def decide(p3: float) -> str:
"""
Make decision based on P3 score.
Returns: 'block', 'safe', or 'escalate'
"""
if p3 >= 0.85:
return "block"
elif p3 < 0.40:
return "safe"
else:
return "escalate"