| # Adapter Methods |
|
|
| PAWN is designed as a testbed for parameter-efficient fine-tuning. The frozen ~36M-parameter backbone provides learned chess representations from pretraining on random games; adapters specialize those representations for downstream tasks like predicting human moves at a given Elo level. |
|
|
| All adapter implementations live in `pawn/adapters/`. Each wraps a frozen `PAWNCLM` backbone and exposes a uniform interface: `forward_hidden()`, `project_head()`, `forward()`, and `forward_generate()` (with KV-cache). |
|
|
| ## Bottleneck ([Houlsby et al., 2019](https://arxiv.org/abs/1902.00751)) |
|
|
| **Module:** `pawn.adapters.bottleneck.BottleneckCLM` |
|
|
| Inserts small residual MLP bottlenecks after the attention sublayer and/or FFN sublayer within each transformer block, following "Parameter-Efficient Transfer Learning for NLP" (ICML 2019): |
|
|
| ``` |
| x = x + up(gelu(down(x))) |
| ``` |
|
|
| The up-projection is zero-initialized, so the model starts identical to the frozen backbone. `bottleneck_dim` controls the parameter budget. |
|
|
| **Key parameters:** |
| - `bottleneck_dim` -- hidden dimension of the bottleneck (default: 8) |
| - `adapt_attn` / `adapt_ffn` -- which sublayers to adapt (default: both) |
| - `layers` -- which transformer layers to adapt (default: all) |
| - `attn_layers` / `ffn_layers` -- per-sublayer layer selection overrides |
|
|
| **Param count:** `n_positions * n_layers * 2 * d_model * bottleneck_dim` where n_positions is 2 (attn + ffn) by default (e.g. `2 * 8 * 2 * 512 * 8 = 131K` at dim=8, both positions, all 8 layers). |
| |
| Best performer at low parameter budgets. The GELU nonlinearity and full-rank projections provide the most expressive per-parameter adaptation. |
| |
| ## LoRA ([Hu et al., 2021](https://arxiv.org/abs/2106.09685)) |
| |
| **Module:** `pawn.adapters.lora.LoRACLM` |
| |
| Injects rank-r adapters into attention projections (and optionally FFN) in every transformer layer, following "LoRA: Low-Rank Adaptation of Large Language Models" (ICLR 2022): |
| |
| ``` |
| output = frozen_linear(x) + (x @ A^T) @ B^T * (alpha / rank) |
| ``` |
| |
| B is zero-initialized for identity start. A is Kaiming-initialized. LoRA modifies the linear projections in-place (replacing `nn.Linear` with `LoRALinear`), so the backbone's own forward pass automatically includes the LoRA contribution. |
| |
| **Key parameters:** |
| - `rank` -- rank of the low-rank matrices (default: 4) |
| - `alpha` -- scaling factor (default: same as rank) |
| - `attn_targets` -- which projections: `"qkvo"`, `"qv"`, or `"qkv"` (default: `"qkvo"`) |
| - `adapt_ffn` -- also adapt FFN projections (w_gate, w_up, w_down) |
| - `layers` -- which transformer layers to adapt (default: all) |
| |
| **Param count:** `n_layers * n_targets * 2 * d_model * rank` (e.g. 131K at rank=4, qkvo, all 8 layers). |
| |
| ## FiLM ([Perez et al., 2017](https://arxiv.org/abs/1709.07871)) |
| |
| **Module:** `pawn.adapters.film.FiLMCLM` |
| |
| Applies learned per-channel affine transforms after each transformer block and optionally on the output logits, following "FiLM: Visual Reasoning with a General Conditioning Layer" (AAAI 2018): |
| |
| ``` |
| h_adapted = gamma * h + beta (hidden layers, dim = d_model) |
| logits_adapted = gamma * logits + beta (output, dim = vocab_size) |
| ``` |
| |
| Identity-initialized: gamma=1, beta=0. |
| |
| **Key parameters:** |
| - `use_output_film` -- apply FiLM to output logits as well (default: True) |
| |
| **Param count:** `n_layers * 2 * d_model + 2 * vocab_size` = ~17K. The lightest adapter by far -- only diagonal (per-channel) modulation with no cross-channel mixing. |
| |
| ## RoSA ([Nikdan et al., 2024](https://arxiv.org/abs/2401.04679)) |
| |
| **Module:** `pawn.adapters.rosa.RoSACLM` |
| |
| Implements Robust Sparse Adaptation from "RoSA: Accurate Parameter-Efficient Fine-Tuning via Robust Adaptation". Combines a low-rank adapter (LoRA) with a gradient-informed sparse adapter on each frozen projection matrix: |
| |
| ``` |
| output = frozen(x) + (x @ A^T) @ B^T * scaling + F.linear(x, delta * mask) |
| ``` |
| |
| Unlike the random masks used by the Sparse adapter, RoSA selects its sparse mask positions based on gradient information from a LoRA warm-up phase. Training proceeds in three phases: |
| |
| 1. **LoRA warm-up** -- train LoRA-only for `warmup_steps` steps to build gradient signal |
| 2. **Mask generation** -- accumulate squared gradient magnitudes over a small data subset, select top-k positions per weight matrix at the target density (Algorithm 1 from the paper) |
| 3. **Joint training** -- train both LoRA and sparse adapters simultaneously |
| |
| The training script (`scripts/train_rosa.py`) also supports two **retrospective ablation** modes via `--mode`: |
| |
| - **`retro-sparse`** -- use LoRA purely as a probe for mask selection, then discard it and train sparse-only on a fresh backbone with the found masks |
| - **`retro-bottleneck`** -- same as retro-sparse, but adds bottleneck adapters (`RetroBottleneckCLM`) after each sublayer for nonlinearity that sparse-only cannot express |
| |
| In retrospective modes, warm-up LoRA weights are saved as a checkpoint for analysis before the backbone is reloaded. |
| |
| **Key parameters:** |
| - `rank` -- LoRA rank during warm-up and joint training (default: 4) |
| - `density` -- target sparse mask density (default: 0.01) |
| - `attn_targets` -- which attention projections: `"qkvo"`, `"qv"`, or `"qkv"` (default: `"qkvo"`) |
| - `adapt_ffn` -- also adapt FFN projections |
| - `warmup_steps` -- LoRA-only steps before mask generation (default: 128) |
| - `mask_samples` -- batches for gradient accumulation (default: 32) |
| - `grad_alpha` -- gradient exponent: 1=mean magnitude, 2=Fisher diagonal (default: 2) |
| - `bottleneck_dim` -- bottleneck dimension for retro-bottleneck mode (default: 8) |
| |
| **Param count:** Depends on mode and configuration. In `rosa` mode: `n_lora_params + density * total_weight_elements`. In `retro-sparse`: `density * total_weight_elements`. In `retro-bottleneck`: sparse params + `2 * n_positions * n_layers * 2 * d_model * bottleneck_dim`. |
| |
| ## Sparse |
| |
| **Module:** `pawn.adapters.sparse.SparseCLM` |
| |
| Perturbs a random subset of frozen weight elements (related to sparse fine-tuning ideas from the [lottery ticket hypothesis](https://arxiv.org/abs/1803.03635); Frankle & Carbin, 2018). A fixed binary mask selects which weight positions get a trainable additive delta (zero-initialized): |
| |
| ``` |
| W_effective = W_frozen + delta * mask |
| ``` |
| |
| The mask is generated once at initialization from a fixed seed and never changes. Only the masked delta values are learned, but the full delta tensor is stored (unmasked entries remain zero and contribute no gradient due to the mask). |
| |
| **Key parameters:** |
| - `density` -- fraction of weight elements to unmask (default: 0.01) |
| - `attn_targets` -- which attention projections (default: qkvo) |
| - `adapt_ffn` -- also adapt FFN projections |
| - `layers` -- which layers (default: all) |
| - `seed` -- RNG seed for reproducible mask generation (default: 42) |
| |
| **Param count:** `density * total_weight_elements` in targeted projections. E.g. density=0.031 on qkvo gives ~65K active params; density=0.081 on qkvo+FFN gives ~2.7M. |
| |
| Excels at high parameter budgets where many small perturbations to existing weights can outperform structured adapters. |
| |
| ## Hybrid ([LoRA](https://arxiv.org/abs/2106.09685) + [FiLM](https://arxiv.org/abs/1709.07871)) |
| |
| **Module:** `pawn.adapters.hybrid.HybridCLM` |
| |
| Combines LoRA and FiLM on a single frozen backbone. LoRA modifies attention projections within layers (cross-channel mixing); FiLM rescales the residual stream between layers (diagonal modulation). Both are identity-initialized. |
| |
| **Key parameters:** Union of LoRA and FiLM parameters, plus: |
| - `lora_layers` / `film_layers` -- independent layer selection for each method |
| - `use_output_film` -- FiLM on logits (default: False) |
| |
| Supports separate learning rates for LoRA and FiLM parameters via the training script (`--lora-lr`, `--film-lr`). |
| |
| --- |
| |
| ## Design Patterns |
| |
| ### Identity initialization |
| |
| All adapters initialize to the identity function. Bottleneck and LoRA zero-initialize their output projections (up-project B-matrix). FiLM initializes gamma=1, beta=0. Sparse initializes all deltas to zero. This guarantees the adapted model produces identical outputs to the frozen backbone before any training. |
| |
| ### Sparse logit projection |
| |
| All adapter wrappers expose `forward_hidden()` (returns `(B, T, d_model)`) and `project_head()` (applies `lm_head`) as separate methods. Training scripts use this split to avoid materializing the full `(B, T, V)` logit tensor: only positions included in the loss mask are projected through `lm_head`. This reduces peak memory significantly since `V=4278` and most positions in a batch are padding. |
| |
| ### Legal masking via Rust engine |
| |
| Evaluation uses `LegalMaskBuilder` to replay games through the Rust chess engine and produce per-position legal move masks. These are scattered into a pre-allocated GPU buffer as sparse indices, avoiding dense `(B, T, V)` boolean masks. |
| |
| ### DataLoader worker safety |
| |
| The chess engine uses rayon for internal parallelism. Python's default `fork` multiprocessing can deadlock if forked after rayon initializes its thread pool. All DataLoader usage must specify `multiprocessing_context='spawn'`. The `LegalMaskCollate` callable moves Rust replay work into spawned workers, and `LichessDataset.share_memory()` avoids per-worker copies of the dataset. |
| |
| ### Parameter management |
| |
| Each adapter exposes three helpers: |
| - `adapter_state_dict()` / `lora_state_dict()` / etc. -- extract only trainable weights for saving |
| - `load_adapter_state_dict()` -- restore adapter weights into a freshly wrapped backbone |
| - `adapter_weight_report()` -- per-layer weight norms for monitoring training dynamics |
| |
| --- |
| |
| ## Results Summary |
| |
| All results on the PAWN-Base backbone (~35.8M params, 8 layers, d_model=512), trained via behavioral cloning on Lichess games with legal-move-masked cross-entropy loss. |
| |
| ### Method comparison (~65K params, 1000-1100 Elo, 100K games) |
| |
| | Method | Params | Val top-1 | |
| |--------|--------|-----------| |
| | Bottleneck (dim=8) | 65K | 39.3% | |
| | Sparse (density=0.031) | 65K | 35.2% | |
| | Hybrid (LoRA+FiLM) | ~65K | 34.1% | |
| | FiLM | ~33K | 30.3% | |
| |
| Bottleneck dominates at matched parameter budgets on low-Elo data. |
| |
| ### Backbone leverage |
| |
| | Model | Params | Val top-1 | |
| |-------|--------|-----------| |
| | Standalone tiny transformer | 529K | 30.9% | |
| | Bottleneck on frozen PAWN | 524K | 42.2% | |
| |
| The frozen backbone provides ~11 percentage points of "free" accuracy (36% relative improvement). Adapters specialize existing representations rather than learning from scratch. |
| |
| ### Capacity scaling (1800-1900 Elo, 1M games) |
| |
| | Method | Params | Val top-1 | |
| |--------|--------|-----------| |
| | Sparse (density=0.081, qkvo+FFN) | 2.7M | 44.7% | |
| | Bottleneck (dim=64, all layers) | 1.0M | 43.5% | |
| | Bottleneck (dim=32, all layers) | 524K | 41.7% | |
| | Sparse (density=0.015, qkvo+FFN) | 503K | 40.2% | |
| |
| Below ~1M params, bottleneck is more parameter-efficient. Above ~1M, sparse catches up and overtakes -- likely because direct weight perturbation can modify more individual weight entries than a structured bottleneck at the same total parameter count. |
| |
| ### Data scaling (1000-1100 Elo, bottleneck dim=32) |
| |
| 10x more data (100K to 1M games) yields +2.9pp (39.3% to 42.2%) and reduces train/val gap from ~1pp to ~0.3pp. |
| |
| --- |
| |
| ## Quick Start |
| |
| All commands assume you are in the `pawn/` directory. `--checkpoint` points to a pretrained PAWN backbone and `--pgn` to a Lichess PGN file. |
| |
| ```bash |
| # Bottleneck (recommended default) |
| uv run python scripts/train_bottleneck.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --bottleneck-dim 32 --max-games 100000 --lr 3e-4 |
| |
| # LoRA |
| uv run python scripts/train_lora.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --lora-rank 8 --lora-targets qkvo --lr 3e-4 |
| |
| # FiLM |
| uv run python scripts/train_film.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --lr 1e-3 |
| |
| # Sparse |
| uv run python scripts/train_sparse.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --density 0.015 --sparse-ffn --lr 3e-4 |
| |
| # Hybrid (LoRA + FiLM) |
| uv run python scripts/train_hybrid.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --lora-rank 4 --lora-lr 3e-4 --film-lr 1e-3 |
| |
| # RoSA (standard: joint LoRA + gradient-informed sparse) |
| uv run python scripts/train_rosa.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --mode rosa --density 0.01 --lora-rank 4 --warmup-steps 128 --lr 3e-4 \ |
| --local-checkpoints |
| |
| # RoSA (retrospective sparse + bottleneck) |
| uv run python scripts/train_rosa.py \ |
| --checkpoint checkpoints/pawn.pt --pgn data/lichess_1800.pgn \ |
| --mode retro-bottleneck --density 0.01 --bottleneck-dim 8 --lr 3e-4 \ |
| --local-checkpoints |
| ``` |
| |
| All scripts share common flags: `--epochs`, `--batch-size`, `--patience` (early stopping), `--no-compile` (required on ROCm), `--device`, `--num-workers`, `--resume` (checkpoint path for resuming). |
|
|
| Run `--help` on any script for the full argument list. |
|
|