File size: 3,436 Bytes
8d5f9ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | """
GNN graph classifiers for Android ransomware detection.
Architectures: GIN, GCN, GAT — all using BatchNorm + global mean pooling.
Uses `PyTorchModelHubMixin` for native `from_pretrained()` / `save_pretrained()`.
Usage:
from model import GINClassifier
# Load from a local directory saved by save_pretrained()
model = GINClassifier.from_pretrained("./GIN/internal_only")
model.eval()
"""
import torch
import torch.nn as nn
from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool
from huggingface_hub import PyTorchModelHubMixin
class GINClassifier(nn.Module, PyTorchModelHubMixin):
"""Graph Isomorphism Network for graph-level binary classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
mlp = nn.Sequential(
nn.Linear(dim_in, hidden), nn.ReLU(), nn.Linear(hidden, hidden),
)
self.convs.append(GINConv(mlp))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
class GCNClassifier(nn.Module, PyTorchModelHubMixin):
"""GCN baseline for graph-level classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
self.convs.append(GCNConv(dim_in, hidden))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
class GATClassifier(nn.Module, PyTorchModelHubMixin):
"""GAT baseline for graph-level classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5, heads=4):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
self.convs.append(GATConv(dim_in, hidden // heads, heads=heads, concat=True))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
MODEL_REGISTRY = {"GIN": GINClassifier, "GCN": GCNClassifier, "GAT": GATClassifier}
|