Mo²BERTa-v2 — Frozen KV Context for Encoder MoR (Prototype)

Mo²BERTa-v2 extends Mo²BERTa with a single focused architectural addition: frozen KV context from exited tokens. In the base model, tokens that exit early under expert-choice routing vanish from the attention context of deeper tokens entirely, a context isolation problem inherent to sparse recursive attention. This variant preserves their Keys and Values, contributing them to all subsequent recursion steps at zero additional Q or MLP cost.

The result is a measurable improvement in both validation loss and routing discrimination, and the first controlled isolation of context isolation as a quantifiable bottleneck in encoder-side Mixture-of-Recursions.

Status: Research prototype / Proof-of-Concept. Not intended for production use.
Readers unfamiliar with the base architecture (M-Cycle Middle, expert-choice routing, GQA+RoPE, MLM objective) should start with Mo²BERTa first.

Usage

  • For model, training, inference code: please check MoRBERT_v2.ipynb in the repo

Model Details

Field Value
Architecture Encoder-only Transformer with MoR + Frozen KV accumulator
Base Extends Mo²BERTa; see v1 card for full arch detail
Unique parameters ~9.6M (identical to Mo²BERTa; frozen KV adds no parameters)
Effective depth Variable per token: 1–4 recursions over a shared 2-layer block (~6.56 flat-layer equivalent)
Training data TinyStories-valid (small subset, PoC scale)
Tokenizer bert-base-uncased
Compute cap 600 TFLOPs
Hardware NVIDIA GeForce 930M (Compute Capability 5.0, consumer laptop GPU)
License MIT

What Changed from Mo²BERTa

1. The Context Isolation Problem

In Mo²BERTa (FULL_SKIP phase), attention at recursion depth r operates only over tokens still active at depth r. A function word that exits at recursion 1 disappears from the KV context of a [MASK] token still processing at recursion 3. The unique final layer provides one round of global mixing after all recursions complete, but during recursive refinement hard tokens progressively lose context from easy tokens that exited early.

This was documented as a known limitation in the v1 model card, with frozen KV listed as future work.

2. The Fix: Frozen KV Accumulator

When a token exits at recursion step r, its final hidden state x_block (post-block, the richest representation it achieved) is projected through wk and wv of the last shared block, RoPE-encoded using its original global position, and appended to a frozen KV accumulator. All subsequent recursion steps concatenate this accumulator into their attention context:

Q:  [k_active]               ← only continuing tokens query
K:  [k_active + k_frozen]    ← active + all previously exited
V:  [k_active + k_frozen]

Exited tokens contribute context but pay no Q projection, no MLP, and no attention score computation over their own position. The dominant FLOP savings from sparse routing are preserved.

Key design decisions:

  • Freeze from x_block (post-block) not x (pre-block) which is the richest state the token achieved, consistent with what the router judged "done enough to exit"
  • Use shared_blocks[-1].attn.wk/wv, the last shared block processed the exiting token; with N_LAYERS=4 giving n_shared=2, the two shared blocks have independent weights
  • Freeze before the gated residual update modifies x, captures the raw block output
  • RoPE on frozen K uses original global positions, composes correctly with active Q positions at deeper steps without positional integrity issues
  • frozen_kv=None default throughout all forward passes and flat baselines (IsoParam, IsoDepth) are completely unaffected, zero overhead, no code path changes

3. Attention Optimization: Padded Batched SDPA

Both the FULL_SKIP baseline and FrozenKV in this repo use a refactored _attn_skip that replaces the original Python-level for b in range(B) loop with a single padded batched F.scaled_dot_product_attention call over a [B, H, k_max, head_dim] tensor. Padding positions are masked via attn_bias set to -inf on the key dimension.

This is compatible with Compute Capability 5.0 (no Triton or FlashAttention required) and produced a significant wall-clock improvement:

Original Mo²BERTa (Python loop, 600T run):  172.2 minutes  (~1.38s/step)
SDPA-optimized MoR-BERT (this repo):        123.4 minutes  (~1.44–1.70 step/s)
FrozenKV (SDPA + accumulator):              115.8 minutes  (~1.41–1.45 step/s)

The ~29% wall-clock reduction comes purely from replacing B serial kernel launches with one parallel batched SDPA call. FrozenKV's frozen KV accumulator adds negligible overhead at this scale, the difference between 115.8m and 123.4m is within run-to-run variance on consumer hardware.

Evaluation Results (600 TFLOP Final)

