Spaces:
Running
Running
File size: 6,590 Bytes
bebe233 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | # ============================================================
# PhishGuard AI - gnn/gnn_model.py
# GNN + MLP model definitions for phishing graph classification.
#
# PhishGNN: 3-layer GCN with global_mean_pool β Linear β Sigmoid
# GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β
# GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid
#
# PhishMLP: Fallback for single URL or when torch_geometric unavailable
# Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid
# ============================================================
from __future__ import annotations
import os
import logging
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger("phishguard.gnn.model")
INPUT_DIM: int = 12 # 12-dim node features
HIDDEN_DIM: int = 64
OUTPUT_DIM: int = 1 # binary: sigmoid output
# ββ Try importing PyTorch Geometric ββββββββββββββββββββββββββββββββββ
PYGEOM_AVAILABLE: bool = False
try:
from torch_geometric.nn import GCNConv, global_mean_pool
PYGEOM_AVAILABLE = True
logger.info("PyTorch Geometric found β using full GCN model")
except ImportError:
PYGEOM_AVAILABLE = False
logger.info("PyTorch Geometric not found β using MLP fallback")
# ββ PhishGNN: Full 3-layer Graph Convolutional Network βββββββββββββββ
if PYGEOM_AVAILABLE:
class PhishGNN(nn.Module):
"""
3-layer GCN for graph-level phishing classification.
Architecture from spec:
GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β
GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid
"""
def __init__(
self,
in_channels: int = INPUT_DIM,
hidden: int = HIDDEN_DIM,
out_channels: int = OUTPUT_DIM,
) -> None:
super().__init__()
self.conv1 = GCNConv(in_channels, hidden) # 12 β 64
self.conv2 = GCNConv(hidden, hidden // 2) # 64 β 32
self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 β 16
self.fc = nn.Linear(hidden // 4, out_channels) # 16 β 1
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Handle empty edge_index
if edge_index.numel() == 0:
edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = F.relu(self.conv3(x, edge_index))
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
x = global_mean_pool(x, batch) # (batch_size, 16)
x = self.fc(x) # (batch_size, 1)
return torch.sigmoid(x) # [0, 1]
def predict_proba(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> float:
"""Return P_gnn β [0,1] β probability of phishing."""
self.eval()
with torch.no_grad():
output = self.forward(x, edge_index, batch)
return output.squeeze().item()
# ββ PhishMLP: Fallback for single URL or no torch_geometric ββββββββββ
class PhishMLP(nn.Module):
"""
MLP fallback for phishing classification.
Used when torch_geometric is unavailable or graph has < 2 nodes.
Architecture: Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid
"""
def __init__(self, in_channels: int = INPUT_DIM) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_channels, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(
self,
x: torch.Tensor,
edge_index: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Pool all node features to single vector via mean
if x.dim() == 2 and x.size(0) > 1:
x = x.mean(dim=0, keepdim=True)
elif x.dim() == 1:
x = x.unsqueeze(0)
out = self.net(x)
return torch.sigmoid(out)
def predict_proba(
self,
x: torch.Tensor,
edge_index: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
) -> float:
"""Return P_gnn β [0,1] β probability of phishing."""
self.eval()
with torch.no_grad():
output = self.forward(x, edge_index, batch)
return output.squeeze().item()
# ββ Model loading utility ββββββββββββββββββββββββββββββββββββββββββββ
def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
"""
Load GNN or MLP model with optional trained weights.
Returns model in eval mode, or None if creation fails.
"""
model: Optional[nn.Module] = None
try:
model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
except Exception as e:
logger.error(f"GNN model creation failed: {e}")
try:
model = PhishMLP()
except Exception as e2:
logger.error(f"MLP fallback creation also failed: {e2}")
return None
if model_path and os.path.exists(model_path):
try:
state = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state)
logger.info(f"GNN weights loaded from {model_path}")
except RuntimeError as e:
logger.warning(f"GNN weights mismatch (architecture changed?): {e}")
except Exception as e:
logger.warning(f"GNN weight load failed: {e}")
elif model_path:
logger.info(f"GNN weights file not found: {model_path}")
else:
logger.info("No GNN weights path β using untrained model")
try:
model.eval()
except Exception as e:
logger.error(f"GNN eval() failed: {e}")
return None
return model
# Legacy alias
def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
return load_gnn_model(model_path)
|