phishguard-api / gnn_inference.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - gnn/gnn_inference.py
# GNN inference wrapper for main.py.
# Loads model once at startup, reuses for every request.
# Supports: predict, hot-reload, incremental_update.
# ============================================================
from __future__ import annotations
import os
import sys
import random
import logging
from pathlib import Path
from typing import List, Optional, Tuple
import torch
logger = logging.getLogger("phishguard.gnn.inference")
# Add parent paths
_GNN_DIR = Path(__file__).parent
_BACKEND_DIR = _GNN_DIR.parent
sys.path.insert(0, str(_GNN_DIR))
sys.path.insert(0, str(_BACKEND_DIR))
from domain_graph_builder import DomainGraphBuilder
from gnn_model import load_gnn_model, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
if PYGEOM_AVAILABLE:
from gnn_model import PhishGNN
MODEL_PATH = _GNN_DIR / "gnn_weights.pt"
REPLAY_BUFFER_PATH = _BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
class GNNInference:
"""
GNN inference wrapper with hot-reload and incremental update support.
"""
def __init__(self, weights_path: Optional[Path] = None) -> None:
self._weights_path = weights_path or MODEL_PATH
self._model: Optional[torch.nn.Module] = None
self._builder = DomainGraphBuilder()
self._loaded = False
def load(self, weights_path: Optional[Path] = None) -> bool:
"""Load GNN model from weights file."""
path = weights_path or self._weights_path
self._model = load_gnn_model(str(path) if path.exists() else None)
self._loaded = self._model is not None
if self._loaded:
logger.info(f"GNN model loaded from {path}")
return self._loaded
def predict(self, url: str, related_urls: Optional[List[str]] = None) -> float:
"""
Predict phishing probability for a URL.
Returns P_gnn ∈ [0,1].
Falls back to MLP if model unavailable or graph too small.
"""
if not self._loaded:
self.load()
if self._model is None:
return 0.5 # Neutral when model unavailable
urls = [url] + (related_urls or [])
# Single URL β†’ MLP fallback path
if len(urls) == 1:
graph = self._builder.build_single_node_graph(url)
else:
graph = self._builder.build_graph(urls)
x = torch.tensor(graph["features"], dtype=torch.float)
edges = graph["edges"]
if edges and len(edges) > 0:
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
else:
n = x.size(0)
edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
prob = self._model.predict_proba(x, edge_index)
return round(float(prob), 4)
def reload(self, weights_path: Optional[Path] = None) -> bool:
"""Hot-reload model with new weights (no server restart needed)."""
path = weights_path or self._weights_path
new_model = load_gnn_model(str(path))
if new_model is not None:
self._model = new_model
self._loaded = True
logger.info(f"GNN model hot-reloaded from {path}")
return True
logger.warning(f"GNN hot-reload failed from {path}")
return False
def incremental_update(
self,
samples: List[Tuple[str, int]],
replay_buffer_path: Optional[Path] = None,
lr: float = 5e-4,
epochs: int = 5,
) -> Optional[float]:
"""
Incremental update on feedback samples + replay buffer.
Returns accuracy_delta or None if failed.
samples: list of (url, label) where label is 0 or 1
"""
if self._model is None:
logger.warning("GNN not loaded, cannot incrementally update")
return None
if len(samples) < 5:
logger.warning(f"Too few samples ({len(samples)}) for GNN update")
return None
try:
import torch.nn.functional as F
device = torch.device("cpu")
model = self._model.to(device)
builder = DomainGraphBuilder()
# Build graphs from new feedback
new_graphs = []
CHUNK = 4
phish = [url for url, label in samples if label == 1]
legit = [url for url, label in samples if label == 0]
for urls, label in [(phish, 1), (legit, 0)]:
for i in range(0, len(urls), CHUNK):
chunk = urls[i:i + CHUNK]
if not chunk:
continue
graph = builder.build_graph(chunk)
x = torch.tensor(graph["features"], dtype=torch.float)
edges = graph["edges"]
if edges:
ei = torch.tensor(edges, dtype=torch.long).t().contiguous()
else:
n = x.size(0)
ei = torch.arange(n).unsqueeze(0).repeat(2, 1)
new_graphs.append({
"x": x, "edge_index": ei,
"y": torch.tensor([float(label)]),
})
# Load replay buffer (20% mix)
buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
replay_graphs = []
if buf_path.exists():
try:
all_replay = torch.load(buf_path, map_location="cpu", weights_only=False)
replay_count = max(1, len(all_replay) // 5) # 20%
replay_graphs = random.sample(all_replay, min(replay_count, len(all_replay)))
except Exception as e:
logger.warning(f"Replay buffer load failed: {e}")
# Merge: 80% new + 20% replay
dataset = new_graphs + replay_graphs
random.shuffle(dataset)
if not dataset:
return None
# Pre-update accuracy
model.eval()
pre_correct = 0
with torch.no_grad():
for item in dataset:
out = model(item["x"].to(device), item["edge_index"].to(device))
pred = 1 if out.squeeze().item() >= 0.5 else 0
pre_correct += int(pred == int(item["y"].item()))
pre_acc = pre_correct / len(dataset)
# Train
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
model.train()
for epoch in range(epochs):
random.shuffle(dataset)
total_loss = 0.0
for item in dataset:
x = item["x"].to(device)
ei = item["edge_index"].to(device)
y = item["y"].to(device)
optimizer.zero_grad()
out = model(x, ei)
loss = F.binary_cross_entropy(out.squeeze(), y.squeeze())
loss.backward()
optimizer.step()
total_loss += loss.item()
logger.info(f"GNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(dataset):.4f}")
# Post-update accuracy
model.eval()
post_correct = 0
with torch.no_grad():
for item in dataset:
out = model(item["x"].to(device), item["edge_index"].to(device))
pred = 1 if out.squeeze().item() >= 0.5 else 0
post_correct += int(pred == int(item["y"].item()))
post_acc = post_correct / len(dataset)
delta = post_acc - pre_acc
self._model = model
# Save weights
torch.save(model.state_dict(), self._weights_path)
logger.info(f"GNN incremental update: {pre_acc:.4f} β†’ {post_acc:.4f} (Ξ”={delta:+.4f})")
# Update replay buffer (rolling 500)
try:
existing = []
if buf_path.exists():
existing = torch.load(buf_path, map_location="cpu", weights_only=False)
combined = existing + new_graphs
if len(combined) > 500:
combined = combined[-500:]
buf_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(combined, buf_path)
except Exception as e:
logger.warning(f"Replay buffer update failed: {e}")
return round(delta, 4)
except Exception as e:
logger.error(f"GNN incremental update failed: {e}")
return None
@property
def is_loaded(self) -> bool:
return self._loaded
# ── Legacy compatibility functions ───────────────────────────────────
_inference = GNNInference()
def analyze_url_with_gnn(url: str, related_urls: list = None) -> dict:
"""Legacy wrapper for backward compatibility."""
if not _inference.is_loaded:
_inference.load()
if not _inference.is_loaded:
return {
"gnn_phish_prob": None,
"tier3_status": "model_not_loaded",
"node_count": 0,
"edge_count": 0,
"graph_suspicious": False,
}
prob = _inference.predict(url, related_urls)
return {
"gnn_phish_prob": prob,
"node_count": 1 + len(related_urls or []),
"edge_count": 0,
"graph_suspicious": prob > 0.6,
}
def reload_model(new_weights_path: str = None) -> bool:
path = Path(new_weights_path) if new_weights_path else None
return _inference.reload(path)
def is_model_loaded() -> bool:
return _inference.is_loaded