Gemma-4 E2B with LEARNED-ROUTER grouped expert routing (K=192, d=0.50)
Sparsified variant of google/gemma-4-e2b-it at 50% per-token density (96 of
192 expert groups fire per token). Routing is performed by a small learned
linear projection R(x) = W_r @ x β β^192 instead of the standard
max(|gelu(W_gate @ x)|) per-group reduction. This makes routing 32β64Γ
cheaper to compute than the gate-then-reduce pipeline used in the original
K=192 sibling, with a small quality cost.
Sister model: Cactus-Compute/gemma4-e2b-grouped-k192
uses the original max-gate routing β strictly higher quality but more
expensive routing compute.
Why learned routing?
The original max-gate approach computes the FULL gate_proj(x) (a
D_MODEL Γ D_FFN matmul, β9.4M FLOPs/tok early layers, β18.9M late) just to
decide which groups to fire β then either applies a mask (no compute saving)
or skips groups (memory saving but full compute paid for the gate). Replacing
this with a tiny learned router R(x) β β^192 (a D_MODEL Γ K matmul, β295k
FLOPs/tok) lets a custom kernel skip both the gate AND the unselected
group columns in up_proj/down_proj/gate_proj β opening a real
inference-time speedup.
Architecture (per layer)
- 192 expert groups partitioned via k-means on activation profiles (same
group assignments as the max-gate sibling, in
groups/) - 96 of 192 groups fire per token (50% density)
- Group sizes: 32 neurons (early, D_FFN=6144) / 64 (late, D_FFN=12288). At int8 these are 32 B / 64 B β DRAM sector aligned (32 B/sector on NVIDIA GPUs)
- Routing:
R(x) = W_r @ xwithW_r β β^{192 Γ 1536}per layer (~295k params/layer Γ 35 layers = 10.3M total). Frozen at inference; trained once via least-squares against the max-gate oracle, then continued during the polish phase. - Polish recipe: LoRA r=256 on
up_proj/down_proj+ int4 QAT (group=32) + CE loss on chat-trajectory data, 3000 steps from the gate-only baseline.
Results (PPL on gemma4_e2b_it_final_50k.jsonl, 200-seq subset)
| Variant | PPL | Routing FLOPs/tok |
|---|---|---|
| Base Gemma-4-IT (dense, no compression) | 1.27 | β (full gate) |
| Max-gate K=192 (sibling) | 2.10 | full gate (~9-19M) |
| Learned-router K=192 (this) | 2.23 | router only (~295k) |
| K=192 baseline + naive replace router (R1 lstsq, no polish) | 2.40 | router only |
Quality cost vs max-gate: +0.13 PPL (6% relative). Routing cost: 32-64Γ cheaper. Probe (172q world knowledge): equivalent.
Usage
git lfs install
git clone https://huggingface.co/Cactus-Compute/gemma4-e2b-grouped-k192-router
cd $(basename Cactus-Compute/gemma4-e2b-grouped-k192-router)
pip install torch transformers
# Inference reference:
python inference_router_k192.py \
--checkpoint Sw_grouped_50_K192_router_v27_lora256.pt \
--group_assignments_dir groups \
--group_tag s50 \
--prompt "What is the capital of France?"
The inference script:
- Loads base Gemma-4-E2B-IT from HuggingFace
- Installs
GroupedMaskedMLP+LearnedRouterMLPwrappers (cluster assignments fromgroups/) - Applies int4 QAT (g=32) and LoRA r=256 wrappers
- Loads the trained state dict β including the learned router weights
- Verifies every layer has
LearnedRouterMLPwith K=192/K_active=96 and a correctly-shapedrouter.W_rbefore generation
Notes for inference engine integration
- The router state lives in the checkpoint at
model.layers[i].mlp.router.W_r β β^{192 Γ 1536}(one per layer). - At inference, replace
group_score[:, g] = max_{j β group_g} |gelu(W_gate[j] @ x)|withgroup_score = router.W_r @ x(single matmul, K outputs) then take top-K_active and apply the same mask as the max-gate variant. - The
GroupedMaskedMLP.forwardreference inrung8_grouped_g4.pyand theLearnedRouterMLP.forwardinrung9_router.pyshow the exact mask computation. To realize the FLOP savings, your kernel should:- Compute
R(x)first (cheap) - Get top-96 group ids
- Compute
gate_proj(x),up_proj(x),down_proj(h)ONLY on the rows/columns belonging to the selected groups - Apply
gelu(gate_act) * up_actonly on selected neurons
- Compute
Recipe summary (research detail)
- Start from
Sw_gate_only_50.ptbaseline (gate-only K=192 d=0.50 trained via rung 7Sw_gate_only_50recipe). - Initialize router via closed-form least squares against per-token max-gate scores from R0 oracle (200K calibration tokens).
- Polish: install
LearnedRouterMLP+ LoRA r=256 + int4 QAT, train 3000 steps CE with--unfreeze_base --train_routerat lr=5e-6 ongemma4_e2b_it_bulk_50k.jsonl.
Full design + ablations: see SWIGLU_GATE_TRANSFER.md
section "Rung 9".
Citation
Cactus Compute (Noah Cylich) + Anthropic Claude, 2026-04. Companion research: https://github.com/kar-m/matryoshka-distil