File size: 3,091 Bytes
c84b37e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import torch
import torch_geometric

__all__ = ["GNNPolicy"]

class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
    """
    Class Description:
    Based on graph convolution, define the bipartite graph semi-convolution process.
    """

    def __init__(self):
        '''
        Function Description:
        Define the size of the encoding space, and implement the semi-convolution layer and output layer.
        '''
        raise NotImplementedError('BipartiteGraphConvolution __init__ method should be implemented')

    def forward(self, left_features, edge_indices, edge_features, right_features):
        '''
        Function Description:
        Based on the given node and edge features, output the result of forward propagation after semi-convolution.

        Parameters:
        - left_features: Features of the nodes on the left side of the bipartite graph.
        - edge_indices: Edge information.
        - edge_features: Edge features.
        - right_features: Features of the nodes on the right side of the bipartite graph.

        Return: The result after forward propagation.
        '''
        raise NotImplementedError('BipartiteGraphConvolution forward method should be implemented')

    def message(self, node_features_i, node_features_j, edge_features):
        '''
        Function Description:
        This method sends the messages, computed in the message method.
        
        Parameters:
        - node_features_i: Features of the nodes on the left side of the bipartite graph.
        - node_features_j: Features of the nodes on the right side of the bipartite graph.
        - edge_features: Edge features.

        Return: The result after the message passing in the semi-convolution.
        '''
        raise NotImplementedError('BipartiteGraphConvolution message method should be implemented')


class GNNPolicy(torch.nn.Module):
    """
    Class Description:
    Based on the semi-convolutional layer, define the entire GNN network structure.
    """
    def __init__(self):
        '''
        Function Description:
        Define the size of the encoding space, and define the layers for decision variable encoding, edge feature encoding, and constraint feature encoding.
        Define two semi-convolutional layers and the final output layer.
        '''
        raise NotImplementedError('GNNPolicy __init__ method should be implemented')

        

    def forward(
        self, constraint_features, edge_indices, edge_features, variable_features
    ):
        '''
        Function Description:
        Based on the given constraint, edge, and variable features, output the result of forward propagation after GNN.

        Parameters:
        - constraint_features: Features of the constraint points.
        - edge_indices: Edge information.
        - edge_features: Edge features.
        - variable_features: Features of the variable points.

        Return: The result after forward propagation.
        '''
        raise NotImplementedError('GNNPolicy forward method should be implemented')