Gemma-4 E2B with LEARNED-ROUTER grouped expert routing (K=192, d=0.50)

Sparsified variant of google/gemma-4-e2b-it at 50% per-token density (96 of 192 expert groups fire per token). Routing is performed by a small learned linear projection R(x) = W_r @ x ∈ ℝ^192 instead of the standard max(|gelu(W_gate @ x)|) per-group reduction. This makes routing 32–64Γ— cheaper to compute than the gate-then-reduce pipeline used in the original K=192 sibling, with a small quality cost.

Sister model: Cactus-Compute/gemma4-e2b-grouped-k192 uses the original max-gate routing β€” strictly higher quality but more expensive routing compute.

Why learned routing?

The original max-gate approach computes the FULL gate_proj(x) (a D_MODEL Γ— D_FFN matmul, β‰ˆ9.4M FLOPs/tok early layers, β‰ˆ18.9M late) just to decide which groups to fire β€” then either applies a mask (no compute saving) or skips groups (memory saving but full compute paid for the gate). Replacing this with a tiny learned router R(x) ∈ ℝ^192 (a D_MODEL Γ— K matmul, β‰ˆ295k FLOPs/tok) lets a custom kernel skip both the gate AND the unselected group columns in up_proj/down_proj/gate_proj β€” opening a real inference-time speedup.

Architecture (per layer)

  • 192 expert groups partitioned via k-means on activation profiles (same group assignments as the max-gate sibling, in groups/)
  • 96 of 192 groups fire per token (50% density)
  • Group sizes: 32 neurons (early, D_FFN=6144) / 64 (late, D_FFN=12288). At int8 these are 32 B / 64 B β€” DRAM sector aligned (32 B/sector on NVIDIA GPUs)
  • Routing: R(x) = W_r @ x with W_r ∈ ℝ^{192 Γ— 1536} per layer (~295k params/layer Γ— 35 layers = 10.3M total). Frozen at inference; trained once via least-squares against the max-gate oracle, then continued during the polish phase.
  • Polish recipe: LoRA r=256 on up_proj/down_proj + int4 QAT (group=32) + CE loss on chat-trajectory data, 3000 steps from the gate-only baseline.

Results (PPL on gemma4_e2b_it_final_50k.jsonl, 200-seq subset)

Variant PPL Routing FLOPs/tok
Base Gemma-4-IT (dense, no compression) 1.27 β€” (full gate)
Max-gate K=192 (sibling) 2.10 full gate (~9-19M)
Learned-router K=192 (this) 2.23 router only (~295k)
K=192 baseline + naive replace router (R1 lstsq, no polish) 2.40 router only

Quality cost vs max-gate: +0.13 PPL (6% relative). Routing cost: 32-64Γ— cheaper. Probe (172q world knowledge): equivalent.

Usage

git lfs install
git clone https://huggingface.co/Cactus-Compute/gemma4-e2b-grouped-k192-router
cd $(basename Cactus-Compute/gemma4-e2b-grouped-k192-router)
pip install torch transformers
# Inference reference:
python inference_router_k192.py \
  --checkpoint Sw_grouped_50_K192_router_v27_lora256.pt \
  --group_assignments_dir groups \
  --group_tag s50 \
  --prompt "What is the capital of France?"

The inference script:

  1. Loads base Gemma-4-E2B-IT from HuggingFace
  2. Installs GroupedMaskedMLP + LearnedRouterMLP wrappers (cluster assignments from groups/)
  3. Applies int4 QAT (g=32) and LoRA r=256 wrappers
  4. Loads the trained state dict β€” including the learned router weights
  5. Verifies every layer has LearnedRouterMLP with K=192/K_active=96 and a correctly-shaped router.W_r before generation

Notes for inference engine integration

  • The router state lives in the checkpoint at model.layers[i].mlp.router.W_r ∈ ℝ^{192 Γ— 1536} (one per layer).
  • At inference, replace group_score[:, g] = max_{j ∈ group_g} |gelu(W_gate[j] @ x)| with group_score = router.W_r @ x (single matmul, K outputs) then take top-K_active and apply the same mask as the max-gate variant.
  • The GroupedMaskedMLP.forward reference in rung8_grouped_g4.py and the LearnedRouterMLP.forward in rung9_router.py show the exact mask computation. To realize the FLOP savings, your kernel should:
    • Compute R(x) first (cheap)
    • Get top-96 group ids
    • Compute gate_proj(x), up_proj(x), down_proj(h) ONLY on the rows/columns belonging to the selected groups
    • Apply gelu(gate_act) * up_act only on selected neurons

Recipe summary (research detail)

  1. Start from Sw_gate_only_50.pt baseline (gate-only K=192 d=0.50 trained via rung 7 Sw_gate_only_50 recipe).
  2. Initialize router via closed-form least squares against per-token max-gate scores from R0 oracle (200K calibration tokens).
  3. Polish: install LearnedRouterMLP + LoRA r=256 + int4 QAT, train 3000 steps CE with --unfreeze_base --train_router at lr=5e-6 on gemma4_e2b_it_bulk_50k.jsonl.

Full design + ablations: see SWIGLU_GATE_TRANSFER.md section "Rung 9".

Citation

Cactus Compute (Noah Cylich) + Anthropic Claude, 2026-04. Companion research: https://github.com/kar-m/matryoshka-distil

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support