LisaMegaWatts commited on
Commit
fbd85d7
Β·
verified Β·
1 Parent(s): a61fa31

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +132 -43
README.md CHANGED
@@ -6,6 +6,7 @@ tags:
6
  - julia
7
  - lux
8
  - transformer
 
9
  - language-model
10
  - chinchilla
11
  - bpe
@@ -16,13 +17,31 @@ pipeline_tag: text-generation
16
 
17
  # Julia SLM β€” Small Language Models in Pure Julia
18
 
19
- Transformer language models built entirely in Julia using [Lux.jl](https://github.com/LuxDL/Lux.jl), trained on the [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) dataset.
20
 
21
  ## Models
22
 
23
- ### 5M Chinchilla (`5m-chinchilla/`)
24
 
25
- A 5.04M parameter transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  | Param | Value |
28
  |-------|-------|
@@ -38,21 +57,6 @@ A 5.04M parameter transformer trained to Chinchilla-optimal (100M tokens at 20 t
38
  | Weight tying | Yes |
39
  | Normalization | RMSNorm (pre-norm) |
40
  | Positional encoding | RoPE |
41
- | Bias | None |
42
-
43
- **Training details:**
44
-
45
- | Metric | Value |
46
- |--------|-------|
47
- | Optimizer | AdamW (lr=6e-4, min_lr=6e-5, wd=0.1) |
48
- | Schedule | Cosine decay with 500-step warmup |
49
- | Batch size | 32 |
50
- | Training steps | 12,305 |
51
- | Tokens processed | ~100M |
52
- | Training time | 66 min on RTX 3060 12GB |
53
- | Throughput | ~26K tok/s |
54
- | Final val loss | 3.54 |
55
- | Final val PPL | 34.5 |
56
 
57
  **Loss curve:**
58
 
@@ -64,8 +68,56 @@ A 5.04M parameter transformer trained to Chinchilla-optimal (100M tokens at 20 t
64
  | 10,000 | 3.58 | 3.57 | 35.4 |
65
  | 12,305 | 3.55 | 3.54 | 34.5 |
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ## Architecture
68
 
 
69
  ```
70
  JuliaGPTModel
71
  β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256) # weight-tied with output head
@@ -76,58 +128,86 @@ JuliaGPTModel
76
  β”‚ β”‚ β”œβ”€β”€ wq, wk, wv: Dense(256 β†’ 256)
77
  β”‚ β”‚ └── wo: Dense(256 β†’ 256)
78
  β”‚ β”œβ”€β”€ ln2: RMSNorm(256)
79
- β”‚ └── ffn: SwiGLU(256 β†’ 1024 β†’ 256)
80
- β”‚ β”œβ”€β”€ w1: Dense(256 β†’ 1024) # gate
81
- β”‚ β”œβ”€β”€ v: Dense(256 β†’ 1024) # value
82
- β”‚ └── w2: Dense(1024 β†’ 256) # down-project
83
  β”œβ”€β”€ ln_f: RMSNorm(256)
84
- └── head: TiedEmbeddingHead β†’ (2000,) # shares tok_emb weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ```
86
 
87
  ## Usage
88
 
89
- ### Load and generate
90
 
91
  ```julia
92
  using Pkg; Pkg.activate("julia-slm")
93
-
94
  include("src/JuliaGPT.jl")
95
  using .JuliaGPT
96
- using .JuliaGPT: Lux, CUDA, LuxCUDA
97
 
98
- # Load tokenizer
99
  tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt")
100
-
101
- # Load checkpoint
102
- device = Lux.gpu_device() # or Lux.cpu_device()
103
  ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device)
104
 
105
- # Create model (must match checkpoint architecture)
106
  model = create_model(ModelConfig(;
107
  vocab_size=vocab_size(tok), embed_dim=256, n_layers=6,
108
  n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
109
  weight_tying=true,
110
  ))
111
 
112
- # Generate
113
  text = generate(model, ps, st, tok, "the nature of ";
114
  max_new_tokens=200, temperature=0.8, top_k=40)
115
  println(text)
116
  ```
117
 
118
- ### Resume training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  ```bash
121
- julia --project scripts/train.jl --config config/5m.toml --resume 5m-chinchilla/final.jld2
 
 
 
 
122
  ```
123
 
124
  ## Dataset
125
 
126
- Trained on [LisaMegaWatts/philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) β€” a curated collection of 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring.
127
 
128
  - **Train tokens**: 794.9M (pre-encoded as `train.bin`)
129
  - **Val tokens**: 88.2M (pre-encoded as `val.bin`)
130
- - **Tokenizer**: ByteLevel BPE, 2,000 vocab (also available: 4K variant)
131
 
132
  ## Framework
133
 
@@ -135,20 +215,29 @@ Built with:
135
  - [Lux.jl](https://github.com/LuxDL/Lux.jl) β€” Explicit-parameter neural networks
136
  - [Zygote.jl](https://github.com/FluxML/Zygote.jl) β€” Automatic differentiation
137
  - [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) β€” GPU acceleration
 
138
  - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) β€” AdamW with cosine LR
139
- - [NNlib.jl](https://github.com/FluxML/NNlib.jl) β€” Softmax, activations
140
- - [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) β€” GPU-compatible cross-entropy
141
 
142
  ## Files
143
 
144
  ```
145
- 5m-chinchilla/
146
- β”œβ”€β”€ config.toml # Training config (TOML)
147
- β”œβ”€β”€ final.jld2 # Final checkpoint (step 12305)
148
- └── step_12000.jld2 # Intermediate checkpoint
 
 
 
 
 
149
  ```
150
 
151
- Checkpoints are saved in JLD2 format and contain: model parameters (`ps`), model state (`st`), optimizer state, step number, and best validation loss.
 
 
 
 
 
152
 
153
  ## License
154
 
 
6
  - julia
7
  - lux
8
  - transformer
9
+ - monarch-mixer
10
  - language-model
11
  - chinchilla
12
  - bpe
 
17
 
18
  # Julia SLM β€” Small Language Models in Pure Julia
19
 
20
+ Transformer and Monarch Mixer language models built entirely in Julia using [Lux.jl](https://github.com/LuxDL/Lux.jl), trained on the [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) dataset.
21
 
22
  ## Models
23
 
24
+ ### Head-to-Head Comparison
25
 
26
+ | Metric | Transformer (`5m-chinchilla/`) | Monarch Mixer (`5m-monarch/`) |
27
+ |--------|------|------|
28
+ | Parameters | 5,037,312 (5.04M) | 4,983,040 (4.98M) |
29
+ | Blocks | 6 | 8 |
30
+ | Sequence mixing | Softmax attention (4 heads) | Multi-head Monarch (8 heads) + causal conv |
31
+ | Channel mixing | SwiGLU (256β†’640β†’256) | SwiGLU (256β†’640β†’256) |
32
+ | Positional encoding | RoPE | None (learned via Monarch factors) |
33
+ | **Val loss** | **3.54** | **3.65** |
34
+ | **Val PPL** | **34.5** | **38.4** |
35
+ | Training time | 66 min | 89 min |
36
+ | Throughput | ~26K tok/s | ~19K tok/s |
37
+
38
+ Both trained identically: AdamW (lr=6e-4), cosine decay, 12,305 steps, batch 32, RTX 3060 12GB.
39
+
40
+ ---
41
+
42
+ ### 5M Chinchilla Transformer (`5m-chinchilla/`)
43
+
44
+ 5.04M parameter decoder-only transformer trained to Chinchilla-optimal (100M tokens at 20 tokens/param).
45
 
46
  | Param | Value |
47
  |-------|-------|
 
57
  | Weight tying | Yes |
58
  | Normalization | RMSNorm (pre-norm) |
59
  | Positional encoding | RoPE |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  **Loss curve:**
62
 
 
68
  | 10,000 | 3.58 | 3.57 | 35.4 |
69
  | 12,305 | 3.55 | 3.54 | 34.5 |
70
 
71
+ ---
72
+
73
+ ### 5M Monarch Mixer (`5m-monarch/`)
74
+
75
+ 4.98M parameter Monarch Mixer variant using sub-quadratic sequence mixing with structured matrices.
76
+
77
+ | Param | Value |
78
+ |-------|-------|
79
+ | Parameters | 4,983,040 |
80
+ | Architecture | Monarch Mixer |
81
+ | Embedding dim | 256 |
82
+ | Layers | 8 |
83
+ | Monarch heads | 8 |
84
+ | Conv kernel | 4 (causal depthwise) |
85
+ | FFN multiplier | 4x (SwiGLU) |
86
+ | Context length | 256 |
87
+ | Vocab size | 2,000 (BPE) |
88
+ | Weight tying | Yes |
89
+ | Normalization | RMSNorm (pre-norm) |
90
+ | Gating | Learned sigmoid gate |
91
+
92
+ **How Monarch Mixer works:**
93
+
94
+ A Monarch matrix of size TΓ—T (T=pΒ²=256, p=16) factorizes as:
95
+ ```
96
+ M = Pα΅€ Β· BlockDiag(L1) Β· P Β· BlockDiag(L2)
97
+ ```
98
+ where L1, L2 are p block-diagonal matrices of size pΓ—p, and P is a reshape-transpose permutation. Parameters: 2pΒ³ = 2T^{3/2} (8,192 vs 65,536 for dense).
99
+
100
+ Each block uses 8 independent Monarch heads (each mixing 32 channels over 256 positions) combined with a causal depthwise convolution for local n-gram patterns, gated by a learned sigmoid.
101
+
102
+ **Loss curve:**
103
+
104
+ | Step | Train Loss | Val Loss | Val PPL |
105
+ |------|-----------|----------|---------|
106
+ | 500 | 6.31 | 5.26 | 192.4 |
107
+ | 2,000 | 4.15 | 4.15 | 63.4 |
108
+ | 6,000 | 3.77 | 3.79 | 44.3 |
109
+ | 10,000 | 3.62 | 3.67 | 39.3 |
110
+ | 12,305 | 3.62 | 3.65 | 38.4 |
111
+
112
+ **Key findings:**
113
+ - Monarch reaches **94% of baseline quality** (3.65 vs 3.54 val loss) with O(T^{3/2}) parameter complexity in sequence mixing
114
+ - Uses **4x fewer parameters per block** in sequence mixing (67K vs 262K), enabling 8 blocks instead of 6
115
+ - Generates coherent English text with dialogue, grammar, and narrative structure
116
+ - First known Julia implementation of Monarch Mixer for language modeling
117
+
118
  ## Architecture
119
 
120
+ ### Transformer
121
  ```
122
  JuliaGPTModel
123
  β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256) # weight-tied with output head
 
128
  β”‚ β”‚ β”œβ”€β”€ wq, wk, wv: Dense(256 β†’ 256)
129
  β”‚ β”‚ └── wo: Dense(256 β†’ 256)
130
  β”‚ β”œβ”€β”€ ln2: RMSNorm(256)
131
+ β”‚ └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
 
 
 
132
  β”œβ”€β”€ ln_f: RMSNorm(256)
133
+ └── head: TiedEmbeddingHead β†’ (2000,)
134
+ ```
135
+
136
+ ### Monarch Mixer
137
+ ```
138
+ JuliaGPTModel
139
+ β”œβ”€β”€ tok_emb: Embedding(2000 β†’ 256) # weight-tied with output head
140
+ β”œβ”€β”€ blocks Γ— 8:
141
+ β”‚ β”œβ”€β”€ ln1: RMSNorm(256)
142
+ β”‚ β”œβ”€β”€ seq_mixer: MonarchSequenceMixer
143
+ β”‚ β”‚ β”œβ”€β”€ conv: CausalDepthwiseConv1d(256, kernel=4)
144
+ β”‚ β”‚ β”œβ”€β”€ monarchs Γ— 8: MonarchMatrix(256, L1/L2 ∈ ℝ^{16Γ—16Γ—16})
145
+ β”‚ β”‚ └── gate: LearnedGate(256)
146
+ β”‚ β”œβ”€β”€ ln2: RMSNorm(256)
147
+ β”‚ └── ffn: SwiGLU(256 β†’ 640 β†’ 256)
148
+ β”œβ”€β”€ ln_f: RMSNorm(256)
149
+ └── head: TiedEmbeddingHead β†’ (2000,)
150
  ```
151
 
152
  ## Usage
153
 
154
+ ### Load and generate (Transformer)
155
 
156
  ```julia
157
  using Pkg; Pkg.activate("julia-slm")
 
158
  include("src/JuliaGPT.jl")
159
  using .JuliaGPT
160
+ using .JuliaGPT: Lux, CUDA
161
 
 
162
  tok = BPETokenizer("path/to/vocab.json", "path/to/merges.txt")
163
+ device = Lux.gpu_device()
 
 
164
  ps, st, _, step, val_loss = load_checkpoint("5m-chinchilla/final.jld2"; device)
165
 
 
166
  model = create_model(ModelConfig(;
167
  vocab_size=vocab_size(tok), embed_dim=256, n_layers=6,
168
  n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
169
  weight_tying=true,
170
  ))
171
 
 
172
  text = generate(model, ps, st, tok, "the nature of ";
173
  max_new_tokens=200, temperature=0.8, top_k=40)
174
  println(text)
175
  ```
176
 
177
+ ### Load and generate (Monarch Mixer)
178
+
179
+ ```julia
180
+ ps, st, _, step, val_loss = load_checkpoint("5m-monarch/final.jld2"; device)
181
+
182
+ model = create_model(ModelConfig(;
183
+ arch="monarch",
184
+ vocab_size=vocab_size(tok), embed_dim=256, n_layers=8,
185
+ n_heads=4, head_dim=64, ffn_mult=4, context_length=256,
186
+ weight_tying=true, n_monarch_heads=8, conv_kernel_size=4,
187
+ ))
188
+
189
+ text = generate(model, ps, st, tok, "the nature of ";
190
+ max_new_tokens=200, temperature=0.8, top_k=40)
191
+ println(text)
192
+ ```
193
+
194
+ ### Train from scratch
195
 
196
  ```bash
197
+ # Transformer baseline
198
+ julia --project scripts/train.jl --config config/5m.toml
199
+
200
+ # Monarch Mixer
201
+ julia --project scripts/train.jl --config config/5m-monarch.toml
202
  ```
203
 
204
  ## Dataset
205
 
206
+ Trained on [LisaMegaWatts/philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) β€” 981 source texts (BookCorpus, WikiText-103, PG-19, classical philosophy) processed through a custom text pipeline with deduplication and quality scoring.
207
 
208
  - **Train tokens**: 794.9M (pre-encoded as `train.bin`)
209
  - **Val tokens**: 88.2M (pre-encoded as `val.bin`)
210
+ - **Tokenizer**: ByteLevel BPE, 2,000 vocab
211
 
212
  ## Framework
213
 
 
215
  - [Lux.jl](https://github.com/LuxDL/Lux.jl) β€” Explicit-parameter neural networks
216
  - [Zygote.jl](https://github.com/FluxML/Zygote.jl) β€” Automatic differentiation
217
  - [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) β€” GPU acceleration
218
+ - [NNlib.jl](https://github.com/FluxML/NNlib.jl) β€” Batched matrix multiply, activations
219
  - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) β€” AdamW with cosine LR
 
 
220
 
221
  ## Files
222
 
223
  ```
224
+ 5m-chinchilla/ # Baseline transformer
225
+ β”œβ”€β”€ config.toml
226
+ β”œβ”€β”€ final.jld2 # Step 12,305
227
+ └── step_12000.jld2
228
+
229
+ 5m-monarch/ # Monarch Mixer variant
230
+ β”œβ”€β”€ config.toml
231
+ β”œβ”€β”€ final.jld2 # Step 12,305
232
+ └── step_12000.jld2
233
  ```
234
 
235
+ Checkpoints are JLD2 format containing: model parameters (`ps`), model state (`st`), optimizer state, step number, and best validation loss.
236
+
237
+ ## References
238
+
239
+ - [Monarch Mixer (Dao et al., 2023)](https://arxiv.org/abs/2310.12109) β€” Sub-quadratic GEMM-based architecture
240
+ - [Chinchilla (Hoffmann et al., 2022)](https://arxiv.org/abs/2203.15556) β€” Compute-optimal training scaling
241
 
242
  ## License
243