julia-slm / README.md
LisaMegaWatts's picture
Upload README.md with huggingface_hub
fbd85d7 verified
---
license: mit
language:
- en
tags:
- julia
- lux
- transformer
- monarch-mixer
- language-model
- chinchilla
- bpe
datasets:
- LisaMegaWatts/philosophy-corpus
pipeline_tag: text-generation
---
# Julia SLM β€” Small Language Models in Pure Julia
Transformer and Monarch Mixer language models built entirely in Julia using [Lux.jl](https://github.com/LuxDL/Lux.jl), trained on the [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) dataset.
## Models
### Head-to-Head Comparison
| Metric | Transformer (`5m-chinchilla/`) | Monarch Mixer (`5m-monarch/`) |
|--------|------|------|
| Parameters | 5,037,312 (5.04M) | 4,983,040 (4.98M) |
| Blocks | 6 | 8 |
| Sequence mixing | Softmax attention (4 heads) | Multi-head Monarch (8 heads) + causal conv |
| Channel mixing | SwiGLU (256β†’640β†’256) | SwiGLU (256β†’640β†’256) |
| Positional encoding | RoPE | None (learned via Monarch factors) |
| **Val loss** | **3.54** | **3.65** |
| **Val PPL** | **34.5** | **38.4** |
| Training time | 66 min | 89 min |
| Throughput | ~26K tok/s | ~19K tok/s |
Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB.
---
### 5M Chinchilla Transformer (`5m-chinchilla/`)
5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param).
| Param | Value |
|-------|-------|
| Parameters | 5,037,312 |
| Architecture | Decoder-only Transformer |
| Embedding dim | 256 |
| Layers | 6 |
| Attention heads | 4 |
| Head dim | 64 |
| FFN multiplier | 4x (SwiGLU) |
| Context length | 256 |
| Vocab size | 2,000 (BPE) |
| Weight tying | Yes |
| Normalization | RMSNorm (pre-norm) |
| Positional encoding | RoPE |
**Loss curve:**
| Step | Train Loss | Val Loss | Val PPL |
|------|-----------|----------|---------|
| 500 | 6.69 | 5.01 | 149.6 |
| 2,000 | 4.09 | 4.02 | 56.0 |
| 6,000 | 3.72 | 3.70 | 40.4 |
| 10,000 | 3.58 | 3.57 | 35.4 |
| 12,305 | 3.55 | 3.54 | 34.5 |
---
### 5M Monarch Mixer (`5m-monarch/`)
4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices.
| Param | Value |
|-------|-------|
| Parameters | 4,983,040 |
| Architecture | Monarch Mixer |
| Embedding dim | 256 |
| Layers | 8 |
| Monarch heads | 8 |
| Conv kernel | 4 (causal depthwise) |
| FFN multiplier | 4x (SwiGLU) |
| Context length | 256 |
| Vocab size | 2,000 (BPE) |
| Weight tying | Yes |
| Normalization | RMSNorm (pre-norm) |
| Gating | Learned sigmoid gate |
**How Monarch Mixer works:**
A Monarch matrix of size TΓ—T (T=pΒ²=256, p=16) factorizes as:
```
M = Pα΅€ Β· BlockDiag(L1) Β· P Β· BlockDiag(L2)
```
where L1, L2 are p block-diagonal matrices of size pΓ—p, and P is a reshape-transpose permutation. Parameters: 2pΒ³ = 2T^{3/2} (8,192 vs 65,536 for dense).
Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid.
**Loss curve:**
| Step | Train Loss | Val Loss | Val PPL |
|------|-----------|----------|---------|
| 500 | 6.31 | 5.26 | 192.4 |
| 2,000 | 4.15 | 4.15 | 63.4 |
| 6,000 | 3.77 | 3.79 | 44.3 |
| 10,000 | 3.62 | 3.67 | 39.3 |
| 12,305 | 3.62 | 3.65 | 38.4 |
**Key findings:**
- Monarch reaches **94% of baseline quality** (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing
- Uses **4x fewer parameters per block** in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6
- Generates coherent English text with dialogue, grammar, and narrative structure
- First known Julia implementation of Monarch Mixer for language modeling
## Architecture
### Transformer
```
JuliaGPTModel
β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256) # weight-tied with output head
β”œβ”€β”€ rope: RotaryPositionalEncoding(256)
β”œβ”€β”€ blocks Γ— 6:
β”‚ β”œβ”€β”€ ln1: RMSNorm(256)
β”‚ β”œβ”€β”€ attn: MultiHeadAttention(4 heads, 64 dim each)
β”‚ β”‚ β”œβ”€β”€ wq, wk, wv: Dense(256 β†’ 256)
β”‚ β”‚ └── wo: Dense(256 β†’ 256)
β”‚ β”œβ”€β”€ ln2: RMSNorm(256)
β”‚ └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
β”œβ”€β”€ ln_f: RMSNorm(256)
└── head: TiedEmbeddingHead β†’ (2000,)
```
### Monarch Mixer
```
JuliaGPTModel
β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256) # weight-tied with output head
β”œβ”€β”€ blocks Γ— 8:
β”‚ β”œβ”€β”€ ln1: RMSNorm(256)
β”‚ β”œβ”€β”€ seq_mixer: MonarchSequenceMixer
β”‚ β”‚ β”œβ”€β”€ conv: CausalDepthwiseConv1d(256, kernel=4)
β”‚ β”‚ β”œβ”€β”€ monarchs Γ— 8: MonarchMatrix(256, L1/L2 ∈ ℝ^{16Γ—16Γ—16})
β”‚ β”‚ └── gate: LearnedGate(256)
β”‚ β”œβ”€β”€ ln2: RMSNorm(256)
β”‚ └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
β”œβ”€β”€ ln_f: RMSNorm(256)
└── head: TiedEmbeddingHead β†’ (2000,)
```
## Usage
### Load and generate (Transformer)
```julia
using Pkg; Pkg.activate("julia-slm")
include("src/JuliaGPT.jl")
using .JuliaGPT
using .JuliaGPT: Lux, CUDA
tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt")
device = Lux.gpu_device()
ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device)
model = create_model(ModelConfig(;
vocab_size=vocab_size(tok), embed_dim=256, n_layers=6,
n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
weight_tying=true,
))
text = generate(model, ps, st, tok, "the nature of ";
max_new_tokens=200, temperature=0.8, top_k=40)
println(text)
```
### Load and generate (Monarch Mixer)
```julia
ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device)
model = create_model(ModelConfig(;
arch="monarch",
vocab_size=vocab_size(tok), embed_dim=256, n_layers=8,
n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
weight_tying=true, n_monarch_heads=8, conv_kernel_size=4,
))
text = generate(model, ps, st, tok, "the nature of ";
max_new_tokens=200, temperature=0.8, top_k=40)
println(text)
```
### Train from scratch
```bash
# Transformer baseline
julia --project scripts/train.jl --config config/5m.toml
# Monarch Mixer
julia --project scripts/train.jl --config config/5m-monarch.toml
```
## Dataset
Trained on [LisaMegaWatts/philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) β€” 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring.
- **Train tokens**: 794.9M (pre-encoded as `train.bin`)
- **Val tokens**: 88.2M (pre-encoded as `val.bin`)
- **Tokenizer**: ByteLevel BPE, 2,000 vocab
## Framework
Built with:
- [Lux.jl](https://github.com/LuxDL/Lux.jl) β€” Explicit-parameter neural networks
- [Zygote.jl](https://github.com/FluxML/Zygote.jl) β€” Automatic differentiation
- [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) β€” GPU acceleration
- [NNlib.jl](https://github.com/FluxML/NNlib.jl) β€” Batched matrix multiply, activations
- [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) β€” AdamW with cosine LR
## Files
```
5m-chinchilla/ # Baseline transformer
β”œβ”€β”€ config.toml
β”œβ”€β”€ final.jld2 # Step 12,305
└── step_12000.jld2
5m-monarch/ # Monarch Mixer variant
β”œβ”€β”€ config.toml
β”œβ”€β”€ final.jld2 # Step 12,305
└── step_12000.jld2
```
Checkpoints are JLD2 format containing: model parameters (`ps`), model state (`st`), optimizer state, step number, and best validation loss.
## References
- [Monarch Mixer (Dao et al., 2023)](https://arxiv.org/abs/2310.12109) β€” Sub-quadratic GEMM-based architecture
- [Chinchilla (Hoffmann et al., 2022)](https://arxiv.org/abs/2203.15556) β€” Compute-optimal training scaling
## License
MIT