| """ |
| PyTorch Geometric model for failure propagation prediction. |
| Falls back to dummy linear model if PyG not available. |
| """ |
| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| from torch_geometric.nn import GCNConv |
| TORCH_GEOMETRIC_AVAILABLE = True |
| except ImportError: |
| TORCH_GEOMETRIC_AVAILABLE = False |
|
|
| class FailureGNN(torch.nn.Module): |
| def __init__(self, num_features=5, hidden=16, num_classes=2): |
| super().__init__() |
| self.num_features = num_features |
| self.hidden = hidden |
| self.num_classes = num_classes |
| |
| if TORCH_GEOMETRIC_AVAILABLE: |
| self.conv1 = GCNConv(num_features, hidden) |
| self.conv2 = GCNConv(hidden, num_classes) |
| else: |
| |
| self.fc = torch.nn.Linear(num_features, num_classes) |
| |
| def forward(self, x, edge_index=None): |
| """ |
| x: node features [num_nodes, num_features] |
| edge_index: graph connectivity [2, num_edges] (optional) |
| """ |
| if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None: |
| x = self.conv1(x, edge_index) |
| x = F.relu(x) |
| x = F.dropout(x, training=self.training) |
| x = self.conv2(x, edge_index) |
| else: |
| x = self.fc(x) |
| return F.log_softmax(x, dim=1) |