LisaMegaWatts commited on
Commit
0624a08
Β·
verified Β·
1 Parent(s): 45c0e0a

Initial space setup: Distilled LLaMA-style OpenAI-compatible server

Browse files
Files changed (6) hide show
  1. Dockerfile +35 -0
  2. Project.toml +7 -0
  3. README.md +45 -5
  4. checkpoint.jl +222 -0
  5. model.jl +290 -0
  6. server.jl +312 -0
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM julia:1.10-bookworm
2
+
3
+ # HuggingFace Spaces requires user ID 1000
4
+ RUN useradd -m -u 1000 user
5
+
6
+ # Shared Julia depot for package caching
7
+ ENV JULIA_DEPOT_PATH=/opt/julia-depot
8
+ RUN mkdir -p /opt/julia-depot && chmod 777 /opt/julia-depot
9
+
10
+ # Copy project file first for dependency caching
11
+ COPY --chown=user Project.toml /home/user/app/
12
+
13
+ # Install and precompile Julia packages
14
+ RUN julia --project=/home/user/app -e ' \
15
+ using Pkg; \
16
+ Pkg.instantiate(); \
17
+ Pkg.precompile(); \
18
+ println("Precompile done")'
19
+
20
+ # Copy application code
21
+ COPY --chown=user model.jl /home/user/app/
22
+ COPY --chown=user checkpoint.jl /home/user/app/
23
+ COPY --chown=user server.jl /home/user/app/
24
+
25
+ # Create checkpoints directory (model downloads from HF at runtime)
26
+ RUN mkdir -p /home/user/app/checkpoints && chown user:user /home/user/app/checkpoints
27
+
28
+ # Switch to non-root user
29
+ USER user
30
+ ENV HOME=/home/user
31
+ WORKDIR /home/user/app
32
+
33
+ EXPOSE 7860
34
+
35
+ CMD ["julia", "--project=/home/user/app", "/home/user/app/server.jl"]
Project.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [deps]
2
+ Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
3
+ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4
+ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
5
+ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
6
+ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
7
+ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
README.md CHANGED
@@ -1,10 +1,50 @@
1
  ---
2
- title: JuliaGPTDistill Space
3
- emoji: πŸ‘
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: docker
 
7
  pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: JuliaGPTDistill
3
+ emoji: "🧬"
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: mit
10
+ tags:
11
+ - julia
12
+ - flux-jl
13
+ - llama-style
14
+ - rope
15
+ - swiglu
16
+ - gqa
17
+ - rmsnorm
18
+ - bpe
19
+ - distillation
20
+ - philosophy
21
+ - openai-compatible
22
  ---
23
 
