# ============================================================ # PhishGuard AI - gnn/gnn_inference.py # GNN inference wrapper for main.py. # Loads model once at startup, reuses for every request. # Supports: predict, hot-reload, incremental_update. # ============================================================ from __future__ import annotations import os import sys import random import logging from pathlib import Path from typing import List, Optional, Tuple import torch logger = logging.getLogger("phishguard.gnn.inference") # Add parent paths _GNN_DIR = Path(__file__).parent _BACKEND_DIR = _GNN_DIR.parent sys.path.insert(0, str(_GNN_DIR)) sys.path.insert(0, str(_BACKEND_DIR)) from domain_graph_builder import DomainGraphBuilder from gnn_model import load_gnn_model, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM if PYGEOM_AVAILABLE: from gnn_model import PhishGNN MODEL_PATH = _GNN_DIR / "gnn_weights.pt" REPLAY_BUFFER_PATH = _BACKEND_DIR / "data" / "gnn_replay_buffer.pt" class GNNInference: """ GNN inference wrapper with hot-reload and incremental update support. """ def __init__(self, weights_path: Optional[Path] = None) -> None: self._weights_path = weights_path or MODEL_PATH self._model: Optional[torch.nn.Module] = None self._builder = DomainGraphBuilder() self._loaded = False def load(self, weights_path: Optional[Path] = None) -> bool: """Load GNN model from weights file.""" path = weights_path or self._weights_path self._model = load_gnn_model(str(path) if path.exists() else None) self._loaded = self._model is not None if self._loaded: logger.info(f"GNN model loaded from {path}") return self._loaded def predict(self, url: str, related_urls: Optional[List[str]] = None) -> float: """ Predict phishing probability for a URL. Returns P_gnn ∈ [0,1]. Falls back to MLP if model unavailable or graph too small. """ if not self._loaded: self.load() if self._model is None: return 0.5 # Neutral when model unavailable urls = [url] + (related_urls or []) # Single URL → MLP fallback path if len(urls) == 1: graph = self._builder.build_single_node_graph(url) else: graph = self._builder.build_graph(urls) x = torch.tensor(graph["features"], dtype=torch.float) edges = graph["edges"] if edges and len(edges) > 0: edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() else: n = x.size(0) edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1) prob = self._model.predict_proba(x, edge_index) return round(float(prob), 4) def reload(self, weights_path: Optional[Path] = None) -> bool: """Hot-reload model with new weights (no server restart needed).""" path = weights_path or self._weights_path new_model = load_gnn_model(str(path)) if new_model is not None: self._model = new_model self._loaded = True logger.info(f"GNN model hot-reloaded from {path}") return True logger.warning(f"GNN hot-reload failed from {path}") return False def incremental_update( self, samples: List[Tuple[str, int]], replay_buffer_path: Optional[Path] = None, lr: float = 5e-4, epochs: int = 5, ) -> Optional[float]: """ Incremental update on feedback samples + replay buffer. Returns accuracy_delta or None if failed. samples: list of (url, label) where label is 0 or 1 """ if self._model is None: logger.warning("GNN not loaded, cannot incrementally update") return None if len(samples) < 5: logger.warning(f"Too few samples ({len(samples)}) for GNN update") return None try: import torch.nn.functional as F device = torch.device("cpu") model = self._model.to(device) builder = DomainGraphBuilder() # Build graphs from new feedback new_graphs = [] CHUNK = 4 phish = [url for url, label in samples if label == 1] legit = [url for url, label in samples if label == 0] for urls, label in [(phish, 1), (legit, 0)]: for i in range(0, len(urls), CHUNK): chunk = urls[i:i + CHUNK] if not chunk: continue graph = builder.build_graph(chunk) x = torch.tensor(graph["features"], dtype=torch.float) edges = graph["edges"] if edges: ei = torch.tensor(edges, dtype=torch.long).t().contiguous() else: n = x.size(0) ei = torch.arange(n).unsqueeze(0).repeat(2, 1) new_graphs.append({ "x": x, "edge_index": ei, "y": torch.tensor([float(label)]), }) # Load replay buffer (20% mix) buf_path = replay_buffer_path or REPLAY_BUFFER_PATH replay_graphs = [] if buf_path.exists(): try: all_replay = torch.load(buf_path, map_location="cpu", weights_only=False) replay_count = max(1, len(all_replay) // 5) # 20% replay_graphs = random.sample(all_replay, min(replay_count, len(all_replay))) except Exception as e: logger.warning(f"Replay buffer load failed: {e}") # Merge: 80% new + 20% replay dataset = new_graphs + replay_graphs random.shuffle(dataset) if not dataset: return None # Pre-update accuracy model.eval() pre_correct = 0 with torch.no_grad(): for item in dataset: out = model(item["x"].to(device), item["edge_index"].to(device)) pred = 1 if out.squeeze().item() >= 0.5 else 0 pre_correct += int(pred == int(item["y"].item())) pre_acc = pre_correct / len(dataset) # Train optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) model.train() for epoch in range(epochs): random.shuffle(dataset) total_loss = 0.0 for item in dataset: x = item["x"].to(device) ei = item["edge_index"].to(device) y = item["y"].to(device) optimizer.zero_grad() out = model(x, ei) loss = F.binary_cross_entropy(out.squeeze(), y.squeeze()) loss.backward() optimizer.step() total_loss += loss.item() logger.info(f"GNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(dataset):.4f}") # Post-update accuracy model.eval() post_correct = 0 with torch.no_grad(): for item in dataset: out = model(item["x"].to(device), item["edge_index"].to(device)) pred = 1 if out.squeeze().item() >= 0.5 else 0 post_correct += int(pred == int(item["y"].item())) post_acc = post_correct / len(dataset) delta = post_acc - pre_acc self._model = model # Save weights torch.save(model.state_dict(), self._weights_path) logger.info(f"GNN incremental update: {pre_acc:.4f} → {post_acc:.4f} (Δ={delta:+.4f})") # Update replay buffer (rolling 500) try: existing = [] if buf_path.exists(): existing = torch.load(buf_path, map_location="cpu", weights_only=False) combined = existing + new_graphs if len(combined) > 500: combined = combined[-500:] buf_path.parent.mkdir(parents=True, exist_ok=True) torch.save(combined, buf_path) except Exception as e: logger.warning(f"Replay buffer update failed: {e}") return round(delta, 4) except Exception as e: logger.error(f"GNN incremental update failed: {e}") return None @property def is_loaded(self) -> bool: return self._loaded # ── Legacy compatibility functions ─────────────────────────────────── _inference = GNNInference() def analyze_url_with_gnn(url: str, related_urls: list = None) -> dict: """Legacy wrapper for backward compatibility.""" if not _inference.is_loaded: _inference.load() if not _inference.is_loaded: return { "gnn_phish_prob": None, "tier3_status": "model_not_loaded", "node_count": 0, "edge_count": 0, "graph_suspicious": False, } prob = _inference.predict(url, related_urls) return { "gnn_phish_prob": prob, "node_count": 1 + len(related_urls or []), "edge_count": 0, "graph_suspicious": prob > 0.6, } def reload_model(new_weights_path: str = None) -> bool: path = Path(new_weights_path) if new_weights_path else None return _inference.reload(path) def is_model_loaded() -> bool: return _inference.is_loaded