The name Mo²BERTa is a post-train naming, so in this report (esp. in plots) the model is still referred to as MoR-BERT

Four models trained to the same 600T cumulative FLOP budget on the same dataset and hardware. Validation on 20 batches of held-out TinyStories-valid every 50T. Best checkpoint metrics reported since all models exhibit some late-stage oscillation under constant LR.

Note: MoR-BERT here refers to the SDPA-optimized FULL_SKIP variant, not the original Python-loop Mo²BERTa from v1. The SDPA refactor does not change training behavior, only wall-clock speed.

Best Checkpoint Metrics

Metric FrozenKV MoR-BERT (SDPA) IsoParam (L4) IsoDepth (L7)
Best Val Loss 1.8023 1.8427 1.9266 1.9549
Best checkpoint @ 550T 550T 600T 550T
Best Val Acc 66.21% 66.08% 63.55% 65.08%
Acc best checkpoint @ 600T 550T 550/600T 550T
Unique parameters 9.69M 9.69M 9.69M 11.07M
Tokens seen @ 600T 10.02M 10.02M 10.32M 9.04M
Wall-clock (600T) 115.8 min 123.4 min 63.7 min 63.7 min

FrozenKV wins on both metrics across all four models, including IsoDepth which has 14% more unique parameters and saw fewer tokens (9.03M vs 10.01M).

Full Validation Trace (every 50T)

FLOPs FrozenKV loss MoR loss IsoParam loss IsoDepth loss
50T 3.4132 3.4944 3.7933 3.5955
100T 2.6808 2.8536 2.8236 2.8789
150T 2.4690 2.6806 2.6511 2.5675
200T 2.3561 2.5125 2.4010 2.4736
250T 2.2963 2.4132 2.3246 2.2268
300T 2.1466 2.2928 2.1173 2.1801
350T 2.0938 2.0468 2.0527 2.0693
400T 1.9806 2.1447 2.0729 2.0864
450T 1.9226 1.9144 2.0412 2.0717
500T 1.9463 1.9443 2.0049 1.9613
550T 1.8023 1.8523 2.0566 1.9630
600T 1.8089 1.8427 1.9266 1.9549

Training Regime Analysis

image

Regime Leader Notes
0–150T FrozenKV Unambiguous lead. 50T: 3.41 vs 3.49–3.79. 100T: 2.68 vs 2.82–2.88.
150–200T FrozenKV (narrowing) Lead shrinks; 200T still ahead by 0.045 over IsoParam
200–300T Contested / IsoParam, IsoDepth FrozenKV loses lead: IsoDepth wins 250T, IsoParam wins 300T
300–400T MoR-BERT MoR's only clear win: 350T at 2.0468 (best loss so far). 400T spike to 2.14 is volatility, not competition.
400–500T Contested / MoR MoR wins 450T (1.9144) and 500T (1.9443). FrozenKV close second.
500–600T FrozenKV Dominant from 550T. 1.80 vs 1.85 vs 1.96 vs 1.95.

Late-stage stability is a secondary finding. MoR-BERT (SDPA) shows oscillation in the 500–600T regime. FrozenKV is measurably more stable, the richer attention context at each recursion step appears to act as a mild regularizer, reducing sensitivity to the constant LR.

Routing Behavior Analysis

Note on the uniform exit distribution (25%/25%/25%/25%): Identical to Mo²BERTa v1 - by architectural design. The capacity schedule hardcodes how many tokens exit at each step. What the router learns is which tokens exit. The depth gap and heatmap panels show this.

The Core Finding: Frozen KV Sharpens Routing Discrimination

Routing analysis at end of 600T training (step 9787), FULL_SKIP vs FrozenKV on matched batches:

                      FULL_SKIP    FrozenKV      Δ
[MASK] mean depth:      3.34         3.79       +0.45
Non-[MASK] depth:       2.35         2.28       −0.07
Depth gap:              0.99         1.51       +0.52  (+53%)

n([MASK]):               152          151
n(non-[MASK]):           872          873

FrozenKV makes the router more decisive in both directions simultaneously. [MASK] tokens go deeper (3.79 vs 3.34) and non-[MASK] tokens exit earlier (2.28 vs 2.35). The possible mechanism: frozen context from early-exit tokens gives the router richer signal to make confident decisions. Easy tokens can exit sooner because their contribution is preserved in the KV accumulator. Hard tokens get pushed deeper because the model now has denser context to refine them against.

