BrainStacks - Gemma 3 12B IT
Cross-Domain Cognitive Capabilities via Frozen MoE-LoRA Stacks for Continual LLM Learning
Paper (arXiv:2604.01152) Β· GitHub Β· Paper Page
Author: Mohammad R. Abu Ayyash β Brains Build Research, Ramallah, Palestine
What is this?
Fine-tuned domain stacks + meta-router for the BrainStacks architecture on Gemma 3 12B IT.
BrainStacks packages domain expertise as frozen MoE-LoRA adapter stacks that compose additively on a shared frozen base model. An outcome-based meta-router selectively activates relevant stacks per prompt, enabling cross-domain composition with zero forgetting.
Key finding: Domain stacks learn transferable cognitive primitives β instruction-following clarity, numerical reasoning, procedural logic β rather than domain-specific knowledge. Medical prompts optimally route to chat+math stacks 97% of the time, with zero medical data in those stacks.
Contents
stacks/
βββ chat/
β βββ stack_1.pt # Chat domain, round 1
β βββ stack_2.pt # Chat domain, round 2 (residual)
βββ code/
β βββ stack_1.pt
β βββ stack_2.pt
βββ math/
β βββ stack_1.pt
β βββ stack_2.pt
βββ medical/
β βββ stack_1.pt
β βββ stack_2.pt
βββ reasoning/
βββ stack_1.pt
βββ stack_2.pt
meta_router.pt # ~2M param sigmoid router
manifest.json # Domain block metadata
code/ # Full training + inference code
Architecture Details
Each stack is a MoELoRADelta module:
- 4 LoRA experts per projection (rank 16, rsLoRA scaling alpha/sqrt(r))
- Top-2 noisy routing with Shazeer-style gating
- Applied to all 7 transformer projections (q, k, v, o, gate, up, down)
- Trained under 4-bit NF4 quantization
These are NOT standard PEFT adapters. They are custom PyTorch state dicts. You must use the provided loading code.
Quick Start
These are custom MoE-LoRA state dicts, not PEFT adapters. Loading requires injecting StackedMoELoRALayer wrappers, then loading each .pt file into MoELoRADelta modules.
from huggingface_hub import snapshot_download
import torch, json, os, math
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# 1. Download all weights + code
local_dir = snapshot_download("MohammadAbuAyyash/brainstacks-gemma3-12b-it")
# 2. Load base model (4-bit NF4)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-3-12b-it",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
),
device_map="auto", torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it")
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda")
# 3. Inject StackedMoELoRALayer into all 7 projections
# (replaces q/k/v/o/gate/up/down with wrappers that hold frozen stacks)
# See code/brainstacks_inference.py for full class definitions
import sys; sys.path.insert(0, os.path.join(local_dir, "code"))
from brainstacks_inference import (
inject_stacked_layers, load_single_stack,
MetaRouter, set_domain_weights, clear_domain_weights, set_base_only
)
model, stacked_layers = inject_stacked_layers(model)
# 4. Load all domain stacks from manifest
with open(os.path.join(local_dir, "manifest.json")) as f:
manifest = json.load(f)
domain_names = []
for block in manifest["domains"]:
name = block["name"]
for sf in block["stack_files"]:
# Remap paths to local download dir
stack_path = os.path.join(local_dir, "stacks", name, os.path.basename(sf))
load_single_stack(model, stacked_layers, stack_path, device)
domain_names.append(name)
print(f" Loaded {name}: {len(block['stack_files'])} stacks")
# Set domain stack counts on all layers
counts = [len(block["stack_files"]) for block in manifest["domains"]]
for layer in stacked_layers:
layer._domain_stack_counts = counts
# 5. Load meta-router
ckpt = torch.load(os.path.join(local_dir, "meta_router.pt"), map_location=device, weights_only=False)
router = MetaRouter(token_dim=ckpt["token_dim"], n_domains=ckpt["n_domains"]).to(device)
router.load_state_dict(ckpt["state_dict"])
router.eval()
print(f"Ready: {len(domain_names)} domains, {sum(counts)} stacks/layer")
For interactive inference with disk-offloaded routing, use brainstacks_inference.py directly:
python brainstacks_inference.py
Test it
from brainstacks_inference import DiskOffloadEngine
domain_stack_paths = {}
for block in manifest["domains"]:
name = block["name"]
domain_stack_paths[name] = [os.path.join(local_dir, "stacks", name, os.path.basename(sf)) for sf in block["stack_files"]]
# Move stacks to GPU (they were loaded to CPU)
for layer in stacked_layers:
for stack in layer.frozen_stacks:
stack.to(device)
engine = DiskOffloadEngine(model, stacked_layers, router, tokenizer,
domain_names, domain_stack_paths, device)
engine._loaded_domains = set(domain_names)
for p in [
"Explain what a neural network is in simple terms.",
"Write a Python function to check if a number is prime.",
"What are the symptoms of type 2 diabetes?",
"If a train travels 120km in 2 hours, what is its speed?",
"A patient needs 500mg of medication per day split into 3 doses. How many mg per dose?",
"Prove that the square root of 2 is irrational.",
"Write Python code to calculate BMI given weight and height.",
"Explain the difference between type 1 and type 2 diabetes.",
]:
resp, stats = engine.routed_generate(p)
print(f"\n> {p}")
print(f" Route: [{stats['route']}]")
print(f" {resp[:500]}")
print("-" * 70)
Important Notes
- Stacks are additive, not replacements. Each stack builds on all previous frozen stacks. Stack 2 in any domain learns the residual that stack 1 left behind β it cannot work alone.
- The meta-router is NOT optional. Without it, all stacks fire simultaneously, causing magnitude accumulation and degraded outputs.
- Domain training order matters: chat -> code -> math -> medical -> reasoning (curriculum dependency).
Training Data
| Domain | ~Samples | Sources |
|---|---|---|
| Chat | 40K | Nemotron v2 chat, UltraFeedback SFT, Daring-Anteater |
| Code | 48K | Python Code Instructions, Nemotron v2 code, OpenCodeReasoning, OpenThoughts |
| Math | 53K | GSM8K, OpenMathReasoning CoT, NuminaMath-CoT, Nemotron v2 math |
| Medical | 20K | MedQA-USMLE, medical-o1-reasoning-SFT, PubMedQA |
| Reasoning | 50K | OpenThoughts-114k, Nemotron v2 STEM, Sky-T1, OpenMathReasoning tool |
Benchmarks (Gemma 3 12B IT, 200 samples each)
| Benchmark | Base | Routed | Delta |
|---|---|---|---|
| HellaSwag | 0.670 | 0.650 | -0.020 |
| ARC-Easy | 0.510 | 0.515 | +0.005 |
| ARC-Challenge | 0.525 | 0.495 | -0.030 |
| TruthfulQA | 0.350 | 0.370 | +0.020 |
| MMLU | 0.450 | 0.435 | -0.015 |
| GSM8K | 0.665 | 0.665 | 0.000 |
| MedQA | 0.385 | 0.350 | -0.035 |
| MedMCQA | 0.330 | 0.360 | +0.030 |
Hardware
Trained on Google Colab G4 (96GB VRAM).
Citation
@article{abuayyash2026brainstacks,
title={BrainStacks: Cross-Domain Cognitive Capabilities via Frozen MoE-LoRA Stacks for Continual LLM Learning},
author={Abu Ayyash, Mohammad R.},
year={2026},
institution={Brains Build Research}
}
