| |
| import torch |
| import torch.nn.functional as F |
| try: |
| from torch_geometric.nn import GCNConv |
| TORCH_GEOMETRIC_AVAILABLE = True |
| except ImportError: |
| TORCH_GEOMETRIC_AVAILABLE = False |
|
|
| class FailureGNN(torch.nn.Module): |
| def __init__(self, num_features=5, hidden=16): |
| super().__init__() |
| if TORCH_GEOMETRIC_AVAILABLE: |
| self.conv1 = GCNConv(num_features, hidden) |
| self.conv2 = GCNConv(hidden, 2) |
| else: |
| self.dummy = torch.nn.Linear(num_features, 2) |
|
|
| def forward(self, x, edge_index=None): |
| if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None: |
| x = self.conv1(x, edge_index) |
| x = F.relu(x) |
| x = self.conv2(x, edge_index) |
| else: |
| x = self.dummy(x) |
| return F.log_softmax(x, dim=1) |