petter2025 commited on
Commit
39cd578
·
verified ·
1 Parent(s): df0e97b

Create gnn_predictor.py

Browse files
Files changed (1) hide show
  1. gnn_predictor.py +26 -0
gnn_predictor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gnn_predictor.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ try:
5
+ from torch_geometric.nn import GCNConv
6
+ TORCH_GEOMETRIC_AVAILABLE = True
7
+ except ImportError:
8
+ TORCH_GEOMETRIC_AVAILABLE = False
9
+
10
+ class FailureGNN(torch.nn.Module):
11
+ def __init__(self, num_features=5, hidden=16):
12
+ super().__init__()
13
+ if TORCH_GEOMETRIC_AVAILABLE:
14
+ self.conv1 = GCNConv(num_features, hidden)
15
+ self.conv2 = GCNConv(hidden, 2)
16
+ else:
17
+ self.dummy = torch.nn.Linear(num_features, 2)
18
+
19
+ def forward(self, x, edge_index=None):
20
+ if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None:
21
+ x = self.conv1(x, edge_index)
22
+ x = F.relu(x)
23
+ x = self.conv2(x, edge_index)
24
+ else:
25
+ x = self.dummy(x)
26
+ return F.log_softmax(x, dim=1)