phishguard-api / gnn_model.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - gnn/gnn_model.py
# GNN + MLP model definitions for phishing graph classification.
#
# PhishGNN: 3-layer GCN with global_mean_pool β†’ Linear β†’ Sigmoid
# GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
# GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
#
# PhishMLP: Fallback for single URL or when torch_geometric unavailable
# Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
# ============================================================
from __future__ import annotations
import os
import logging
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger("phishguard.gnn.model")
INPUT_DIM: int = 12 # 12-dim node features
HIDDEN_DIM: int = 64
OUTPUT_DIM: int = 1 # binary: sigmoid output
# ── Try importing PyTorch Geometric ──────────────────────────────────
PYGEOM_AVAILABLE: bool = False
try:
from torch_geometric.nn import GCNConv, global_mean_pool
PYGEOM_AVAILABLE = True
logger.info("PyTorch Geometric found β€” using full GCN model")
except ImportError:
PYGEOM_AVAILABLE = False
logger.info("PyTorch Geometric not found β€” using MLP fallback")
# ── PhishGNN: Full 3-layer Graph Convolutional Network ───────────────
if PYGEOM_AVAILABLE:
class PhishGNN(nn.Module):
"""
3-layer GCN for graph-level phishing classification.
Architecture from spec:
GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
"""
def __init__(
self,
in_channels: int = INPUT_DIM,
hidden: int = HIDDEN_DIM,
out_channels: int = OUTPUT_DIM,
) -> None:
super().__init__()
self.conv1 = GCNConv(in_channels, hidden) # 12 β†’ 64
self.conv2 = GCNConv(hidden, hidden // 2) # 64 β†’ 32
self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 β†’ 16
self.fc = nn.Linear(hidden // 4, out_channels) # 16 β†’ 1
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Handle empty edge_index
if edge_index.numel() == 0:
edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = F.relu(self.conv3(x, edge_index))
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
x = global_mean_pool(x, batch) # (batch_size, 16)
x = self.fc(x) # (batch_size, 1)
return torch.sigmoid(x) # [0, 1]
def predict_proba(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> float:
"""Return P_gnn ∈ [0,1] β€” probability of phishing."""
self.eval()
with torch.no_grad():
output = self.forward(x, edge_index, batch)
return output.squeeze().item()
# ── PhishMLP: Fallback for single URL or no torch_geometric ──────────
class PhishMLP(nn.Module):
"""
MLP fallback for phishing classification.
Used when torch_geometric is unavailable or graph has < 2 nodes.
Architecture: Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
"""
def __init__(self, in_channels: int = INPUT_DIM) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_channels, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(
self,
x: torch.Tensor,
edge_index: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Pool all node features to single vector via mean
if x.dim() == 2 and x.size(0) > 1:
x = x.mean(dim=0, keepdim=True)
elif x.dim() == 1:
x = x.unsqueeze(0)
out = self.net(x)
return torch.sigmoid(out)
def predict_proba(
self,
x: torch.Tensor,
edge_index: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
) -> float:
"""Return P_gnn ∈ [0,1] β€” probability of phishing."""
self.eval()
with torch.no_grad():
output = self.forward(x, edge_index, batch)
return output.squeeze().item()
# ── Model loading utility ────────────────────────────────────────────
def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
"""
Load GNN or MLP model with optional trained weights.
Returns model in eval mode, or None if creation fails.
"""
model: Optional[nn.Module] = None
try:
model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
except Exception as e:
logger.error(f"GNN model creation failed: {e}")
try:
model = PhishMLP()
except Exception as e2:
logger.error(f"MLP fallback creation also failed: {e2}")
return None
if model_path and os.path.exists(model_path):
try:
state = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state)
logger.info(f"GNN weights loaded from {model_path}")
except RuntimeError as e:
logger.warning(f"GNN weights mismatch (architecture changed?): {e}")
except Exception as e:
logger.warning(f"GNN weight load failed: {e}")
elif model_path:
logger.info(f"GNN weights file not found: {model_path}")
else:
logger.info("No GNN weights path β€” using untrained model")
try:
model.eval()
except Exception as e:
logger.error(f"GNN eval() failed: {e}")
return None
return model
# Legacy alias
def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
return load_gnn_model(model_path)