Fix model card: match actual HF checkpoint (d=512, 8L, 8Q/2KV, ~23M params, ctx=256, FFN=1344)
Browse files
README.md
CHANGED
|
@@ -22,38 +22,38 @@ datasets:
|
|
| 22 |
|
| 23 |
# JuliaFluxGPT
|
| 24 |
|
| 25 |
-
A ~
|
| 26 |
|
| 27 |
## Model Family Context
|
| 28 |
|
| 29 |
-
JuliaFluxGPT
|
| 30 |
|
| 31 |
| Model | Framework | Architecture | Params | Attention |
|
| 32 |
|---|---|---|---|---|
|
|
|
|
| 33 |
| [SymbioGPT-10M](https://huggingface.co/LisaMegaWatts/SymbioGPT-10M) | PyTorch | 4-organelle SymbioGPT | 11.6M | OrganelleGate |
|
| 34 |
| [JuliaSLM](https://huggingface.co/LisaMegaWatts/JuliaSLM) | Lux.jl | Transformer | 5.04M | 4-head MHA |
|
| 35 |
| [MonarchSLM](https://huggingface.co/LisaMegaWatts/MonarchSLM) | Lux.jl | Monarch Mixer | 4.98M | 8-head Monarch |
|
| 36 |
| [SymbioSLM](https://huggingface.co/LisaMegaWatts/SymbioSLM) | Lux.jl | Symbiogenesis | ~4.1M | 3 organelles |
|
| 37 |
-
| **JuliaFluxGPT** | **Flux.jl** | **LLaMA-style GQA** | **~4M** | **4Q/2KV GQA** |
|
| 38 |
| [MicroJulia](https://huggingface.co/LisaMegaWatts/MicroJulia) | Flux.jl | GPT-2 style | ~1M | Standard MHA |
|
| 39 |
|
| 40 |
## Architecture
|
| 41 |
|
| 42 |
```
|
| 43 |
GPT (LLaMA-style)
|
| 44 |
-
+-- wte: Embedding(
|
| 45 |
-
+-- blocks x
|
| 46 |
-
| +-- ln1: RMSNorm(
|
| 47 |
| +-- attn: CausalSelfAttention
|
| 48 |
-
| | +-- wq: Dense(
|
| 49 |
-
| | +-- wkv: Dense(
|
| 50 |
-
| | +-- proj: Dense(
|
| 51 |
-
| +-- ln2: RMSNorm(
|
| 52 |
| +-- ffwd: SwiGLUFFN
|
| 53 |
-
| +-- w_gate: Dense(
|
| 54 |
-
| +-- w_up: Dense(
|
| 55 |
-
| +-- w_down: Dense(
|
| 56 |
-
+-- ln_f: RMSNorm(
|
| 57 |
+-- [output: weight-tied with wte]
|
| 58 |
```
|
| 59 |
|
|
@@ -61,14 +61,14 @@ GPT (LLaMA-style)
|
|
| 61 |
|
| 62 |
GQA (Ainslie et al., 2023) uses fewer key-value heads than query heads, reducing KV-cache memory during inference while maintaining quality:
|
| 63 |
|
| 64 |
-
- **
|
| 65 |
-
- **2 KV heads** (64 dim each) =
|
| 66 |
-
- **
|
| 67 |
- KV heads are repeated (expanded) to match query head count before attention computation
|
| 68 |
|
| 69 |
**Attention parameter savings:**
|
| 70 |
-
- Standard MHA: Q(
|
| 71 |
-
- GQA
|
| 72 |
|
| 73 |
### RoPE (Rotary Position Embeddings)
|
| 74 |
|
|
@@ -82,7 +82,7 @@ k_rotated = apply_rope(k, cos, sin, T)
|
|
| 82 |
### SwiGLU FFN
|
| 83 |
|
| 84 |
```
|
| 85 |
-
hidden = max(64, round_to_64(4 *
|
| 86 |
gate = swish(w_gate(x))
|
| 87 |
value = w_up(x)
|
| 88 |
output = w_down(gate * value)
|
|
@@ -92,15 +92,15 @@ output = w_down(gate * value)
|
|
| 92 |
|
| 93 |
| Parameter | Value |
|
| 94 |
|---|---|
|
| 95 |
-
| Total parameters | ~
|
| 96 |
-
| Embedding dim |
|
| 97 |
-
| Layers |
|
| 98 |
-
| Query heads |
|
| 99 |
-
| KV heads | 2 (GQA ratio =
|
| 100 |
| Head dim | 64 |
|
| 101 |
-
| FFN hidden dim |
|
| 102 |
| Context length | 256 tokens |
|
| 103 |
-
| Vocabulary |
|
| 104 |
| Position encoding | RoPE (base=10000) |
|
| 105 |
| Weight tying | Yes (forward pass uses wte.weight directly) |
|
| 106 |
| Bias | false (all layers) |
|
|
@@ -112,13 +112,12 @@ output = w_down(gate * value)
|
|
| 112 |
|---|---|
|
| 113 |
| Dataset | [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) |
|
| 114 |
| Corpus | Classical philosophy and mathematics texts |
|
| 115 |
-
| Tokenizer | BPE (HuggingFace tokenizer.json format,
|
| 116 |
| Framework | Julia + Flux.jl |
|
| 117 |
| Hardware | NVIDIA RTX 3060 12GB |
|
| 118 |
| Precision | Float32 |
|
| 119 |
-
| Best val loss | 6.
|
| 120 |
-
|
|
| 121 |
-
| Distillation | KD alpha=0.5, temperature=4.0 |
|
| 122 |
|
| 123 |
## Implementation Notes
|
| 124 |
|
|
@@ -172,10 +171,10 @@ Streaming supported with `"stream": true`.
|
|
| 172 |
|
| 173 |
| File | Description |
|
| 174 |
|---|---|
|
| 175 |
-
| `best_model.jld2` | Best checkpoint (step
|
| 176 |
-
| `final_model.jld2` | Final checkpoint
|
| 177 |
| `checkpoint_latest.jld2` | Latest training checkpoint |
|
| 178 |
-
| `tokenizer.json` | BPE tokenizer (HuggingFace format,
|
| 179 |
|
| 180 |
Checkpoint contains:
|
| 181 |
- `model_state` — Flux model weights
|
|
|
|
| 22 |
|
| 23 |
# JuliaFluxGPT
|
| 24 |
|
| 25 |
+
A ~23M parameter LLaMA-style decoder-only model with Grouped Query Attention (GQA), trained on classical philosophy and mathematics texts, implemented in Julia with Flux.jl.
|
| 26 |
|
| 27 |
## Model Family Context
|
| 28 |
|
| 29 |
+
JuliaFluxGPT is the **largest model** in the Julia SLM collection, using a different framework (Flux.jl vs Lux.jl) and a more modern attention design (GQA):
|
| 30 |
|
| 31 |
| Model | Framework | Architecture | Params | Attention |
|
| 32 |
|---|---|---|---|---|
|
| 33 |
+
| **JuliaFluxGPT** | **Flux.jl** | **LLaMA-style GQA** | **~23M** | **8Q/2KV GQA** |
|
| 34 |
| [SymbioGPT-10M](https://huggingface.co/LisaMegaWatts/SymbioGPT-10M) | PyTorch | 4-organelle SymbioGPT | 11.6M | OrganelleGate |
|
| 35 |
| [JuliaSLM](https://huggingface.co/LisaMegaWatts/JuliaSLM) | Lux.jl | Transformer | 5.04M | 4-head MHA |
|
| 36 |
| [MonarchSLM](https://huggingface.co/LisaMegaWatts/MonarchSLM) | Lux.jl | Monarch Mixer | 4.98M | 8-head Monarch |
|
| 37 |
| [SymbioSLM](https://huggingface.co/LisaMegaWatts/SymbioSLM) | Lux.jl | Symbiogenesis | ~4.1M | 3 organelles |
|
|
|
|
| 38 |
| [MicroJulia](https://huggingface.co/LisaMegaWatts/MicroJulia) | Flux.jl | GPT-2 style | ~1M | Standard MHA |
|
| 39 |
|
| 40 |
## Architecture
|
| 41 |
|
| 42 |
```
|
| 43 |
GPT (LLaMA-style)
|
| 44 |
+
+-- wte: Embedding(2000 -> 512) [weight-tied with output projection]
|
| 45 |
+
+-- blocks x 8:
|
| 46 |
+
| +-- ln1: RMSNorm(512)
|
| 47 |
| +-- attn: CausalSelfAttention
|
| 48 |
+
| | +-- wq: Dense(512 -> 512) [8 query heads, 64 dim each]
|
| 49 |
+
| | +-- wkv: Dense(512 -> 256) [2 KV heads, 64 dim each, fused K+V]
|
| 50 |
+
| | +-- proj: Dense(512 -> 512)
|
| 51 |
+
| +-- ln2: RMSNorm(512)
|
| 52 |
| +-- ffwd: SwiGLUFFN
|
| 53 |
+
| +-- w_gate: Dense(512 -> 1344) [gate path]
|
| 54 |
+
| +-- w_up: Dense(512 -> 1344) [value path]
|
| 55 |
+
| +-- w_down: Dense(1344 -> 512)
|
| 56 |
+
+-- ln_f: RMSNorm(512)
|
| 57 |
+-- [output: weight-tied with wte]
|
| 58 |
```
|
| 59 |
|
|
|
|
| 61 |
|
| 62 |
GQA (Ainslie et al., 2023) uses fewer key-value heads than query heads, reducing KV-cache memory during inference while maintaining quality:
|
| 63 |
|
| 64 |
+
- **8 query heads** (64 dim each) = full expressiveness in queries
|
| 65 |
+
- **2 KV heads** (64 dim each) = 4x KV memory reduction
|
| 66 |
+
- **4 query heads per KV group** = each KV head is shared by 4 query heads
|
| 67 |
- KV heads are repeated (expanded) to match query head count before attention computation
|
| 68 |
|
| 69 |
**Attention parameter savings:**
|
| 70 |
+
- Standard MHA: Q(512x512) + K(512x512) + V(512x512) + O(512x512) = 1,048,576
|
| 71 |
+
- GQA 8Q/2KV: Q(512x512) + KV(512x256) + O(512x512) = 655,360 (37% reduction)
|
| 72 |
|
| 73 |
### RoPE (Rotary Position Embeddings)
|
| 74 |
|
|
|
|
| 82 |
### SwiGLU FFN
|
| 83 |
|
| 84 |
```
|
| 85 |
+
hidden = max(64, round_to_64(4 * 512 * 2/3)) = 1344
|
| 86 |
gate = swish(w_gate(x))
|
| 87 |
value = w_up(x)
|
| 88 |
output = w_down(gate * value)
|
|
|
|
| 92 |
|
| 93 |
| Parameter | Value |
|
| 94 |
|---|---|
|
| 95 |
+
| Total parameters | ~23M (22,790,656) |
|
| 96 |
+
| Embedding dim | 512 |
|
| 97 |
+
| Layers | 8 |
|
| 98 |
+
| Query heads | 8 |
|
| 99 |
+
| KV heads | 2 (GQA ratio = 4:1) |
|
| 100 |
| Head dim | 64 |
|
| 101 |
+
| FFN hidden dim | 1344 |
|
| 102 |
| Context length | 256 tokens |
|
| 103 |
+
| Vocabulary | 2,000 (ByteLevel BPE) |
|
| 104 |
| Position encoding | RoPE (base=10000) |
|
| 105 |
| Weight tying | Yes (forward pass uses wte.weight directly) |
|
| 106 |
| Bias | false (all layers) |
|
|
|
|
| 112 |
|---|---|
|
| 113 |
| Dataset | [philosophy-corpus](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus) |
|
| 114 |
| Corpus | Classical philosophy and mathematics texts |
|
| 115 |
+
| Tokenizer | BPE (HuggingFace tokenizer.json format, 2000 tokens) |
|
| 116 |
| Framework | Julia + Flux.jl |
|
| 117 |
| Hardware | NVIDIA RTX 3060 12GB |
|
| 118 |
| Precision | Float32 |
|
| 119 |
+
| Best val loss | 6.622 (step 28998) |
|
| 120 |
+
| Dropout | 0.1 |
|
|
|
|
| 121 |
|
| 122 |
## Implementation Notes
|
| 123 |
|
|
|
|
| 171 |
|
| 172 |
| File | Description |
|
| 173 |
|---|---|
|
| 174 |
+
| `best_model.jld2` | Best checkpoint (step 28998, val_loss=6.622) |
|
| 175 |
+
| `final_model.jld2` | Final checkpoint |
|
| 176 |
| `checkpoint_latest.jld2` | Latest training checkpoint |
|
| 177 |
+
| `tokenizer.json` | BPE tokenizer (HuggingFace format, 2000 tokens) |
|
| 178 |
|
| 179 |
Checkpoint contains:
|
| 180 |
- `model_state` — Flux model weights
|