File size: 8,059 Bytes
c84b37e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import argparse
import pickle
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
import torch_geometric
from pytorch_metric_learning import losses
from model.graphcnn import GNNPolicy
__all__ = ["train"]
class BipartiteNodeData(torch_geometric.data.Data):
"""
Class Description:
This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
observation function in a format understood by the pytorch geometric data handlers.
"""
def __init__(
self,
constraint_features,
edge_indices,
edge_features,
variable_features,
assignment1,
assignment2
):
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.assignment1 = assignment1
self.assignment2 = assignment2
def __inc__(self, key, value, store, *args, **kwargs):
"""
Function Description:
Overload the pytorch geometric method that tells how to increment indices when concatenating graphs for those entries (edge index, candidates) for which this is not obvious.
"""
if key == "edge_index":
return torch.tensor(
[[self.constraint_features.size(0)], [self.variable_features.size(0)]]
)
elif key == "candidates":
return self.variable_features.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
class GraphDataset(torch_geometric.data.Dataset):
"""
Class Description:
This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
It can be used in turn by the data loaders provided by pytorch geometric.
"""
def __init__(self, sample_files):
super().__init__(root=None, transform=None, pre_transform=None)
self.sample_files = sample_files
def len(self):
return len(self.sample_files)
def get(self, index):
"""
Function Description:
This method loads a node bipartite graph observation as saved on the disk during data collection.
"""
with open(self.sample_files[index], "rb") as f:
[variable_features, constraint_features, edge_indices, edge_features, solution1, solution2] = pickle.load(f)
graph = BipartiteNodeData(
torch.FloatTensor(constraint_features),
torch.LongTensor(edge_indices),
torch.FloatTensor(edge_features),
torch.FloatTensor(variable_features),
torch.LongTensor(solution1),
torch.LongTensor(solution2),
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = len(constraint_features) + len(variable_features)
graph.cons_nodes = len(constraint_features)
graph.vars_nodes = len(variable_features)
return graph
def pad_tensor(input_, pad_sizes, pad_value=-1e8):
"""
Function Description:
This utility function splits a tensor and pads each split to make them all the same size, then stacks them.
"""
max_pad_size = pad_sizes.max()
output = input_.split(pad_sizes.cpu().numpy().tolist())
output = torch.stack(
[
F.pad(slice_, (0, max_pad_size - slice_.size(0)), "constant", pad_value)
for slice_ in output
],
dim=0,
)
return output
def process(policy, data_loader, device, optimizer=None):
"""
Function Description:
This function will process a whole epoch of training or validation, depending on whether an optimizer is provided.
"""
mean_loss = 0
mean_acc = 0
n_samples_processed = 0
with torch.set_grad_enabled(optimizer is not None):
for batch in data_loader:
#print("QwQ")
batch = batch.to(device)
# Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
logits = policy(
batch.constraint_features,
batch.edge_index,
batch.edge_attr,
batch.variable_features,
)
# Graph partitioning related metric functions, where num_classes represents the number of partitions in the graph.
loss_funcA = losses.ProxyAnchorLoss(num_classes = 10, embedding_size = 16)
# Metric functions related to the optimal solution. In general integer programming problems, clustering the solution values and modifying num_classes can be done.
loss_funcB = losses.ProxyAnchorLoss(num_classes = 2, embedding_size = 16)
loss = loss_funcA(logits, batch.assignment1.to(torch.int64)) + loss_funcB(logits, batch.assignment2.to(torch.int64))
loss_optimizerA = torch.optim.SGD(loss_funcA.parameters(), lr = 0.01)
loss_optimizerB = torch.optim.SGD(loss_funcB.parameters(), lr = 0.01)
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_optimizerA.step()
loss_optimizerB.step()
mean_loss += loss.item() * batch.num_graphs
n_samples_processed += batch.num_graphs
mean_loss /= n_samples_processed
return mean_loss
def train(
model_save_path: Union[str, Path],
batch_size: int = 1,
learning_rate: float = 1e-3,
num_epochs: int = 20,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
"""
Function Description:
This function trains a GNN policy on training data.
Parameters:
- data_path: Path to the data directory.
- model_save_path: Path to save the model.
- batch_size: Batch size for training.
- learning_rate: Learning rate for the optimizer.
- num_epochs: Number of epochs to train for.
- device: Device to use for training.
"""
# load samples from data_path and divide them
sample_files = [str(path) for path in Path('./example').glob("pair*.pickle")]
print(sample_files)
train_files = sample_files[: int(0.6 * len(sample_files))]
valid_files = sample_files[int(0.9 * len(sample_files)) :]
train_data = GraphDataset(train_files)
train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=batch_size, shuffle = False)
valid_data = GraphDataset(valid_files)
valid_loader = torch_geometric.loader.DataLoader(valid_data, batch_size=batch_size, shuffle = False)
policy = GNNPolicy().to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
train_loss = process(policy, train_loader, device, optimizer)
#valid_loss = process(policy, valid_loader, device, None)
valid_loss = 0
print(f"Epoch {epoch+1}: Train Loss: {train_loss:0.3f}, Valid Loss: {valid_loss:0.3f}")
torch.save(policy.state_dict(), model_save_path)
print(f"Trained parameters saved to {model_save_path}")
def parse_args():
"""
Function Description:
This function parses the command line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--model_save_path", type=str, default="trained_model.pkl", help="Path to save the model.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer.")
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs to train for.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
train(**vars(args))
|