--- license: gemma base_model: google/gemma-4-e2b-it tags: - gemma - moe - grouped-experts - pruning --- # Gemma-4 E2B with grouped expert routing (K=96, d=0.50) This model is a sparsified variant of `google/gemma-4-e2b-it` where each MLP's FFN dimension is partitioned into **96 groups** via k-means on activation profiles. At inference, the top **48 of 96 groups** fire per token (50% density), allowing entire groups of neurons to be skipped (and their corresponding rows of `up_proj`/`gate_proj` and columns of `down_proj` to be skipped from memory). Group sizes (D_FFN / 96): 64 in early layers (D_FFN=6144), 128 in late layers (D_FFN=12288). 64-element groups in bf16 = exactly 128B = one L2 cache line on NVIDIA GPUs — **memory-bandwidth-aligned routing**. ## Training recipe 1. Train `Sw_gate_only_50` baseline: 5000 steps gate-only, density=0.50, lr=1e-4, tau anneal 1.0→0.01, int4 QAT, on `gemma4_e2b_it_bulk_50k.jsonl`. 2. Cluster activations: 4 sequences × 2048 tokens through baseline; k-means at K=96 per layer on |gate*up| activation profile. 3. Polish: install GroupedMaskedMLP (K=96, K_active=48), add LoRA r128 on `up_proj`/`down_proj`, train 1000 steps CE on `gemma4_e2b_it_final_50k.jsonl`. ## Results (eval on `gemma4_e2b_it_final_50k.jsonl`) | Metric | Value | vs Base | vs polish_50_per_neuron | |---|---|---|---| | PPL | 2.95 | +1.31× | +0.85× | | MMLU | 30.6 | +0.2 | +1.0 | | HellaSwag | 54.0 | +2.6 | -1.8 | | ARC-Challenge | 35.2 | -2.2 | +1.6 | | ARC-Easy | 40.6 | +1.2 | -0.6 | | BoolQ | 74.2 | -2.4 | +3.2 | | **5-task avg** | **46.9** | **-0.1** | **+0.7** | | World-knowledge probe (172q) | 58.1 | -2.9 | +5.2 | ## Usage ```bash git lfs install git clone https://huggingface.co/Cactus-Compute/gemma4-e2b-grouped-k96 cd $(basename Cactus-Compute/gemma4-e2b-grouped-k96) pip install torch transformers python inference_k96.py --prompt "The capital of France is" ``` The inference script loads the base Gemma-4 weights from HuggingFace, installs the grouped routing wrappers (loading cluster assignments from `groups/`), then applies int4 QAT + LoRA + the trained state dict. `verify_grouped_routing()` asserts every layer is genuinely a `GroupedMaskedMLP` with K=96 before generation. ## Citation Internal research (Anthropic Claude + Noah Cylich, Cactus Compute), 2026-04.