Julia SLM β Small Language Models in Pure Julia
Transformer and Monarch Mixer language models built entirely in Julia using Lux.jl, trained on the philosophy-corpus dataset.
Models
Head-to-Head Comparison
| Metric | Transformer (5m-chinchilla/) |
Monarch Mixer (5m-monarch/) |
|---|---|---|
| Parameters | 5,037,312 (5.04M) | 4,983,040 (4.98M) |
| Blocks | 6 | 8 |
| Sequence mixing | Softmax attention (4 heads) | Multi-head Monarch (8 heads) + causal conv |
| Channel mixing | SwiGLU (256β640β256) | SwiGLU (256β640β256) |
| Positional encoding | RoPE | None (learned via Monarch factors) |
| Val loss | 3.54 | 3.65 |
| Val PPL | 34.5 | 38.4 |
| Training time | 66 min | 89 min |
| Throughput | ~26K tok/s | ~19K tok/s |
Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB.
5M Chinchilla Transformer (5m-chinchilla/)
5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param).
| Param | Value |
|---|---|
| Parameters | 5,037,312 |
| Architecture | Decoder-only Transformer |
| Embedding dim | 256 |
| Layers | 6 |
| Attention heads | 4 |
| Head dim | 64 |
| FFN multiplier | 4x (SwiGLU) |
| Context length | 256 |
| Vocab size | 2,000 (BPE) |
| Weight tying | Yes |
| Normalization | RMSNorm (pre-norm) |
| Positional encoding | RoPE |
Loss curve:
| Step | Train Loss | Val Loss | Val PPL |
|---|---|---|---|
| 500 | 6.69 | 5.01 | 149.6 |
| 2,000 | 4.09 | 4.02 | 56.0 |
| 6,000 | 3.72 | 3.70 | 40.4 |
| 10,000 | 3.58 | 3.57 | 35.4 |
| 12,305 | 3.55 | 3.54 | 34.5 |
5M Monarch Mixer (5m-monarch/)
4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices.
| Param | Value |
|---|---|
| Parameters | 4,983,040 |
| Architecture | Monarch Mixer |
| Embedding dim | 256 |
| Layers | 8 |
| Monarch heads | 8 |
| Conv kernel | 4 (causal depthwise) |
| FFN multiplier | 4x (SwiGLU) |
| Context length | 256 |
| Vocab size | 2,000 (BPE) |
| Weight tying | Yes |
| Normalization | RMSNorm (pre-norm) |
| Gating | Learned sigmoid gate |
How Monarch Mixer works:
A Monarch matrix of size TΓT (T=pΒ²=256, p=16) factorizes as:
M = Pα΅ Β· BlockDiag(L1) Β· P Β· BlockDiag(L2)
where L1, L2 are p block-diagonal matrices of size pΓp, and P is a reshape-transpose permutation. Parameters: 2pΒ³ = 2T^{3/2} (8,192 vs 65,536 for dense).
Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid.
Loss curve:
| Step | Train Loss | Val Loss | Val PPL |
|---|---|---|---|
| 500 | 6.31 | 5.26 | 192.4 |
| 2,000 | 4.15 | 4.15 | 63.4 |
| 6,000 | 3.77 | 3.79 | 44.3 |
| 10,000 | 3.62 | 3.67 | 39.3 |
| 12,305 | 3.62 | 3.65 | 38.4 |
Key findings:
- Monarch reaches 94% of baseline quality (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing
- Uses 4x fewer parameters per block in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6
- Generates coherent English text with dialogue, grammar, and narrative structure
- First known Julia implementation of Monarch Mixer for language modeling
Architecture
Transformer
JuliaGPTModel
βββ tok_emb: Embedding(2000 β 256) # weight-tied with output head
βββ rope: RotaryPositionalEncoding(256)
βββ blocks Γ 6:
β βββ ln1: RMSNorm(256)
β βββ attn: MultiHeadAttention(4 heads, 64 dim each)
β β βββ wq, wk, wv: Dense(256 β 256)
β β βββ wo: Dense(256 β 256)
β βββ ln2: RMSNorm(256)
β βββ ffn: SwiGLU(256 β 640 β 256)
βββ ln_f: RMSNorm(256)
βββ head: TiedEmbeddingHead β (2000,)
Monarch Mixer
JuliaGPTModel
βββ tok_emb: Embedding(2000 β 256) # weight-tied with output head
βββ blocks Γ 8:
β βββ ln1: RMSNorm(256)
β βββ seq_mixer: MonarchSequenceMixer
β β βββ conv: CausalDepthwiseConv1d(256, kernel=4)
β β βββ monarchs Γ 8: MonarchMatrix(256, L1/L2 β β^{16Γ16Γ16})
β β βββ gate: LearnedGate(256)
β βββ ln2: RMSNorm(256)
β βββ ffn: SwiGLU(256 β 640 β 256)
βββ ln_f: RMSNorm(256)
βββ head: TiedEmbeddingHead β (2000,)
Usage
Load and generate (Transformer)
using Pkg; Pkg.activate("julia-slm")
include("src/JuliaGPT.jl")
using .JuliaGPT
using .JuliaGPT: Lux, CUDA
tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt")
device = Lux.gpu_device()
ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device)
model = create_model(ModelConfig(;
vocab_size=vocab_size(tok), embed_dim=256, n_layers=6,
n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
weight_tying=true,
))
text = generate(model, ps, st, tok, "the nature of ";
max_new_tokens=200, temperature=0.8, top_k=40)
println(text)
Load and generate (Monarch Mixer)
ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device)
model = create_model(ModelConfig(;
arch="monarch",
vocab_size=vocab_size(tok), embed_dim=256, n_layers=8,
n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
weight_tying=true, n_monarch_heads=8, conv_kernel_size=4,
))
text = generate(model, ps, st, tok, "the nature of ";
max_new_tokens=200, temperature=0.8, top_k=40)
println(text)
Train from scratch
# Transformer baseline
julia --project scripts/train.jl --config config/5m.toml
# Monarch Mixer
julia --project scripts/train.jl --config config/5m-monarch.toml
Dataset
Trained on LisaMegaWatts/philosophy-corpus β 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring.
- Train tokens: 794.9M (pre-encoded as
train.bin) - Val tokens: 88.2M (pre-encoded as
val.bin) - Tokenizer: ByteLevel BPE, 2,000 vocab
Framework
Built with:
- Lux.jl β Explicit-parameter neural networks
- Zygote.jl β Automatic differentiation
- CUDA.jl β GPU acceleration
- NNlib.jl β Batched matrix multiply, activations
- Optimisers.jl β AdamW with cosine LR
Files
5m-chinchilla/ # Baseline transformer
βββ config.toml
βββ final.jld2 # Step 12,305
βββ step_12000.jld2
5m-monarch/ # Monarch Mixer variant
βββ config.toml
βββ final.jld2 # Step 12,305
βββ step_12000.jld2
Checkpoints are JLD2 format containing: model parameters (ps), model state (st), optimizer state, step number, and best validation loss.
References
- Monarch Mixer (Dao et al., 2023) β Sub-quadratic GEMM-based architecture
- Chinchilla (Hoffmann et al., 2022) β Compute-optimal training scaling
License
MIT