LisaMegaWatts commited on
Commit
6f2e71d
·
verified ·
1 Parent(s): b85c7dc

Restore Julia-native server (replace Python/FastAPI with Flux.jl + HTTP.jl)

Browse files
Files changed (8) hide show
  1. Dockerfile +31 -6
  2. Project.toml +7 -0
  3. README.md +39 -30
  4. checkpoint.jl +222 -0
  5. model.jl +290 -0
  6. requirements.txt +0 -7
  7. server.jl +312 -0
  8. server.py +0 -708
Dockerfile CHANGED
@@ -1,10 +1,35 @@
1
- FROM python:3.11-slim
 
 
2
  RUN useradd -m -u 1000 user
3
- WORKDIR /home/user/app
4
- COPY --chown=user requirements.txt .
5
- RUN pip install --no-cache-dir -r requirements.txt
6
- COPY --chown=user server.py .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  USER user
8
  ENV HOME=/home/user
 
 
9
  EXPOSE 7860
10
- CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
1
+ FROM julia:1.10-bookworm
2
+
3
+ # HuggingFace Spaces requires user ID 1000
4
  RUN useradd -m -u 1000 user
5
+
6
+ # Shared Julia depot for package caching
7
+ ENV JULIA_DEPOT_PATH=/opt/julia-depot
8
+ RUN mkdir -p /opt/julia-depot && chmod 777 /opt/julia-depot
9
+
10
+ # Copy project file first for dependency caching
11
+ COPY --chown=user Project.toml /home/user/app/
12
+
13
+ # Install and precompile Julia packages
14
+ RUN julia --project=/home/user/app -e ' \
15
+ using Pkg; \
16
+ Pkg.instantiate(); \
17
+ Pkg.precompile(); \
18
+ println("Precompile done")'
19
+
20
+ # Copy application code
21
+ COPY --chown=user model.jl /home/user/app/
22
+ COPY --chown=user checkpoint.jl /home/user/app/
23
+ COPY --chown=user server.jl /home/user/app/
24
+
25
+ # Create checkpoints directory (model downloads from HF at runtime)
26
+ RUN mkdir -p /home/user/app/checkpoints && chown user:user /home/user/app/checkpoints
27
+
28
+ # Switch to non-root user
29
  USER user
30
  ENV HOME=/home/user
31
+ WORKDIR /home/user/app
32
+
33
  EXPOSE 7860
34
+
35
+ CMD ["julia", "--project=/home/user/app", "/home/user/app/server.jl"]
Project.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [deps]
2
+ Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
3
+ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4
+ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
5
+ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
6
+ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
7
+ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
README.md CHANGED
@@ -1,51 +1,60 @@
1
  ---
2
  title: JuliaFluxGPT
3
- emoji: 🏛️
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: docker
 
7
  pinned: false
8
  license: mit
9
- short_description: LLaMA-style GPT in Flux.jl — philosophy text generation
10
- app_port: 7860
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  # JuliaFluxGPT
14
 
15
- A LLaMA-style small language model built in Flux.jl, trained on classical philosophy and mathematics texts with a pre-punctuation 28-character vocabulary.
16
 
17
- **100% Julia — no Python dependencies.**
18
 
19
- ## API
 
 
20
 
21
- OpenAI-compatible inference endpoint:
22
 
23
  ```bash
24
- curl -X POST https://lisamegawatts-juliafluxgpt.hf.space/v1/chat/completions \
 
25
  -H "Content-Type: application/json" \
26
- -d '{"messages":[{"role":"user","content":"the nature of"}],"temperature":0.8,"max_tokens":200}'
27
- ```
28
 
29
- ### Endpoints
30
-
31
- | Method | Path | Description |
32
- |--------|------|-------------|
33
- | GET | `/` | Health check + API info |
34
- | GET | `/v1/models` | List available models |
35
- | POST | `/v1/chat/completions` | Generate text (OpenAI format) |
36
 
37
  ## Architecture
38
 
39
- - LLaMA-style decoder-only transformer
40
- - RoPE (Rotary Positional Embeddings)
41
- - SwiGLU feed-forward blocks
42
- - GQA (Grouped Query Attention)
43
- - RMSNorm (pre-norm)
44
- - Weight tying (embedding = output projection)
45
- - BPE tokenizer with character-level fallback
46
 
47
- ## Links
48
 
