""" GNN graph classifiers for Android ransomware detection. Architectures: GIN, GCN, GAT — all using BatchNorm + global mean pooling. Uses `PyTorchModelHubMixin` for native `from_pretrained()` / `save_pretrained()`. Usage: from model import GINClassifier # Load from a local directory saved by save_pretrained() model = GINClassifier.from_pretrained("./GIN/internal_only") model.eval() """ import torch import torch.nn as nn from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool from huggingface_hub import PyTorchModelHubMixin class GINClassifier(nn.Module, PyTorchModelHubMixin): """Graph Isomorphism Network for graph-level binary classification.""" def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5): super().__init__() self.convs = nn.ModuleList() self.bns = nn.ModuleList() for i in range(num_layers): dim_in = in_dim if i == 0 else hidden mlp = nn.Sequential( nn.Linear(dim_in, hidden), nn.ReLU(), nn.Linear(hidden, hidden), ) self.convs.append(GINConv(mlp)) self.bns.append(nn.BatchNorm1d(hidden)) self.classifier = nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, num_classes), ) def forward(self, x, edge_index, batch): for conv, bn in zip(self.convs, self.bns): x = torch.relu(bn(conv(x, edge_index))) return self.classifier(global_mean_pool(x, batch)) class GCNClassifier(nn.Module, PyTorchModelHubMixin): """GCN baseline for graph-level classification.""" def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5): super().__init__() self.convs = nn.ModuleList() self.bns = nn.ModuleList() for i in range(num_layers): dim_in = in_dim if i == 0 else hidden self.convs.append(GCNConv(dim_in, hidden)) self.bns.append(nn.BatchNorm1d(hidden)) self.classifier = nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, num_classes), ) def forward(self, x, edge_index, batch): for conv, bn in zip(self.convs, self.bns): x = torch.relu(bn(conv(x, edge_index))) return self.classifier(global_mean_pool(x, batch)) class GATClassifier(nn.Module, PyTorchModelHubMixin): """GAT baseline for graph-level classification.""" def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5, heads=4): super().__init__() self.convs = nn.ModuleList() self.bns = nn.ModuleList() for i in range(num_layers): dim_in = in_dim if i == 0 else hidden self.convs.append(GATConv(dim_in, hidden // heads, heads=heads, concat=True)) self.bns.append(nn.BatchNorm1d(hidden)) self.classifier = nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, num_classes), ) def forward(self, x, edge_index, batch): for conv, bn in zip(self.convs, self.bns): x = torch.relu(bn(conv(x, edge_index))) return self.classifier(global_mean_pool(x, batch)) MODEL_REGISTRY = {"GIN": GINClassifier, "GCN": GCNClassifier, "GAT": GATClassifier}