THU-IAR's picture
Upload 28 files
c84b37e
import argparse
import pickle
from pathlib import Path
from typing import Union
import os
import torch
import torch.nn.functional as F
import torch_geometric
from pytorch_metric_learning import losses
from model.graphcnn import GNNPolicy
class BipartiteNodeData(torch_geometric.data.Data):
"""
Class Description:
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
observation function in a format understood by the pytorch geometric data handlers.
"""
def __init__(
self,
constraint_features,
edge_indices,
edge_features,
variable_features,
assignment1,
assignment2
):
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.assignment1 = assignment1
self.assignment2 = assignment2
def __inc__(self, key, value, store, *args, **kwargs):
"""
Function Description:
We overload the pytorch geometric method that tells how to increment indices when concatenating graphs
for those entries (edge index, candidates) for which this is not obvious.
"""
if key == "edge_index":
return torch.tensor(
[[self.constraint_features.size(0)], [self.variable_features.size(0)]]
)
elif key == "candidates":
return self.variable_features.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
class GraphDataset(torch_geometric.data.Dataset):
"""
Class Description:
This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
It can be used in turn by the data loaders provided by pytorch geometric.
"""
def __init__(self, sample_files):
super().__init__(root=None, transform=None, pre_transform=None)
self.sample_files = sample_files
def len(self):
return len(self.sample_files)
def get(self, index):
"""
Function Description:
This method loads a node bipartite graph observation as saved on the disk during data collection.
"""
with open(self.sample_files[index], "rb") as f:
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f)
graph = BipartiteNodeData(
torch.FloatTensor(constraint_features),
torch.LongTensor(edge_indices),
torch.FloatTensor(edge_features),
torch.FloatTensor(variable_features),
torch.LongTensor(solution1),
torch.LongTensor(solution2),
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = len(constraint_features) + len(variable_features)
graph.cons_nodes = len(constraint_features)
graph.vars_nodes = len(variable_features)
return graph
def make(number: int,
model_path : str,
device : torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
"""
Function Description:
Obtain the encoding information of the decision variables based on the input problem data and package the output.
"""
policy = GNNPolicy().to(device)
policy.load_state_dict(torch.load(model_path, policy.state_dict()))
File = []
for num in range(number):
if(os.path.exists('./example/pair' + str(num) + '.pickle') == False):
print("No input file!")
return
File.append('example/pair' + str(num) + '.pickle')
data = GraphDataset(File)
loader = torch_geometric.loader.DataLoader(data, batch_size = 1)
now_site = 0
for batch in loader:
batch = batch.to(device)
# Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
logits = policy(
batch.constraint_features,
batch.edge_index,
batch.edge_attr,
batch.variable_features,
)
print(logits)
with open('./example/sample' + str(now_site) + '.pickle', "rb") as f:
solution = pickle.load(f)
with open('./example/node' + str(now_site) + '.pickle', 'wb') as f:
pickle.dump([logits.tolist(), solution[4]], f)
print(now_site)
now_site += 1
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--number", type=int, default=30)
parser.add_argument("--model_path", type=str, default="trained_model.pkl")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
make(**vars(args))