Neuroplastic Transformer
A small language model that grew its own architecture during training.
Instead of picking a fixed number of layers and heads upfront, this model started as a tiny 2-layer transformer and figured out what it needed by watching its own gradients. A plasticity controller checks every 500 steps which attention heads are pulling their weight and which layers are doing the most work, then splits, merges, or prunes accordingly.
All modifications are function-preserving. When a layer gets duplicated, the new copy starts with zero residual weight so it doesn't change the output. It has to earn its place through gradient descent.
What came out of it
| Start | End | |
|---|---|---|
| Layers | 2 | 19 |
| Parameters | 6.9M | 10.4M |
| Total heads | 4 | 42 |
| d_model | 128 | 128 |
| Loss | 10.69 | 4.17 |
The model made 236 structural changes over 30k training steps. Most were layer splits followed by prunes of layers that didn't develop useful weights. The final architecture is non-uniform: layers 9-12 have 3 attention heads, everything else has 2. Nobody told it to do that.
Several layers learned negative alpha (residual scaling) values, meaning the model subtracts their contribution instead of adding it. Layer 17 has the highest alpha at 1.50.
Training details
- Dataset: FineWeb-Edu sample-10BT (educational web text)
- Tokenizer: GPT-2 BPE (50,257 vocab)
- Hardware: single A100 40GB (GCP a2-highgpu-1g, europe-west4)
- Wall time: about 2 hours 16 minutes
- Optimizer: AdamW, peak lr 6e-4, cosine schedule, 300 step warmup
- Batch: 64 sequences x 512 tokens x 2 gradient accumulation = 65k tokens/step
- Precision: float16 autocast
How to use
import torch
from huggingface_hub import hf_hub_download
# grab the checkpoint
ckpt_path = hf_hub_download("theabmehta/neuroplastic-transformer", "step_30000.pt")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
print(ckpt["config"]) # architecture config
print(list(ckpt["model_state"].keys())[:10]) # weight keys
For full inference with text generation, clone the repo:
git clone https://github.com/TheAbMehta/neuroplastic-transformer
cd neuroplastic-transformer
pip install -r requirements.txt
python chat.py --checkpoint /path/to/step_30000.pt
Then just type prompts. It works best with expository/educational text since that's what it trained on.
>>> Photosynthesis is the process by which
it is carried out in the atmosphere. This can be done by: The first
step is to get a clear picture of the sun...
It's a 10M param model so don't expect GPT-4, but it's grammatical and stays on topic.
Checkpoint contents
The .pt file contains:
model_state: the weights (load with the NeuroplasticTransformer class from the repo)optimizer_state: AdamW state if you want to resume trainingconfig: all hyperparameters as a dictarch_history: architecture snapshots at each structural changeplasticity_history: log of every split/merge/prune event
Architecture
The key difference from a normal transformer: each attention head is its own independent module with separate Q/K/V/O projections. This makes it possible to add or remove a single head without touching shared weight matrices. Each layer also has a learnable scalar alpha that scales its residual contribution.
The plasticity controller tracks:
- Per-head utility: gradient_norm * output_norm (how much each head matters)
- Per-layer complexity: mean gradient magnitude (how hard a layer is working)
Based on those signals it can split heads (clone + halve output projections), merge similar heads (average Q/K/V, sum O), prune low-utility heads, split high-complexity layers (deep copy with alpha=0), or prune near-zero-alpha layers.
Paper
There's a full writeup in the GitHub repo under paper/neuroplastic_transformer.pdf covering the method, all the math, and analysis of the training run.
Citation
@misc{mehta2026neuroplastic,
author = {Mehta, Abhisar},
title = {Neuroplastic Transformers: Dynamic Architecture Adaptation During Language Model Training},
year = {2026},
url = {https://github.com/TheAbMehta/neuroplastic-transformer}
}
- Downloads last month
- 39