24
+ # JuliaGPTDistill Space
25
+
26
+ Distilled LLaMA-style decoder model (256d, 4L, 4Q/2KV) trained via knowledge distillation from JuliaFluxGPT. BPE tokenizer (2000 tokens). Trained on classical philosophy and mathematics.
27
+
28
+ ## Endpoints
29
+
30
+ - `GET /` β€” Health check and model info
31
+ - `GET /v1/models` β€” List available models
32
+ - `POST /v1/chat/completions` β€” Generate text (supports streaming, top-k, top-p)
33
+
34
+ ## Usage
35
+
36
+ ```bash
37
+ curl -X POST https://LisaMegaWatts-JuliaGPTDistill-space.hf.space/v1/chat/completions \
38
+ -H "Content-Type: application/json" \
39
+ -d '{"messages": [{"role": "user", "content": "the nature of"}], "max_tokens": 200}'
40
+ ```
41
+
42
+ ## Architecture
43
+
44
+ - **Model**: 256d embed, 4 layers, 4Q/2KV heads (GQA), ~1.5M params
45
+ - **Tokenizer**: BPE (2000 tokens)
46
+ - **Normalization**: RMSNorm (pre-norm)
47
+ - **Feed-forward**: SwiGLU activation
48
+ - **Weight tying**: Shared embedding/output projection
49
+ - **Training**: Knowledge distillation from JuliaFluxGPT (10M params)
50
+ - **Framework**: Flux.jl
checkpoint.jl ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ checkpoint.jl β€” Load Flux model checkpoints for JuliaFluxGPT
3
+
4
+ Loads JLD2 checkpoints saved by the juliaflux_v2 training notebook.
5
+ Supports BPE tokenizer (tokenizer.json format) with character-level fallback.
6
+
7
+ NOTE: The GPT struct no longer has TiedDense β€” weight tying is done in the
8
+ forward pass. This simplifies checkpoint loading: we load all components
9
+ normally and skip any lm_head key in the checkpoint (it's redundant since
10
+ the output projection uses wte.weight directly).
11
+ =#
12
+
13
+ include("model.jl")
14
+ using JLD2
15
+ using JSON3
16
+
17
+ # ═══════════════════════════════════════════════════════════════════════════════
18
+ # BPE Tokenizer (loaded from tokenizer.json β€” HuggingFace format)
19
+ # ═══════════════════════════════════════════════════════════════════════════════
20
+
21
+ struct BPETokenizer
22
+ vocab::Dict{String, Int}
23
+ id_to_token::Dict{Int, String}
24
+ merges::Vector{Tuple{String, String}}
25
+ merge_rank::Dict{Tuple{String, String}, Int}
26
+ byte_to_unicode::Dict{UInt8, String}
27
+ unicode_to_byte::Dict{Char, UInt8}
28
+ word_cache::Dict{String, Vector{Int}}
29
+ gpt2_pattern::Regex
30
+ end
31
+
32
+ function build_byte_to_unicode()
33
+ bs = UInt8[]
34
+ cs = Char[]
35
+ for r in [0x21:0x7e, 0xa1:0xac, 0xae:0xff]
36
+ for b in r
37
+ push!(bs, b)
38
+ push!(cs, Char(b))
39
+ end
40
+ end
41
+ n = 0
42
+ for b in 0x00:0xff
43
+ if b βˆ‰ bs
44
+ push!(bs, b)
45
+ push!(cs, Char(256 + n))
46
+ n += 1
47
+ end
48
+ end
49
+ b2u = Dict(bs[i] => string(cs[i]) for i in eachindex(bs))
50
+ u2b = Dict(v[1] => k for (k, v) in b2u)
51
+ return b2u, u2b
52
+ end
53
+
54
+ function load_bpe_tokenizer(path::String)
55
+ tok_json = JSON3.read(read(path, String))
56
+
57
+ vocab = Dict{String, Int}()
58
+ for (tok_str, id) in pairs(tok_json.model.vocab)
59
+ vocab[string(tok_str)] = Int(id) + 1 # +1 for Julia 1-indexing
60
+ end
61
+
62
+ merges = Tuple{String, String}[]
63
+ for merge_entry in tok_json.model.merges
64
+ if merge_entry isa AbstractVector && length(merge_entry) >= 2
65
+ push!(merges, (String(merge_entry[1]), String(merge_entry[2])))
66
+ else
67
+ parts = split(string(merge_entry), " ", limit=2)
68
+ if length(parts) == 2
69
+ push!(merges, (String(parts[1]), String(parts[2])))
70
+ end
71
+ end
72
+ end
73
+
74
+ id_to_token = Dict{Int, String}(id => tok for (tok, id) in vocab)
75
+ merge_rank = Dict{Tuple{String, String}, Int}(
76
+ (a, b) => i for (i, (a, b)) in enumerate(merges)
77
+ )
78
+ b2u, u2b = build_byte_to_unicode()
79
+ gpt2_pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
80
+
81
+ BPETokenizer(vocab, id_to_token, merges, merge_rank, b2u, u2b,
82
+ Dict{String, Vector{Int}}(), gpt2_pat)
83
+ end
84
+
85
+ function bpe_encode_word(tok::BPETokenizer, word::Vector{String})
86
+ tokens = copy(word)
87
+ while length(tokens) >= 2
88
+ best_rank = typemax(Int)
89
+ best_pair = ("", "")
90
+ for i in 1:length(tokens)-1
91
+ rank = get(tok.merge_rank, (tokens[i], tokens[i+1]), typemax(Int))
92
+ if rank < best_rank
93
+ best_rank = rank
94
+ best_pair = (tokens[i], tokens[i+1])
95
+ end
96
+ end
97
+ best_rank == typemax(Int) && break
98
+ a, b = best_pair
99
+ new_tokens = String[]
100
+ i = 1
101
+ while i <= length(tokens)
102
+ if i < length(tokens) && tokens[i] == a && tokens[i+1] == b
103
+ push!(new_tokens, a * b)
104
+ i += 2
105
+ else
106
+ push!(new_tokens, tokens[i])
107
+ i += 1
108
+ end
109
+ end
110
+ tokens = new_tokens
111
+ end
112
+ return tokens
113
+ end
114
+
115
+ function encode_bpe(tok::BPETokenizer, s::String)
116
+ ids = Int[]
117
+ for m in eachmatch(tok.gpt2_pattern, s)
118
+ word = m.match
119
+ cached = get(tok.word_cache, word, nothing)
120
+ if cached !== nothing
121
+ append!(ids, cached)
122
+ else
123
+ word_bytes = Vector{UInt8}(word)
124
+ chars = [tok.byte_to_unicode[b] for b in word_bytes]
125
+ tokens = bpe_encode_word(tok, chars)
126
+ word_ids = Int[]
127
+ for t in tokens
128
+ id = get(tok.vocab, t, nothing)
129
+ id !== nothing && push!(word_ids, id)
130
+ end
131
+ tok.word_cache[word] = word_ids
132
+ append!(ids, word_ids)
133
+ end
134
+ end
135
+ return ids
136
+ end
137
+
138
+ function decode_bpe(tok::BPETokenizer, ids::Vector{Int})
139
+ text = join(get(tok.id_to_token, id, "") for id in ids)
140
+ bytes = UInt8[tok.unicode_to_byte[c] for c in text if haskey(tok.unicode_to_byte, c)]
141
+ return String(bytes)
142
+ end
143
+
144
+ # ═════════════════════���═════════════════════════════════════════════════════════
145
+ # Checkpoint loading
146
+ # ═══════════════════════════════════════════════════════════════════════════════
147
+
148
+ function load_flux_checkpoint(checkpoint_path::String; tokenizer_path::String="")
149
+ println("Loading checkpoint from $checkpoint_path ...")
150
+ data = JLD2.load(checkpoint_path)
151
+
152
+ hp = data["hyperparams"]
153
+ vocab_size = Int(hp["vocab_size"])
154
+ n_embd = Int(hp["n_embd"])
155
+ block_size = Int(hp["block_size"])
156
+ n_layer = Int(hp["n_layer"])
157
+ n_head = Int(hp["n_head"])
158
+ n_kv_head = Int(get(hp, "n_kv_head", hp["n_head"]))
159
+ dropout_val = Float64(get(hp, "dropout", 0.0))
160
+
161
+ model = GPT(;
162
+ vocab_size = vocab_size,
163
+ n_embd = n_embd,
164
+ block_size = block_size,
165
+ n_layer = n_layer,
166
+ n_head = n_head,
167
+ n_kv_head = n_kv_head,
168
+ dropout = 0.0 # No dropout at inference
169
+ )
170
+
171
+ # Load weights component-by-component
172
+ ms = data["model_state"]
173
+ Flux.loadmodel!(model.wte, ms[:wte])
174
+ Flux.loadmodel!(model.drop, ms[:drop])
175
+ Flux.loadmodel!(model.blocks, ms[:blocks])
176
+ Flux.loadmodel!(model.ln_f, ms[:ln_f])
177
+
178
+ # Set to test mode (disables dropout)
179
+ Flux.testmode!(model)
180
+
181
+ step = get(data, "step", 0)
182
+ best_val = get(data, "best_val_loss", Inf)
183
+
184
+ println(" Model loaded: vocab=$vocab_size, embd=$n_embd, layers=$n_layer, " *
185
+ "heads=$(n_head)Q/$(n_kv_head)KV, block=$block_size")
186
+ println(" Step=$step, best_val=$(round(best_val, digits=4))")
187
+
188
+ # Load tokenizer
189
+ encode_fn = nothing
190
+ decode_fn = nothing
191
+
192
+ if !isempty(tokenizer_path) && isfile(tokenizer_path)
193
+ println(" Loading BPE tokenizer from $tokenizer_path")
194
+ bpe = load_bpe_tokenizer(tokenizer_path)
195
+ tok_vocab_size = length(bpe.vocab)
196
+
197
+ if tok_vocab_size != vocab_size
198
+ @warn "Vocab mismatch! Model expects vocab_size=$vocab_size but tokenizer has $tok_vocab_size tokens. " *
199
+ "Token IDs above $vocab_size will be clamped."
200
+ end
201
+
202
+ encode_fn = function(s)
203
+ ids = encode_bpe(bpe, s)
204
+ return [clamp(id, 1, vocab_size) for id in ids]
205
+ end
206
+ decode_fn = ids -> decode_bpe(bpe, ids)
207
+ println(" BPE tokenizer loaded: $(tok_vocab_size) tokens (model vocab: $vocab_size)")
208
+ else
209
+ # Character-level fallback
210
+ chars = vcat(collect('a':'z'), [' ', '.'])
211
+ stoi = Dict(c => i for (i, c) in enumerate(chars))
212
+ itos = Dict(i => c for (i, c) in enumerate(chars))
213
+ encode_fn = s -> [get(stoi, c, 1) for c in s]
214
+ decode_fn = ids -> join(get(itos, id, '?') for id in ids)
215
+ println(" No tokenizer.json found, using character-level fallback ($(length(chars)) chars)")
216
+ end
217
+
218
+ return (;
219
+ model, vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head,
220
+ step, best_val, encode_fn, decode_fn
221
+ )
222
+ end
model.jl ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ model.jl β€” LLaMA-style GPT model in Flux.jl for JuliaFluxGPT
3
+
4
+ Contains: RMSNorm, SwiGLU, CausalSelfAttention (GQA + RoPE),
5
+ TransformerBlock, GPT, and generation utilities.
6
+
7
+ Same architecture as juliaflux_v2.ipynb β€” extracted for inference serving.
8
+
9
+ NOTE: Weight tying is done by computing the output projection directly using
10
+ m.wte.weight in the forward pass. This matches the training notebooks and
11
+ ensures Flux.loadmodel! works without needing to skip lm_head.
12
+ =#
13
+
14
+ using Flux
15
+ using NNlib
16
+ using NNlib: batched_mul
17
+ using Statistics
18
+ using Random
19
+ using LinearAlgebra
20
+
21
+ # ═══════════════════════════════════════════════════════════════════════════════
22
+ # RoPE β€” Rotary Positional Embeddings
23
+ # ═══════════════════════════════════════════════════════════════════════════════
24
+
25
+ function precompute_rope_freqs(head_dim::Int, max_seq_len::Int; base::Float32 = 10000.0f0)
26
+ half_dim = head_dim Γ· 2
27
+ freqs = Float32[1.0f0 / (base ^ (Float32(2 * (i - 1)) / Float32(head_dim))) for i in 1:half_dim]
28
+ positions = Float32.(collect(0:max_seq_len-1))
29
+ angles = freqs * positions'
30
+ return cos.(angles), sin.(angles)
31
+ end
32
+
33
+ function apply_rope(x, cos_f, sin_f, T::Int)
34
+ d = size(x, 1) Γ· 2
35
+ x1 = x[1:d, :, :]
36
+ x2 = x[d+1:2d, :, :]
37
+ c = cos_f[:, 1:T]
38
+ s = sin_f[:, 1:T]
39
+ return vcat(x1 .* c .- x2 .* s, x1 .* s .+ x2 .* c)
40
+ end
41
+
42
+ # ═══════════════════════════════════════════════════════════════════════════════
43
+ # Model components
44
+ # ═══════════════════════════════════════════════════════════════════════════════
45
+
46
+ struct RMSNorm{W <: AbstractVector}
47
+ weight::W
48
+ eps::Float32
49
+ end
50
+
51
+ Flux.@layer RMSNorm
52
+
53
+ RMSNorm(dim::Int; eps::Float32 = 1.0f-6) = RMSNorm(ones(Float32, dim), eps)
54
+
55
+ function (rn::RMSNorm)(x)
56
+ rms = sqrt.(mean(x .^ 2, dims=1) .+ rn.eps)
57
+ return (x ./ rms) .* rn.weight
58
+ end
59
+
60
+ struct SwiGLUFFN
61
+ w_gate::Dense
62
+ w_up::Dense
63
+ w_down::Dense
64
+ drop::Dropout
65
+ end
66
+
67
+ Flux.@layer SwiGLUFFN
68
+
69
+ function SwiGLUFFN(n_embd::Int; bias=false, dropout=0.0)
70
+ raw_inner = Int(floor(4 * n_embd * 2 / 3))
71
+ inner_dim = max(64, 64 * div(raw_inner + 32, 64))
72
+ SwiGLUFFN(
73
+ Dense(n_embd => inner_dim; bias),
74
+ Dense(n_embd => inner_dim; bias),
75
+ Dense(inner_dim => n_embd; bias),
76
+ Dropout(dropout)
77
+ )
78
+ end
79
+
80
+ function (ff::SwiGLUFFN)(x)
81
+ ff.drop(ff.w_down(NNlib.swish(ff.w_gate(x)) .* ff.w_up(x)))
82
+ end
83
+
84
+ struct CausalSelfAttention
85
+ wq::Dense
86
+ wkv::Dense
87
+ proj::Dense
88
+ n_head::Int
89
+ n_kv_head::Int
90
+ end
91
+
92
+ Flux.@layer CausalSelfAttention trainable=(wq, wkv, proj)
93
+
94
+ function CausalSelfAttention(n_embd::Int, n_head::Int, n_kv_head::Int; bias=false)
95
+ head_dim = n_embd Γ· n_head
96
+ kv_dim = head_dim * n_kv_head
97
+ CausalSelfAttention(
98
+ Dense(n_embd => n_embd; bias),
99
+ Dense(n_embd => 2 * kv_dim; bias),
100
+ Dense(n_embd => n_embd; bias),
101
+ n_head,
102
+ n_kv_head
103
+ )
104
+ end
105
+
106
+ function (attn::CausalSelfAttention)(x, causal_mask, rope_cos, rope_sin)
107
+ C, T, B = size(x)
108
+ nh = attn.n_head
109
+ nkv = attn.n_kv_head
110
+ hs = C Γ· nh
111
+ kv_dim = hs * nkv
112
+ groups = nh Γ· nkv
113
+
114
+ q = attn.wq(x)
115
+ kv = attn.wkv(x)
116
+ k = kv[1:kv_dim, :, :]
117
+ v = kv[kv_dim+1:2*kv_dim, :, :]
118
+
119
+ q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
120
+ k = reshape(permutedims(reshape(k, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
121
+ v = reshape(permutedims(reshape(v, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
122
+
123
+ q = apply_rope(q, rope_cos, rope_sin, T)
124
+ k = apply_rope(k, rope_cos, rope_sin, T)
125
+
126
+ if groups > 1
127
+ k_4d = reshape(k, hs, T, nkv, B)
128
+ v_4d = reshape(v, hs, T, nkv, B)
129
+ k_rep = repeat(reshape(k_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
130
+ v_rep = repeat(reshape(v_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
131
+ k = reshape(permutedims(k_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
132
+ v = reshape(permutedims(v_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
133
+ end
134
+
135
+ scale = Float32(1 / sqrt(hs))
136
+ wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale
137
+ wei = wei .+ causal_mask[1:T, 1:T]
138
+ wei = softmax(wei; dims=2)
139
+
140
+ out = batched_mul(v, permutedims(wei, (2, 1, 3)))
141
+ out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)
142
+
143
+ attn.proj(out)
144
+ end
145
+
146
+ struct TransformerBlock
147
+ ln1::RMSNorm
148
+ attn::CausalSelfAttention
149
+ ln2::RMSNorm
150
+ ffwd::SwiGLUFFN
151
+ end
152
+
153
+ Flux.@layer TransformerBlock
154
+
155
+ function TransformerBlock(n_embd::Int, n_head::Int, n_kv_head::Int; dropout=0.0)
156
+ TransformerBlock(
157
+ RMSNorm(n_embd),
158
+ CausalSelfAttention(n_embd, n_head, n_kv_head),
159
+ RMSNorm(n_embd),
160
+ SwiGLUFFN(n_embd; dropout)
161
+ )
162
+ end
163
+
164
+ # ═══════════════════════════════════════════════════════════════════════════════
165
+ # GPT β€” weight-tied output projection (matches training notebooks)
166
+ # ═══════════════════════════════════════════════════════════════════════════════
167
+
168
+ struct GPT
169
+ wte::Embedding
170
+ drop::Dropout
171
+ blocks::Chain
172
+ ln_f::RMSNorm
173
+ # Precomputed constants (not trainable)
174
+ causal_mask::Matrix{Float32}
175
+ rope_cos::Matrix{Float32}
176
+ rope_sin::Matrix{Float32}
177
+ n_head::Int
178
+ n_kv_head::Int
179
+ block_size::Int
180
+ end
181
+
182
+ Flux.@layer GPT trainable=(wte, drop, blocks, ln_f)
183
+
184
+ function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, dropout=0.0)
185
+ head_dim = n_embd Γ· n_head
186
+ wte = Embedding(vocab_size => n_embd)
187
+ causal_mask = triu(fill(typemin(Float32), block_size, block_size), 1)
188
+ rope_cos, rope_sin = precompute_rope_freqs(head_dim, block_size)
189
+ GPT(
190
+ wte,
191
+ Dropout(dropout),
192
+ Chain([TransformerBlock(n_embd, n_head, n_kv_head; dropout) for _ in 1:n_layer]...),
193
+ RMSNorm(n_embd),
194
+ causal_mask,
195
+ rope_cos,
196
+ rope_sin,
197
+ n_head,
198
+ n_kv_head,
199
+ block_size
200
+ )
201
+ end
202
+
203
+ function (m::GPT)(idx)
204
+ B, T = size(idx)
205
+ tok = permutedims(m.wte(idx), (1, 3, 2)) # (C, T, B)
206
+ x = m.drop(tok)
207
+ for block in m.blocks
208
+ x = x .+ block.attn(block.ln1(x), m.causal_mask, m.rope_cos, m.rope_sin)
209
+ x = x .+ block.ffwd(block.ln2(x))
210
+ end
211
+ x = m.ln_f(x)
212
+ # Weight-tied output projection β€” same weight as embedding
213
+ W = m.wte.weight
214
+ C = size(x, 1)
215
+ x_flat = reshape(x, C, T * B)
216
+ out = W' * x_flat
217
+ reshape(out, size(W, 2), T, B)
218
+ end
219
+
220
+ # ═══════════════════════════════════════════════════════════════════════════════
221
+ # Text generation with streaming support
222
+ # ═══════════════════════════════════════════════════════════════════════════════
223
+
224
+ function generate_streaming(model, encode_fn, decode_fn, vocab_size::Int, block_size::Int;
225
+ prompt::String="", max_tokens::Int=200, temperature::Float64=0.8,
226
+ top_k::Int=40, top_p::Float64=1.0, on_token=nothing)
227
+ if !isempty(prompt)
228
+ prompt_ids = encode_fn(prompt)
229
+ idx = reshape(prompt_ids, 1, :)
230
+ else
231
+ idx = reshape([rand(1:vocab_size)], 1, 1)
232
+ end
233
+
234
+ generated_ids = Int[]
235
+
236
+ for _ in 1:max_tokens
237
+ idx_cond = idx[:, max(1, end-block_size+1):end]
238
+ logits = model(idx_cond)
239
+ logits_last = Vector{Float32}(logits[:, end, 1])
240
+
241
+ # Temperature scaling
242
+ logits_last ./= Float32(max(temperature, 0.01))
243
+
244
+ # Top-k filtering
245
+ if top_k > 0 && top_k < length(logits_last)
246
+ threshold = partialsort(logits_last, top_k; rev=true)
247
+ for i in eachindex(logits_last)
248
+ if logits_last[i] < threshold
249
+ logits_last[i] = -Inf32
250
+ end
251
+ end
252
+ end
253
+
254
+ # Top-p (nucleus) filtering
255
+ if top_p < 1.0
256
+ sorted_indices = sortperm(logits_last; rev=true)
257
+ sorted_logits = logits_last[sorted_indices]
258
+ probs_sorted = NNlib.softmax(sorted_logits)
259
+ cumprobs = cumsum(Array(probs_sorted))
260
+ cutoff = something(findfirst(>=(Float32(top_p)), cumprobs), length(probs_sorted))
261
+ for i in (cutoff+1):length(sorted_indices)
262
+ logits_last[sorted_indices[i]] = -Inf32
263
+ end
264
+ end
265
+
266
+ probs = NNlib.softmax(logits_last)
267
+ probs_cpu = Float64.(probs)
268
+
269
+ r = rand()
270
+ cum = 0.0
271
+ next_id = length(probs_cpu)
272
+ for (i, p) in enumerate(probs_cpu)
273
+ cum += p
274
+ if r <= cum
275
+ next_id = i
276
+ break
277
+ end
278
+ end
279
+
280
+ push!(generated_ids, next_id)
281
+ idx = hcat(idx, reshape([next_id], 1, 1))
282
+
283
+ if on_token !== nothing
284
+ token_str = decode_fn([next_id])
285
+ on_token(token_str)
286
+ end
287
+ end
288
+
289
+ return decode_fn(generated_ids)
290
+ end
server.jl ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ server.jl β€” OpenAI-compatible inference server for JuliaGPTDistill
3
+
4
+ Serves a Flux.jl trained LLaMA-style GPT model (RoPE, GQA, RMSNorm, SwiGLU).
5
+ Downloads checkpoint and tokenizer from HuggingFace model repo on first run.
6
+
7
+ Endpoints:
8
+ GET / -> health check / API info
9
+ GET /v1/models -> list available models
10
+ POST /v1/chat/completions -> generate text (OpenAI format, streaming supported)
11
+ =#
12
+
13
+ include("checkpoint.jl")
14
+ using HTTP
15
+ using UUIDs
16
+ using Downloads
17
+
18
+ # ═══════════════════════════════════════════════════════════════════
19
+ # Download artifacts from HuggingFace
20
+ # ═══════════════════════════════════════════════════════════════════
21
+
22
+ const CKPT_DIR = "checkpoints"
23
+ const CKPT_PATH = joinpath(CKPT_DIR, "best_model.jld2")
24
+ const TOKENIZER_PATH = joinpath(CKPT_DIR, "tokenizer.json")
25
+ const HF_REPO = get(ENV, "HF_REPO", "LisaMegaWatts/JuliaGPTDistill")
26
+ const PORT = parse(Int, get(ENV, "PORT", "7860"))
27
+
28
+ function download_from_hf(repo::String, filename::String, local_path::String)
29
+ url = "https://huggingface.co/$repo/resolve/main/$filename"
30
+ println("Downloading $url ...")
31
+ mkpath(dirname(local_path))
32
+ Downloads.download(url, local_path)
33
+ sz = round(filesize(local_path) / 1024^2, digits=1)
34
+ println(" -> $local_path ($sz MB)")
35
+ end
36
+
37
+ function ensure_artifacts()
38
+ for (localpath, remote) in [(CKPT_PATH, "best_model.jld2"),
39
+ (TOKENIZER_PATH, "tokenizer.json")]
40
+ if !isfile(localpath)
41
+ println("No local $remote found, downloading from $HF_REPO ...")
42
+ try
43
+ download_from_hf(HF_REPO, remote, localpath)
44
+ catch e
45
+ println("Download failed for $remote: $e")
46
+ println("Place $remote at $localpath manually.")
47
+ exit(1)
48
+ end
49
+ end
50
+ end
51
+ end
52
+
53
+ # ═══════════════════════════════════════════════════════════════════
54
+ # Download and load model
55
+ # ═══════════════════════════════════════════════════════════════════
56
+
57
+ ensure_artifacts()
58
+
59
+ println("\nLoading model...")
60
+ const CKPT = load_flux_checkpoint(CKPT_PATH; tokenizer_path=TOKENIZER_PATH)
61
+ const MODEL = CKPT.model
62
+ const VOCAB_SIZE = CKPT.vocab_size
63
+ const BLOCK_SIZE = CKPT.block_size
64
+ const ENCODE_FN = CKPT.encode_fn
65
+ const DECODE_FN = CKPT.decode_fn
66
+ const MODEL_CREATED_AT = Int(floor(time()))
67
+
68
+ println("\nModel ready: vocab=$(VOCAB_SIZE), embd=$(CKPT.n_embd), " *
69
+ "layers=$(CKPT.n_layer), heads=$(CKPT.n_head)Q/$(CKPT.n_kv_head)KV, " *
70
+ "block=$(BLOCK_SIZE)")
71
+
72
+ # ═══════════════════════════════════════════════════════════════════
73
+ # HTTP helpers
74
+ # ═══════════════════════════════════════════════════════════════════
75
+
76
+ const CORS_HEADERS = [
77
+ "Access-Control-Allow-Origin" => "*",
78
+ "Access-Control-Allow-Methods" => "GET, POST, OPTIONS",
79
+ "Access-Control-Allow-Headers" => "Content-Type, Authorization",
80
+ ]
81
+
82
+ function json_response(status::Int, body; extra_headers=[])
83
+ json_bytes = JSON3.write(body)
84
+ headers = [
85
+ "Content-Type" => "application/json",
86
+ CORS_HEADERS...,
87
+ extra_headers...
88
+ ]
89
+ return HTTP.Response(status, headers, json_bytes)
90
+ end
91
+
92
+ function cors_preflight()
93
+ return HTTP.Response(204, CORS_HEADERS)
94
+ end
95
+
96
+ # ═══════════════════════════════════════════════════════════════════
97
+ # Extract prompt from OpenAI chat messages
98
+ # ═══════════════════════════════════════════════════════════════════
99
+
100
+ function extract_prompt(messages)
101
+ if isempty(messages)
102
+ return ""
103
+ end
104
+ for i in length(messages):-1:1
105
+ role = string(get(messages[i], :role, ""))
106
+ if role == "user"
107
+ return string(get(messages[i], :content, ""))
108
+ end
109
+ end
110
+ return string(get(messages[end], :content, ""))
111
+ end
112
+
113
+ # ═══════════════════════════════════════════════════���═══════════════
114
+ # SSE helpers
115
+ # ═══════════════════════════════════════════════════════════════════
116
+
117
+ function sse_line(data)
118
+ return "data: $(JSON3.write(data))\n\n"
119
+ end
120
+
121
+ # ═══════════════════════════════════════════════════════════════════
122
+ # Request handler
123
+ # ═══════════════════════════════════════════════════════════════════
124
+
125
+ function handle_request(request::HTTP.Request)
126
+ method = request.method
127
+ target = request.target
128
+
129
+ # CORS preflight
130
+ if method == "OPTIONS"
131
+ return cors_preflight()
132
+ end
133
+
134
+ # GET / β€” health check and model info
135
+ if method == "GET" && target == "/"
136
+ return json_response(200, Dict(
137
+ "name" => "JuliaGPTDistill",
138
+ "version" => "1.0.0",
139
+ "description" => "Distilled LLaMA-style GPT in Flux.jl β€” knowledge distillation from JuliaFluxGPT",
140
+ "architecture" => "RoPE + SwiGLU + GQA + RMSNorm + weight tying",
141
+ "model" => Dict(
142
+ "vocab_size" => VOCAB_SIZE,
143
+ "n_embd" => CKPT.n_embd,
144
+ "n_layer" => CKPT.n_layer,
145
+ "n_head" => CKPT.n_head,
146
+ "n_kv_head" => CKPT.n_kv_head,
147
+ "block_size" => BLOCK_SIZE
148
+ ),
149
+ "endpoints" => ["/v1/models", "/v1/chat/completions"],
150
+ "features" => ["streaming", "OpenAI-compatible", "top-k", "top-p"],
151
+ "compatible_with" => ["OpenAI API", "OpenRouter"]
152
+ ))
153
+ end
154
+
155
+ # GET /v1/models β€” list available models
156
+ if method == "GET" && target == "/v1/models"
157
+ return json_response(200, Dict(
158
+ "object" => "list",
159
+ "data" => [Dict(
160
+ "id" => "juliagptdistill-philosophy",
161
+ "object" => "model",
162
+ "created" => MODEL_CREATED_AT,
163
+ "owned_by" => "juliagptdistill"
164
+ )]
165
+ ))
166
+ end
167
+
168
+ # POST /v1/chat/completions β€” generate text
169
+ if method == "POST" && target == "/v1/chat/completions"
170
+ local body
171
+ try
172
+ body = JSON3.read(String(request.body))
173
+ catch e
174
+ return json_response(400, Dict("error" => Dict(
175
+ "message" => "Invalid JSON in request body",
176
+ "type" => "invalid_request_error",
177
+ "code" => "invalid_json")))
178
+ end
179
+
180
+ temperature = Float64(clamp(get(body, :temperature, 0.8), 0.01, 2.0))
181
+ max_tokens = Int(clamp(get(body, :max_tokens, 200), 1, BLOCK_SIZE))
182
+ top_k_val = Int(clamp(get(body, :top_k, 40), 0, VOCAB_SIZE))
183
+ top_p_val = Float64(clamp(get(body, :top_p, 1.0), 0.0, 1.0))
184
+ stream = Bool(get(body, :stream, false))
185
+
186
+ messages = get(body, :messages, [])
187
+ prompt_text = extract_prompt(messages)
188
+
189
+ if stream
190
+ # ── SSE streaming response (buffered) ──
191
+ completion_id = "chatcmpl-" * string(uuid4())
192
+ created = Int(floor(time()))
193
+
194
+ buf = IOBuffer()
195
+
196
+ # Initial chunk with role
197
+ initial_chunk = Dict(
198
+ "id" => completion_id,
199
+ "object" => "chat.completion.chunk",
200
+ "created" => created,
201
+ "model" => "juliagptdistill-philosophy",
202
+ "choices" => [Dict(
203
+ "index" => 0,
204
+ "delta" => Dict("role" => "assistant", "content" => ""),
205
+ "finish_reason" => nothing
206
+ )]
207
+ )
208
+ write(buf, sse_line(initial_chunk))
209
+
210
+ token_count = Ref(0)
211
+
212
+ generate_streaming(MODEL, ENCODE_FN, DECODE_FN, VOCAB_SIZE, BLOCK_SIZE;
213
+ prompt=prompt_text, max_tokens=max_tokens,
214
+ temperature=temperature, top_k=top_k_val, top_p=top_p_val,
215
+ on_token = function(token_str)
216
+ token_count[] += 1
217
+ chunk = Dict(
218
+ "id" => completion_id,
219
+ "object" => "chat.completion.chunk",
220
+ "created" => created,
221
+ "model" => "juliagptdistill-philosophy",
222
+ "choices" => [Dict(
223
+ "index" => 0,
224
+ "delta" => Dict("content" => token_str),
225
+ "finish_reason" => nothing
226
+ )]
227
+ )
228
+ write(buf, sse_line(chunk))
229
+ end)
230
+
231
+ # Final chunk with finish_reason
232
+ prompt_tokens = length(ENCODE_FN(prompt_text))
233
+ finish_chunk = Dict(
234
+ "id" => completion_id,
235
+ "object" => "chat.completion.chunk",
236
+ "created" => created,
237
+ "model" => "juliagptdistill-philosophy",
238
+ "choices" => [Dict(
239
+ "index" => 0,
240
+ "delta" => Dict(),
241
+ "finish_reason" => token_count[] >= max_tokens ? "length" : "stop"
242
+ )],
243
+ "usage" => Dict(
244
+ "prompt_tokens" => prompt_tokens,
245
+ "completion_tokens" => token_count[],
246
+ "total_tokens" => prompt_tokens + token_count[]
247
+ )
248
+ )
249
+ write(buf, sse_line(finish_chunk))
250
+ write(buf, "data: [DONE]\n\n")
251
+
252
+ sse_body = take!(buf)
253
+ headers = [
254
+ "Content-Type" => "text/event-stream",
255
+ "Cache-Control" => "no-cache",
256
+ "X-Accel-Buffering" => "no",
257
+ CORS_HEADERS...
258
+ ]
259
+ return HTTP.Response(200, headers, sse_body)
260
+
261
+ else
262
+ # ── Standard (non-streaming) response ──
263
+ n_completions = Int(clamp(get(body, :n, 1), 1, 4))
264
+
265
+ choices = []
266
+ total_completion_tokens = 0
267
+ for i in 1:n_completions
268
+ text = generate_streaming(MODEL, ENCODE_FN, DECODE_FN, VOCAB_SIZE, BLOCK_SIZE;
269
+ prompt=prompt_text, max_tokens=max_tokens,
270
+ temperature=temperature, top_k=top_k_val, top_p=top_p_val)
271
+ finish_reason = length(text) >= max_tokens ? "length" : "stop"
272
+ push!(choices, Dict(
273
+ "index" => i - 1,
274
+ "message" => Dict("role" => "assistant", "content" => text),
275
+ "finish_reason" => finish_reason))
276
+ total_completion_tokens += length(text)
277
+ end
278
+
279
+ prompt_tokens = length(ENCODE_FN(prompt_text))
280
+ return json_response(200, Dict(
281
+ "id" => "chatcmpl-" * string(uuid4()),
282
+ "object" => "chat.completion",
283
+ "created" => Int(floor(time())),
284
+ "model" => "juliagptdistill-philosophy",
285
+ "choices" => choices,
286
+ "usage" => Dict(
287
+ "prompt_tokens" => prompt_tokens,
288
+ "completion_tokens" => total_completion_tokens,
289
+ "total_tokens" => prompt_tokens + total_completion_tokens),
290
+ "system_fingerprint" => "juliagptdistill-flux-v1"))
291
+ end
292
+ end
293
+
294
+ # 404 fallback
295
+ return json_response(404, Dict("error" => Dict(
296
+ "message" => "Not found: $method $target",
297
+ "type" => "invalid_request_error",
298
+ "code" => "not_found")))
299
+ end
300
+
301
+ # ═══════════════════════════════════════════════════════════════════
302
+ # Start server
303
+ # ═══════════════════════════════════════════════════════════════════
304
+
305
+ println("\nJuliaGPTDistill server starting on 0.0.0.0:$PORT ...")
306
+ println(" GET http://localhost:$PORT/")
307
+ println(" GET http://localhost:$PORT/v1/models")
308
+ println(" POST http://localhost:$PORT/v1/chat/completions")
309
+ println(" POST http://localhost:$PORT/v1/chat/completions (stream=true)")
310
+ println()
311
+
312
+ HTTP.serve(handle_request, "0.0.0.0", PORT)