THU-IAR's picture
Upload 28 files
c84b37e
import argparse
import pickle
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
import torch_geometric
from pytorch_metric_learning import losses
from model.graphcnn import GNNPolicy
__all__ = ["train"]
class BipartiteNodeData(torch_geometric.data.Data):
"""
Class Description:
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
observation function in a format understood by the pytorch geometric data handlers.
"""
def __init__(
self,
constraint_features,
edge_indices,
edge_features,
variable_features,
assignment1,
assignment2
):
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.assignment1 = assignment1
self.assignment2 = assignment2
def __inc__(self, key, value, store, *args, **kwargs):
"""
Function Description:
Overload the pytorch geometric method that tells how to increment indices when concatenating graphs for those entries (edge index, candidates) for which this is not obvious.
"""
if key == "edge_index":
return torch.tensor(
[[self.constraint_features.size(0)], [self.variable_features.size(0)]]
)
elif key == "candidates":
return self.variable_features.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
class GraphDataset(torch_geometric.data.Dataset):
"""
Class Description:
This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
It can be used in turn by the data loaders provided by pytorch geometric.
"""
def __init__(self, sample_files):
super().__init__(root=None, transform=None, pre_transform=None)
self.sample_files = sample_files
def len(self):
return len(self.sample_files)
def get(self, index):
"""
Function Description:
This method loads a node bipartite graph observation as saved on the disk during data collection.
"""
with open(self.sample_files[index], "rb") as f:
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f)
graph = BipartiteNodeData(
torch.FloatTensor(constraint_features),
torch.LongTensor(edge_indices),
torch.FloatTensor(edge_features),
torch.FloatTensor(variable_features),
torch.LongTensor(solution1),
torch.LongTensor(solution2),
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = len(constraint_features) + len(variable_features)
graph.cons_nodes = len(constraint_features)
graph.vars_nodes = len(variable_features)
return graph
def pad_tensor(input_, pad_sizes, pad_value=-1e8):
"""
Function Description:
This utility function splits a tensor and pads each split to make them all the same size, then stacks them.
"""
max_pad_size = pad_sizes.max()
output = input_.split(pad_sizes.cpu().numpy().tolist())
output = torch.stack(
[
F.pad(slice_, (0, max_pad_size - slice_.size(0)), "constant", pad_value)
for slice_ in output
],
dim=0,
)
return output
def process(policy, data_loader, device, optimizer=None):
"""
Function Description:
This function will process a whole epoch of training or validation, depending on whether an optimizer is provided.
"""
mean_loss = 0
mean_acc = 0
n_samples_processed = 0
with torch.set_grad_enabled(optimizer is not None):
for batch in data_loader:
#print("QwQ")
batch = batch.to(device)
# Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
logits = policy(
batch.constraint_features,
batch.edge_index,
batch.edge_attr,
batch.variable_features,
)
# Graph partitioning related metric functions, where num_classes represents the number of partitions in the graph.
loss_funcA = losses.ProxyAnchorLoss(num_classes = 10, embedding_size = 16)
# Metric functions related to the optimal solution. In general integer programming problems, clustering the solution values and modifying num_classes can be done.
loss_funcB = losses.ProxyAnchorLoss(num_classes = 2, embedding_size = 16)
loss = loss_funcA(logits, batch.assignment1.to(torch.int64)) + loss_funcB(logits, batch.assignment2.to(torch.int64))
loss_optimizerA = torch.optim.SGD(loss_funcA.parameters(), lr = 0.01)
loss_optimizerB = torch.optim.SGD(loss_funcB.parameters(), lr = 0.01)
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_optimizerA.step()
loss_optimizerB.step()
mean_loss += loss.item() * batch.num_graphs
n_samples_processed += batch.num_graphs
mean_loss /= n_samples_processed
return mean_loss
def train(
model_save_path: Union[str, Path],
batch_size: int = 1,
learning_rate: float = 1e-3,
num_epochs: int = 20,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
"""
Function Description:
This function trains a GNN policy on training data.
Parameters:
- data_path: Path to the data directory.
- model_save_path: Path to save the model.
- batch_size: Batch size for training.
- learning_rate: Learning rate for the optimizer.
- num_epochs: Number of epochs to train for.
- device: Device to use for training.
"""
# load samples from data_path and divide them
sample_files = [str(path) for path in Path('./example').glob("pair*.pickle")]
print(sample_files)
train_files = sample_files[: int(0.6 * len(sample_files))]
valid_files = sample_files[int(0.9 * len(sample_files)) :]
train_data = GraphDataset(train_files)
train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=batch_size, shuffle = False)
valid_data = GraphDataset(valid_files)
valid_loader = torch_geometric.loader.DataLoader(valid_data, batch_size=batch_size, shuffle = False)
policy = GNNPolicy().to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
train_loss = process(policy, train_loader, device, optimizer)
#valid_loss = process(policy, valid_loader, device, None)
valid_loss = 0
print(f"Epoch {epoch+1}: Train Loss: {train_loss:0.3f}, Valid Loss: {valid_loss:0.3f}")
torch.save(policy.state_dict(), model_save_path)
print(f"Trained parameters saved to {model_save_path}")
def parse_args():
"""
Function Description:
This function parses the command line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--model_save_path", type=str, default="trained_model.pkl", help="Path to save the model.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer.")
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs to train for.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
train(**vars(args))