| """ | |
| GNN graph classifiers for Android ransomware detection. | |
| Architectures: GIN, GCN, GAT — all using BatchNorm + global mean pooling. | |
| Uses `PyTorchModelHubMixin` for native `from_pretrained()` / `save_pretrained()`. | |
| Usage: | |
| from model import GINClassifier | |
| # Load from a local directory saved by save_pretrained() | |
| model = GINClassifier.from_pretrained("./GIN/internal_only") | |
| model.eval() | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class GINClassifier(nn.Module, PyTorchModelHubMixin): | |
| """Graph Isomorphism Network for graph-level binary classification.""" | |
| def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5): | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.bns = nn.ModuleList() | |
| for i in range(num_layers): | |
| dim_in = in_dim if i == 0 else hidden | |
| mlp = nn.Sequential( | |
| nn.Linear(dim_in, hidden), nn.ReLU(), nn.Linear(hidden, hidden), | |
| ) | |
| self.convs.append(GINConv(mlp)) | |
| self.bns.append(nn.BatchNorm1d(hidden)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(hidden, num_classes), | |
| ) | |
| def forward(self, x, edge_index, batch): | |
| for conv, bn in zip(self.convs, self.bns): | |
| x = torch.relu(bn(conv(x, edge_index))) | |
| return self.classifier(global_mean_pool(x, batch)) | |
| class GCNClassifier(nn.Module, PyTorchModelHubMixin): | |
| """GCN baseline for graph-level classification.""" | |
| def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5): | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.bns = nn.ModuleList() | |
| for i in range(num_layers): | |
| dim_in = in_dim if i == 0 else hidden | |
| self.convs.append(GCNConv(dim_in, hidden)) | |
| self.bns.append(nn.BatchNorm1d(hidden)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(hidden, num_classes), | |
| ) | |
| def forward(self, x, edge_index, batch): | |
| for conv, bn in zip(self.convs, self.bns): | |
| x = torch.relu(bn(conv(x, edge_index))) | |
| return self.classifier(global_mean_pool(x, batch)) | |
| class GATClassifier(nn.Module, PyTorchModelHubMixin): | |
| """GAT baseline for graph-level classification.""" | |
| def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5, heads=4): | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.bns = nn.ModuleList() | |
| for i in range(num_layers): | |
| dim_in = in_dim if i == 0 else hidden | |
| self.convs.append(GATConv(dim_in, hidden // heads, heads=heads, concat=True)) | |
| self.bns.append(nn.BatchNorm1d(hidden)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(hidden, num_classes), | |
| ) | |
| def forward(self, x, edge_index, batch): | |
| for conv, bn in zip(self.convs, self.bns): | |
| x = torch.relu(bn(conv(x, edge_index))) | |
| return self.classifier(global_mean_pool(x, batch)) | |
| MODEL_REGISTRY = {"GIN": GINClassifier, "GCN": GCNClassifier, "GAT": GATClassifier} | |