| | --- |
| | license: mit |
| | language: |
| | - en |
| | tags: |
| | - julia |
| | - lux |
| | - transformer |
| | - monarch-mixer |
| | - language-model |
| | - chinchilla |
| | - bpe |
| | datasets: |
| | - LisaMegaWatts/philosophy-corpus |
| | pipeline_tag: text-generation |
| | --- |
| | |
| | # Julia SLM β Small Language Models in Pure Julia |
| |
|
| | Transformer and Monarch Mixer language models built entirely in Julia using [Lux.jl](https://github.com/LuxDL/Lux.jl), trained on the [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) dataset. |
| |
|
| | ## Models |
| |
|
| | ### Head-to-Head Comparison |
| |
|
| | | Metric | Transformer (`5m-chinchilla/`) | Monarch Mixer (`5m-monarch/`) | |
| | |--------|------|------| |
| | | Parameters | 5,037,312 (5.04M) | 4,983,040 (4.98M) | |
| | | Blocks | 6 | 8 | |
| | | Sequence mixing | Softmax attention (4 heads) | Multi-head Monarch (8 heads) + causal conv | |
| | | Channel mixing | SwiGLU (256β640β256) | SwiGLU (256β640β256) | |
| | | Positional encoding | RoPE | None (learned via Monarch factors) | |
| | | **Val loss** | **3.54** | **3.65** | |
| | | **Val PPL** | **34.5** | **38.4** | |
| | | Training time | 66 min | 89 min | |
| | | Throughput | ~26K tok/s | ~19K tok/s | |
| |
|
| | Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB. |
| |
|
| | --- |
| |
|
| | ### 5M Chinchilla Transformer (`5m-chinchilla/`) |
| |
|
| | 5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param). |
| |
|
| | | Param | Value | |
| | |-------|-------| |
| | | Parameters | 5,037,312 | |
| | | Architecture | Decoder-only Transformer | |
| | | Embedding dim | 256 | |
| | | Layers | 6 | |
| | | Attention heads | 4 | |
| | | Head dim | 64 | |
| | | FFN multiplier | 4x (SwiGLU) | |
| | | Context length | 256 | |
| | | Vocab size | 2,000 (BPE) | |
| | | Weight tying | Yes | |
| | | Normalization | RMSNorm (pre-norm) | |
| | | Positional encoding | RoPE | |
| |
|
| | **Loss curve:** |
| |
|
| | | Step | Train Loss | Val Loss | Val PPL | |
| | |------|-----------|----------|---------| |
| | | 500 | 6.69 | 5.01 | 149.6 | |
| | | 2,000 | 4.09 | 4.02 | 56.0 | |
| | | 6,000 | 3.72 | 3.70 | 40.4 | |
| | | 10,000 | 3.58 | 3.57 | 35.4 | |
| | | 12,305 | 3.55 | 3.54 | 34.5 | |
| |
|
| | --- |
| |
|
| | ### 5M Monarch Mixer (`5m-monarch/`) |
| |
|
| | 4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices. |
| |
|
| | | Param | Value | |
| | |-------|-------| |
| | | Parameters | 4,983,040 | |
| | | Architecture | Monarch Mixer | |
| | | Embedding dim | 256 | |
| | | Layers | 8 | |
| | | Monarch heads | 8 | |
| | | Conv kernel | 4 (causal depthwise) | |
| | | FFN multiplier | 4x (SwiGLU) | |
| | | Context length | 256 | |
| | | Vocab size | 2,000 (BPE) | |
| | | Weight tying | Yes | |
| | | Normalization | RMSNorm (pre-norm) | |
| | | Gating | Learned sigmoid gate | |
| |
|
| | **How Monarch Mixer works:** |
| |
|
| | A Monarch matrix of size TΓT (T=pΒ²=256, p=16) factorizes as: |
| | ``` |
| | M = Pα΅ Β· BlockDiag(L1) Β· P Β· BlockDiag(L2) |
| | ``` |
| | where L1, L2 are p block-diagonal matrices of size pΓp, and P is a reshape-transpose permutation. Parameters: 2pΒ³ = 2T^{3/2} (8,192 vs 65,536 for dense). |
| |
|
| | Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid. |
| |
|
| | **Loss curve:** |
| |
|
| | | Step | Train Loss | Val Loss | Val PPL | |
| | |------|-----------|----------|---------| |
| | | 500 | 6.31 | 5.26 | 192.4 | |
| | | 2,000 | 4.15 | 4.15 | 63.4 | |
| | | 6,000 | 3.77 | 3.79 | 44.3 | |
| | | 10,000 | 3.62 | 3.67 | 39.3 | |
| | | 12,305 | 3.62 | 3.65 | 38.4 | |
| |
|
| | **Key findings:** |
| | - Monarch reaches **94% of baseline quality** (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing |
| | - Uses **4x fewer parameters per block** in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6 |
| | - Generates coherent English text with dialogue, grammar, and narrative structure |
| | - First known Julia implementation of Monarch Mixer for language modeling |
| |
|
| | ## Architecture |
| |
|
| | ### Transformer |
| | ``` |
| | JuliaGPTModel |
| | βββ tok_emb: Embedding(2000 β 256) # weight-tied with output head |
| | βββ rope: RotaryPositionalEncoding(256) |
| | βββ blocks Γ 6: |
| | β βββ ln1: RMSNorm(256) |
| | β βββ attn: MultiHeadAttention(4 heads, 64 dim each) |
| | β β βββ wq, wk, wv: Dense(256 β 256) |
| | β β βββ wo: Dense(256 β 256) |
| | β βββ ln2: RMSNorm(256) |
| | β βββ ffn: SwiGLU(256 β 640 β 256) |
| | βββ ln_f: RMSNorm(256) |
| | βββ head: TiedEmbeddingHead β (2000,) |
| | ``` |
| |
|
| | ### Monarch Mixer |
| | ``` |
| | JuliaGPTModel |
| | βββ tok_emb: Embedding(2000 β 256) # weight-tied with output head |
| | βββ blocks Γ 8: |
| | β βββ ln1: RMSNorm(256) |
| | β βββ seq_mixer: MonarchSequenceMixer |
| | β β βββ conv: CausalDepthwiseConv1d(256, kernel=4) |
| | β β βββ monarchs Γ 8: MonarchMatrix(256, L1/L2 β β^{16Γ16Γ16}) |
| | β β βββ gate: LearnedGate(256) |
| | β βββ ln2: RMSNorm(256) |
| | β βββ ffn: SwiGLU(256 β 640 β 256) |
| | βββ ln_f: RMSNorm(256) |
| | βββ head: TiedEmbeddingHead β (2000,) |
| | ``` |
| |
|
| | ## Usage |
| |
|
| | ### Load and generate (Transformer) |
| |
|
| | ```julia |
| | using Pkg; Pkg.activate("julia-slm") |
| | include("src/JuliaGPT.jl") |
| | using .JuliaGPT |
| | using .JuliaGPT: Lux, CUDA |
| | |
| | tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt") |
| | device = Lux.gpu_device() |
| | ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device) |
| | |
| | model = create_model(ModelConfig(; |
| | vocab_size=vocab_size(tok), embed_dim=256, n_layers=6, |
| | n_heads=4, head_dim=64, ffn_mult=4, context_length=256, |
| | weight_tying=true, |
| | )) |
| | |
| | text = generate(model, ps, st, tok, "the nature of "; |
| | max_new_tokens=200, temperature=0.8, top_k=40) |
| | println(text) |
| | ``` |
| |
|
| | ### Load and generate (Monarch Mixer) |
| |
|
| | ```julia |
| | ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device) |
| | |
| | model = create_model(ModelConfig(; |
| | arch="monarch", |
| | vocab_size=vocab_size(tok), embed_dim=256, n_layers=8, |
| | n_heads=4, head_dim=64, ffn_mult=4, context_length=256, |
| | weight_tying=true, n_monarch_heads=8, conv_kernel_size=4, |
| | )) |
| | |
| | text = generate(model, ps, st, tok, "the nature of "; |
| | max_new_tokens=200, temperature=0.8, top_k=40) |
| | println(text) |
| | ``` |
| |
|
| | ### Train from scratch |
| |
|
| | ```bash |
| | # Transformer baseline |
| | julia --project scripts/train.jl --config config/5m.toml |
| | |
| | # Monarch Mixer |
| | julia --project scripts/train.jl --config config/5m-monarch.toml |
| | ``` |
| |
|
| | ## Dataset |
| |
|
| | Trained on [LisaMegaWatts/philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) β 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring. |
| |
|
| | - **Train tokens**: 794.9M (pre-encoded as `train.bin`) |
| | - **Val tokens**: 88.2M (pre-encoded as `val.bin`) |
| | - **Tokenizer**: ByteLevel BPE, 2,000 vocab |
| |
|
| | ## Framework |
| |
|
| | Built with: |
| | - [Lux.jl](https://github.com/LuxDL/Lux.jl) β Explicit-parameter neural networks |
| | - [Zygote.jl](https://github.com/FluxML/Zygote.jl) β Automatic differentiation |
| | - [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) β GPU acceleration |
| | - [NNlib.jl](https://github.com/FluxML/NNlib.jl) β Batched matrix multiply, activations |
| | - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) β AdamW with cosine LR |
| |
|
| | ## Files |
| |
|
| | ``` |
| | 5m-chinchilla/ # Baseline transformer |
| | βββ config.toml |
| | βββ final.jld2 # Step 12,305 |
| | βββ step_12000.jld2 |
| | |
| | 5m-monarch/ # Monarch Mixer variant |
| | βββ config.toml |
| | βββ final.jld2 # Step 12,305 |
| | βββ step_12000.jld2 |
| | ``` |
| |
|
| | Checkpoints are JLD2 format containing: model parameters (`ps`), model state (`st`), optimizer state, step number, and best validation loss. |
| |
|
| | ## References |
| |
|
| | - [Monarch Mixer (Dao et al., 2023)](https://arxiv.org/abs/2310.12109) β Sub-quadratic GEMM-based architecture |
| | - [Chinchilla (Hoffmann et al., 2022)](https://arxiv.org/abs/2203.15556) β Compute-optimal training scaling |
| |
|
| | ## License |
| |
|
| | MIT |
| |
|