RSK Transformer
A transformer that learns inverse combinatorial bijections β the Robinson-Schensted-Knuth correspondence (permutations and matrices), the Hillman-Grassl correspondence (reverse plane partitions), and the cylindric growth diagram bijection (cylindric plane partitions). The same architecture handles all tasks without modification.
Achieves 100% exact-match accuracy on held-out test data for permutations at n=10, 99.99% at n=15 (1.3 trillion permutations), 100% on 3Γ3 matrix RSK, 100% on reverse plane partitions of shape (4,3,2,1), and 100% on cylindric plane partitions β significantly improving on the PNNL ML4AlgComb benchmark. Scales to 5Γ5 matrices (96.8% exact match on a space of ~10ΒΉβ΄).
π Paper: paper.pdf π» Code: github.com/RaggedR/rsk-transformer π Thesis: Langer (2013) β Cylindric plane partitions, Lambda determinants, Commutants in semicircular systems β the mathematical foundation for the cylindric growth diagram bijection (Β§4.2β4.3) and generalized RSK via Fomin growth diagrams (Β§2.1β2.2)
Results
Experiment 1: Permutation RSK
Given a pair of standard Young tableaux (P, Q), predict the permutation Ο.
| n | |S_n| | Training data | Test exact match | Best epoch |
|---|---|---|---|---|
| 8 | 40,320 | 29,031 (72% of S_n) | 99.95% | 23 |
| 10 | 3,628,800 | 500,000 (14% of S_n) | 100.00% | 28 |
| 15 | 1.3 Γ 10ΒΉΒ² | 500,000 (0.00004%) | 99.99% | 52 |
The n=10 result rules out memorisation: a 1.2M-parameter model trained on 14% of the permutation space achieves perfect accuracy on 50,000 unseen test permutations. At n=15 (1.3 trillion permutations), training on 0.00004% of the space yields 99.99% β unambiguous algorithmic generalisation.
Experiment 2: Full Matrix RSK
Given a pair of semistandard Young tableaux (P, Q) from Knuth's full RSK on non-negative integer matrices, recover the biword. Same model architecture, same embedding β only the task flag changes.
| Experiment | |Ξ»| | Classes | Training data | Test exact match | Per-position | Best epoch |
|---|---|---|---|---|---|---|
| 3Γ3, N=10 | 10 | 3 | 500,000 | 100.00% | 100.00% | 18 |
| 4Γ4, N=20 | 20 | 4 | 500,000 | 99.32% | 99.96% | 20 |
| 5Γ5, N=30 | 30 | 5 | 2,000,000 | 96.79% | 99.87% | 16 |
Results are data-limited, not architecture-limited: per-position accuracy is 99.87%+ and exact-match gaps are explained by independent errors compounding across positions ((0.9987)^30 β 96.2%). The space of 5Γ5 matrices with entry sum 30 is ~10ΒΉβ΄; 2M training samples covers ~10β»βΈ of it. More data would likely improve results, but with limited computational resources (single Apple M4 Max laptop) we prioritised moving on to qualitatively new experiments (reverse plane partitions via Fomin growth diagrams).
Ablation: Transformer vs MLP (Permutations)
| n | Model | Parameters | Greedy exact | Argmax exact | Per-position |
|---|---|---|---|---|---|
| 10 | RSKEncoder (transformer) | 1,207,012 | 100.00% | 100.00% | 100.00% |
| 10 | BaselineMLP (flat) | 133,604 | 95.67% | 90.31% | 98.85% |
| 15 | RSKEncoder (transformer) | 1,225,057 | 99.99% | 99.98% | 100.00% |
| 15 | BaselineMLP (flat) | 133,604 | 3.07% | 0.04% | 62.02% |
The MLP collapses from 95.67% to 3.07% as n increases from 10 to 15, while the transformer barely notices (100% β 99.99%). Without spatial structure, the MLP cannot coordinate predictions across positions.
Experiment 3: Reverse Plane Partitions (Hillman-Grassl)
Given a reverse plane partition (RPP) of shape Ξ», recover the arbitrary filling via the inverse Hillman-Grassl correspondence. Same model architecture β the only change is that the input is a single filling (not a pair), so tableau_emb(0) is used for all tokens.
| Shape Ξ» | Type | |Ξ»| | Classes | Training data | Test exact match | Per-position | Best epoch |
|---|---|---|---|---|---|---|---|
| (4,3,2,1) | Staircase | 10 | 5 | 500,000 | 100.00% | 100.00% | 23 |
| (6,4,2) | Wide | 12 | 5 | 500,000 | 99.99% | 100.00% | 17 |
| (2,2,2,2,2,1) | Tall | 11 | 5 | 500,000 | 99.99% | 100.00% | 36 |
The Hillman-Grassl bijection is fundamentally different from RSK β it involves zigzag paths through the Young diagram rather than Schensted insertion β yet the same transformer architecture learns it to near-perfect accuracy. Tall shapes converge slower (36 epochs vs 17-23) because longer zigzag paths create longer-range dependencies.
Experiment 4: Cylindric Plane Partitions (Growth Diagrams)
Given a cylindric plane partition (CPP) with binary profile Ο, recover the base partition Ξ³ and the ALCD face labels via the inverse cylindric growth diagram bijection. This uses the Burge local rule applied recursively through a cylindric growth diagram, as described in Langer (2013), Β§4.2β4.3. Same model architecture.
| Profile Ο | T | ALCD labels | Training data | Test exact match | Per-position | Best epoch |
|---|---|---|---|---|---|---|
| (1,0,1,0) | 4 | 3 | 500,000 | 100.00% | 100.00% | 2 |
| (1,0,1,0,0) | 5 | 5 | 500,000 | 100.00% | 100.00% | 7 |
| (1,0,1,0,1,0) | 6 | 6 | 500,000 | 100.00% | 100.00% | 3 |
| (1,0,1,0,1,0,1,0) | 8 | 10 | 500,000 | 99.98% | 100.00% | 9 |
The cylindric bijection is qualitatively different from all previous experiments: there is no direct closed-form algorithm. The bijection is defined implicitly by the Burge local rule applied at each face of the cylindric growth diagram. The model must learn to invert a recursive process (the π_i composition from Langer 2013, Β§4.2) that peels off one ALCD label at each step by solving a local Burge equation. Despite this complexity, the transformer achieves 100% on all tested profiles.
Key Idea: Structured 2D Token Embeddings
Previous work encoded tableaux as flat bracket strings, destroying 2D geometry. We encode each tableau entry as a token with four learned embeddings:
embedding(entry) = value_emb(v) + row_emb(r) + col_emb(c) + tableau_emb(P or Q)
Architecture
Input: (P, Q) as 2n structured tokens
β TokenEmbedding (value + row + col + tableau_id)
β 6-layer TransformerEncoder (d=128, 8 heads, pre-norm, GELU)
β Mean pool over all 2n tokens
β n parallel classification heads β logits (batch, n, n)
β Masked greedy decode β predicted Ο
~1.2M parameters. Encoder-only (no autoregressive decoding needed).
Checkpoints
Experiment 1: Permutation RSK
| File | Description | Parameters |
|---|---|---|
checkpoints/encoder_n8/best.pt |
RSKEncoder trained on Sβ (HuggingFace data) | 1,202,368 |
checkpoints/encoder_n10/best.pt |
RSKEncoder trained on Sββ (sampled) | 1,207,012 |
checkpoints/encoder_n15/best.pt |
RSKEncoder trained on Sββ (sampled) | 1,225,057 |
checkpoints/mlp_n10/best.pt |
Baseline MLP on Sββ (for ablation) | 133,604 |
checkpoints/mlp_n15/best.pt |
Baseline MLP on Sββ (for ablation) | 133,604 |
Experiment 2: Full Matrix RSK
| File | Description | Parameters |
|---|---|---|
checkpoints/encoder_matrix_a3_b3_N10/best.pt |
RSKEncoder on 3Γ3 matrices, N=10 | ~1.2M |
checkpoints/encoder_matrix_a4_b4_N20/best.pt |
RSKEncoder on 4Γ4 matrices, N=20 | ~1.2M |
checkpoints/encoder_matrix_a5_b5_N30/best.pt |
RSKEncoder on 5Γ5 matrices, N=30 | ~1.2M |
Experiment 3: Reverse Plane Partitions (Hillman-Grassl)
| File | Description | Parameters |
|---|---|---|
checkpoints/encoder_rpp_4x3x2x1_m4/best.pt |
RSKEncoder on RPP shape (4,3,2,1), max_entry=4 | ~1.2M |
checkpoints/encoder_rpp_6x4x2_m4/best.pt |
RSKEncoder on RPP shape (6,4,2), max_entry=4 | ~1.2M |
checkpoints/encoder_rpp_2x2x2x2x2x1_m4/best.pt |
RSKEncoder on RPP shape (2,2,2,2,2,1), max_entry=4 | ~1.2M |
Experiment 4: Cylindric Plane Partitions
| File | Description | Parameters |
|---|---|---|
checkpoints/encoder_cyl_1010_m3/best.pt |
RSKEncoder on CPP profile (1,0,1,0), max_label=3 | ~1.2M |
checkpoints/encoder_cyl_10100_m3/best.pt |
RSKEncoder on CPP profile (1,0,1,0,0), max_label=3 | ~1.2M |
checkpoints/encoder_cyl_101010_m3/best.pt |
RSKEncoder on CPP profile (1,0,1,0,1,0), max_label=3 | ~1.2M |
Loading a checkpoint
import torch
from model import RSKEncoder
from config import ModelConfig
# Load n=10 model
ckpt = torch.load("checkpoints/encoder_n10/best.pt", map_location="cpu", weights_only=False)
config = ckpt["model_config"]
model = RSKEncoder(config)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
Training
pip install torch datasets
# --- Experiment 1: Permutation RSK ---
python train.py --model encoder --n 10 --source sample --train-size 500000 --batch-size 512
python train.py --model encoder --n 8 --source hf
# --- Experiment 2: Full Matrix RSK ---
python train.py --model encoder --task matrix --a-dim 3 --b-dim 3 --total-n 10 \
--source sample --train-size 500000
python train.py --model encoder --task matrix --a-dim 4 --b-dim 4 --total-n 20 \
--source sample --train-size 500000
python train.py --model encoder --task matrix --a-dim 5 --b-dim 5 --total-n 30 \
--source sample --train-size 2000000 --batch-size 512
# --- Experiment 3: Reverse Plane Partitions ---
python train.py --model encoder --task rpp --shape 4,3,2,1 --max-entry 4 \
--source sample --train-size 500000
python train.py --model encoder --task rpp --shape 6,4,2 --max-entry 4 \
--source sample --train-size 500000
# --- Experiment 4: Cylindric Plane Partitions ---
python train.py --model encoder --task cylindric --profile 1010 --max-label 3 \
--source sample --train-size 500000
python train.py --model encoder --task cylindric --profile 101010 --max-label 3 \
--source sample --train-size 500000
Citation
@software{rsk_transformer,
author={Langer, Robin},
title={Learning the RSK Correspondence with Transformers},
year={2026},
url={https://github.com/RaggedR/rsk-transformer}
}
Acknowledgements
- PNNL ML4AlgComb for the original benchmark
- ACDRepo for pre-computed RSK datasets