| import numpy as np | |
| import torch | |
| import torch_geometric | |
| __all__ = ["GNNPolicy"] | |
| class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing): | |
| """ | |
| Class Description: | |
| Based on graph convolution, define the bipartite graph semi-convolution process. | |
| """ | |
| def __init__(self): | |
| ''' | |
| Function Description: | |
| Define the size of the encoding space, and implement the semi-convolution layer and output layer. | |
| ''' | |
| raise NotImplementedError('BipartiteGraphConvolution __init__ method should be implemented') | |
| def forward(self, left_features, edge_indices, edge_features, right_features): | |
| ''' | |
| Function Description: | |
| Based on the given node and edge features, output the result of forward propagation after semi-convolution. | |
| Parameters: | |
| - left_features: Features of the nodes on the left side of the bipartite graph. | |
| - edge_indices: Edge information. | |
| - edge_features: Edge features. | |
| - right_features: Features of the nodes on the right side of the bipartite graph. | |
| Return: The result after forward propagation. | |
| ''' | |
| raise NotImplementedError('BipartiteGraphConvolution forward method should be implemented') | |
| def message(self, node_features_i, node_features_j, edge_features): | |
| ''' | |
| Function Description: | |
| This method sends the messages, computed in the message method. | |
| Parameters: | |
| - node_features_i: Features of the nodes on the left side of the bipartite graph. | |
| - node_features_j: Features of the nodes on the right side of the bipartite graph. | |
| - edge_features: Edge features. | |
| Return: The result after the message passing in the semi-convolution. | |
| ''' | |
| raise NotImplementedError('BipartiteGraphConvolution message method should be implemented') | |
| class GNNPolicy(torch.nn.Module): | |
| """ | |
| Class Description: | |
| Based on the semi-convolutional layer, define the entire GNN network structure. | |
| """ | |
| def __init__(self): | |
| ''' | |
| Function Description: | |
| Define the size of the encoding space, and define the layers for decision variable encoding, edge feature encoding, and constraint feature encoding. | |
| Define two semi-convolutional layers and the final output layer. | |
| ''' | |
| raise NotImplementedError('GNNPolicy __init__ method should be implemented') | |
| def forward( | |
| self, constraint_features, edge_indices, edge_features, variable_features | |
| ): | |
| ''' | |
| Function Description: | |
| Based on the given constraint, edge, and variable features, output the result of forward propagation after GNN. | |
| Parameters: | |
| - constraint_features: Features of the constraint points. | |
| - edge_indices: Edge information. | |
| - edge_features: Edge features. | |
| - variable_features: Features of the variable points. | |
| Return: The result after forward propagation. | |
| ''' | |
| raise NotImplementedError('GNNPolicy forward method should be implemented') | |