# ============================================================ # PhishGuard AI - gnn/gnn_model.py # GNN + MLP model definitions for phishing graph classification. # # PhishGNN: 3-layer GCN with global_mean_pool → Linear → Sigmoid # GCNConv(12→64) → ReLU → GCNConv(64→32) → ReLU → # GCNConv(32→16) → global_mean_pool → Linear(16→1) → Sigmoid # # PhishMLP: Fallback for single URL or when torch_geometric unavailable # Linear(12→64) → ReLU → Dropout(0.3) → Linear(64→1) → Sigmoid # ============================================================ from __future__ import annotations import os import logging from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger("phishguard.gnn.model") INPUT_DIM: int = 12 # 12-dim node features HIDDEN_DIM: int = 64 OUTPUT_DIM: int = 1 # binary: sigmoid output # ── Try importing PyTorch Geometric ────────────────────────────────── PYGEOM_AVAILABLE: bool = False try: from torch_geometric.nn import GCNConv, global_mean_pool PYGEOM_AVAILABLE = True logger.info("PyTorch Geometric found — using full GCN model") except ImportError: PYGEOM_AVAILABLE = False logger.info("PyTorch Geometric not found — using MLP fallback") # ── PhishGNN: Full 3-layer Graph Convolutional Network ─────────────── if PYGEOM_AVAILABLE: class PhishGNN(nn.Module): """ 3-layer GCN for graph-level phishing classification. Architecture from spec: GCNConv(12→64) → ReLU → GCNConv(64→32) → ReLU → GCNConv(32→16) → global_mean_pool → Linear(16→1) → Sigmoid """ def __init__( self, in_channels: int = INPUT_DIM, hidden: int = HIDDEN_DIM, out_channels: int = OUTPUT_DIM, ) -> None: super().__init__() self.conv1 = GCNConv(in_channels, hidden) # 12 → 64 self.conv2 = GCNConv(hidden, hidden // 2) # 64 → 32 self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 → 16 self.fc = nn.Linear(hidden // 4, out_channels) # 16 → 1 def forward( self, x: torch.Tensor, edge_index: torch.Tensor, batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Handle empty edge_index if edge_index.numel() == 0: edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device) x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = F.relu(self.conv3(x, edge_index)) if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) x = global_mean_pool(x, batch) # (batch_size, 16) x = self.fc(x) # (batch_size, 1) return torch.sigmoid(x) # [0, 1] def predict_proba( self, x: torch.Tensor, edge_index: torch.Tensor, batch: Optional[torch.Tensor] = None, ) -> float: """Return P_gnn ∈ [0,1] — probability of phishing.""" self.eval() with torch.no_grad(): output = self.forward(x, edge_index, batch) return output.squeeze().item() # ── PhishMLP: Fallback for single URL or no torch_geometric ────────── class PhishMLP(nn.Module): """ MLP fallback for phishing classification. Used when torch_geometric is unavailable or graph has < 2 nodes. Architecture: Linear(12→64) → ReLU → Dropout(0.3) → Linear(64→1) → Sigmoid """ def __init__(self, in_channels: int = INPUT_DIM) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(in_channels, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 1), ) def forward( self, x: torch.Tensor, edge_index: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Pool all node features to single vector via mean if x.dim() == 2 and x.size(0) > 1: x = x.mean(dim=0, keepdim=True) elif x.dim() == 1: x = x.unsqueeze(0) out = self.net(x) return torch.sigmoid(out) def predict_proba( self, x: torch.Tensor, edge_index: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, ) -> float: """Return P_gnn ∈ [0,1] — probability of phishing.""" self.eval() with torch.no_grad(): output = self.forward(x, edge_index, batch) return output.squeeze().item() # ── Model loading utility ──────────────────────────────────────────── def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]: """ Load GNN or MLP model with optional trained weights. Returns model in eval mode, or None if creation fails. """ model: Optional[nn.Module] = None try: model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP() except Exception as e: logger.error(f"GNN model creation failed: {e}") try: model = PhishMLP() except Exception as e2: logger.error(f"MLP fallback creation also failed: {e2}") return None if model_path and os.path.exists(model_path): try: state = torch.load(model_path, map_location="cpu", weights_only=True) model.load_state_dict(state) logger.info(f"GNN weights loaded from {model_path}") except RuntimeError as e: logger.warning(f"GNN weights mismatch (architecture changed?): {e}") except Exception as e: logger.warning(f"GNN weight load failed: {e}") elif model_path: logger.info(f"GNN weights file not found: {model_path}") else: logger.info("No GNN weights path — using untrained model") try: model.eval() except Exception as e: logger.error(f"GNN eval() failed: {e}") return None return model # Legacy alias def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]: return load_gnn_model(model_path)