JLB-JLB's picture
Upload model.py with huggingface_hub
4be8e03 verified
"""
GNN graph classifiers for Android ransomware detection.
Architectures: GIN, GCN, GAT — all using BatchNorm + global mean pooling.
Uses `PyTorchModelHubMixin` for native `from_pretrained()` / `save_pretrained()`.
Usage:
from model import GINClassifier
# Load from a local directory saved by save_pretrained()
model = GINClassifier.from_pretrained("./GIN/internal_only")
model.eval()
"""
import torch
import torch.nn as nn
from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool
from huggingface_hub import PyTorchModelHubMixin
class GINClassifier(nn.Module, PyTorchModelHubMixin):
"""Graph Isomorphism Network for graph-level binary classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
mlp = nn.Sequential(
nn.Linear(dim_in, hidden), nn.ReLU(), nn.Linear(hidden, hidden),
)
self.convs.append(GINConv(mlp))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
class GCNClassifier(nn.Module, PyTorchModelHubMixin):
"""GCN baseline for graph-level classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
self.convs.append(GCNConv(dim_in, hidden))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
class GATClassifier(nn.Module, PyTorchModelHubMixin):
"""GAT baseline for graph-level classification."""
def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5, heads=4):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
dim_in = in_dim if i == 0 else hidden
self.convs.append(GATConv(dim_in, hidden // heads, heads=heads, concat=True))
self.bns.append(nn.BatchNorm1d(hidden))
self.classifier = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x, edge_index, batch):
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.classifier(global_mean_pool(x, batch))
MODEL_REGISTRY = {"GIN": GINClassifier, "GCN": GCNClassifier, "GAT": GATClassifier}