BrainStacks - Gemma 3 12B IT

Cross-Domain Cognitive Capabilities via Frozen MoE-LoRA Stacks for Continual LLM Learning

fig1_lora_vs_moe

Paper (arXiv:2604.01152) Β· GitHub Β· Paper Page

Author: Mohammad R. Abu Ayyash β€” Brains Build Research, Ramallah, Palestine


What is this?

Fine-tuned domain stacks + meta-router for the BrainStacks architecture on Gemma 3 12B IT.

BrainStacks packages domain expertise as frozen MoE-LoRA adapter stacks that compose additively on a shared frozen base model. An outcome-based meta-router selectively activates relevant stacks per prompt, enabling cross-domain composition with zero forgetting.

Key finding: Domain stacks learn transferable cognitive primitives β€” instruction-following clarity, numerical reasoning, procedural logic β€” rather than domain-specific knowledge. Medical prompts optimally route to chat+math stacks 97% of the time, with zero medical data in those stacks.

Contents

stacks/
β”œβ”€β”€ chat/
β”‚   β”œβ”€β”€ stack_1.pt          # Chat domain, round 1
β”‚   └── stack_2.pt          # Chat domain, round 2 (residual)
β”œβ”€β”€ code/
β”‚   β”œβ”€β”€ stack_1.pt
β”‚   └── stack_2.pt
β”œβ”€β”€ math/
β”‚   β”œβ”€β”€ stack_1.pt
β”‚   └── stack_2.pt
β”œβ”€β”€ medical/
β”‚   β”œβ”€β”€ stack_1.pt
β”‚   └── stack_2.pt
└── reasoning/
    β”œβ”€β”€ stack_1.pt
    └── stack_2.pt
meta_router.pt              # ~2M param sigmoid router
manifest.json               # Domain block metadata
code/                       # Full training + inference code

Architecture Details

Each stack is a MoELoRADelta module:

  • 4 LoRA experts per projection (rank 16, rsLoRA scaling alpha/sqrt(r))
  • Top-2 noisy routing with Shazeer-style gating
  • Applied to all 7 transformer projections (q, k, v, o, gate, up, down)
  • Trained under 4-bit NF4 quantization

These are NOT standard PEFT adapters. They are custom PyTorch state dicts. You must use the provided loading code.

Quick Start

These are custom MoE-LoRA state dicts, not PEFT adapters. Loading requires injecting StackedMoELoRALayer wrappers, then loading each .pt file into MoELoRADelta modules.

from huggingface_hub import snapshot_download
import torch, json, os, math
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# 1. Download all weights + code
local_dir = snapshot_download("MohammadAbuAyyash/brainstacks-gemma3-12b-it")

# 2. Load base model (4-bit NF4)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3-12b-it",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
    ),
    device_map="auto", torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it")
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda")

# 3. Inject StackedMoELoRALayer into all 7 projections
#    (replaces q/k/v/o/gate/up/down with wrappers that hold frozen stacks)
#    See code/brainstacks_inference.py for full class definitions
import sys; sys.path.insert(0, os.path.join(local_dir, "code"))
from brainstacks_inference import (
    inject_stacked_layers, load_single_stack,
    MetaRouter, set_domain_weights, clear_domain_weights, set_base_only
)

model, stacked_layers = inject_stacked_layers(model)

# 4. Load all domain stacks from manifest
with open(os.path.join(local_dir, "manifest.json")) as f:
    manifest = json.load(f)

domain_names = []
for block in manifest["domains"]:
    name = block["name"]
    for sf in block["stack_files"]:
        # Remap paths to local download dir
        stack_path = os.path.join(local_dir, "stacks", name, os.path.basename(sf))
        load_single_stack(model, stacked_layers, stack_path, device)
    domain_names.append(name)
    print(f"  Loaded {name}: {len(block['stack_files'])} stacks")

# Set domain stack counts on all layers
counts = [len(block["stack_files"]) for block in manifest["domains"]]
for layer in stacked_layers:
    layer._domain_stack_counts = counts