49
- - [Training data](https://huggingface.co/datasets/LisaMegaWatts/philosophy-corpus)
50
- - [Source code](https://github.com/DavinciDreams/JuliaGPT)
51
- - [JuliaGPT (autograd version)](https://huggingface.co/spaces/LisaMegaWatts/JuliaGPT)
 
1
  ---
2
  title: JuliaFluxGPT
3
+ emoji: "\U0001F9E0"
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  license: mit
10
+ tags:
11
+ - julia
12
+ - flux-jl
13
+ - llama-style
14
+ - rope
15
+ - swiglu
16
+ - gqa
17
+ - rmsnorm
18
+ - bpe
19
+ - philosophy
20
+ - openai-compatible
21
  ---
22
 
23
  # JuliaFluxGPT
24
 
25
+ A LLaMA-style decoder-only model (RoPE, GQA, RMSNorm, SwiGLU, weight-tied) trained on classical philosophy and mathematics texts, implemented in Julia with Flux.jl. Serves an OpenAI-compatible API with streaming support.
26
 
27
+ ## Endpoints
28
 
29
+ - `GET /` — Health check and model info
30
+ - `GET /v1/models` — List available models
31
+ - `POST /v1/chat/completions` — Generate text (supports streaming, top-k, top-p)
32
 
33
+ ## Usage
34
 
35
  ```bash
36
+ # Non-streaming
37
+ curl -X POST https://your-space.hf.space/v1/chat/completions \
38
  -H "Content-Type: application/json" \
39
+ -d '{"messages": [{"role": "user", "content": "the nature of"}], "max_tokens": 200}'
 
40
 
41
+ # Streaming
42
+ curl -X POST https://your-space.hf.space/v1/chat/completions \
43
+ -H "Content-Type: application/json" \
44
+ -d '{"messages": [{"role": "user", "content": "the nature of"}], "stream": true, "temperature": 0.7, "top_k": 40}'
45
+ ```
 
 
46
 
47
  ## Architecture
48
 
49
+ - **Model**: ~10M params, 512d embed, 8 layers, 8Q/2KV heads (GQA)
50
+ - **Sequence mixing**: Grouped Query Attention + RoPE
51
+ - **Tokenizer**: BPE (2000 tokens)
52
+ - **Framework**: Flux.jl
53
+ - **Normalization**: RMSNorm (pre-norm)
54
+ - **Feed-forward**: SwiGLU activation
55
+ - **Weight tying**: Shared embedding/output projection
56
 
57
+ ## Environment Variables
58
 
59
+ - `HF_REPO` — HuggingFace model repo (default: `LisaMegaWatts/JuliaFluxGPT`)
60
+ - `PORT` — Server port (default: `7860`)
 
checkpoint.jl ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ checkpoint.jl — Load Flux model checkpoints for JuliaFluxGPT
3
+
4
+ Loads JLD2 checkpoints saved by the juliaflux_v2 training notebook.
5
+ Supports BPE tokenizer (tokenizer.json format) with character-level fallback.
6
+
7
+ NOTE: The GPT struct no longer has TiedDense — weight tying is done in the
8
+ forward pass. This simplifies checkpoint loading: we load all components
9
+ normally and skip any lm_head key in the checkpoint (it's redundant since
10
+ the output projection uses wte.weight directly).
11
+ =#
12
+
13
+ include("model.jl")
14
+ using JLD2
15
+ using JSON3
16
+
17
+ # ═══════════════════════════════════════════════════════════════════════════════
18
+ # BPE Tokenizer (loaded from tokenizer.json — HuggingFace format)
19
+ # ═══════════════════════════════════════════════════════════════════════════════
20
+
21
+ struct BPETokenizer
22
+ vocab::Dict{String, Int}
23
+ id_to_token::Dict{Int, String}
24
+ merges::Vector{Tuple{String, String}}
25
+ merge_rank::Dict{Tuple{String, String}, Int}
26
+ byte_to_unicode::Dict{UInt8, String}
27
+ unicode_to_byte::Dict{Char, UInt8}
28
+ word_cache::Dict{String, Vector{Int}}
29
+ gpt2_pattern::Regex
30
+ end
31
+
32
+ function build_byte_to_unicode()
33
+ bs = UInt8[]
34
+ cs = Char[]
35
+ for r in [0x21:0x7e, 0xa1:0xac, 0xae:0xff]
36
+ for b in r
37
+ push!(bs, b)
38
+ push!(cs, Char(b))
39
+ end
40
+ end
41
+ n = 0
42
+ for b in 0x00:0xff
43
+ if b ∉ bs
44
+ push!(bs, b)
45
+ push!(cs, Char(256 + n))
46
+ n += 1
47
+ end
48
+ end
49
+ b2u = Dict(bs[i] => string(cs[i]) for i in eachindex(bs))
50
+ u2b = Dict(v[1] => k for (k, v) in b2u)
51
+ return b2u, u2b
52
+ end
53
+
54
+ function load_bpe_tokenizer(path::String)
55
+ tok_json = JSON3.read(read(path, String))
56
+
57
+ vocab = Dict{String, Int}()
58
+ for (tok_str, id) in pairs(tok_json.model.vocab)
59
+ vocab[string(tok_str)] = Int(id) + 1 # +1 for Julia 1-indexing
60
+ end
61
+
62
+ merges = Tuple{String, String}[]
63
+ for merge_entry in tok_json.model.merges
64
+ if merge_entry isa AbstractVector && length(merge_entry) >= 2
65
+ push!(merges, (String(merge_entry[1]), String(merge_entry[2])))
66
+ else
67
+ parts = split(string(merge_entry), " ", limit=2)
68
+ if length(parts) == 2
69
+ push!(merges, (String(parts[1]), String(parts[2])))
70
+ end
71
+ end
72
+ end
73
+
74
+ id_to_token = Dict{Int, String}(id => tok for (tok, id) in vocab)
75
+ merge_rank = Dict{Tuple{String, String}, Int}(
76
+ (a, b) => i for (i, (a, b)) in enumerate(merges)
77
+ )
78
+ b2u, u2b = build_byte_to_unicode()
79
+ gpt2_pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
80
+
81
+ BPETokenizer(vocab, id_to_token, merges, merge_rank, b2u, u2b,
82
+ Dict{String, Vector{Int}}(), gpt2_pat)
83
+ end
84
+
85
+ function bpe_encode_word(tok::BPETokenizer, word::Vector{String})
86
+ tokens = copy(word)
87
+ while length(tokens) >= 2
88
+ best_rank = typemax(Int)
89
+ best_pair = ("", "")
90
+ for i in 1:length(tokens)-1
91
+ rank = get(tok.merge_rank, (tokens[i], tokens[i+1]), typemax(Int))
92
+ if rank < best_rank
93
+ best_rank = rank
94
+ best_pair = (tokens[i], tokens[i+1])
95
+ end
96
+ end
97
+ best_rank == typemax(Int) && break
98
+ a, b = best_pair
99
+ new_tokens = String[]
100
+ i = 1
101
+ while i <= length(tokens)
102
+ if i < length(tokens) && tokens[i] == a && tokens[i+1] == b
103
+ push!(new_tokens, a * b)
104
+ i += 2
105
+ else
106
+ push!(new_tokens, tokens[i])
107
+ i += 1
108
+ end
109
+ end
110
+ tokens = new_tokens
111
+ end
112
+ return tokens
113
+ end
114
+
115
+ function encode_bpe(tok::BPETokenizer, s::String)
116
+ ids = Int[]
117
+ for m in eachmatch(tok.gpt2_pattern, s)
118
+ word = m.match
119
+ cached = get(tok.word_cache, word, nothing)
120
+ if cached !== nothing
121
+ append!(ids, cached)
122
+ else
123
+ word_bytes = Vector{UInt8}(word)
124
+ chars = [tok.byte_to_unicode[b] for b in word_bytes]
125
+ tokens = bpe_encode_word(tok, chars)
126
+ word_ids = Int[]
127
+ for t in tokens
128
+ id = get(tok.vocab, t, nothing)
129
+ id !== nothing && push!(word_ids, id)
130
+ end
131
+ tok.word_cache[word] = word_ids
132
+ append!(ids, word_ids)
133
+ end
134
+ end
135
+ return ids
136
+ end
137
+
138
+ function decode_bpe(tok::BPETokenizer, ids::Vector{Int})
139
+ text = join(get(tok.id_to_token, id, "") for id in ids)
140
+ bytes = UInt8[tok.unicode_to_byte[c] for c in text if haskey(tok.unicode_to_byte, c)]
141
+ return String(bytes)
142
+ end
143
+
144
+ # ═════════════════════���═════════════════════════════════════════════════════════
145
+ # Checkpoint loading
146
+ # ═══════════════════════════════════════════════════════════════════════════════
147
+
148
+ function load_flux_checkpoint(checkpoint_path::String; tokenizer_path::String="")
149
+ println("Loading checkpoint from $checkpoint_path ...")
150
+ data = JLD2.load(checkpoint_path)
151
+
152
+ hp = data["hyperparams"]
153
+ vocab_size = Int(hp["vocab_size"])
154
+ n_embd = Int(hp["n_embd"])
155
+ block_size = Int(hp["block_size"])
156
+ n_layer = Int(hp["n_layer"])
157
+ n_head = Int(hp["n_head"])
158
+ n_kv_head = Int(get(hp, "n_kv_head", hp["n_head"]))
159
+ dropout_val = Float64(get(hp, "dropout", 0.0))
160
+
161
+ model = GPT(;
162
+ vocab_size = vocab_size,
163
+ n_embd = n_embd,
164
+ block_size = block_size,
165
+ n_layer = n_layer,
166
+ n_head = n_head,
167
+ n_kv_head = n_kv_head,
168
+ dropout = 0.0 # No dropout at inference
169
+ )
170
+
171
+ # Load weights component-by-component
172
+ ms = data["model_state"]
173
+ Flux.loadmodel!(model.wte, ms[:wte])
174
+ Flux.loadmodel!(model.drop, ms[:drop])
175
+ Flux.loadmodel!(model.blocks, ms[:blocks])
176
+ Flux.loadmodel!(model.ln_f, ms[:ln_f])
177
+
178
+ # Set to test mode (disables dropout)
179
+ Flux.testmode!(model)
180
+
181
+ step = get(data, "step", 0)
182
+ best_val = get(data, "best_val_loss", Inf)
183
+
184
+ println(" Model loaded: vocab=$vocab_size, embd=$n_embd, layers=$n_layer, " *
185
+ "heads=$(n_head)Q/$(n_kv_head)KV, block=$block_size")
186
+ println(" Step=$step, best_val=$(round(best_val, digits=4))")
187
+
188
+ # Load tokenizer
189
+ encode_fn = nothing
190
+ decode_fn = nothing
191
+
192
+ if !isempty(tokenizer_path) && isfile(tokenizer_path)
193
+ println(" Loading BPE tokenizer from $tokenizer_path")
194
+ bpe = load_bpe_tokenizer(tokenizer_path)
195
+ tok_vocab_size = length(bpe.vocab)
196
+
197
+ if tok_vocab_size != vocab_size
198
+ @warn "Vocab mismatch! Model expects vocab_size=$vocab_size but tokenizer has $tok_vocab_size tokens. " *
199
+ "Token IDs above $vocab_size will be clamped."
200
+ end
201
+
202
+ encode_fn = function(s)
203
+ ids = encode_bpe(bpe, s)
204
+ return [clamp(id, 1, vocab_size) for id in ids]
205
+ end
206
+ decode_fn = ids -> decode_bpe(bpe, ids)
207
+ println(" BPE tokenizer loaded: $(tok_vocab_size) tokens (model vocab: $vocab_size)")
208
+ else
209
+ # Character-level fallback
210
+ chars = vcat(collect('a':'z'), [' ', '.'])
211
+ stoi = Dict(c => i for (i, c) in enumerate(chars))
212
+ itos = Dict(i => c for (i, c) in enumerate(chars))
213
+ encode_fn = s -> [get(stoi, c, 1) for c in s]
214
+ decode_fn = ids -> join(get(itos, id, '?') for id in ids)
215
+ println(" No tokenizer.json found, using character-level fallback ($(length(chars)) chars)")
216
+ end
217
+
218
+ return (;
219
+ model, vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head,
220
+ step, best_val, encode_fn, decode_fn
221
+ )
222
+ end
model.jl ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ model.jl — LLaMA-style GPT model in Flux.jl for JuliaFluxGPT
3
+
4
+ Contains: RMSNorm, SwiGLU, CausalSelfAttention (GQA + RoPE),
5
+ TransformerBlock, GPT, and generation utilities.
6
+
7
+ Same architecture as juliaflux_v2.ipynb — extracted for inference serving.
8
+
9
+ NOTE: Weight tying is done by computing the output projection directly using
10
+ m.wte.weight in the forward pass. This matches the training notebooks and
11
+ ensures Flux.loadmodel! works without needing to skip lm_head.
12
+ =#
13
+
14
+ using Flux
15
+ using NNlib
16
+ using NNlib: batched_mul
17
+ using Statistics
18
+ using Random
19
+ using LinearAlgebra
20
+
21
+ # ═══════════════════════════════════════════════════════════════════════════════
22
+ # RoPE — Rotary Positional Embeddings
23
+ # ═══════════════════════════════════════════════════════════════════════════════
24
+
25
+ function precompute_rope_freqs(head_dim::Int, max_seq_len::Int; base::Float32 = 10000.0f0)
26
+ half_dim = head_dim ÷ 2
27
+ freqs = Float32[1.0f0 / (base ^ (Float32(2 * (i - 1)) / Float32(head_dim))) for i in 1:half_dim]
28
+ positions = Float32.(collect(0:max_seq_len-1))
29
+ angles = freqs * positions'
30
+ return cos.(angles), sin.(angles)
31
+ end
32
+
33
+ function apply_rope(x, cos_f, sin_f, T::Int)
34
+ d = size(x, 1) ÷ 2
35
+ x1 = x[1:d, :, :]
36
+ x2 = x[d+1:2d, :, :]
37
+ c = cos_f[:, 1:T]
38
+ s = sin_f[:, 1:T]
39
+ return vcat(x1 .* c .- x2 .* s, x1 .* s .+ x2 .* c)
40
+ end
41
+
42
+ # ═══════════════════════════════════════════════════════════════════════════════
43
+ # Model components
44
+ # ═══════════════════════════════════════════════════════════════════════════════
45
+
46
+ struct RMSNorm{W <: AbstractVector}
47
+ weight::W
48
+ eps::Float32
49
+ end
50
+
51
+ Flux.@layer RMSNorm
52
+
53
+ RMSNorm(dim::Int; eps::Float32 = 1.0f-6) = RMSNorm(ones(Float32, dim), eps)
54
+
55
+ function (rn::RMSNorm)(x)
56
+ rms = sqrt.(mean(x .^ 2, dims=1) .+ rn.eps)
57
+ return (x ./ rms) .* rn.weight
58
+ end
59
+
60
+ struct SwiGLUFFN
61
+ w_gate::Dense
62
+ w_up::Dense
63
+ w_down::Dense
64
+ drop::Dropout
65
+ end
66
+
67
+ Flux.@layer SwiGLUFFN
68
+
69
+ function SwiGLUFFN(n_embd::Int; bias=false, dropout=0.0)
70
+ raw_inner = Int(floor(4 * n_embd * 2 / 3))
71
+ inner_dim = max(64, 64 * div(raw_inner + 32, 64))
72
+ SwiGLUFFN(
73
+ Dense(n_embd => inner_dim; bias),
74
+ Dense(n_embd => inner_dim; bias),
75
+ Dense(inner_dim => n_embd; bias),
76
+ Dropout(dropout)
77
+ )
78
+ end
79
+
80
+ function (ff::SwiGLUFFN)(x)
81
+ ff.drop(ff.w_down(NNlib.swish(ff.w_gate(x)) .* ff.w_up(x)))
82
+ end
83
+
84
+ struct CausalSelfAttention
85
+ wq::Dense
86
+ wkv::Dense
87
+ proj::Dense
88
+ n_head::Int
89
+ n_kv_head::Int
90
+ end
91
+
92
+ Flux.@layer CausalSelfAttention trainable=(wq, wkv, proj)
93
+
94
+ function CausalSelfAttention(n_embd::Int, n_head::Int, n_kv_head::Int; bias=false)
95
+ head_dim = n_embd ÷ n_head
96
+ kv_dim = head_dim * n_kv_head
97
+ CausalSelfAttention(
98
+ Dense(n_embd => n_embd; bias),
99
+ Dense(n_embd => 2 * kv_dim; bias),
100
+ Dense(n_embd => n_embd; bias),
101
+ n_head,
102
+ n_kv_head
103
+ )
104
+ end
105
+
106
+ function (attn::CausalSelfAttention)(x, causal_mask, rope_cos, rope_sin)
107
+ C, T, B = size(x)
108
+ nh = attn.n_head
109
+ nkv = attn.n_kv_head
110
+ hs = C ÷ nh
111
+ kv_dim = hs * nkv
112
+ groups = nh ÷ nkv
113
+
114
+ q = attn.wq(x)
115
+ kv = attn.wkv(x)
116
+ k = kv[1:kv_dim, :, :]
117
+ v = kv[kv_dim+1:2*kv_dim, :, :]
118
+
119
+ q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
120
+ k = reshape(permutedims(reshape(k, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
121
+ v = reshape(permutedims(reshape(v, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
122
+
123
+ q = apply_rope(q, rope_cos, rope_sin, T)
124
+ k = apply_rope(k, rope_cos, rope_sin, T)
125
+
126
+ if groups > 1
127
+ k_4d = reshape(k, hs, T, nkv, B)
128
+ v_4d = reshape(v, hs, T, nkv, B)
129
+ k_rep = repeat(reshape(k_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
130
+ v_rep = repeat(reshape(v_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
131
+ k = reshape(permutedims(k_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
132
+ v = reshape(permutedims(v_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
133
+ end
134
+
135
+ scale = Float32(1 / sqrt(hs))
136
+ wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale
137
+ wei = wei .+ causal_mask[1:T, 1:T]
138
+ wei = softmax(wei; dims=2)
139
+
140
+ out = batched_mul(v, permutedims(wei, (2, 1, 3)))
141
+ out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)
142
+
143
+ attn.proj(out)
144
+ end
145
+
146
+ struct TransformerBlock
147
+ ln1::RMSNorm
148
+ attn::CausalSelfAttention
149
+ ln2::RMSNorm
150
+ ffwd::SwiGLUFFN
151
+ end
152
+
153
+ Flux.@layer TransformerBlock
154
+
155
+ function TransformerBlock(n_embd::Int, n_head::Int, n_kv_head::Int; dropout=0.0)
156
+ TransformerBlock(
157
+ RMSNorm(n_embd),
158
+ CausalSelfAttention(n_embd, n_head, n_kv_head),
159
+ RMSNorm(n_embd),
160
+ SwiGLUFFN(n_embd; dropout)
161
+ )
162
+ end
163
+
164
+ # ═══════════════════════════════════════════════════════════════════════════════
165
+ # GPT — weight-tied output projection (matches training notebooks)
166
+ # ═══════════════════════════════════════════════════════════════════════════════
167
+
168
+ struct GPT
169
+ wte::Embedding
170
+ drop::Dropout
171
+ blocks::Chain
172
+ ln_f::RMSNorm
173
+ # Precomputed constants (not trainable)
174
+ causal_mask::Matrix{Float32}
175
+ rope_cos::Matrix{Float32}
176
+ rope_sin::Matrix{Float32}
177
+ n_head::Int
178
+ n_kv_head::Int
179
+ block_size::Int
180
+ end
181
+
182
+ Flux.@layer GPT trainable=(wte, drop, blocks, ln_f)
183
+
184
+ function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, dropout=0.0)
185
+ head_dim = n_embd ÷ n_head
186
+ wte = Embedding(vocab_size => n_embd)
187
+ causal_mask = triu(fill(typemin(Float32), block_size, block_size), 1)
188
+ rope_cos, rope_sin = precompute_rope_freqs(head_dim, block_size)
189
+ GPT(
190
+ wte,
191
+ Dropout(dropout),
192
+ Chain([TransformerBlock(n_embd, n_head, n_kv_head; dropout) for _ in 1:n_layer]...),
193
+ RMSNorm(n_embd),
194
+ causal_mask,
195
+ rope_cos,
196
+ rope_sin,
197
+ n_head,
198
+ n_kv_head,
199
+ block_size
200
+ )
201
+ end
202
+
203
+ function (m::GPT)(idx)
204
+ B, T = size(idx)
205
+ tok = permutedims(m.wte(idx), (1, 3, 2)) # (C, T, B)
206
+ x = m.drop(tok)
207
+ for block in m.blocks
208
+ x = x .+ block.attn(block.ln1(x), m.causal_mask, m.rope_cos, m.rope_sin)
209
+ x = x .+ block.ffwd(block.ln2(x))
210
+ end
211
+ x = m.ln_f(x)
212
+ # Weight-tied output projection — same weight as embedding
213
+ W = m.wte.weight
214
+ C = size(x, 1)
215
+ x_flat = reshape(x, C, T * B)
216
+ out = W' * x_flat
217
+ reshape(out, size(W, 2), T, B)
218
+ end
219
+
220
+ # ═══════════════════════════════════════════════════════════════════════════════
221
+ # Text generation with streaming support
222
+ # ═══════════════════════════════════════════════════════════════════════════════
223
+
224
+ function generate_streaming(model, encode_fn, decode_fn, vocab_size::Int, block_size::Int;
225
+ prompt::String="", max_tokens::Int=200, temperature::Float64=0.8,
226
+ top_k::Int=40, top_p::Float64=1.0, on_token=nothing)
227
+ if !isempty(prompt)
228
+ prompt_ids = encode_fn(prompt)
229
+ idx = reshape(prompt_ids, 1, :)
230
+ else
231
+ idx = reshape([rand(1:vocab_size)], 1, 1)
232
+ end
233
+
234
+ generated_ids = Int[]
235
+
236
+ for _ in 1:max_tokens
237
+ idx_cond = idx[:, max(1, end-block_size+1):end]
238
+ logits = model(idx_cond)
239
+ logits_last = Vector{Float32}(logits[:, end, 1])
240
+
241
+ # Temperature scaling
242
+ logits_last ./= Float32(max(temperature, 0.01))
243
+
244
+ # Top-k filtering
245
+ if top_k > 0 && top_k < length(logits_last)
246
+ threshold = partialsort(logits_last, top_k; rev=true)
247
+ for i in eachindex(logits_last)
248
+ if logits_last[i] < threshold
249
+ logits_last[i] = -Inf32
250
+ end
251
+ end
252
+ end
253
+
254
+ # Top-p (nucleus) filtering
255
+ if top_p < 1.0
256
+ sorted_indices = sortperm(logits_last; rev=true)
257
+ sorted_logits = logits_last[sorted_indices]
258
+ probs_sorted = NNlib.softmax(sorted_logits)
259
+ cumprobs = cumsum(Array(probs_sorted))
260
+ cutoff = something(findfirst(>=(Float32(top_p)), cumprobs), length(probs_sorted))
261
+ for i in (cutoff+1):length(sorted_indices)
262
+ logits_last[sorted_indices[i]] = -Inf32
263
+ end
264
+ end
265
+
266
+ probs = NNlib.softmax(logits_last)
267
+ probs_cpu = Float64.(probs)
268
+
269
+ r = rand()
270
+ cum = 0.0
271
+ next_id = length(probs_cpu)
272
+ for (i, p) in enumerate(probs_cpu)
273
+ cum += p
274
+ if r <= cum
275
+ next_id = i
276
+ break
277
+ end
278
+ end
279
+
280
+ push!(generated_ids, next_id)
281
+ idx = hcat(idx, reshape([next_id], 1, 1))
282
+
283
+ if on_token !== nothing
284
+ token_str = decode_fn([next_id])
285
+ on_token(token_str)
286
+ end
287
+ end
288
+
289
+ return decode_fn(generated_ids)
290
+ end
requirements.txt DELETED
@@ -1,7 +0,0 @@
1
- fastapi>=0.110.0
2
- uvicorn>=0.29.0
3
- torch>=2.0.0
4
- h5py>=3.10.0
5
- huggingface_hub>=0.20.0
6
- pydantic>=2.0.0
7
- tokenizers>=0.15.0
 
 
 
 
 
 
 
 
server.jl ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ server.jl — OpenAI-compatible inference server for JuliaFluxGPT
3
+
4
+ Serves a Flux.jl trained LLaMA-style GPT model (RoPE, GQA, RMSNorm, SwiGLU).
5
+ Downloads checkpoint and tokenizer from HuggingFace model repo on first run.
6
+
7
+ Endpoints:
8
+ GET / -> health check / API info
9
+ GET /v1/models -> list available models
10
+ POST /v1/chat/completions -> generate text (OpenAI format, streaming supported)
11
+ =#
12
+
13
+ include("checkpoint.jl")
14
+ using HTTP
15
+ using UUIDs
16
+ using Downloads
17
+
18
+ # ═══════════════════════════════════════════════════════════════════
19
+ # Download artifacts from HuggingFace
20
+ # ═══════════════════════════════════════════════════════════════════
21
+
22
+ const CKPT_DIR = "checkpoints"
23
+ const CKPT_PATH = joinpath(CKPT_DIR, "best_model.jld2")
24
+ const TOKENIZER_PATH = joinpath(CKPT_DIR, "tokenizer.json")
25
+ const HF_REPO = get(ENV, "HF_REPO", "LisaMegaWatts/JuliaFluxGPT")
26
+ const PORT = parse(Int, get(ENV, "PORT", "7860"))
27
+
28
+ function download_from_hf(repo::String, filename::String, local_path::String)
29
+ url = "https://huggingface.co/$repo/resolve/main/$filename"
30
+ println("Downloading $url ...")
31
+ mkpath(dirname(local_path))
32
+ Downloads.download(url, local_path)
33
+ sz = round(filesize(local_path) / 1024^2, digits=1)
34
+ println(" -> $local_path ($sz MB)")
35
+ end
36
+
37
+ function ensure_artifacts()
38
+ for (localpath, remote) in [(CKPT_PATH, "best_model.jld2"),
39
+ (TOKENIZER_PATH, "tokenizer.json")]
40
+ if !isfile(localpath)
41
+ println("No local $remote found, downloading from $HF_REPO ...")
42
+ try
43
+ download_from_hf(HF_REPO, remote, localpath)
44
+ catch e
45
+ println("Download failed for $remote: $e")
46
+ println("Place $remote at $localpath manually.")
47
+ exit(1)
48
+ end
49
+ end
50
+ end
51
+ end
52
+
53
+ # ═══════════════════════════════════════════════════════════════════
54
+ # Download and load model
55
+ # ═══════════════════════════════════════════════════════════════════
56
+
57
+ ensure_artifacts()
58
+
59
+ println("\nLoading model...")
60
+ const CKPT = load_flux_checkpoint(CKPT_PATH; tokenizer_path=TOKENIZER_PATH)
61
+ const MODEL = CKPT.model
62
+ const VOCAB_SIZE = CKPT.vocab_size
63
+ const BLOCK_SIZE = CKPT.block_size
64
+ const ENCODE_FN = CKPT.encode_fn
65
+ const DECODE_FN = CKPT.decode_fn
66
+ const MODEL_CREATED_AT = Int(floor(time()))
67
+
68
+ println("\nModel ready: vocab=$(VOCAB_SIZE), embd=$(CKPT.n_embd), " *
69
+ "layers=$(CKPT.n_layer), heads=$(CKPT.n_head)Q/$(CKPT.n_kv_head)KV, " *
70
+ "block=$(BLOCK_SIZE)")
71
+
72
+ # ═══════════════════════════════════════════════════════════════════
73
+ # HTTP helpers
74
+ # ═══════════════════════════════════════════════════════════════════
75
+
76
+ const CORS_HEADERS = [
77
+ "Access-Control-Allow-Origin" => "*",
78
+ "Access-Control-Allow-Methods" => "GET, POST, OPTIONS",
79
+ "Access-Control-Allow-Headers" => "Content-Type, Authorization",
80
+ ]
81
+
82
+ function json_response(status::Int, body; extra_headers=[])
83
+ json_bytes = JSON3.write(body)
84
+ headers = [
85
+ "Content-Type" => "application/json",
86
+ CORS_HEADERS...,
87
+ extra_headers...
88
+ ]
89
+ return HTTP.Response(status, headers, json_bytes)
90
+ end
91
+
92
+ function cors_preflight()
93
+ return HTTP.Response(204, CORS_HEADERS)
94
+ end
95
+
96
+ # ═══════════════════════════════════════════════════════════════════
97
+ # Extract prompt from OpenAI chat messages
98
+ # ═══════════════════════════════════════════════════════════════════
99
+
100
+ function extract_prompt(messages)
101
+ if isempty(messages)
102
+ return ""
103
+ end
104
+ for i in length(messages):-1:1
105
+ role = string(get(messages[i], :role, ""))
106
+ if role == "user"
107
+ return string(get(messages[i], :content, ""))
108
+ end
109
+ end
110
+ return string(get(messages[end], :content, ""))
111
+ end
112
+
113
+ # ═════════════════════════════════════════════════════���═════════════
114
+ # SSE helpers
115
+ # ═══════════════════════════════════════════════════════════════════
116
+
117
+ function sse_line(data)
118
+ return "data: $(JSON3.write(data))\n\n"
119
+ end
120
+
121
+ # ═══════════════════════════════════════════════════════════════════
122
+ # Request handler
123
+ # ═══════════════════════════════════════════════════════════════════
124
+
125
+ function handle_request(request::HTTP.Request)
126
+ method = request.method
127
+ target = request.target
128
+
129
+ # CORS preflight
130
+ if method == "OPTIONS"
131
+ return cors_preflight()
132
+ end
133
+
134
+ # GET / — health check and model info
135
+ if method == "GET" && target == "/"
136
+ return json_response(200, Dict(
137
+ "name" => "JuliaFluxGPT",
138
+ "version" => "1.0.0",
139
+ "description" => "LLaMA-style GPT in Flux.jl — trained on philosophy and mathematics",
140
+ "architecture" => "RoPE + SwiGLU + GQA + RMSNorm + weight tying",
141
+ "model" => Dict(
142
+ "vocab_size" => VOCAB_SIZE,
143
+ "n_embd" => CKPT.n_embd,
144
+ "n_layer" => CKPT.n_layer,
145
+ "n_head" => CKPT.n_head,
146
+ "n_kv_head" => CKPT.n_kv_head,
147
+ "block_size" => BLOCK_SIZE
148
+ ),
149
+ "endpoints" => ["/v1/models", "/v1/chat/completions"],
150
+ "features" => ["streaming", "OpenAI-compatible", "top-k", "top-p"],
151
+ "compatible_with" => ["OpenAI API", "OpenRouter"]
152
+ ))
153
+ end
154
+
155
+ # GET /v1/models — list available models
156
+ if method == "GET" && target == "/v1/models"
157
+ return json_response(200, Dict(
158
+ "object" => "list",
159
+ "data" => [Dict(
160
+ "id" => "juliafluxgpt-philosophy",
161
+ "object" => "model",
162
+ "created" => MODEL_CREATED_AT,
163
+ "owned_by" => "juliafluxgpt"
164
+ )]
165
+ ))
166
+ end
167
+
168
+ # POST /v1/chat/completions — generate text
169
+ if method == "POST" && target == "/v1/chat/completions"
170
+ local body
171
+ try
172
+ body = JSON3.read(String(request.body))
173
+ catch e
174
+ return json_response(400, Dict("error" => Dict(
175
+ "message" => "Invalid JSON in request body",
176
+ "type" => "invalid_request_error",
177
+ "code" => "invalid_json")))
178
+ end
179
+
180
+ temperature = Float64(clamp(get(body, :temperature, 0.8), 0.01, 2.0))
181
+ max_tokens = Int(clamp(get(body, :max_tokens, 200), 1, BLOCK_SIZE))
182
+ top_k_val = Int(clamp(get(body, :top_k, 40), 0, VOCAB_SIZE))
183
+ top_p_val = Float64(clamp(get(body, :top_p, 1.0), 0.0, 1.0))
184
+ stream = Bool(get(body, :stream, false))
185
+
186
+ messages = get(body, :messages, [])
187
+ prompt_text = extract_prompt(messages)
188
+
189
+ if stream
190
+ # ── SSE streaming response (buffered) ──
191
+ completion_id = "chatcmpl-" * string(uuid4())
192
+ created = Int(floor(time()))
193
+
194
+ buf = IOBuffer()
195
+
196
+ # Initial chunk with role
197
+ initial_chunk = Dict(
198
+ "id" => completion_id,
199
+ "object" => "chat.completion.chunk",
200
+ "created" => created,
201
+ "model" => "juliafluxgpt-philosophy",
202
+ "choices" => [Dict(
203
+ "index" => 0,
204
+ "delta" => Dict("role" => "assistant", "content" => ""),
205
+ "finish_reason" => nothing
206
+ )]
207
+ )
208
+ write(buf, sse_line(initial_chunk))
209
+
210
+ token_count = Ref(0)
211
+
212
+ generate_streaming(MODEL, ENCODE_FN, DECODE_FN, VOCAB_SIZE, BLOCK_SIZE;
213
+ prompt=prompt_text, max_tokens=max_tokens,
214
+ temperature=temperature, top_k=top_k_val, top_p=top_p_val,
215
+ on_token = function(token_str)
216
+ token_count[] += 1
217
+ chunk = Dict(
218
+ "id" => completion_id,
219
+ "object" => "chat.completion.chunk",
220
+ "created" => created,
221
+ "model" => "juliafluxgpt-philosophy",
222
+ "choices" => [Dict(
223
+ "index" => 0,
224
+ "delta" => Dict("content" => token_str),
225
+ "finish_reason" => nothing
226
+ )]
227
+ )
228
+ write(buf, sse_line(chunk))
229
+ end)
230
+
231
+ # Final chunk with finish_reason
232
+ prompt_tokens = length(ENCODE_FN(prompt_text))
233
+ finish_chunk = Dict(
234
+ "id" => completion_id,
235
+ "object" => "chat.completion.chunk",
236
+ "created" => created,
237
+ "model" => "juliafluxgpt-philosophy",
238
+ "choices" => [Dict(
239
+ "index" => 0,
240
+ "delta" => Dict(),
241
+ "finish_reason" => token_count[] >= max_tokens ? "length" : "stop"
242
+ )],
243
+ "usage" => Dict(
244
+ "prompt_tokens" => prompt_tokens,
245
+ "completion_tokens" => token_count[],
246
+ "total_tokens" => prompt_tokens + token_count[]
247
+ )
248
+ )
249
+ write(buf, sse_line(finish_chunk))
250
+ write(buf, "data: [DONE]\n\n")
251
+
252
+ sse_body = take!(buf)
253
+ headers = [
254
+ "Content-Type" => "text/event-stream",
255
+ "Cache-Control" => "no-cache",
256
+ "X-Accel-Buffering" => "no",
257
+ CORS_HEADERS...
258
+ ]
259
+ return HTTP.Response(200, headers, sse_body)
260
+
261
+ else
262
+ # ── Standard (non-streaming) response ──
263
+ n_completions = Int(clamp(get(body, :n, 1), 1, 4))
264
+
265
+ choices = []
266
+ total_completion_tokens = 0
267
+ for i in 1:n_completions
268
+ text = generate_streaming(MODEL, ENCODE_FN, DECODE_FN, VOCAB_SIZE, BLOCK_SIZE;
269
+ prompt=prompt_text, max_tokens=max_tokens,
270
+ temperature=temperature, top_k=top_k_val, top_p=top_p_val)
271
+ finish_reason = length(text) >= max_tokens ? "length" : "stop"
272
+ push!(choices, Dict(
273
+ "index" => i - 1,
274
+ "message" => Dict("role" => "assistant", "content" => text),
275
+ "finish_reason" => finish_reason))
276
+ total_completion_tokens += length(text)
277
+ end
278
+
279
+ prompt_tokens = length(ENCODE_FN(prompt_text))
280
+ return json_response(200, Dict(
281
+ "id" => "chatcmpl-" * string(uuid4()),
282
+ "object" => "chat.completion",
283
+ "created" => Int(floor(time())),
284
+ "model" => "juliafluxgpt-philosophy",
285
+ "choices" => choices,
286
+ "usage" => Dict(
287
+ "prompt_tokens" => prompt_tokens,
288
+ "completion_tokens" => total_completion_tokens,
289
+ "total_tokens" => prompt_tokens + total_completion_tokens),
290
+ "system_fingerprint" => "juliafluxgpt-flux-v1"))
291
+ end
292
+ end
293
+
294
+ # 404 fallback
295
+ return json_response(404, Dict("error" => Dict(
296
+ "message" => "Not found: $method $target",
297
+ "type" => "invalid_request_error",
298
+ "code" => "not_found")))
299
+ end
300
+
301
+ # ═══════════════════════════════════════════════════════════════════
302
+ # Start server
303
+ # ═══════════════════════════════════════════════════════════════════
304
+
305
+ println("\nJuliaFluxGPT server starting on 0.0.0.0:$PORT ...")
306
+ println(" GET http://localhost:$PORT/")
307
+ println(" GET http://localhost:$PORT/v1/models")
308
+ println(" POST http://localhost:$PORT/v1/chat/completions")
309
+ println(" POST http://localhost:$PORT/v1/chat/completions (stream=true)")
310
+ println()
311
+
312
+ HTTP.serve(handle_request, "0.0.0.0", PORT)
server.py DELETED
@@ -1,708 +0,0 @@
1
- """
2
- server.py — OpenAI-compatible FastAPI inference server for JuliaFluxGPT
3
-
4
- Endpoints:
5
- GET / -> model info
6
- GET /v1/models -> list available models
7
- POST /v1/chat/completions -> generate text (streaming via SSE)
8
-
9
- Weights are loaded from HuggingFace Hub at startup:
10
- repo: LisaMegaWatts/JuliaFluxGPT
11
- files: best_model.jld2, tokenizer.json
12
-
13
- Architecture: LLaMA-style GPT
14
- - RMSNorm (weight only, no bias)
15
- - RoPE (Rotary Positional Embeddings, base=10000)
16
- - GQA (Grouped Query Attention, 8 query heads / 2 KV heads)
17
- - SwiGLU FFN
18
- - Weight-tied output projection (lm_head shares wte weights)
19
- """
20
-
21
- from __future__ import annotations
22
-
23
- import json
24
- import math
25
- import os
26
- import time
27
- import uuid
28
- from typing import List, Optional
29
-
30
- import h5py
31
- import numpy as np
32
- import torch
33
- import torch.nn as nn
34
- import torch.nn.functional as F
35
- import uvicorn
36
- from fastapi import FastAPI, HTTPException, Request
37
- from fastapi.middleware.cors import CORSMiddleware
38
- from fastapi.responses import JSONResponse, StreamingResponse
39
- from fastapi.exceptions import RequestValidationError
40
- from huggingface_hub import hf_hub_download
41
- from pydantic import BaseModel
42
- from tokenizers import Tokenizer
43
-
44
- # ---------------------------------------------------------------------------
45
- # Hyperparameters (must match training checkpoint)
46
- # ---------------------------------------------------------------------------
47
-
48
- VOCAB_SIZE = 2000
49
- N_EMBD = 512
50
- N_HEAD = 8
51
- N_KV_HEAD = 2
52
- N_LAYER = 8
53
- BLOCK_SIZE = 256
54
- ROPE_BASE = 10000.0
55
- RMS_EPS = 1e-6
56
-
57
- MODEL_ID = "juliafluxgpt-philosophy"
58
- HF_REPO = "LisaMegaWatts/JuliaFluxGPT"
59
- HF_WEIGHTS = "best_model.jld2"
60
- HF_TOKENIZER = "tokenizer.json"
61
-
62
- DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
63
-
64
- # ---------------------------------------------------------------------------
65
- # RoPE helpers
66
- # ---------------------------------------------------------------------------
67
-
68
- def precompute_rope(head_dim: int, max_seq_len: int, base: float = 10000.0):
69
- """
70
- Returns (cos, sin) each of shape (max_seq_len, head_dim // 2).
71
- Sliced to actual sequence length in apply_rope.
72
- """
73
- half = head_dim // 2
74
- freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
75
- positions = torch.arange(max_seq_len).float()
76
- angles = positions.unsqueeze(1) * freqs.unsqueeze(0) # (T, half)
77
- return torch.cos(angles), torch.sin(angles) # (T, half)
78
-
79
-
80
- def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
81
- """
82
- x : (B, n_head, T, head_dim)
83
- cos : (T, head_dim // 2)
84
- sin : (T, head_dim // 2)
85
- """
86
- T = x.shape[2]
87
- cos = cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, half)
88
- sin = sin[:T].unsqueeze(0).unsqueeze(0)
89
- d = x.shape[-1] // 2
90
- x1, x2 = x[..., :d], x[..., d:]
91
- return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
92
-
93
-
94
- # ---------------------------------------------------------------------------
95
- # Model components
96
- # ---------------------------------------------------------------------------
97
-
98
- class RMSNorm(nn.Module):
99
- def __init__(self, dim: int, eps: float = 1e-6):
100
- super().__init__()
101
- self.eps = eps
102
- self.weight = nn.Parameter(torch.ones(dim))
103
-
104
- def forward(self, x: torch.Tensor) -> torch.Tensor:
105
- # x: (B, T, C)
106
- rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
107
- return x / rms * self.weight
108
-
109
-
110
- class GQAttention(nn.Module):
111
- """
112
- Grouped Query Attention.
113
- wq : (n_embd, n_embd) — query projection
114
- wkv : (n_embd, 2 * kv_dim) — combined K+V projection
115
- proj: (n_embd, n_embd) — output projection
116
- """
117
-
118
- def __init__(self, n_embd: int, n_head: int, n_kv_head: int):
119
- super().__init__()
120
- assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
121
- assert n_head % n_kv_head == 0, "n_head must be divisible by n_kv_head"
122
- self.n_head = n_head
123
- self.n_kv_head = n_kv_head
124
- self.head_dim = n_embd // n_head
125
- kv_dim = self.head_dim * n_kv_head
126
-
127
- self.wq = nn.Linear(n_embd, n_embd, bias=False)
128
- self.wkv = nn.Linear(n_embd, 2 * kv_dim, bias=False)
129
- self.proj = nn.Linear(n_embd, n_embd, bias=False)
130
-
131
- def forward(
132
- self,
133
- x: torch.Tensor,
134
- rope_cos: torch.Tensor,
135
- rope_sin: torch.Tensor,
136
- ) -> torch.Tensor:
137
- B, T, C = x.shape
138
- nh, nkv, hd = self.n_head, self.n_kv_head, self.head_dim
139
- groups = nh // nkv
140
-
141
- # Project
142
- q = self.wq(x) # (B, T, n_embd)
143
- kv = self.wkv(x) # (B, T, 2*kv_dim)
144
- k, v = kv.split(hd * nkv, dim=-1) # each (B, T, kv_dim)
145
-
146
- # Reshape to (B, heads, T, head_dim)
147
- q = q.view(B, T, nh, hd).transpose(1, 2) # (B, nh, T, hd)
148
- k = k.view(B, T, nkv, hd).transpose(1, 2) # (B, nkv, T, hd)
149
- v = v.view(B, T, nkv, hd).transpose(1, 2) # (B, nkv, T, hd)
150
-
151
- # Apply RoPE to queries and keys
152
- q = apply_rope(q, rope_cos, rope_sin)
153
- k = apply_rope(k, rope_cos, rope_sin)
154
-
155
- # Expand KV heads to match query heads (GQA)
156
- if groups > 1:
157
- k = k.repeat_interleave(groups, dim=1) # (B, nh, T, hd)
158
- v = v.repeat_interleave(groups, dim=1)
159
-
160
- # Scaled dot-product attention with causal mask
161
- scale = math.sqrt(hd)
162
- attn = torch.matmul(q, k.transpose(-2, -1)) / scale # (B, nh, T, T)
163
-
164
- # Causal mask: upper triangle = -inf
165
- causal = torch.triu(
166
- torch.full((T, T), float("-inf"), device=x.device, dtype=x.dtype),
167
- diagonal=1,
168
- )
169
- attn = attn + causal
170
- attn = F.softmax(attn, dim=-1)
171
-
172
- # Weighted sum and reshape
173
- out = torch.matmul(attn, v) # (B, nh, T, hd)
174
- out = out.transpose(1, 2).contiguous().view(B, T, C)
175
-
176
- return self.proj(out)
177
-
178
-
179
- class SwiGLUFFN(nn.Module):
180
- """
181
- SwiGLU feed-forward network.
182
- forward: w_down(swish(w_gate(x)) * w_up(x))
183
- """
184
-
185
- def __init__(self, n_embd: int, inner_dim: int):
186
- super().__init__()
187
- self.w_gate = nn.Linear(n_embd, inner_dim, bias=False)
188
- self.w_up = nn.Linear(n_embd, inner_dim, bias=False)
189
- self.w_down = nn.Linear(inner_dim, n_embd, bias=False)
190
-
191
- def forward(self, x: torch.Tensor) -> torch.Tensor:
192
- return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
193
-
194
-
195
- class Block(nn.Module):
196
- def __init__(self, n_embd: int, n_head: int, n_kv_head: int, inner_dim: int):
197
- super().__init__()
198
- self.ln1 = RMSNorm(n_embd)
199
- self.attn = GQAttention(n_embd, n_head, n_kv_head)
200
- self.ln2 = RMSNorm(n_embd)
201
- self.ffwd = SwiGLUFFN(n_embd, inner_dim)
202
-
203
- def forward(
204
- self,
205
- x: torch.Tensor,
206
- rope_cos: torch.Tensor,
207
- rope_sin: torch.Tensor,
208
- ) -> torch.Tensor:
209
- x = x + self.attn(self.ln1(x), rope_cos, rope_sin)
210
- x = x + self.ffwd(self.ln2(x))
211
- return x
212
-
213
-
214
- class GPT(nn.Module):
215
- def __init__(
216
- self,
217
- vocab_size: int,
218
- n_embd: int,
219
- n_head: int,
220
- n_kv_head: int,
221
- n_layer: int,
222
- block_size: int,
223
- inner_dim: int,
224
- rope_base: float = 10000.0,
225
- ):
226
- super().__init__()
227
- self.block_size = block_size
228
- head_dim = n_embd // n_head
229
-
230
- self.wte = nn.Embedding(vocab_size, n_embd)
231
- self.blocks = nn.ModuleList(
232
- [Block(n_embd, n_head, n_kv_head, inner_dim) for _ in range(n_layer)]
233
- )
234
- self.ln_f = RMSNorm(n_embd)
235
-
236
- # lm_head shares weights with wte (weight tying)
237
- self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
238
- self.lm_head.weight = self.wte.weight
239
-
240
- # Precompute RoPE frequencies — registered as buffers (not parameters)
241
- rope_cos, rope_sin = precompute_rope(head_dim, block_size, rope_base)
242
- self.register_buffer("rope_cos", rope_cos) # (block_size, head_dim//2)
243
- self.register_buffer("rope_sin", rope_sin)
244
-
245
- def forward(self, idx: torch.Tensor) -> torch.Tensor:
246
- """
247
- idx: (B, T) of token ids
248
- Returns logits: (B, T, vocab_size)
249
- """
250
- B, T = idx.shape
251
- assert T <= self.block_size, f"Sequence length {T} > block_size {self.block_size}"
252
-
253
- x = self.wte(idx) # (B, T, n_embd)
254
- cos = self.rope_cos[:T] # (T, head_dim//2)
255
- sin = self.rope_sin[:T]
256
-
257
- for block in self.blocks:
258
- x = block(x, cos, sin)
259
-
260
- x = self.ln_f(x)
261
- return self.lm_head(x) # (B, T, vocab_size)
262
-
263
-
264
- # ---------------------------------------------------------------------------
265
- # JLD2 weight loader
266
- # ---------------------------------------------------------------------------
267
-
268
- def _deref(f: h5py.File, ref):
269
- """Dereference an HDF5 object reference."""
270
- obj = f[ref]
271
- return obj[()] if isinstance(obj, h5py.Dataset) else obj
272
-
273
-
274
- def _get_weight(f: h5py.File, struct, *path):
275
- """
276
- Walk a numpy.void struct following *path, dereferencing HDF5 references
277
- at each step, and return the final value as a numpy array.
278
- """
279
- val = struct
280
- for p in path:
281
- val = val[p]
282
- if isinstance(val, h5py.h5r.Reference):
283
- val = _deref(f, val)
284
- if isinstance(val, np.ndarray):
285
- return val
286
- return np.array(f[val])
287
-
288
-
289
- def load_weights_from_jld2(path: str, model: GPT) -> None:
290
- """
291
- Read weights from a JLD2 (HDF5) file produced by Julia's Flux.jl and
292
- copy them into the PyTorch GPT model.
293
-
294
- Julia is column-major. h5py reads in row-major order, so:
295
- - Embedding (2-D, vocab x embd): h5py gives (vocab, embd) -> use as-is
296
- - Dense/Linear (2-D, in x out): h5py gives (in, out) -> transpose to (out, in)
297
- - 1-D vectors (RMSNorm weight): no transpose needed
298
- """
299
- print(f"Loading JLD2 weights from {path} ...")
300
- with h5py.File(path, "r") as f:
301
-
302
- ms = f["model_state"][()]
303
-
304
- # ── top-level embedding ──────────────────────────────────────────────
305
- wte_w = _get_weight(f, ms, "wte", "weight") # h5py: (vocab, embd)
306
- # No transpose: Julia Embedding stores (embd, vocab) internally,
307
- # but HDF5 row-major flip already gives us (vocab, embd) which is
308
- # what PyTorch Embedding expects.
309
- model.wte.weight.data.copy_(
310
- torch.from_numpy(wte_w.copy()).float()
311
- )
312
-
313
- # ── final layer norm ─────────────────────────────────────────────────
314
- ln_f_w = _get_weight(f, ms, "ln_f", "weight") # (embd,)
315
- model.ln_f.weight.data.copy_(
316
- torch.from_numpy(ln_f_w.copy()).float()
317
- )
318
-
319
- # ── transformer blocks ───────────────────────────────────────────────
320
- blocks_ref = ms["blocks"]
321
- if isinstance(blocks_ref, h5py.h5r.Reference):
322
- blocks_ref = _deref(f, blocks_ref)
323
- layers_ref = blocks_ref["layers"]
324
- if isinstance(layers_ref, h5py.h5r.Reference):
325
- layers_ref = _deref(f, layers_ref)
326
-
327
- for layer_idx, block in enumerate(model.blocks):
328
- # Julia layers are 1-indexed
329
- jl_key = str(layer_idx + 1)
330
- l = layers_ref[jl_key]
331
-
332
- def gw(*path):
333
- return _get_weight(f, l, *path)
334
-
335
- # Attention weights — h5py gives (in, out), need (out, in)
336
- wq_np = gw("attn", "wq", "weight") # (512, 512)
337
- wkv_np = gw("attn", "wkv", "weight") # (512, 256)
338
- proj_np = gw("attn", "proj", "weight") # (512, 512)
339
-
340
- block.attn.wq.weight.data.copy_(
341
- torch.from_numpy(wq_np.T.copy()).float()
342
- )
343
- block.attn.wkv.weight.data.copy_(
344
- torch.from_numpy(wkv_np.T.copy()).float()
345
- )
346
- block.attn.proj.weight.data.copy_(
347
- torch.from_numpy(proj_np.T.copy()).float()
348
- )
349
-
350
- # FFN weights — h5py gives (in, out), need (out, in)
351
- w_gate_np = gw("ffwd", "w_gate", "weight") # (512, 1344)
352
- w_up_np = gw("ffwd", "w_up", "weight") # (512, 1344)
353
- w_down_np = gw("ffwd", "w_down", "weight") # (1344, 512)
354
-
355
- block.ffwd.w_gate.weight.data.copy_(
356
- torch.from_numpy(w_gate_np.T.copy()).float()
357
- )
358
- block.ffwd.w_up.weight.data.copy_(
359
- torch.from_numpy(w_up_np.T.copy()).float()
360
- )
361
- block.ffwd.w_down.weight.data.copy_(
362
- torch.from_numpy(w_down_np.T.copy()).float()
363
- )
364
-
365
- # Layer norms — 1-D, no transpose
366
- ln1_np = gw("ln1", "weight") # (512,)
367
- ln2_np = gw("ln2", "weight") # (512,)
368
- block.ln1.weight.data.copy_(
369
- torch.from_numpy(ln1_np.copy()).float()
370
- )
371
- block.ln2.weight.data.copy_(
372
- torch.from_numpy(ln2_np.copy()).float()
373
- )
374
-
375
- # Weight tying: lm_head must share wte's storage
376
- model.lm_head.weight = model.wte.weight
377
- print("Weights loaded successfully.")
378
-
379
-
380
- # ---------------------------------------------------------------------------
381
- # Sampling helpers
382
- # ---------------------------------------------------------------------------
383
-
384
- @torch.inference_mode()
385
- def _sample_next_token(
386
- logits: torch.Tensor, # (vocab_size,) on CPU
387
- temperature: float,
388
- top_k: int,
389
- seen_ids: list[int],
390
- repetition_penalty: float,
391
- ) -> int:
392
- """
393
- Apply repetition penalty, temperature scaling, top-k filtering, then sample.
394
- """
395
- logits = logits.clone().float()
396
-
397
- # Repetition penalty
398
- if repetition_penalty != 1.0 and seen_ids:
399
- for tok_id in set(seen_ids):
400
- if logits[tok_id] > 0:
401
- logits[tok_id] /= repetition_penalty
402
- else:
403
- logits[tok_id] *= repetition_penalty
404
-
405
- # Temperature
406
- logits = logits / max(temperature, 1e-6)
407
-
408
- # Top-k
409
- if 0 < top_k < logits.size(0):
410
- topk_vals, _ = torch.topk(logits, top_k)
411
- threshold = topk_vals[-1]
412
- logits[logits < threshold] = float("-inf")
413
-
414
- probs = F.softmax(logits, dim=-1)
415
- next_id = torch.multinomial(probs, num_samples=1).item()
416
- return int(next_id)
417
-
418
-
419
- # ---------------------------------------------------------------------------
420
- # Model initialisation at module level
421
- # ---------------------------------------------------------------------------
422
-
423
- # Compute inner_dim to match Julia's SwiGLUFFN sizing:
424
- # raw_inner = floor(4 * n_embd * 2 / 3) = floor(4*512*2/3) = 1365
425
- # inner_dim = max(64, 64 * div(raw_inner + 32, 64))
426
- # = max(64, 64 * div(1397, 64))
427
- # = max(64, 64 * 21) = 1344
428
- _raw_inner = int(math.floor(4 * N_EMBD * 2 / 3))
429
- _INNER_DIM = max(64, 64 * ((_raw_inner + 32) // 64)) # 1344
430
-
431
- print(f"Building GPT model (n_layer={N_LAYER}, n_embd={N_EMBD}, "
432
- f"n_head={N_HEAD}, n_kv_head={N_KV_HEAD}, inner_dim={_INNER_DIM}) ...")
433
-
434
- MODEL = GPT(
435
- vocab_size=VOCAB_SIZE,
436
- n_embd=N_EMBD,
437
- n_head=N_HEAD,
438
- n_kv_head=N_KV_HEAD,
439
- n_layer=N_LAYER,
440
- block_size=BLOCK_SIZE,
441
- inner_dim=_INNER_DIM,
442
- rope_base=ROPE_BASE,
443
- ).to(DEVICE)
444
-
445
- # Download and load weights from HuggingFace Hub
446
- print(f"Downloading weights from {HF_REPO} ...")
447
- _weights_path = hf_hub_download(repo_id=HF_REPO, filename=HF_WEIGHTS)
448
- print(f"Downloading tokenizer from {HF_REPO} ...")
449
- _tokenizer_path = hf_hub_download(repo_id=HF_REPO, filename=HF_TOKENIZER)
450
-
451
- load_weights_from_jld2(_weights_path, MODEL)
452
- MODEL.eval()
453
-
454
- # Load tokenizer
455
- TOKENIZER: Tokenizer = Tokenizer.from_file(_tokenizer_path)
456
- print("Tokenizer loaded.")
457
-
458
- MODEL_CREATED_AT = int(time.time())
459
- print(f"JuliaFluxGPT ready on device={DEVICE}.")
460
-
461
-
462
- # ---------------------------------------------------------------------------
463
- # Token-by-token generator
464
- # ---------------------------------------------------------------------------
465
-
466
- @torch.inference_mode()
467
- def generate_stream(
468
- prompt: str,
469
- max_tokens: int = 200,
470
- temperature: float = 0.1,
471
- top_k: int = 8,
472
- repetition_penalty: float = 1.3,
473
- ):
474
- """
475
- Yields (token_text: str, is_last: bool) one token at a time.
476
- Uses a sliding window of BLOCK_SIZE tokens.
477
- """
478
- # Encode prompt; if empty start with a random token
479
- if prompt.strip():
480
- input_ids = TOKENIZER.encode(prompt).ids
481
- else:
482
- input_ids = [int(torch.randint(VOCAB_SIZE, (1,)).item())]
483
-
484
- context: list[int] = list(input_ids)
485
- generated: list[int] = []
486
-
487
- for step in range(max_tokens):
488
- # Sliding window
489
- window = context[-BLOCK_SIZE:]
490
- idx = torch.tensor([window], dtype=torch.long, device=DEVICE) # (1, T)
491
-
492
- logits = MODEL(idx) # (1, T, vocab_size)
493
- next_logits = logits[0, -1, :].cpu() # (vocab_size,)
494
-
495
- # Build seen window for repetition penalty (last 64 tokens)
496
- seen = context[max(0, len(context) - 64):]
497
-
498
- next_id = _sample_next_token(
499
- next_logits,
500
- temperature=temperature,
501
- top_k=top_k,
502
- seen_ids=seen,
503
- repetition_penalty=repetition_penalty,
504
- )
505
-
506
- generated.append(next_id)
507
- context.append(next_id)
508
-
509
- token_text = TOKENIZER.decode([next_id])
510
- is_last = (step == max_tokens - 1)
511
- yield token_text, is_last
512
-
513
-
514
- # ---------------------------------------------------------------------------
515
- # Pydantic request / response models
516
- # ---------------------------------------------------------------------------
517
-
518
- class Message(BaseModel):
519
- role: str
520
- content: str
521
-
522
-
523
- class ChatRequest(BaseModel):
524
- model: Optional[str] = MODEL_ID
525
- messages: List[Message]
526
- temperature: Optional[float] = 0.8
527
- max_tokens: Optional[int] = 200
528
- top_k: Optional[int] = 40
529
- repetition_penalty: Optional[float] = 1.3
530
- stream: Optional[bool] = False
531
- n: Optional[int] = 1
532
-
533
-
534
- # ---------------------------------------------------------------------------
535
- # FastAPI application
536
- # ---------------------------------------------------------------------------
537
-
538
- app = FastAPI(title="JuliaFluxGPT", version="1.0.0")
539
-
540
- app.add_middleware(
541
- CORSMiddleware,
542
- allow_origins=["*"],
543
- allow_methods=["*"],
544
- allow_headers=["*"],
545
- )
546
-
547
-
548
- def _openai_error(status: int, message: str, err_type: str = "invalid_request_error", code: str = None):
549
- body = {"error": {"message": message, "type": err_type}}
550
- if code:
551
- body["error"]["code"] = code
552
- return JSONResponse(status_code=status, content=body)
553
-
554
-
555
- @app.exception_handler(HTTPException)
556
- async def http_exception_handler(request: Request, exc: HTTPException):
557
- return _openai_error(exc.status_code, str(exc.detail))
558
-
559
-
560
- @app.exception_handler(RequestValidationError)
561
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
562
- msg = "; ".join(f"{e['loc'][-1]}: {e['msg']}" for e in exc.errors())
563
- return _openai_error(422, msg, code="invalid_request_error")
564
-
565
-
566
- # ── GET / ────────────────────────────────────────────────────────────────────
567
-
568
- @app.get("/")
569
- def root():
570
- return {
571
- "name": "JuliaFluxGPT",
572
- "version": "1.0.0",
573
- "description": "LLaMA-style GPT in Flux.jl — trained on philosophy and mathematics",
574
- "architecture": "RoPE + SwiGLU + GQA + RMSNorm + weight tying",
575
- "hyperparams": {
576
- "vocab_size": VOCAB_SIZE,
577
- "n_embd": N_EMBD,
578
- "n_head": N_HEAD,
579
- "n_kv_head": N_KV_HEAD,
580
- "n_layer": N_LAYER,
581
- "block_size": BLOCK_SIZE,
582
- },
583
- "endpoints": ["/v1/models", "/v1/chat/completions"],
584
- "compatible_with": ["OpenAI API", "OpenRouter"],
585
- }
586
-
587
-
588
- # ── GET /v1/models ───────────────────────────────────────────────────────────
589
-
590
- @app.get("/v1/models")
591
- def list_models():
592
- return {
593
- "object": "list",
594
- "data": [
595
- {
596
- "id": MODEL_ID,
597
- "object": "model",
598
- "created": MODEL_CREATED_AT,
599
- "owned_by": "juliafluxgpt",
600
- }
601
- ],
602
- }
603
-
604
-
605
- # ── POST /v1/chat/completions ─────────────────────────────────────────────────
606
-
607
- def _sse(data: dict) -> str:
608
- return f"data: {json.dumps(data)}\n\n"
609
-
610
-
611
- def _stream_completion(prompt, max_tokens, temperature, top_k, rep_penalty, completion_id):
612
- """Synchronous generator that yields SSE chunks one token at a time."""
613
- token_count = 0
614
- for token_text, is_last in generate_stream(
615
- prompt=prompt,
616
- max_tokens=max_tokens,
617
- temperature=temperature,
618
- top_k=top_k,
619
- repetition_penalty=rep_penalty,
620
- ):
621
- token_count += 1
622
- finish_reason = ("length" if token_count >= max_tokens else "stop") if is_last else None
623
- yield _sse({
624
- "id": completion_id,
625
- "object": "chat.completion.chunk",
626
- "created": int(time.time()),
627
- "model": MODEL_ID,
628
- "choices": [{
629
- "index": 0,
630
- "delta": {"content": token_text},
631
- "finish_reason": finish_reason,
632
- }],
633
- })
634
-
635
- yield "data: [DONE]\n\n"
636
-
637
-
638
- @app.post("/v1/chat/completions")
639
- def chat_completions(request: ChatRequest):
640
- # Extract prompt from the last user message
641
- prompt = request.messages[-1].content.strip() if request.messages else ""
642
- if not prompt:
643
- raise HTTPException(status_code=400, detail="No content in messages")
644
-
645
- max_tokens = max(1, min(request.max_tokens or 200, BLOCK_SIZE))
646
- temperature = max(0.01, min(request.temperature or 0.8, 2.0))
647
- top_k = max(1, min(request.top_k or 40, VOCAB_SIZE))
648
- rep_penalty = max(1.0, min(request.repetition_penalty or 1.3, 3.0))
649
- n = max(1, min(request.n or 1, 4))
650
- completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
651
-
652
- # ── Streaming ────────────────────────────────────────────────────────────
653
- if request.stream:
654
- return StreamingResponse(
655
- _stream_completion(
656
- prompt, max_tokens, temperature,
657
- top_k, rep_penalty, completion_id,
658
- ),
659
- media_type="text/event-stream",
660
- headers={"X-Accel-Buffering": "no"},
661
- )
662
-
663
- # ── Non-streaming (generate all n completions) ────────────────────────────
664
- choices = []
665
- total_completion_tokens = 0
666
-
667
- for i in range(n):
668
- tokens = list(
669
- generate_stream(
670
- prompt=prompt,
671
- max_tokens=max_tokens,
672
- temperature=temperature,
673
- top_k=top_k,
674
- repetition_penalty=rep_penalty,
675
- )
676
- )
677
- content = "".join(t for t, _ in tokens)
678
- total_completion_tokens += len(tokens)
679
- choices.append({
680
- "index": i,
681
- "message": {"role": "assistant", "content": content},
682
- "finish_reason": "length" if len(tokens) >= max_tokens else "stop",
683
- })
684
-
685
- prompt_tokens = len(TOKENIZER.encode(prompt).ids) if prompt else 0
686
-
687
- return {
688
- "id": completion_id,
689
- "object": "chat.completion",
690
- "created": int(time.time()),
691
- "model": MODEL_ID,
692
- "system_fingerprint": "juliafluxgpt-v1",
693
- "choices": choices,
694
- "usage": {
695
- "prompt_tokens": prompt_tokens,
696
- "completion_tokens": total_completion_tokens,
697
- "total_tokens": prompt_tokens + total_completion_tokens,
698
- },
699
- }
700
-
701
-
702
- # ---------------------------------------------------------------------------
703
- # Entrypoint
704
- # ---------------------------------------------------------------------------
705
-
706
- if __name__ == "__main__":
707
- port = int(os.environ.get("PORT", 7860))
708
- uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False)