ncylich's picture
Upload README.md with huggingface_hub
3adf84f verified
---
license: gemma
base_model: google/gemma-4-e2b-it
tags:
- gemma
- moe
- grouped-experts
- pruning
---
# Gemma-4 E2B with grouped expert routing (K=96, d=0.50)
This model is a sparsified variant of `google/gemma-4-e2b-it` where each MLP's
FFN dimension is partitioned into **96 groups** via k-means on activation
profiles. At inference, the top **48 of 96 groups** fire per token (50% density),
allowing entire groups of neurons to be skipped (and their corresponding rows
of `up_proj`/`gate_proj` and columns of `down_proj` to be skipped from memory).
Group sizes (D_FFN / 96): 64 in early layers (D_FFN=6144), 128 in late layers
(D_FFN=12288). 64-element groups in bf16 = exactly 128B = one L2 cache line on
NVIDIA GPUs — **memory-bandwidth-aligned routing**.
## Training recipe
1. Train `Sw_gate_only_50` baseline: 5000 steps gate-only, density=0.50, lr=1e-4,
tau anneal 1.0→0.01, int4 QAT, on `gemma4_e2b_it_bulk_50k.jsonl`.
2. Cluster activations: 4 sequences × 2048 tokens through baseline; k-means at
K=96 per layer on |gate*up| activation profile.
3. Polish: install GroupedMaskedMLP (K=96, K_active=48), add LoRA r128 on
`up_proj`/`down_proj`, train 1000 steps CE on `gemma4_e2b_it_final_50k.jsonl`.
## Results (eval on `gemma4_e2b_it_final_50k.jsonl`)
| Metric | Value | vs Base | vs polish_50_per_neuron |
|---|---|---|---|
| PPL | 2.95 | +1.31× | +0.85× |
| MMLU | 30.6 | +0.2 | +1.0 |
| HellaSwag | 54.0 | +2.6 | -1.8 |
| ARC-Challenge | 35.2 | -2.2 | +1.6 |
| ARC-Easy | 40.6 | +1.2 | -0.6 |
| BoolQ | 74.2 | -2.4 | +3.2 |
| **5-task avg** | **46.9** | **-0.1** | **+0.7** |
| World-knowledge probe (172q) | 58.1 | -2.9 | +5.2 |
## Usage
```bash
git lfs install
git clone https://huggingface.co/Cactus-Compute/gemma4-e2b-grouped-k96
cd $(basename Cactus-Compute/gemma4-e2b-grouped-k96)
pip install torch transformers
python inference_k96.py --prompt "The capital of France is"
```
The inference script loads the base Gemma-4 weights from HuggingFace, installs
the grouped routing wrappers (loading cluster assignments from `groups/`), then
applies int4 QAT + LoRA + the trained state dict. `verify_grouped_routing()`
asserts every layer is genuinely a `GroupedMaskedMLP` with K=96 before generation.
## Citation
Internal research (Anthropic Claude + Noah Cylich, Cactus Compute), 2026-04.