This is the opposite of the hypothesis that frozen context would let [MASK] tokens exit earlier because context compensates. Instead the router uses richer context to justify more refinement on hard tokens, not less.

Score distributions confirm increased router confidence. FrozenKV Router 1 shows a sharper spike at 0.0 than FULL_SKIP, indicating more tokens being confidently routed out at the first step. Both models remain strongly bimodal throughout.

FULL_SKIP (Mo2BERT-proto)

image

FROZEN_KV (Mo2BERT-v2-proto)

image

At Inference (Post-Training, TinyStories Val Sample)

FrozenKV inference routing on a held-out sequence:

  • [MASK] tokens: mean depth 3.60 / 4 (n=15)
  • Non-[MASK] tokens: mean depth 2.35 / 4 (n=113)
  • Depth gap: 1.25 which is larger than Mo²BERTa v1's inference gap of 1.05

Score distributions remain strongly bimodal at inference, confirming the routing behavior is genuinely learned and not an artifact of the auxiliary loss alone.

image

MLM Prediction Quality (Same Inference Sample)

14 masked positions evaluated:

Position True token Rank Notes
13 ? 1 Correct, high margin (10.97 vs 9.67 next)
30 jelly 1 Correct, low-frequency content word
71 the 1 Correct, confident (12.92 vs 8.84)
90 " 1 Correct, very high confidence (16.79)
93 . 1 Correct, very high confidence (16.49)
96 at 1 Correct (11.15 vs 8.42)
99 and 1 Correct (12.87 vs 7.76)
109 she 1 Correct (10.71 vs 8.13)
126 feel 1 Correct (9.92 vs 6.47)
28 to 2 and ranked first, both valid
29 spread miss be at rank 1; rare verb in TinyStories (n=339)
36 run miss it at rank 1; insufficient disambiguation
51 slice miss top at rank 1; semantically adjacent
79 fingers miss face at rank 1; body-part confusion

Overall: rank-1 accuracy 9/14 (64%), top-5 accuracy 10/14 (71%). Misses cluster on low-frequency content words and semantic near-synonyms, expected failure modes at this scale and domain.

What This PoC Does and Does Not Prove

Supported claims (new in v2):

  • Frozen KV from exited tokens measurably reduces the context isolation penalty in encoder-side expert-choice MoR — val loss improves by 0.04 at equal parameter count and FLOP budget.
  • FrozenKV beats a 14%-larger flat baseline (IsoDepth L7, 11.07M params) on both loss and accuracy at equal compute, having seen more tokens (10.01M vs 9.03M).
  • Frozen context sharpens routing discrimination: the [MASK] vs non-[MASK] depth gap increases from 0.99 to 1.51 (+53%) compared to FULL_SKIP at identical parameter count and FLOP budget.
  • FrozenKV improves late-stage training stability compared to FULL_SKIP under constant LR.
  • Padded batched SDPA (CC 5.0 compatible) reduces wall-clock time by ~29% vs the original Python-loop gather-scatter implementation, with theoretically identical outputs.
  • FrozenKV overhead is negligible at PoC scale, wall-clock is within run-to-run variance of the SDPA-only baseline (115.8m vs 123.4m).

Inherited supported claims from Mo²BERTa v1:

  • MoR transfers from autoregressive decoders to bidirectional encoders without fundamental barriers.
  • Expert-choice routing with BCE auxiliary loss produces well-calibrated bimodal routing within ~200 training steps.
  • The router learns to allocate depth by semantic difficulty using only the MLM signal, with no token-type supervision.

