JLB-JLB commited on
Commit
8d5f9ba
·
verified ·
1 Parent(s): 29ce481

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +96 -0
model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GNN graph classifiers for Android ransomware detection.
3
+
4
+ Architectures: GIN, GCN, GAT — all using BatchNorm + global mean pooling.
5
+ Uses `PyTorchModelHubMixin` for native `from_pretrained()` / `save_pretrained()`.
6
+
7
+ Usage:
8
+ from model import GINClassifier
9
+
10
+ # Load from a local directory saved by save_pretrained()
11
+ model = GINClassifier.from_pretrained("./GIN/internal_only")
12
+
13
+ # Or load directly from the Hugging Face Hub
14
+ model = GINClassifier.from_pretrained(
15
+ "USER/android-ransomware-gnn-baseline",
16
+ subfolder="GIN/internal_only",
17
+ )
18
+ model.eval()
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool
24
+ from huggingface_hub import PyTorchModelHubMixin
25
+
26
+
27
+ class GINClassifier(nn.Module, PyTorchModelHubMixin):
28
+ """Graph Isomorphism Network for graph-level binary classification."""
29
+
30
+ def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
31
+ super().__init__()
32
+ self.convs = nn.ModuleList()
33
+ self.bns = nn.ModuleList()
34
+ for i in range(num_layers):
35
+ dim_in = in_dim if i == 0 else hidden
36
+ mlp = nn.Sequential(
37
+ nn.Linear(dim_in, hidden), nn.ReLU(), nn.Linear(hidden, hidden),
38
+ )
39
+ self.convs.append(GINConv(mlp))
40
+ self.bns.append(nn.BatchNorm1d(hidden))
41
+ self.classifier = nn.Sequential(
42
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
43
+ nn.Linear(hidden, num_classes),
44
+ )
45
+
46
+ def forward(self, x, edge_index, batch):
47
+ for conv, bn in zip(self.convs, self.bns):
48
+ x = torch.relu(bn(conv(x, edge_index)))
49
+ return self.classifier(global_mean_pool(x, batch))
50
+
51
+
52
+ class GCNClassifier(nn.Module, PyTorchModelHubMixin):
53
+ """GCN baseline for graph-level classification."""
54
+
55
+ def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5):
56
+ super().__init__()
57
+ self.convs = nn.ModuleList()
58
+ self.bns = nn.ModuleList()
59
+ for i in range(num_layers):
60
+ dim_in = in_dim if i == 0 else hidden
61
+ self.convs.append(GCNConv(dim_in, hidden))
62
+ self.bns.append(nn.BatchNorm1d(hidden))
63
+ self.classifier = nn.Sequential(
64
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
65
+ nn.Linear(hidden, num_classes),
66
+ )
67
+
68
+ def forward(self, x, edge_index, batch):
69
+ for conv, bn in zip(self.convs, self.bns):
70
+ x = torch.relu(bn(conv(x, edge_index)))
71
+ return self.classifier(global_mean_pool(x, batch))
72
+
73
+
74
+ class GATClassifier(nn.Module, PyTorchModelHubMixin):
75
+ """GAT baseline for graph-level classification."""
76
+
77
+ def __init__(self, in_dim=5, hidden=128, num_layers=3, num_classes=2, dropout=0.5, heads=4):
78
+ super().__init__()
79
+ self.convs = nn.ModuleList()
80
+ self.bns = nn.ModuleList()
81
+ for i in range(num_layers):
82
+ dim_in = in_dim if i == 0 else hidden
83
+ self.convs.append(GATConv(dim_in, hidden // heads, heads=heads, concat=True))
84
+ self.bns.append(nn.BatchNorm1d(hidden))
85
+ self.classifier = nn.Sequential(
86
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
87
+ nn.Linear(hidden, num_classes),
88
+ )
89
+
90
+ def forward(self, x, edge_index, batch):
91
+ for conv, bn in zip(self.convs, self.bns):
92
+ x = torch.relu(bn(conv(x, edge_index)))
93
+ return self.classifier(global_mean_pool(x, batch))
94
+
95
+
96
+ MODEL_REGISTRY = {"GIN": GINClassifier, "GCN": GCNClassifier, "GAT": GATClassifier}