File size: 6,590 Bytes
bebe233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# ============================================================
# PhishGuard AI - gnn/gnn_model.py
# GNN + MLP model definitions for phishing graph classification.
#
# PhishGNN: 3-layer GCN with global_mean_pool β†’ Linear β†’ Sigmoid
#   GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
#   GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
#
# PhishMLP: Fallback for single URL or when torch_geometric unavailable
#   Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
# ============================================================

from __future__ import annotations

import os
import logging
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger("phishguard.gnn.model")

INPUT_DIM: int = 12   # 12-dim node features
HIDDEN_DIM: int = 64
OUTPUT_DIM: int = 1   # binary: sigmoid output

# ── Try importing PyTorch Geometric ──────────────────────────────────
PYGEOM_AVAILABLE: bool = False
try:
    from torch_geometric.nn import GCNConv, global_mean_pool
    PYGEOM_AVAILABLE = True
    logger.info("PyTorch Geometric found β€” using full GCN model")
except ImportError:
    PYGEOM_AVAILABLE = False
    logger.info("PyTorch Geometric not found β€” using MLP fallback")


# ── PhishGNN: Full 3-layer Graph Convolutional Network ───────────────
if PYGEOM_AVAILABLE:
    class PhishGNN(nn.Module):
        """
        3-layer GCN for graph-level phishing classification.
        Architecture from spec:
          GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
          GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
        """

        def __init__(
            self,
            in_channels: int = INPUT_DIM,
            hidden: int = HIDDEN_DIM,
            out_channels: int = OUTPUT_DIM,
        ) -> None:
            super().__init__()
            self.conv1 = GCNConv(in_channels, hidden)         # 12 β†’ 64
            self.conv2 = GCNConv(hidden, hidden // 2)         # 64 β†’ 32
            self.conv3 = GCNConv(hidden // 2, hidden // 4)    # 32 β†’ 16
            self.fc = nn.Linear(hidden // 4, out_channels)    # 16 β†’ 1

        def forward(
            self,
            x: torch.Tensor,
            edge_index: torch.Tensor,
            batch: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
            # Handle empty edge_index
            if edge_index.numel() == 0:
                edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)

            x = F.relu(self.conv1(x, edge_index))
            x = F.relu(self.conv2(x, edge_index))
            x = F.relu(self.conv3(x, edge_index))

            if batch is None:
                batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

            x = global_mean_pool(x, batch)      # (batch_size, 16)
            x = self.fc(x)                       # (batch_size, 1)
            return torch.sigmoid(x)              # [0, 1]

        def predict_proba(
            self,
            x: torch.Tensor,
            edge_index: torch.Tensor,
            batch: Optional[torch.Tensor] = None,
        ) -> float:
            """Return P_gnn ∈ [0,1] β€” probability of phishing."""
            self.eval()
            with torch.no_grad():
                output = self.forward(x, edge_index, batch)
                return output.squeeze().item()


# ── PhishMLP: Fallback for single URL or no torch_geometric ──────────
class PhishMLP(nn.Module):
    """
    MLP fallback for phishing classification.
    Used when torch_geometric is unavailable or graph has < 2 nodes.
    Architecture: Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
    """

    def __init__(self, in_channels: int = INPUT_DIM) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_channels, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Pool all node features to single vector via mean
        if x.dim() == 2 and x.size(0) > 1:
            x = x.mean(dim=0, keepdim=True)
        elif x.dim() == 1:
            x = x.unsqueeze(0)
        out = self.net(x)
        return torch.sigmoid(out)

    def predict_proba(
        self,
        x: torch.Tensor,
        edge_index: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None,
    ) -> float:
        """Return P_gnn ∈ [0,1] β€” probability of phishing."""
        self.eval()
        with torch.no_grad():
            output = self.forward(x, edge_index, batch)
            return output.squeeze().item()


# ── Model loading utility ────────────────────────────────────────────
def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
    """
    Load GNN or MLP model with optional trained weights.
    Returns model in eval mode, or None if creation fails.
    """
    model: Optional[nn.Module] = None

    try:
        model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
    except Exception as e:
        logger.error(f"GNN model creation failed: {e}")
        try:
            model = PhishMLP()
        except Exception as e2:
            logger.error(f"MLP fallback creation also failed: {e2}")
            return None

    if model_path and os.path.exists(model_path):
        try:
            state = torch.load(model_path, map_location="cpu", weights_only=True)
            model.load_state_dict(state)
            logger.info(f"GNN weights loaded from {model_path}")
        except RuntimeError as e:
            logger.warning(f"GNN weights mismatch (architecture changed?): {e}")
        except Exception as e:
            logger.warning(f"GNN weight load failed: {e}")
    elif model_path:
        logger.info(f"GNN weights file not found: {model_path}")
    else:
        logger.info("No GNN weights path β€” using untrained model")

    try:
        model.eval()
    except Exception as e:
        logger.error(f"GNN eval() failed: {e}")
        return None

    return model


# Legacy alias
def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
    return load_gnn_model(model_path)