File size: 3,091 Bytes
c84b37e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import numpy as np
import torch
import torch_geometric
__all__ = ["GNNPolicy"]
class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
"""
Class Description:
Based on graph convolution, define the bipartite graph semi-convolution process.
"""
def __init__(self):
'''
Function Description:
Define the size of the encoding space, and implement the semi-convolution layer and output layer.
'''
raise NotImplementedError('BipartiteGraphConvolution __init__ method should be implemented')
def forward(self, left_features, edge_indices, edge_features, right_features):
'''
Function Description:
Based on the given node and edge features, output the result of forward propagation after semi-convolution.
Parameters:
- left_features: Features of the nodes on the left side of the bipartite graph.
- edge_indices: Edge information.
- edge_features: Edge features.
- right_features: Features of the nodes on the right side of the bipartite graph.
Return: The result after forward propagation.
'''
raise NotImplementedError('BipartiteGraphConvolution forward method should be implemented')
def message(self, node_features_i, node_features_j, edge_features):
'''
Function Description:
This method sends the messages, computed in the message method.
Parameters:
- node_features_i: Features of the nodes on the left side of the bipartite graph.
- node_features_j: Features of the nodes on the right side of the bipartite graph.
- edge_features: Edge features.
Return: The result after the message passing in the semi-convolution.
'''
raise NotImplementedError('BipartiteGraphConvolution message method should be implemented')
class GNNPolicy(torch.nn.Module):
"""
Class Description:
Based on the semi-convolutional layer, define the entire GNN network structure.
"""
def __init__(self):
'''
Function Description:
Define the size of the encoding space, and define the layers for decision variable encoding, edge feature encoding, and constraint feature encoding.
Define two semi-convolutional layers and the final output layer.
'''
raise NotImplementedError('GNNPolicy __init__ method should be implemented')
def forward(
self, constraint_features, edge_indices, edge_features, variable_features
):
'''
Function Description:
Based on the given constraint, edge, and variable features, output the result of forward propagation after GNN.
Parameters:
- constraint_features: Features of the constraint points.
- edge_indices: Edge information.
- edge_features: Edge features.
- variable_features: Features of the variable points.
Return: The result after forward propagation.
'''
raise NotImplementedError('GNNPolicy forward method should be implemented')
|