JuliaGPTDistill / README.md
LisaMegaWatts's picture
Add proper model card: 256d/4L/4H/2KV, vocab=2000, distilled from JuliaFluxGPT
b2f06c9 verified
metadata
language:
  - en
license: mit
library_name: flux
tags:
  - julia
  - flux-jl
  - distillation
  - knowledge-distillation
  - llama-style
  - gqa
  - rope
  - rmsnorm
  - swiglu
  - bpe
  - philosophy
  - text-generation
pipeline_tag: text-generation
datasets:
  - LisaMegaWatts/philosophy-corpus
model-index:
  - name: JuliaGPTDistill
    results:
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          type: LisaMegaWatts/philosophy-corpus
          name: philosophy-corpus
        metrics:
          - type: loss
            value: 7.44
            name: Val Loss
            verified: false

JuliaGPTDistill

A ~5M parameter LLaMA-style student model distilled from JuliaFluxGPT (10M params). Uses knowledge distillation with temperature scaling to compress the teacher's knowledge into a smaller architecture.

Architecture

Parameter Value
Architecture LLaMA-style (RMSNorm, SwiGLU, RoPE, GQA)
Embedding dim 256
Layers 4
Query heads 4
KV heads 2 (GQA ratio 2:1)
Head dim 64
Context length 256 tokens
Vocabulary 2,000 (ByteLevel BPE)
Dropout 0.1
Weight tying Yes
Framework Julia + Flux.jl

Distillation Settings

Parameter Value
Teacher model JuliaFluxGPT (512d/8L/8Q/2KV)
KD temperature 4.0
KD alpha 0.5
Loss 0.5 * CE + 0.5 * KL(teacher || student)

Training

Value
Dataset philosophy-corpus
Tokenizer BPE (2,000 vocab, ByteLevel)
Training steps 4,089
Best val loss 7.44
Hardware NVIDIA RTX 3060 12GB

Inference Settings

Parameter Value
vocab_size 2,000
context_length 256
temperature 0.8
top_k 40

Note: This model requires the same BPE tokenizer used by JuliaFluxGPT. No tokenizer file is included in this repo — use the tokenizer from JuliaFluxGPT.

Checkpoint Format

JLD2 files containing:

  • model_state — Flux model weights
  • hyperparamsDict("n_embd"=>256, "n_layer"=>4, "n_head"=>4, "n_kv_head"=>2, "vocab_size"=>2000, "block_size"=>256, "dropout"=>0.1, "kd_temperature"=>4.0, "kd_alpha"=>0.5)
  • step, best_val_loss, train_losses, val_losses

Files

File Description
best_model.jld2 Best validation loss checkpoint
final_model.jld2 Final training step checkpoint
checkpoint_latest.jld2 Latest periodic checkpoint

Provenance

License

MIT