petter2025 commited on
Commit
d0da04a
·
verified ·
1 Parent(s): 7707632

Create gnn_predictor.py

Browse files
Files changed (1) hide show
  1. gnn_predictor.py +20 -6
gnn_predictor.py CHANGED
@@ -1,6 +1,10 @@
1
- # gnn_predictor.py
 
 
 
2
  import torch
3
  import torch.nn.functional as F
 
4
  try:
5
  from torch_geometric.nn import GCNConv
6
  TORCH_GEOMETRIC_AVAILABLE = True
@@ -8,19 +12,29 @@ except ImportError:
8
  TORCH_GEOMETRIC_AVAILABLE = False
9
 
10
  class FailureGNN(torch.nn.Module):
11
- def __init__(self, num_features=5, hidden=16):
12
  super().__init__()
 
 
 
 
13
  if TORCH_GEOMETRIC_AVAILABLE:
14
  self.conv1 = GCNConv(num_features, hidden)
15
- self.conv2 = GCNConv(hidden, 2)
16
  else:
17
- self.dummy = torch.nn.Linear(num_features, 2)
18
-
 
19
  def forward(self, x, edge_index=None):
 
 
 
 
20
  if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None:
21
  x = self.conv1(x, edge_index)
22
  x = F.relu(x)
 
23
  x = self.conv2(x, edge_index)
24
  else:
25
- x = self.dummy(x)
26
  return F.log_softmax(x, dim=1)
 
1
+ """
2
+ PyTorch Geometric model for failure propagation prediction.
3
+ Falls back to dummy linear model if PyG not available.
4
+ """
5
  import torch
6
  import torch.nn.functional as F
7
+
8
  try:
9
  from torch_geometric.nn import GCNConv
10
  TORCH_GEOMETRIC_AVAILABLE = True
 
12
  TORCH_GEOMETRIC_AVAILABLE = False
13
 
14
  class FailureGNN(torch.nn.Module):
15
+ def __init__(self, num_features=5, hidden=16, num_classes=2):
16
  super().__init__()
17
+ self.num_features = num_features
18
+ self.hidden = hidden
19
+ self.num_classes = num_classes
20
+
21
  if TORCH_GEOMETRIC_AVAILABLE:
22
  self.conv1 = GCNConv(num_features, hidden)
23
+ self.conv2 = GCNConv(hidden, num_classes)
24
  else:
25
+ # Fallback linear model (no graph structure)
26
+ self.fc = torch.nn.Linear(num_features, num_classes)
27
+
28
  def forward(self, x, edge_index=None):
29
+ """
30
+ x: node features [num_nodes, num_features]
31
+ edge_index: graph connectivity [2, num_edges] (optional)
32
+ """
33
  if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None:
34
  x = self.conv1(x, edge_index)
35
  x = F.relu(x)
36
+ x = F.dropout(x, training=self.training)
37
  x = self.conv2(x, edge_index)
38
  else:
39
+ x = self.fc(x) # ignore graph structure
40
  return F.log_softmax(x, dim=1)