Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - gnn/gnn_model.py | |
| # GNN + MLP model definitions for phishing graph classification. | |
| # | |
| # PhishGNN: 3-layer GCN with global_mean_pool β Linear β Sigmoid | |
| # GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β | |
| # GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid | |
| # | |
| # PhishMLP: Fallback for single URL or when torch_geometric unavailable | |
| # Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid | |
| # ============================================================ | |
| from __future__ import annotations | |
| import os | |
| import logging | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| logger = logging.getLogger("phishguard.gnn.model") | |
| INPUT_DIM: int = 12 # 12-dim node features | |
| HIDDEN_DIM: int = 64 | |
| OUTPUT_DIM: int = 1 # binary: sigmoid output | |
| # ββ Try importing PyTorch Geometric ββββββββββββββββββββββββββββββββββ | |
| PYGEOM_AVAILABLE: bool = False | |
| try: | |
| from torch_geometric.nn import GCNConv, global_mean_pool | |
| PYGEOM_AVAILABLE = True | |
| logger.info("PyTorch Geometric found β using full GCN model") | |
| except ImportError: | |
| PYGEOM_AVAILABLE = False | |
| logger.info("PyTorch Geometric not found β using MLP fallback") | |
| # ββ PhishGNN: Full 3-layer Graph Convolutional Network βββββββββββββββ | |
| if PYGEOM_AVAILABLE: | |
| class PhishGNN(nn.Module): | |
| """ | |
| 3-layer GCN for graph-level phishing classification. | |
| Architecture from spec: | |
| GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β | |
| GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = INPUT_DIM, | |
| hidden: int = HIDDEN_DIM, | |
| out_channels: int = OUTPUT_DIM, | |
| ) -> None: | |
| super().__init__() | |
| self.conv1 = GCNConv(in_channels, hidden) # 12 β 64 | |
| self.conv2 = GCNConv(hidden, hidden // 2) # 64 β 32 | |
| self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 β 16 | |
| self.fc = nn.Linear(hidden // 4, out_channels) # 16 β 1 | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| batch: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # Handle empty edge_index | |
| if edge_index.numel() == 0: | |
| edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device) | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = F.relu(self.conv2(x, edge_index)) | |
| x = F.relu(self.conv3(x, edge_index)) | |
| if batch is None: | |
| batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) | |
| x = global_mean_pool(x, batch) # (batch_size, 16) | |
| x = self.fc(x) # (batch_size, 1) | |
| return torch.sigmoid(x) # [0, 1] | |
| def predict_proba( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| batch: Optional[torch.Tensor] = None, | |
| ) -> float: | |
| """Return P_gnn β [0,1] β probability of phishing.""" | |
| self.eval() | |
| with torch.no_grad(): | |
| output = self.forward(x, edge_index, batch) | |
| return output.squeeze().item() | |
| # ββ PhishMLP: Fallback for single URL or no torch_geometric ββββββββββ | |
| class PhishMLP(nn.Module): | |
| """ | |
| MLP fallback for phishing classification. | |
| Used when torch_geometric is unavailable or graph has < 2 nodes. | |
| Architecture: Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid | |
| """ | |
| def __init__(self, in_channels: int = INPUT_DIM) -> None: | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_channels, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: Optional[torch.Tensor] = None, | |
| batch: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # Pool all node features to single vector via mean | |
| if x.dim() == 2 and x.size(0) > 1: | |
| x = x.mean(dim=0, keepdim=True) | |
| elif x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| out = self.net(x) | |
| return torch.sigmoid(out) | |
| def predict_proba( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: Optional[torch.Tensor] = None, | |
| batch: Optional[torch.Tensor] = None, | |
| ) -> float: | |
| """Return P_gnn β [0,1] β probability of phishing.""" | |
| self.eval() | |
| with torch.no_grad(): | |
| output = self.forward(x, edge_index, batch) | |
| return output.squeeze().item() | |
| # ββ Model loading utility ββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]: | |
| """ | |
| Load GNN or MLP model with optional trained weights. | |
| Returns model in eval mode, or None if creation fails. | |
| """ | |
| model: Optional[nn.Module] = None | |
| try: | |
| model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP() | |
| except Exception as e: | |
| logger.error(f"GNN model creation failed: {e}") | |
| try: | |
| model = PhishMLP() | |
| except Exception as e2: | |
| logger.error(f"MLP fallback creation also failed: {e2}") | |
| return None | |
| if model_path and os.path.exists(model_path): | |
| try: | |
| state = torch.load(model_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state) | |
| logger.info(f"GNN weights loaded from {model_path}") | |
| except RuntimeError as e: | |
| logger.warning(f"GNN weights mismatch (architecture changed?): {e}") | |
| except Exception as e: | |
| logger.warning(f"GNN weight load failed: {e}") | |
| elif model_path: | |
| logger.info(f"GNN weights file not found: {model_path}") | |
| else: | |
| logger.info("No GNN weights path β using untrained model") | |
| try: | |
| model.eval() | |
| except Exception as e: | |
| logger.error(f"GNN eval() failed: {e}") | |
| return None | |
| return model | |
| # Legacy alias | |
| def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]: | |
| return load_gnn_model(model_path) | |