Not supported / out of scope:

  • Wall-clock inference speedup from sparse routing (requires custom CUDA kernels; SDPA is faster than the Python loop but theoretical active-token FLOP savings still don't materialize as throughput without kernel-level support)
  • Scaling behavior (one model size, one dataset, one compute budget)
  • Comparison to production encoders (BERT-base, RoBERTa, DeBERTa, etc.)
  • Generalization beyond TinyStories domain
  • Optimal hyperparameter configuration (constant LR, fixed α=0.1, no ablation over N_recursion or capacity schedule shape)
  • Whether the depth gap increase under FrozenKV translates to better downstream task performance

Known Limitations

All limitations from Mo²BERTa v1 apply. Additional v2-specific notes:

Frozen KV memory overhead. The accumulator grows across recursion steps, holding [B, Hk, k_frozen, head_dim] tensors that concatenate at each step. At the scales tested (B=8, T=128, head_dim=32) this is negligible. At larger sequence lengths or batch sizes the accumulated KV tensors become non-trivial and a retention policy (sliding window, top-k, or importance-weighted eviction) would be needed.

Frozen KV FLOP accounting. The estimator treats FrozenKV identically to FULL_SKIP, ignoring the small overhead of K/V projections for exiting tokens and the extended key dimension in SDPA. Accurate at PoC scale; would need revision for precise isoFLOP comparison at larger scale.

Step count vs wall-clock asymmetry. Flat baselines (IsoParam, IsoDepth) complete more steps than MoR variants within the same TFLOP budget (IsoParam: 10,083 steps vs MoR: 9,787) because each flat step is computationally cheaper. This is correct behavior, the FLOP cap is the equalizer, not step count.

IsoDepth result diverges from v1. In the original Mo²BERTa run, IsoDepth eventually reached 67.19% accuracy. In this run it peaks at 65.08%. This is attributed to run-to-run variance under constant LR rather than any architectural change, but a controlled repeat would be needed to confirm.

Known Bug (Fixed in Code, Not Rerun)

A scatter collision in _attn_skip caused token position 0 to occasionally receive a zeroed attention output when it was simultaneously active and used as a padding dummy index. Effect: token 0 behaved as if it exited early in affected steps. Both FULL_SKIP and FrozenKV were affected equally throughout the 600T runs, so relative comparisons remain valid. The fix (per-item scatter excluding padding slots) is present in the released code but the reported metrics reflect the buggy training runs.

Known TODOs / Future Work

Architecture

  • Sliding window / top-k frozen KV retention for longer sequences
  • Token-choice routing variant (learned non-uniform exit distribution)
  • FrozenKV + trapezoidal LR schedule combined experiment (probably updated in v3, within a month)

Performance

  • Formal wall-clock benchmark: FrozenKV vs FULL_SKIP at larger B and T
  • Variable-length FlashAttention kernel for CC 7.0+ hardware
  • torch.allclose correctness assert: padded SDPA vs original Python loop outputs

Experiments

  • LR schedule matching (trapezoidal warmup-decay per MoR paper) @ 100T (probably updated in v3, within a month)
  • LR schedule matching (trapezoidal warmup-decay per MoR paper) @ 600T (probably updated in v3, within a month)
  • Examine the embedding structure vs. vanilla BERT (probably updated in v3, within a month)
  • Scale up: larger model, fuller dataset
  • Ablation: FrozenKV with token-choice routing
  • Downstream task evaluation to test if depth gap improvement transfers

Code

  • Modularize into model.py / train.py / router.py / dataset.py
  • Add config dataclass to replace module-level constants
  • Unit tests for frozen KV accumulator shape and masking correctness

Citation

If you use this work, please cite the original MoR paper, the base Mo²BERTa prototype, and this repository:

@misc{bae2025mixtureofrecursionslearningdynamicrecursive,
      title={Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation}, 
      author={Sangmin Bae and Yujin Kim and Reza Bayat and Sungnyun Kim and Jiyoun Ha and
              Tal Schuster and Adam Fisch and Hrayr Harutyunyan and Ziwei Ji and
              Aaron Courville and Se-Young Yun},
      year={2025},
      eprint={2507.10524},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2507.10524}, 
}

@software{mo2berta_v2_proto,
  author  = {GP Bayu},
  title   = {{Mo²BERTa-v2}: Frozen KV Context for Encoder Mixture-of-Recursions},
  url     = {https://huggingface.co/gbyuvd/Mo2BERTa-v2-proto},
  version = {0.1},
  year    = {2026},
}

@software{mo2berta_proto,
  author  = {GP Bayu},
  title   = {{Mo²BERTa}: Mixture-of-Recursions for Bidirectional MLM},
  url     = {https://huggingface.co/gbyuvd/Mo2BERTa-proto},
  version = {0.1},
  year    = {2026},
}

@misc{eldan2023tinystoriessmalllanguagemodels,
      title={TinyStories: How Small Can Language Models Be and Still Speak Coherent English?}, 
      author={Ronen Eldan and Yuanzhi Li},
      year={2023},
      eprint={2305.07759},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2305.07759}, 
}

Contact

For questions about this prototype, open an issue in the source repository.
For questions about the base architecture, see Mo²BERTa.

Downloads last month
30
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including gbyuvd/Mo2BERTa-v2-proto

Papers for gbyuvd/Mo2BERTa-v2-proto