# 5. Load meta-router
ckpt = torch.load(os.path.join(local_dir, "meta_router.pt"), map_location=device, weights_only=False)
router = MetaRouter(token_dim=ckpt["token_dim"], n_domains=ckpt["n_domains"]).to(device)
router.load_state_dict(ckpt["state_dict"])
router.eval()

print(f"Ready: {len(domain_names)} domains, {sum(counts)} stacks/layer")

For interactive inference with disk-offloaded routing, use brainstacks_inference.py directly:

python brainstacks_inference.py

Test it

from brainstacks_inference import DiskOffloadEngine

domain_stack_paths = {}
for block in manifest["domains"]:
    name = block["name"]
    domain_stack_paths[name] = [os.path.join(local_dir, "stacks", name, os.path.basename(sf)) for sf in block["stack_files"]]

# Move stacks to GPU (they were loaded to CPU)
for layer in stacked_layers:
    for stack in layer.frozen_stacks:
        stack.to(device)

engine = DiskOffloadEngine(model, stacked_layers, router, tokenizer,
                           domain_names, domain_stack_paths, device)
engine._loaded_domains = set(domain_names)

for p in [
    "Explain what a neural network is in simple terms.",
    "Write a Python function to check if a number is prime.",
    "What are the symptoms of type 2 diabetes?",
    "If a train travels 120km in 2 hours, what is its speed?",
    "A patient needs 500mg of medication per day split into 3 doses. How many mg per dose?",
    "Prove that the square root of 2 is irrational.",
    "Write Python code to calculate BMI given weight and height.",
    "Explain the difference between type 1 and type 2 diabetes.",
]:
    resp, stats = engine.routed_generate(p)
    print(f"\n> {p}")
    print(f"  Route: [{stats['route']}]")
    print(f"  {resp[:500]}")
    print("-" * 70)

Important Notes

  1. Stacks are additive, not replacements. Each stack builds on all previous frozen stacks. Stack 2 in any domain learns the residual that stack 1 left behind β€” it cannot work alone.
  2. The meta-router is NOT optional. Without it, all stacks fire simultaneously, causing magnitude accumulation and degraded outputs.
  3. Domain training order matters: chat -> code -> math -> medical -> reasoning (curriculum dependency).

Training Data

Domain ~Samples Sources
Chat 40K Nemotron v2 chat, UltraFeedback SFT, Daring-Anteater
Code 48K Python Code Instructions, Nemotron v2 code, OpenCodeReasoning, OpenThoughts
Math 53K GSM8K, OpenMathReasoning CoT, NuminaMath-CoT, Nemotron v2 math
Medical 20K MedQA-USMLE, medical-o1-reasoning-SFT, PubMedQA
Reasoning 50K OpenThoughts-114k, Nemotron v2 STEM, Sky-T1, OpenMathReasoning tool

Benchmarks (Gemma 3 12B IT, 200 samples each)

Benchmark Base Routed Delta
HellaSwag 0.670 0.650 -0.020
ARC-Easy 0.510 0.515 +0.005
ARC-Challenge 0.525 0.495 -0.030
TruthfulQA 0.350 0.370 +0.020
MMLU 0.450 0.435 -0.015
GSM8K 0.665 0.665 0.000
MedQA 0.385 0.350 -0.035
MedMCQA 0.330 0.360 +0.030

Hardware

Trained on Google Colab G4 (96GB VRAM).

Citation

@article{abuayyash2026brainstacks,
  title={BrainStacks: Cross-Domain Cognitive Capabilities via Frozen MoE-LoRA Stacks for Continual LLM Learning},
  author={Abu Ayyash, Mohammad R.},
  year={2026},
  institution={Brains Build Research}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for MohammadAbuAyyash/brainstacks-gemma3-12b-it

Adapter
(351)
this model

Paper for MohammadAbuAyyash/brainstacks-gemma3-12b-it