Spaces:
Running
Running
| The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model: | |
| Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In | |
| a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of | |
| the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes | |
| in the GNN that will execute the algorithm. | |
| A sample of every algorithm is represented as a graph, with each input, output and hint located in | |
| either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension, | |
| and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of | |
| the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar, | |
| categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and | |
| loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and | |
| optimised using mean squared error. | |
| Base Model | |
| Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS | |
| benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder | |
| fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as | |
| high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the | |
| same dimension and are added together; the same happens with hints and inputs located in edges, | |
| and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph | |
| 3 | |
| A Generalist Neural Algorithmic Learner | |
| embeddings. Thus, at the | |
| step for a time-step t of the algorithm, we have a | |
| n end of the encoding | |
| o | |
| (t) (t) | |
| (t) | |
| single set of embeddings xi , eij , g | |
| , shapes n × h, n × n × h, and h, in the nodes, edges and | |
| graph, respectively. Note that this is independent of the number and type of the inputs and hints of | |
| the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS. | |
| Further, note that at each step, the input encoding is fed directly to these embeddings—this recall | |
| mechanism significantly improves the model’s robustness over long trajectories [34]. | |
| Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed | |
| (t) | |
| node embeddings, hi . Additionally, the processor uses the processed node embeddings from the | |
| (t−1) | |
| previous step, hi | |
| , as inputs. Importantly, the same processor model can operate on graphs of any | |
| size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and | |
| passing messages over a fully-connected graph, as our base model. The MPNN computes processed | |
| embeddings as follows: | |
| (t) | |
| (t−1) | |
| (t) | |
| (t) (t) (t) | |
| (t) | |
| (t) | |
| (t) | |
| z(t) = xi khi | |
| mi = max fm zi , zj , eij , g(t) | |
| hi = fr zi , mi | |
| (1) | |
| 1≤j≤n | |
| starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the | |
| message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh → | |
| Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max | |
| aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the | |
| neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations | |
| (t) | |
| where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before | |
| using them further. Further details on the MPNN processor may be found in Veličković et al. [5]. | |
| Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict | |
| the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder | |
| relies mainly on a linear decoder for each hint and output, along with a mechanism to compute | |
| pairwise node similarities when appropriate. Specifically, the pointer type decoder computes | |
| a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the | |
| argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used). | |
| Loss. The decoded hints and outputs are used to compute the loss during training, according to their | |
| type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time, | |
| and the output loss is averaged across outputs (most algorithms have a single output, though some | |
| have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at | |
| each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing | |
| is used (see Section 3.2.1). | |
| We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution | |
| samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution | |
| evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only | |
| these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can | |
| be found in Appendix A. | |
| 3.2 | |
| Model improvements | |
| As previously discussed, single-task improvements, especially in terms of learning stability, will | |
| empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner, | |
| all the changes made to the model, which have lead to an absolute improvement of over 20% on | |
| average across all 30 tasks in CLRS. | |
| 3.2.1 | |
| Dataset and training | |
| Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints | |
| in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes | |
| advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint | |
| values instead of the network’s own predictions. In the prior model [5], ground-truth hints were | |
| 4 | |
| A Generalist Neural Algorithmic Learner | |
| provided during training with probability 0.5, as, without teacher forcing, losses tended to grow | |
| unbounded along a trajectory when scalar hints were present, destabilising the training. In this | |
| work we incorporate several significant stabilising changes (described in future paragraphs), which | |
| allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the | |
| network becoming overconfident in always expecting correct hint predictions. With teacher forcing, | |
| performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String | |
| Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9). | |
| Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed | |
| CLRS training dataset [5], we augmented the training data in three key ways, without breaking | |
| the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new | |
| training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly, | |
| we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate | |
| for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph | |
| algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi | |
| model [38]); and for string matching algorithms, we varied the length of the pattern to be matched. | |
| These both serve to expose the model to different trajectory lengths; for example, in many graph | |
| algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and | |
| varying the connection probability in the graph generation allows for varying the expected diameter. | |
| These changes considerably increase training data variability, compared to the original dataset in | |
| Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process | |
| in Appendix A. | |
| Soft hint propagation. When predicted hints are fed back as inputs during training, gradients | |
| may or may not be allowed to flow through them. In previous work, only hints of the scalar type | |
| allowed gradients through, as all categoricals were post-processed from logits into the ground-truth | |
| format via argmax or thresholding before being fed back. Instead, in this work we use softmax | |
| for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without | |
| these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing), | |
| as well as in Naïve String Matcher (Appendix A, Figs. 7-9). | |
| Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common | |
| to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of | |
| this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the | |
| model can overfit to the fixed training values. We therefore turned this fixed hint into an input for | |
| these 11 algorithms, eliminating the need for explicitly predicting it. | |
| Improving training stability with encoder initialisation and gradient clipping. The scalar | |
| hints have unbounded values, in principle, and are optimised using mean-squared error, hence their | |
| gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then | |
| get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to | |
| exploding signals (and consequently gradients), even before any training takes place. | |
| To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for | |
| scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun | |
| initialisation [46] elsewhere. This combination of initialisations proved important for the initial | |
| learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw | |
| drastic improvements in learning stability, as well as significant increases in validation performance, | |
| with gradient clipping [47], which we subsequently employed in all experiments. | |
| 3.2.2 | |
| Encoders and decoders | |
| Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar | |
| input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node | |
| index. To avoid overfitting to these linearly spaced values during training, we replaced them with | |
| random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly | |
| spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to | |
| 3 | |
| Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity | |
| Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44]. | |
| 5 | |
| A Generalist Neural Algorithmic Learner | |
| these positions, such as string matching. Namely, the model could learn to base all of its computations | |
| on the assumption that it will always be finding a m-character pattern inside an n-character string, | |
| even though at test time, m and n will increase fourfold. | |
| Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble | |
| Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS | |
| benchmark, this permutation is encoded as a pointer where each node points to its predecessor in | |
| the sorted order (the first node points to itself); this is represented as a n × n matrix P where each | |
| row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of | |
| pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained | |
| decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does | |
| not explicitly take advantage of the fact that the pointers encode a permutation, which the model | |
| has to learn instead. Our early experiments showed that the model was often failing to predict valid | |
| permutations OOD. | |
| Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as | |
| follows. First, we modify the output representation by rewiring the first node to point to the last one, | |
| turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We | |
| also augment the representation with a one-hot vector of size n that specifies the first node, so we do | |
| not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the | |
| permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax | |
| with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly | |
| stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating | |
| and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by: | |
| S 0 (Y) = exp(Y) | |
| S l (Y) = Tc (Tr (S l−1 (Y))) | |
| S(Y) = lim S l (Y), | |
| l→∞ | |
| (2) | |
| where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively. | |
| Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix, | |
| we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking | |
| P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a | |
| permutation matrix [52, Theorem 1]. | |
| In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a | |
| smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and | |
| lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good | |
| balance between speed of convergence and tie-breaking. We also encode the fact that no node points | |
| to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to | |
| −∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to | |
| applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P, | |
| and mask_one pointing to the first element, into the original pointer representation used by CLRS. | |
| 3.2.3 | |
| Processor networks | |
| Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping | |
| the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it | |
| updates all hidden states in each step. Although it is theoretically possible for the network to keep the | |
| states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness | |
| in NDRs [54], we augment the network with an update gate, biased to be closed by default. We | |
| found that the gate stabilizes learning on many of the tasks, and increases the mean performance | |
| over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be | |
| advantageous in the multi-task case. | |
| To add gating to the MPNN model we produce a per-node gating vector from the same inputs that | |
| process the embeddings in Equation 1: | |
| (t) | |
| (t) | |
| (t) | |
| gi = fg zi , mi | |
| (3) | |
| where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with | |
| ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the | |
| final layer bias of fg is initialized to a value of −3, which biases the network for not updating its | |
| 6 | |
| A Generalist Neural Algorithmic Learner | |
| Our model | |
| Previous SOTA [5] | |
| 80 | |
| 60 | |
| 40 | |
| Quickselect | |
| Heapsort | |
| Knuth-Morris-Pratt | |
| Strongly Conn. Comps. | |
| DFS | |
| Floyd-Warshall | |
| Quicksort | |
| Bubble Sort | |
| Optimal BST | |
| Find Max. Subarray | |
| Insertion Sort | |
| Binary Search | |
| LCS Length | |
| Naïve String Matcher | |
| MST Prim | |
| Topological Sort | |
| Task Scheduling | |
| MST Kruskal | |
| Articulation Points | |
| Jarvis' March | |
| Matrix Chain Order | |
| Bridges | |
| Graham Scan | |
| Dijkstra | |
| Activity Selector | |
| Bellman-Ford | |
| DAG Shortest Paths | |
| Segments Intersect | |
| 0 | |
| BFS | |
| 20 | |
| Minimum | |
| Average score [%] | |
| 100 | |
| Figure 2: The OOD performance in single-task experiments before and after the improvements | |
| presented in this paper, sorted in descending order of current performance. Error bars represent | |
| standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current). | |
| The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2). | |
| b (t) , are computed as follows: | |
| representations, unless necessary. The processed gated embeddings, h | |
| i | |
| b (t) = g(t) | |
| h | |
| i | |
| i | |
| and are used instead of | |
| (t) | |
| hi | |
| (t) | |
| (t) | |
| hi + (1 − gi ) | |
| in the subsequent steps, replacing z | |
| (t−1) | |
| hi | |
| (t) | |
| (4) | |
| in Eq. 1 by z | |
| (t) | |
| = | |
| (t) b (t−1) | |
| xi kh | |
| . | |
| i | |
| Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning— | |
| where edges store values, and update them based on other edges’ values. An example of this is the | |
| Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The | |
| update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which | |
| roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then | |
| from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic | |
| programming. Even though there are no node representations in the above update, all our processors | |
| are centered on passing messages between node representations hi . | |
| To rectify this situation, we augment our processor to perform message passing towards edges. | |
| Referring again to the update for dij , we note that the edge representations are updated by choosing | |
| an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations | |
| over triplets of nodes, then reducing over one node to obtain edge latents: | |
| tijk = ψt (hi , hj , hk , eij , eik , ekj , g) | |
| hij = φt (max tijk ) | |
| (5) | |
| k | |
| Here, ψt is a triplet message function, mapping all relevant representations to a single vector for | |
| each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets | |
| for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the | |
| max aggregation to obtain edge representations. The computed hij vectors can then be used in any | |
| edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks | |
| where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree | |
| algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily | |
| sort the edges by weight, as it selects how to augment the spanning forest at each step. | |
| In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only | |
| 8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions, | |
| to make them compatible with the rest of the architecture. Our initial experimentation demonstrated | |
| that the output dimensionality of ψt did not significantly affect downstream performance. Note that | |
| computing triplet representations has been a useful approach in general GNN design [57]—however, | |
| it has predominantly been studied in the context of GNNs over constant input features. Our study is | |
| among the first to verify their utility over reasoning tasks with well-specified initial features. | |
| 3.3 | |
| Results | |
| By incorporating the changes described in the previous sections we arrived at a single model type, | |
| with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance | |
| 7 | |
| A Generalist Neural Algorithmic Learner | |
| Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our | |
| best model Triplet-GMPNN with all our improvements, after 10,000 training steps. | |
| Alg. Type | |
| Memnet [5] | |
| MPNN [5] | |
| PGN [5] | |
| Triplet-GMPNN (ours) | |
| Div. & C. | |
| DP | |
| Geometry | |
| Graphs | |
| Greedy | |
| Search | |
| Sorting | |
| Strings | |
| 13.05% ± 0.14 | |
| 67.94% ± 8.20 | |
| 45.14% ± 11.95 | |
| 24.12% ± 5.30 | |
| 53.42% ± 20.82 | |
| 34.35% ± 21.67 | |
| 71.53% ± 1.41 | |
| 1.51% ± 0.46 | |
| 20.30% ± 0.85 | |
| 65.10% ± 6.44 | |
| 73.11% ± 17.19 | |
| 62.79% ± 8.75 | |
| 82.39% ± 3.01 | |
| 41.20% ± 19.87 | |
| 11.83% ± 2.78 | |
| 3.21% ± 0.94 | |
| 65.23% ± 4.44 | |
| 70.58% ± 6.48 | |
| 61.19% ± 7.01 | |
| 60.25% ± 8.42 | |
| 75.84% ± 6.59 | |
| 56.11% ± 21.56 | |
| 15.45% ± 8.46 | |
| 2.04% ± 0.20 | |
| 76.36% ± 1.34 | |
| 81.99% ± 4.98 | |
| 94.09% ± 2.30 | |
| 81.41% ± 6.21 | |
| 91.21% ± 2.95 | |
| 58.61% ± 24.34 | |
| 60.37% ± 12.16 | |
| 49.09% ± 23.49 | |
| 38.88% | |
| 44.99% | |
| 50.84% | |
| 74.14% | |
| 0/30 | |
| 3/30 | |
| 10/30 | |
| 6/30 | |
| 9/30 | |
| 14/30 | |
| 3/30 | |
| 7/30 | |
| 15/30 | |
| 11/30 | |
| 17/30 | |
| 24/30 | |
| Overall avg. | |
| > 90% | |
| > 80% | |
| > 60% | |
| on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as | |
| Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test | |
| set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines | |
| include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5]. | |
| Figure 2 displays the comparison between the improved model and the best model from Veličković | |
| et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher | |
| (in absolute terms) compared to the next best model (see Table 1), and to a significant performance | |
| improvement in all but one algorithm family, compared to every other model. Further, our stabilising | |
| changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient | |
| updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We | |
| finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to | |
| the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%. | |
| There are two notable examples of algorithm families with significant OOD performance improvement. | |
| The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now | |
| solved at approximately 94% OOD, compared to the previous best of about 73%; the second being | |
| string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds | |
| 49% compared to the previous best of approximately 3%. | |
| The significant overall performance boost is reflected in the increased number of algorithms we can | |
| now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically, | |
| we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9 | |
| previously) and 90% for 11 algorithms (6 previously). | |