petter2025's picture
Create gnn_predictor.py
d0da04a verified
raw
history blame
1.36 kB
"""
PyTorch Geometric model for failure propagation prediction.
Falls back to dummy linear model if PyG not available.
"""
import torch
import torch.nn.functional as F
try:
from torch_geometric.nn import GCNConv
TORCH_GEOMETRIC_AVAILABLE = True
except ImportError:
TORCH_GEOMETRIC_AVAILABLE = False
class FailureGNN(torch.nn.Module):
def __init__(self, num_features=5, hidden=16, num_classes=2):
super().__init__()
self.num_features = num_features
self.hidden = hidden
self.num_classes = num_classes
if TORCH_GEOMETRIC_AVAILABLE:
self.conv1 = GCNConv(num_features, hidden)
self.conv2 = GCNConv(hidden, num_classes)
else:
# Fallback linear model (no graph structure)
self.fc = torch.nn.Linear(num_features, num_classes)
def forward(self, x, edge_index=None):
"""
x: node features [num_nodes, num_features]
edge_index: graph connectivity [2, num_edges] (optional)
"""
if TORCH_GEOMETRIC_AVAILABLE and edge_index is not None:
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
else:
x = self.fc(x) # ignore graph structure
return F.log_softmax(x, dim=1)