File size: 4,968 Bytes
c84b37e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
import pickle
from pathlib import Path
from typing import Union
import os
import torch
import torch.nn.functional as F
import torch_geometric
from pytorch_metric_learning import losses
from model.graphcnn import GNNPolicy
class BipartiteNodeData(torch_geometric.data.Data):
"""
Class Description:
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
observation function in a format understood by the pytorch geometric data handlers.
"""
def __init__(
self,
constraint_features,
edge_indices,
edge_features,
variable_features,
assignment1,
assignment2
):
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.assignment1 = assignment1
self.assignment2 = assignment2
def __inc__(self, key, value, store, *args, **kwargs):
"""
Function Description:
We overload the pytorch geometric method that tells how to increment indices when concatenating graphs
for those entries (edge index, candidates) for which this is not obvious.
"""
if key == "edge_index":
return torch.tensor(
[[self.constraint_features.size(0)], [self.variable_features.size(0)]]
)
elif key == "candidates":
return self.variable_features.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
class GraphDataset(torch_geometric.data.Dataset):
"""
Class Description:
This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
It can be used in turn by the data loaders provided by pytorch geometric.
"""
def __init__(self, sample_files):
super().__init__(root=None, transform=None, pre_transform=None)
self.sample_files = sample_files
def len(self):
return len(self.sample_files)
def get(self, index):
"""
Function Description:
This method loads a node bipartite graph observation as saved on the disk during data collection.
"""
with open(self.sample_files[index], "rb") as f:
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f)
graph = BipartiteNodeData(
torch.FloatTensor(constraint_features),
torch.LongTensor(edge_indices),
torch.FloatTensor(edge_features),
torch.FloatTensor(variable_features),
torch.LongTensor(solution1),
torch.LongTensor(solution2),
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = len(constraint_features) + len(variable_features)
graph.cons_nodes = len(constraint_features)
graph.vars_nodes = len(variable_features)
return graph
def make(number: int,
model_path : str,
device : torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
"""
Function Description:
Obtain the encoding information of the decision variables based on the input problem data and package the output.
"""
policy = GNNPolicy().to(device)
policy.load_state_dict(torch.load(model_path, policy.state_dict()))
File = []
for num in range(number):
if(os.path.exists('./example/pair' + str(num) + '.pickle') == False):
print("No input file!")
return
File.append('example/pair' + str(num) + '.pickle')
data = GraphDataset(File)
loader = torch_geometric.loader.DataLoader(data, batch_size = 1)
now_site = 0
for batch in loader:
batch = batch.to(device)
# Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
logits = policy(
batch.constraint_features,
batch.edge_index,
batch.edge_attr,
batch.variable_features,
)
print(logits)
with open('./example/sample' + str(now_site) + '.pickle', "rb") as f:
solution = pickle.load(f)
with open('./example/node' + str(now_site) + '.pickle', 'wb') as f:
pickle.dump([logits.tolist(), solution[4]], f)
print(now_site)
now_site += 1
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--number", type=int, default=30)
parser.add_argument("--model_path", type=str, default="trained_model.pkl")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
make(**vars(args)) |