File size: 4,968 Bytes
c84b37e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import pickle
from pathlib import Path
from typing import Union

import os
import torch
import torch.nn.functional as F
import torch_geometric
from pytorch_metric_learning import losses

from model.graphcnn import GNNPolicy

class BipartiteNodeData(torch_geometric.data.Data):
    """
    Class Description:
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
    observation function in a format understood by the pytorch geometric data handlers.
    """

    def __init__(
        self,
        constraint_features,
        edge_indices,
        edge_features,
        variable_features,
        assignment1,
        assignment2
    ):
        super().__init__()
        self.constraint_features = constraint_features
        self.edge_index = edge_indices
        self.edge_attr = edge_features
        self.variable_features = variable_features
        self.assignment1 = assignment1
        self.assignment2 = assignment2

    def __inc__(self, key, value, store, *args, **kwargs):
        """
        Function Description:
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs
        for those entries (edge index, candidates) for which this is not obvious.
        """
        if key == "edge_index":
            return torch.tensor(
                [[self.constraint_features.size(0)], [self.variable_features.size(0)]]
            )
        elif key == "candidates":
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class GraphDataset(torch_geometric.data.Dataset):
    """
    Class Description:
    This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
    It can be used in turn by the data loaders provided by pytorch geometric.
    """

    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        """
        Function Description:
        This method loads a node bipartite graph observation as saved on the disk during data collection.
        """
        with open(self.sample_files[index], "rb") as f:
            [variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f)

        graph = BipartiteNodeData(
            torch.FloatTensor(constraint_features),
            torch.LongTensor(edge_indices),
            torch.FloatTensor(edge_features),
            torch.FloatTensor(variable_features),
            torch.LongTensor(solution1),
            torch.LongTensor(solution2),
        )

        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = len(constraint_features) + len(variable_features)
        graph.cons_nodes = len(constraint_features)
        graph.vars_nodes = len(variable_features)

        return graph

def make(number: int,
         model_path : str,
         device : torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    """
    Function Description:
    Obtain the encoding information of the decision variables based on the input problem data and package the output.
    """
    policy = GNNPolicy().to(device)
    policy.load_state_dict(torch.load(model_path, policy.state_dict()))
    File = []
    for num in range(number):
        if(os.path.exists('./example/pair' + str(num) + '.pickle') == False):
            print("No input file!")
            return 
        File.append('example/pair' + str(num) + '.pickle')
    
    data = GraphDataset(File)
    loader = torch_geometric.loader.DataLoader(data, batch_size = 1)

    now_site = 0
    for batch in loader:
        batch = batch.to(device)
        # Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
        logits = policy(
            batch.constraint_features,
            batch.edge_index,
            batch.edge_attr,
            batch.variable_features,
        )
        print(logits)
        with open('./example/sample' + str(now_site) + '.pickle', "rb") as f:
            solution = pickle.load(f)
        with open('./example/node' + str(now_site) + '.pickle', 'wb') as f:
            pickle.dump([logits.tolist(), solution[4]], f)
            print(now_site)
            now_site += 1

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", type=int, default=30)
    parser.add_argument("--model_path", type=str, default="trained_model.pkl")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    make(**vars(args))