JuliaFluxGPT
A ~23M parameter LLaMA-style decoder-only model with Grouped Query Attention (GQA), trained on classical philosophy and mathematics texts, implemented in Julia with Flux.jl.
Model Family Context
JuliaFluxGPT is the largest model in the Julia SLM collection, using a different framework (Flux.jl vs Lux.jl) and a more modern attention design (GQA):
| Model | Framework | Architecture | Params | Attention |
|---|---|---|---|---|
| JuliaFluxGPT | Flux.jl | LLaMA-style GQA | ~23M | 8Q/2KV GQA |
| SymbioGPT-10M | PyTorch | 4-organelle SymbioGPT | 11.6M | OrganelleGate |
| JuliaSLM | Lux.jl | Transformer | 5.04M | 4-head MHA |
| MonarchSLM | Lux.jl | Monarch Mixer | 4.98M | 8-head Monarch |
| SymbioSLM | Lux.jl | Symbiogenesis | ~4.1M | 3 organelles |
| MicroJulia | Flux.jl | GPT-2 style | ~1M | Standard MHA |
Architecture
GPT (LLaMA-style)
+-- wte: Embedding(2000 -> 512) [weight-tied with output projection]
+-- blocks x 8:
| +-- ln1: RMSNorm(512)
| +-- attn: CausalSelfAttention
| | +-- wq: Dense(512 -> 512) [8 query heads, 64 dim each]
| | +-- wkv: Dense(512 -> 256) [2 KV heads, 64 dim each, fused K+V]
| | +-- proj: Dense(512 -> 512)
| +-- ln2: RMSNorm(512)
| +-- ffwd: SwiGLUFFN
| +-- w_gate: Dense(512 -> 1344) [gate path]
| +-- w_up: Dense(512 -> 1344) [value path]
| +-- w_down: Dense(1344 -> 512)
+-- ln_f: RMSNorm(512)
+-- [output: weight-tied with wte]
Grouped Query Attention (GQA)
GQA (Ainslie et al., 2023) uses fewer key-value heads than query heads, reducing KV-cache memory during inference while maintaining quality:
- 8 query heads (64 dim each) = full expressiveness in queries
- 2 KV heads (64 dim each) = 4x KV memory reduction
- 4 query heads per KV group = each KV head is shared by 4 query heads
- KV heads are repeated (expanded) to match query head count before attention computation
Attention parameter savings:
- Standard MHA: Q(512x512) + K(512x512) + V(512x512) + O(512x512) = 1,048,576
- GQA 8Q/2KV: Q(512x512) + KV(512x256) + O(512x512) = 655,360 (37% reduction)
RoPE (Rotary Position Embeddings)
Applied to Q and K after projection, before attention scores:
cos_cache, sin_cache = precompute_rope_freqs(head_dim=64, max_seq_len=256)
q_rotated = apply_rope(q, cos, sin, T)
k_rotated = apply_rope(k, cos, sin, T)
SwiGLU FFN
hidden = max(64, round_to_64(4 * 512 * 2/3)) = 1344
gate = swish(w_gate(x))
value = w_up(x)
output = w_down(gate * value)
Model Details
| Parameter | Value |
|---|---|
| Total parameters | ~23M (22,790,656) |
| Embedding dim | 512 |
| Layers | 8 |
| Query heads | 8 |
| KV heads | 2 (GQA ratio = 4:1) |
| Head dim | 64 |
| FFN hidden dim | 1344 |
| Context length | 256 tokens |
| Vocabulary | 2,000 (ByteLevel BPE) |
| Position encoding | RoPE (base=10000) |
| Weight tying | Yes (forward pass uses wte.weight directly) |
| Bias | false (all layers) |
| Dropout | 0.1 (training), 0.0 (inference) |
Training
| Value | |
|---|---|
| Dataset | philosophy-corpus |
| Corpus | Classical philosophy and mathematics texts |
| Tokenizer | BPE (HuggingFace tokenizer.json format, 2000 tokens) |
| Framework | Julia + Flux.jl |
| Hardware | NVIDIA RTX 3060 12GB |
| Precision | Float32 |
| Best val loss | 6.622 (step 28998) |
| Dropout | 0.1 |
Implementation Notes
Flux.jl vs Lux.jl
JuliaFluxGPT uses Flux.jl (implicit parameters, @layer macro) rather than Lux.jl (explicit parameters). Key differences:
| Flux.jl (this model) | Lux.jl (JuliaSLM family) | |
|---|---|---|
| Parameter style | Implicit (stored in model struct) | Explicit (separate ps NamedTuple) |
| State management | Flux.testmode!() |
Explicit state st |
| Serialization | Flux.loadmodel!() |
JLD2 direct load |
| AD backend | Zygote | Zygote |
Weight Tying Implementation
Weight tying is implemented in the forward pass rather than through a separate tied layer:
function (m::GPT)(idx)
# ... forward through blocks ...
x = m.ln_f(x)
W = m.wte.weight # reuse embedding weights
out = W' * reshape(x, C, T*B) # transpose matmul
reshape(out, vocab_size, T, B)
end
This avoids complications with Flux.loadmodel! when loading checkpoints.
Usage
OpenAI-Compatible API
Served via JuliaFluxGPT Space:
curl -X POST https://lisamegawatts-juliafluxgpt.hf.space/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "the nature of"}],
"max_tokens": 200,
"temperature": 0.8,
"top_k": 40
}'
Streaming supported with "stream": true.
Files
| File | Description |
|---|---|
best_model.jld2 |
Best checkpoint (step 28998, val_loss=6.622) |
final_model.jld2 |
Final checkpoint |
checkpoint_latest.jld2 |
Latest training checkpoint |
tokenizer.json |
BPE tokenizer (HuggingFace format, 2000 tokens) |
Checkpoint contains:
model_state— Flux model weightshyperparams— Dict with vocab_size, n_embd, block_size, n_layer, n_head, n_kv_headstep— Training step at checkpointbest_val_loss— Best validation loss achieved
Provenance
- Author: LisaMegaWatts
- Source: DavinciDreams/symbiogenesis
- Training notebook:
juliaflux_v2.ipynb - Training date: February 2026
- Architecture reference: LLaMA (Touvron et al., 2023) with GQA (Ainslie et al., 2023)
References
- Touvron, H., et al. (2023). LLaMA: Open and Efficient Foundation Language Models.
- Ainslie, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.
- Karpathy, A. (2023). nanoGPT. GitHub repository.
Citation
@misc{juliafluxgpt2026,
title={JuliaFluxGPT: A LLaMA-style GQA Model in Julia/Flux.jl},
author={LisaMegaWatts},
year={2026},
url={https://huggingface.co/LisaMegaWatts/JuliaFluxGPT}
}
License
MIT