Julia SLM β€” Small Language Models in Pure Julia

Transformer and Monarch Mixer language models built entirely in Julia using Lux.jl, trained on the philosophy-corpus dataset.

Models

Head-to-Head Comparison

Metric Transformer (5m-chinchilla/) Monarch Mixer (5m-monarch/)
Parameters 5,037,312 (5.04M) 4,983,040 (4.98M)
Blocks 6 8
Sequence mixing Softmax attention (4 heads) Multi-head Monarch (8 heads) + causal conv
Channel mixing SwiGLU (256β†’640β†’256) SwiGLU (256β†’640β†’256)
Positional encoding RoPE None (learned via Monarch factors)
Val loss 3.54 3.65
Val PPL 34.5 38.4
Training time 66 min 89 min
Throughput ~26K tok/s ~19K tok/s

Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB.


5M Chinchilla Transformer (5m-chinchilla/)

5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param).

Param Value
Parameters 5,037,312
Architecture Decoder-only Transformer
Embedding dim 256
Layers 6
Attention heads 4
Head dim 64
FFN multiplier 4x (SwiGLU)
Context length 256
Vocab size 2,000 (BPE)
Weight tying Yes
Normalization RMSNorm (pre-norm)
Positional encoding RoPE

Loss curve:

Step Train Loss Val Loss Val PPL
500 6.69 5.01 149.6
2,000 4.09 4.02 56.0
6,000 3.72 3.70 40.4
10,000 3.58 3.57 35.4
12,305 3.55 3.54 34.5

5M Monarch Mixer (5m-monarch/)

4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices.

Param Value
Parameters 4,983,040
Architecture Monarch Mixer
Embedding dim 256
Layers 8
Monarch heads 8
Conv kernel 4 (causal depthwise)
FFN multiplier 4x (SwiGLU)
Context length 256
Vocab size 2,000 (BPE)
Weight tying Yes
Normalization RMSNorm (pre-norm)
Gating Learned sigmoid gate

How Monarch Mixer works:

A Monarch matrix of size TΓ—T (T=pΒ²=256, p=16) factorizes as:

M = Pα΅€ Β· BlockDiag(L1) Β· P Β· BlockDiag(L2)

where L1, L2 are p block-diagonal matrices of size pΓ—p, and P is a reshape-transpose permutation. Parameters: 2pΒ³ = 2T^{3/2} (8,192 vs 65,536 for dense).

Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid.

Loss curve:

Step Train Loss Val Loss Val PPL
500 6.31 5.26 192.4
2,000 4.15 4.15 63.4
6,000 3.77 3.79 44.3
10,000 3.62 3.67 39.3
12,305 3.62 3.65 38.4

Key findings:

  • Monarch reaches 94% of baseline quality (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing
  • Uses 4x fewer parameters per block in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6
  • Generates coherent English text with dialogue, grammar, and narrative structure
  • First known Julia implementation of Monarch Mixer for language modeling

Architecture

Transformer

JuliaGPTModel
β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256)     # weight-tied with output head
β”œβ”€β”€ rope: RotaryPositionalEncoding(256)
β”œβ”€β”€ blocks Γ— 6:
β”‚   β”œβ”€β”€ ln1: RMSNorm(256)
β”‚   β”œβ”€β”€ attn: MultiHeadAttention(4 heads, 64 dim each)
β”‚   β”‚   β”œβ”€β”€ wq, wk, wv: Dense(256 β†’ 256)
β”‚   β”‚   └── wo: Dense(256 β†’ 256)
β”‚   β”œβ”€β”€ ln2: RMSNorm(256)
β”‚   └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
β”œβ”€β”€ ln_f: RMSNorm(256)
└── head: TiedEmbeddingHead β†’ (2000,)

Monarch Mixer

JuliaGPTModel
β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256)     # weight-tied with output head
β”œβ”€β”€ blocks Γ— 8:
β”‚   β”œβ”€β”€ ln1: RMSNorm(256)
β”‚   β”œβ”€β”€ seq_mixer: MonarchSequenceMixer
β”‚   β”‚   β”œβ”€β”€ conv: CausalDepthwiseConv1d(256, kernel=4)
β”‚   β”‚   β”œβ”€β”€ monarchs Γ— 8: MonarchMatrix(256, L1/L2 ∈ ℝ^{16Γ—16Γ—16})
β”‚   β”‚   └── gate: LearnedGate(256)
β”‚   β”œβ”€β”€ ln2: RMSNorm(256)
β”‚   └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
β”œβ”€β”€ ln_f: RMSNorm(256)
└── head: TiedEmbeddingHead β†’ (2000,)

Usage

Load and generate (Transformer)

using Pkg; Pkg.activate("julia-slm")
include("src/JuliaGPT.jl")
using .JuliaGPT
using .JuliaGPT: Lux, CUDA

tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt")
device = Lux.gpu_device()
ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device)

model = create_model(ModelConfig(;
    vocab_size=vocab_size(tok), embed_dim=256, n_layers=6,
    n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
    weight_tying=true,
))

text = generate(model, ps, st, tok, "the nature of ";
               max_new_tokens=200, temperature=0.8, top_k=40)
println(text)

Load and generate (Monarch Mixer)

ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device)

model = create_model(ModelConfig(;
    arch="monarch",
    vocab_size=vocab_size(tok), embed_dim=256, n_layers=8,
    n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
    weight_tying=true, n_monarch_heads=8, conv_kernel_size=4,
))

text = generate(model, ps, st, tok, "the nature of ";
               max_new_tokens=200, temperature=0.8, top_k=40)
println(text)

Train from scratch

# Transformer baseline
julia --project scripts/train.jl --config config/5m.toml

# Monarch Mixer
julia --project scripts/train.jl --config config/5m-monarch.toml

Dataset

Trained on LisaMegaWatts/philosophy-corpus β€” 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring.

  • Train tokens: 794.9M (pre-encoded as train.bin)
  • Val tokens: 88.2M (pre-encoded as val.bin)
  • Tokenizer: ByteLevel BPE, 2,000 vocab

Framework

Built with:

  • Lux.jl β€” Explicit-parameter neural networks
  • Zygote.jl β€” Automatic differentiation
  • CUDA.jl β€” GPU acceleration
  • NNlib.jl β€” Batched matrix multiply, activations
  • Optimisers.jl β€” AdamW with cosine LR

Files

5m-chinchilla/           # Baseline transformer
β”œβ”€β”€ config.toml
β”œβ”€β”€ final.jld2           # Step 12,305
└── step_12000.jld2

5m-monarch/              # Monarch Mixer variant
β”œβ”€β”€ config.toml
β”œβ”€β”€ final.jld2           # Step 12,305
└── step_12000.jld2

Checkpoints are JLD2 format containing: model parameters (ps), model state (st), optimizer state, step number, and best validation loss.

References

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train LisaMegaWatts/julia-slm

Papers for LisaMegaWatts/julia-slm