--- license: mit language: - en tags: - julia - lux - transformer - monarch-mixer - language-model - chinchilla - bpe datasets: - LisaMegaWatts/philosophy-corpus pipeline_tag: text-generation --- # Julia SLM — Small Language Models in Pure Julia Transformer and Monarch Mixer language models built entirely in Julia using [Lux.jl](https://github.com/LuxDL/Lux.jl), trained on the [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) dataset. ## Models ### Head-to-Head Comparison | Metric | Transformer (`5m-chinchilla/`) | Monarch Mixer (`5m-monarch/`) | |--------|------|------| | Parameters | 5,037,312 (5.04M) | 4,983,040 (4.98M) | | Blocks | 6 | 8 | | Sequence mixing | Softmax attention (4 heads) | Multi-head Monarch (8 heads) + causal conv | | Channel mixing | SwiGLU (256→640→256) | SwiGLU (256→640→256) | | Positional encoding | RoPE | None (learned via Monarch factors) | | **Val loss** | **3.54** | **3.65** | | **Val PPL** | **34.5** | **38.4** | | Training time | 66 min | 89 min | | Throughput | ~26K tok/s | ~19K tok/s | Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB. --- ### 5M Chinchilla Transformer (`5m-chinchilla/`) 5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param). | Param | Value | |-------|-------| | Parameters | 5,037,312 | | Architecture | Decoder-only Transformer | | Embedding dim | 256 | | Layers | 6 | | Attention heads | 4 | | Head dim | 64 | | FFN multiplier | 4x (SwiGLU) | | Context length | 256 | | Vocab size | 2,000 (BPE) | | Weight tying | Yes | | Normalization | RMSNorm (pre-norm) | | Positional encoding | RoPE | **Loss curve:** | Step | Train Loss | Val Loss | Val PPL | |------|-----------|----------|---------| | 500 | 6.69 | 5.01 | 149.6 | | 2,000 | 4.09 | 4.02 | 56.0 | | 6,000 | 3.72 | 3.70 | 40.4 | | 10,000 | 3.58 | 3.57 | 35.4 | | 12,305 | 3.55 | 3.54 | 34.5 | --- ### 5M Monarch Mixer (`5m-monarch/`) 4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices. | Param | Value | |-------|-------| | Parameters | 4,983,040 | | Architecture | Monarch Mixer | | Embedding dim | 256 | | Layers | 8 | | Monarch heads | 8 | | Conv kernel | 4 (causal depthwise) | | FFN multiplier | 4x (SwiGLU) | | Context length | 256 | | Vocab size | 2,000 (BPE) | | Weight tying | Yes | | Normalization | RMSNorm (pre-norm) | | Gating | Learned sigmoid gate | **How Monarch Mixer works:** A Monarch matrix of size T×T (T=p²=256, p=16) factorizes as: ``` M = Pᵀ · BlockDiag(L1) · P · BlockDiag(L2) ``` where L1, L2 are p block-diagonal matrices of size p×p, and P is a reshape-transpose permutation. Parameters: 2p³ = 2T^{3/2} (8,192 vs 65,536 for dense). Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid. **Loss curve:** | Step | Train Loss | Val Loss | Val PPL | |------|-----------|----------|---------| | 500 | 6.31 | 5.26 | 192.4 | | 2,000 | 4.15 | 4.15 | 63.4 | | 6,000 | 3.77 | 3.79 | 44.3 | | 10,000 | 3.62 | 3.67 | 39.3 | | 12,305 | 3.62 | 3.65 | 38.4 | **Key findings:** - Monarch reaches **94% of baseline quality** (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing - Uses **4x fewer parameters per block** in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6 - Generates coherent English text with dialogue, grammar, and narrative structure - First known Julia implementation of Monarch Mixer for language modeling ## Architecture ### Transformer ``` JuliaGPTModel ├── tok_emb: Embedding(2000 → 256) # weight-tied with output head ├── rope: RotaryPositionalEncoding(256) ├── blocks × 6: │ ├── ln1: RMSNorm(256) │ ├── attn: MultiHeadAttention(4 heads, 64 dim each) │ │ ├── wq, wk, wv: Dense(256 → 256) │ │ └── wo: Dense(256 → 256) │ ├── ln2: RMSNorm(256) │ └── ffn: SwiGLU(256 → 640 → 256) ├── ln_f: RMSNorm(256) └── head: TiedEmbeddingHead → (2000,) ``` ### Monarch Mixer ``` JuliaGPTModel ├── tok_emb: Embedding(2000 → 256) # weight-tied with output head ├── blocks × 8: │ ├── ln1: RMSNorm(256) │ ├── seq_mixer: MonarchSequenceMixer │ │ ├── conv: CausalDepthwiseConv1d(256, kernel=4) │ │ ├── monarchs × 8: MonarchMatrix(256, L1/L2 ∈ ℝ^{16×16×16}) │ │ └── gate: LearnedGate(256) │ ├── ln2: RMSNorm(256) │ └── ffn: SwiGLU(256 → 640 → 256) ├── ln_f: RMSNorm(256) └── head: TiedEmbeddingHead → (2000,) ``` ## Usage ### Load and generate (Transformer) ```julia using Pkg; Pkg.activate("julia-slm") include("src/JuliaGPT.jl") using .JuliaGPT using .JuliaGPT: Lux, CUDA tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt") device = Lux.gpu_device() ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device) model = create_model(ModelConfig(; vocab_size=vocab_size(tok), embed_dim=256, n_layers=6, n_heads=4, head_dim=64, ffn_mult=4, context_length=256, weight_tying=true, )) text = generate(model, ps, st, tok, "the nature of "; max_new_tokens=200, temperature=0.8, top_k=40) println(text) ``` ### Load and generate (Monarch Mixer) ```julia ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device) model = create_model(ModelConfig(; arch="monarch", vocab_size=vocab_size(tok), embed_dim=256, n_layers=8, n_heads=4, head_dim=64, ffn_mult=4, context_length=256, weight_tying=true, n_monarch_heads=8, conv_kernel_size=4, )) text = generate(model, ps, st, tok, "the nature of "; max_new_tokens=200, temperature=0.8, top_k=40) println(text) ``` ### Train from scratch ```bash # Transformer baseline julia --project scripts/train.jl --config config/5m.toml # Monarch Mixer julia --project scripts/train.jl --config config/5m-monarch.toml ``` ## Dataset Trained on [LisaMegaWatts/philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) — 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring. - **Train tokens**: 794.9M (pre-encoded as `train.bin`) - **Val tokens**: 88.2M (pre-encoded as `val.bin`) - **Tokenizer**: ByteLevel BPE, 2,000 vocab ## Framework Built with: - [Lux.jl](https://github.com/LuxDL/Lux.jl) — Explicit-parameter neural networks - [Zygote.jl](https://github.com/FluxML/Zygote.jl) — Automatic differentiation - [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) — GPU acceleration - [NNlib.jl](https://github.com/FluxML/NNlib.jl) — Batched matrix multiply, activations - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) — AdamW with cosine LR ## Files ``` 5m-chinchilla/ # Baseline transformer ├── config.toml ├── final.jld2 # Step 12,305 └── step_12000.jld2 5m-monarch/ # Monarch Mixer variant ├── config.toml ├── final.jld2 # Step 12,305 └── step_12000.jld2 ``` Checkpoints are JLD2 format containing: model parameters (`ps`), model state (`st`), optimizer state, step number, and best validation loss. ## References - [Monarch Mixer (Dao et al., 2023)](https://arxiv.org/abs/2310.12109) — Sub-quadratic GEMM-based architecture - [Chinchilla (Hoffmann et al., 2022)](https://arxiv.org/abs/2203.15556) — Compute-optimal training scaling ## License MIT