|
|
import argparse |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch_geometric |
|
|
from pytorch_metric_learning import losses |
|
|
|
|
|
from model.graphcnn import GNNPolicy |
|
|
|
|
|
__all__ = ["train"] |
|
|
|
|
|
class BipartiteNodeData(torch_geometric.data.Data): |
|
|
""" |
|
|
Class Description: |
|
|
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` |
|
|
observation function in a format understood by the pytorch geometric data handlers. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
constraint_features, |
|
|
edge_indices, |
|
|
edge_features, |
|
|
variable_features, |
|
|
assignment1, |
|
|
assignment2 |
|
|
): |
|
|
super().__init__() |
|
|
self.constraint_features = constraint_features |
|
|
self.edge_index = edge_indices |
|
|
self.edge_attr = edge_features |
|
|
self.variable_features = variable_features |
|
|
self.assignment1 = assignment1 |
|
|
self.assignment2 = assignment2 |
|
|
|
|
|
def __inc__(self, key, value, store, *args, **kwargs): |
|
|
""" |
|
|
Function Description: |
|
|
Overload the pytorch geometric method that tells how to increment indices when concatenating graphs for those entries (edge index, candidates) for which this is not obvious. |
|
|
""" |
|
|
if key == "edge_index": |
|
|
return torch.tensor( |
|
|
[[self.constraint_features.size(0)], [self.variable_features.size(0)]] |
|
|
) |
|
|
elif key == "candidates": |
|
|
return self.variable_features.size(0) |
|
|
else: |
|
|
return super().__inc__(key, value, *args, **kwargs) |
|
|
|
|
|
|
|
|
class GraphDataset(torch_geometric.data.Dataset): |
|
|
""" |
|
|
Class Description: |
|
|
This class encodes a collection of graphs, as well as a method to load such graphs from the disk. |
|
|
It can be used in turn by the data loaders provided by pytorch geometric. |
|
|
""" |
|
|
|
|
|
def __init__(self, sample_files): |
|
|
super().__init__(root=None, transform=None, pre_transform=None) |
|
|
self.sample_files = sample_files |
|
|
|
|
|
def len(self): |
|
|
return len(self.sample_files) |
|
|
|
|
|
def get(self, index): |
|
|
""" |
|
|
Function Description: |
|
|
This method loads a node bipartite graph observation as saved on the disk during data collection. |
|
|
""" |
|
|
with open(self.sample_files[index], "rb") as f: |
|
|
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f) |
|
|
|
|
|
graph = BipartiteNodeData( |
|
|
torch.FloatTensor(constraint_features), |
|
|
torch.LongTensor(edge_indices), |
|
|
torch.FloatTensor(edge_features), |
|
|
torch.FloatTensor(variable_features), |
|
|
torch.LongTensor(solution1), |
|
|
torch.LongTensor(solution2), |
|
|
) |
|
|
|
|
|
|
|
|
graph.num_nodes = len(constraint_features) + len(variable_features) |
|
|
graph.cons_nodes = len(constraint_features) |
|
|
graph.vars_nodes = len(variable_features) |
|
|
|
|
|
return graph |
|
|
|
|
|
|
|
|
def pad_tensor(input_, pad_sizes, pad_value=-1e8): |
|
|
""" |
|
|
Function Description: |
|
|
This utility function splits a tensor and pads each split to make them all the same size, then stacks them. |
|
|
""" |
|
|
max_pad_size = pad_sizes.max() |
|
|
output = input_.split(pad_sizes.cpu().numpy().tolist()) |
|
|
output = torch.stack( |
|
|
[ |
|
|
F.pad(slice_, (0, max_pad_size - slice_.size(0)), "constant", pad_value) |
|
|
for slice_ in output |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
return output |
|
|
|
|
|
def process(policy, data_loader, device, optimizer=None): |
|
|
""" |
|
|
Function Description: |
|
|
This function will process a whole epoch of training or validation, depending on whether an optimizer is provided. |
|
|
""" |
|
|
mean_loss = 0 |
|
|
mean_acc = 0 |
|
|
|
|
|
n_samples_processed = 0 |
|
|
with torch.set_grad_enabled(optimizer is not None): |
|
|
for batch in data_loader: |
|
|
|
|
|
batch = batch.to(device) |
|
|
|
|
|
logits = policy( |
|
|
batch.constraint_features, |
|
|
batch.edge_index, |
|
|
batch.edge_attr, |
|
|
batch.variable_features, |
|
|
) |
|
|
|
|
|
loss_funcA = losses.ProxyAnchorLoss(num_classes = 10, embedding_size = 16) |
|
|
|
|
|
loss_funcB = losses.ProxyAnchorLoss(num_classes = 2, embedding_size = 16) |
|
|
loss = loss_funcA(logits, batch.assignment1.to(torch.int64)) + loss_funcB(logits, batch.assignment2.to(torch.int64)) |
|
|
|
|
|
loss_optimizerA = torch.optim.SGD(loss_funcA.parameters(), lr = 0.01) |
|
|
loss_optimizerB = torch.optim.SGD(loss_funcB.parameters(), lr = 0.01) |
|
|
|
|
|
if optimizer is not None: |
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
loss_optimizerA.step() |
|
|
loss_optimizerB.step() |
|
|
|
|
|
mean_loss += loss.item() * batch.num_graphs |
|
|
n_samples_processed += batch.num_graphs |
|
|
|
|
|
mean_loss /= n_samples_processed |
|
|
return mean_loss |
|
|
|
|
|
def train( |
|
|
model_save_path: Union[str, Path], |
|
|
batch_size: int = 1, |
|
|
learning_rate: float = 1e-3, |
|
|
num_epochs: int = 20, |
|
|
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
): |
|
|
""" |
|
|
Function Description: |
|
|
This function trains a GNN policy on training data. |
|
|
|
|
|
Parameters: |
|
|
- data_path: Path to the data directory. |
|
|
- model_save_path: Path to save the model. |
|
|
- batch_size: Batch size for training. |
|
|
- learning_rate: Learning rate for the optimizer. |
|
|
- num_epochs: Number of epochs to train for. |
|
|
- device: Device to use for training. |
|
|
""" |
|
|
|
|
|
sample_files = [str(path) for path in Path('./example').glob("pair*.pickle")] |
|
|
print(sample_files) |
|
|
train_files = sample_files[: int(0.6 * len(sample_files))] |
|
|
valid_files = sample_files[int(0.9 * len(sample_files)) :] |
|
|
|
|
|
train_data = GraphDataset(train_files) |
|
|
train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=batch_size, shuffle = False) |
|
|
valid_data = GraphDataset(valid_files) |
|
|
valid_loader = torch_geometric.loader.DataLoader(valid_data, batch_size=batch_size, shuffle = False) |
|
|
|
|
|
policy = GNNPolicy().to(device) |
|
|
|
|
|
optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate) |
|
|
for epoch in range(num_epochs): |
|
|
train_loss = process(policy, train_loader, device, optimizer) |
|
|
|
|
|
valid_loss = 0 |
|
|
print(f"Epoch {epoch+1}: Train Loss: {train_loss:0.3f}, Valid Loss: {valid_loss:0.3f}") |
|
|
|
|
|
torch.save(policy.state_dict(), model_save_path) |
|
|
print(f"Trained parameters saved to {model_save_path}") |
|
|
|
|
|
def parse_args(): |
|
|
""" |
|
|
Function Description: |
|
|
This function parses the command line arguments. |
|
|
""" |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_save_path", type=str, default="trained_model.pkl", help="Path to save the model.") |
|
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer.") |
|
|
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs to train for.") |
|
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
train(**vars(args)) |
|
|
|