| --- |
| license: gemma |
| base_model: google/gemma-4-e2b-it |
| tags: |
| - gemma |
| - moe |
| - grouped-experts |
| - pruning |
| --- |
| |
| # Gemma-4 E2B with grouped expert routing (K=96, d=0.50) |
|
|
| This model is a sparsified variant of `google/gemma-4-e2b-it` where each MLP's |
| FFN dimension is partitioned into **96 groups** via k-means on activation |
| profiles. At inference, the top **48 of 96 groups** fire per token (50% density), |
| allowing entire groups of neurons to be skipped (and their corresponding rows |
| of `up_proj`/`gate_proj` and columns of `down_proj` to be skipped from memory). |
|
|
| Group sizes (D_FFN / 96): 64 in early layers (D_FFN=6144), 128 in late layers |
| (D_FFN=12288). 64-element groups in bf16 = exactly 128B = one L2 cache line on |
| NVIDIA GPUs — **memory-bandwidth-aligned routing**. |
| |
| ## Training recipe |
| |
| 1. Train `Sw_gate_only_50` baseline: 5000 steps gate-only, density=0.50, lr=1e-4, |
| tau anneal 1.0→0.01, int4 QAT, on `gemma4_e2b_it_bulk_50k.jsonl`. |
| 2. Cluster activations: 4 sequences × 2048 tokens through baseline; k-means at |
| K=96 per layer on |gate*up| activation profile. |
| 3. Polish: install GroupedMaskedMLP (K=96, K_active=48), add LoRA r128 on |
| `up_proj`/`down_proj`, train 1000 steps CE on `gemma4_e2b_it_final_50k.jsonl`. |
| |
| ## Results (eval on `gemma4_e2b_it_final_50k.jsonl`) |
| |
| | Metric | Value | vs Base | vs polish_50_per_neuron | |
| |---|---|---|---| |
| | PPL | 2.95 | +1.31× | +0.85× | |
| | MMLU | 30.6 | +0.2 | +1.0 | |
| | HellaSwag | 54.0 | +2.6 | -1.8 | |
| | ARC-Challenge | 35.2 | -2.2 | +1.6 | |
| | ARC-Easy | 40.6 | +1.2 | -0.6 | |
| | BoolQ | 74.2 | -2.4 | +3.2 | |
| | **5-task avg** | **46.9** | **-0.1** | **+0.7** | |
| | World-knowledge probe (172q) | 58.1 | -2.9 | +5.2 | |
| |
| ## Usage |
| |
| ```bash |
| git lfs install |
| git clone https://huggingface.co/Cactus-Compute/gemma4-e2b-grouped-k96 |
| cd $(basename Cactus-Compute/gemma4-e2b-grouped-k96) |
| pip install torch transformers |
| python inference_k96.py --prompt "The capital of France is" |
| ``` |
| |
| The inference script loads the base Gemma-4 weights from HuggingFace, installs |
| the grouped routing wrappers (loading cluster assignments from `groups/`), then |
| applies int4 QAT + LoRA + the trained state dict. `verify_grouped_routing()` |
| asserts every layer is genuinely a `GroupedMaskedMLP` with K=96 before generation. |
| |
| ## Citation |
| |
| Internal research (Anthropic Claude + Noah Cylich, Cactus Compute), 2026-04. |
| |