|
|
import argparse |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch_geometric |
|
|
from pytorch_metric_learning import losses |
|
|
|
|
|
from model.graphcnn import GNNPolicy |
|
|
|
|
|
class BipartiteNodeData(torch_geometric.data.Data): |
|
|
""" |
|
|
Class Description: |
|
|
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` |
|
|
observation function in a format understood by the pytorch geometric data handlers. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
constraint_features, |
|
|
edge_indices, |
|
|
edge_features, |
|
|
variable_features, |
|
|
assignment1, |
|
|
assignment2 |
|
|
): |
|
|
super().__init__() |
|
|
self.constraint_features = constraint_features |
|
|
self.edge_index = edge_indices |
|
|
self.edge_attr = edge_features |
|
|
self.variable_features = variable_features |
|
|
self.assignment1 = assignment1 |
|
|
self.assignment2 = assignment2 |
|
|
|
|
|
def __inc__(self, key, value, store, *args, **kwargs): |
|
|
""" |
|
|
Function Description: |
|
|
We overload the pytorch geometric method that tells how to increment indices when concatenating graphs |
|
|
for those entries (edge index, candidates) for which this is not obvious. |
|
|
""" |
|
|
if key == "edge_index": |
|
|
return torch.tensor( |
|
|
[[self.constraint_features.size(0)], [self.variable_features.size(0)]] |
|
|
) |
|
|
elif key == "candidates": |
|
|
return self.variable_features.size(0) |
|
|
else: |
|
|
return super().__inc__(key, value, *args, **kwargs) |
|
|
|
|
|
|
|
|
class GraphDataset(torch_geometric.data.Dataset): |
|
|
""" |
|
|
Class Description: |
|
|
This class encodes a collection of graphs, as well as a method to load such graphs from the disk. |
|
|
It can be used in turn by the data loaders provided by pytorch geometric. |
|
|
""" |
|
|
|
|
|
def __init__(self, sample_files): |
|
|
super().__init__(root=None, transform=None, pre_transform=None) |
|
|
self.sample_files = sample_files |
|
|
|
|
|
def len(self): |
|
|
return len(self.sample_files) |
|
|
|
|
|
def get(self, index): |
|
|
""" |
|
|
Function Description: |
|
|
This method loads a node bipartite graph observation as saved on the disk during data collection. |
|
|
""" |
|
|
with open(self.sample_files[index], "rb") as f: |
|
|
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f) |
|
|
|
|
|
graph = BipartiteNodeData( |
|
|
torch.FloatTensor(constraint_features), |
|
|
torch.LongTensor(edge_indices), |
|
|
torch.FloatTensor(edge_features), |
|
|
torch.FloatTensor(variable_features), |
|
|
torch.LongTensor(solution1), |
|
|
torch.LongTensor(solution2), |
|
|
) |
|
|
|
|
|
|
|
|
graph.num_nodes = len(constraint_features) + len(variable_features) |
|
|
graph.cons_nodes = len(constraint_features) |
|
|
graph.vars_nodes = len(variable_features) |
|
|
|
|
|
return graph |
|
|
|
|
|
def make(number: int, |
|
|
model_path : str, |
|
|
device : torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")): |
|
|
""" |
|
|
Function Description: |
|
|
Obtain the encoding information of the decision variables based on the input problem data and package the output. |
|
|
""" |
|
|
policy = GNNPolicy().to(device) |
|
|
policy.load_state_dict(torch.load(model_path, policy.state_dict())) |
|
|
File = [] |
|
|
for num in range(number): |
|
|
if(os.path.exists('./example/pair' + str(num) + '.pickle') == False): |
|
|
print("No input file!") |
|
|
return |
|
|
File.append('example/pair' + str(num) + '.pickle') |
|
|
|
|
|
data = GraphDataset(File) |
|
|
loader = torch_geometric.loader.DataLoader(data, batch_size = 1) |
|
|
|
|
|
now_site = 0 |
|
|
for batch in loader: |
|
|
batch = batch.to(device) |
|
|
|
|
|
logits = policy( |
|
|
batch.constraint_features, |
|
|
batch.edge_index, |
|
|
batch.edge_attr, |
|
|
batch.variable_features, |
|
|
) |
|
|
print(logits) |
|
|
with open('./example/sample' + str(now_site) + '.pickle', "rb") as f: |
|
|
solution = pickle.load(f) |
|
|
with open('./example/node' + str(now_site) + '.pickle', 'wb') as f: |
|
|
pickle.dump([logits.tolist(), solution[4]], f) |
|
|
print(now_site) |
|
|
now_site += 1 |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--number", type=int, default=30) |
|
|
parser.add_argument("--model_path", type=str, default="trained_model.pkl") |
|
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.") |
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
make(**vars(args)) |