JuliaFluxGPT / checkpoint.jl
LisaMegaWatts's picture
Restore Julia-native server (replace Python/FastAPI with Flux.jl + HTTP.jl)
6f2e71d verified
#=
checkpoint.jl β€” Load Flux model checkpoints for JuliaFluxGPT
Loads JLD2 checkpoints saved by the juliaflux_v2 training notebook.
Supports BPE tokenizer (tokenizer.json format) with character-level fallback.
NOTE: The GPT struct no longer has TiedDense β€” weight tying is done in the
forward pass. This simplifies checkpoint loading: we load all components
normally and skip any lm_head key in the checkpoint (it's redundant since
the output projection uses wte.weight directly).
=#
include("model.jl")
using JLD2
using JSON3
# ═══════════════════════════════════════════════════════════════════════════════
# BPE Tokenizer (loaded from tokenizer.json β€” HuggingFace format)
# ═══════════════════════════════════════════════════════════════════════════════
struct BPETokenizer
vocab::Dict{String, Int}
id_to_token::Dict{Int, String}
merges::Vector{Tuple{String, String}}
merge_rank::Dict{Tuple{String, String}, Int}
byte_to_unicode::Dict{UInt8, String}
unicode_to_byte::Dict{Char, UInt8}
word_cache::Dict{String, Vector{Int}}
gpt2_pattern::Regex
end
function build_byte_to_unicode()
bs = UInt8[]
cs = Char[]
for r in [0x21:0x7e, 0xa1:0xac, 0xae:0xff]
for b in r
push!(bs, b)
push!(cs, Char(b))
end
end
n = 0
for b in 0x00:0xff
if b βˆ‰ bs
push!(bs, b)
push!(cs, Char(256 + n))
n += 1
end
end
b2u = Dict(bs[i] => string(cs[i]) for i in eachindex(bs))
u2b = Dict(v[1] => k for (k, v) in b2u)
return b2u, u2b
end
function load_bpe_tokenizer(path::String)
tok_json = JSON3.read(read(path, String))
vocab = Dict{String, Int}()
for (tok_str, id) in pairs(tok_json.model.vocab)
vocab[string(tok_str)] = Int(id) + 1 # +1 for Julia 1-indexing
end
merges = Tuple{String, String}[]
for merge_entry in tok_json.model.merges
if merge_entry isa AbstractVector && length(merge_entry) >= 2
push!(merges, (String(merge_entry[1]), String(merge_entry[2])))
else
parts = split(string(merge_entry), " ", limit=2)
if length(parts) == 2
push!(merges, (String(parts[1]), String(parts[2])))
end
end
end
id_to_token = Dict{Int, String}(id => tok for (tok, id) in vocab)
merge_rank = Dict{Tuple{String, String}, Int}(
(a, b) => i for (i, (a, b)) in enumerate(merges)
)
b2u, u2b = build_byte_to_unicode()
gpt2_pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
BPETokenizer(vocab, id_to_token, merges, merge_rank, b2u, u2b,
Dict{String, Vector{Int}}(), gpt2_pat)
end
function bpe_encode_word(tok::BPETokenizer, word::Vector{String})
tokens = copy(word)
while length(tokens) >= 2
best_rank = typemax(Int)
best_pair = ("", "")
for i in 1:length(tokens)-1
rank = get(tok.merge_rank, (tokens[i], tokens[i+1]), typemax(Int))
if rank < best_rank
best_rank = rank
best_pair = (tokens[i], tokens[i+1])
end
end
best_rank == typemax(Int) && break
a, b = best_pair
new_tokens = String[]
i = 1
while i <= length(tokens)
if i < length(tokens) && tokens[i] == a && tokens[i+1] == b
push!(new_tokens, a * b)
i += 2
else
push!(new_tokens, tokens[i])
i += 1
end
end
tokens = new_tokens
end
return tokens
end
function encode_bpe(tok::BPETokenizer, s::String)
ids = Int[]
for m in eachmatch(tok.gpt2_pattern, s)
word = m.match
cached = get(tok.word_cache, word, nothing)
if cached !== nothing
append!(ids, cached)
else
word_bytes = Vector{UInt8}(word)
chars = [tok.byte_to_unicode[b] for b in word_bytes]
tokens = bpe_encode_word(tok, chars)
word_ids = Int[]
for t in tokens
id = get(tok.vocab, t, nothing)
id !== nothing && push!(word_ids, id)
end
tok.word_cache[word] = word_ids
append!(ids, word_ids)
end
end
return ids
end
function decode_bpe(tok::BPETokenizer, ids::Vector{Int})
text = join(get(tok.id_to_token, id, "") for id in ids)
bytes = UInt8[tok.unicode_to_byte[c] for c in text if haskey(tok.unicode_to_byte, c)]
return String(bytes)
end
# ═══════════════════════════════════════════════════════════════════════════════
# Checkpoint loading
# ═══════════════════════════════════════════════════════════════════════════════
function load_flux_checkpoint(checkpoint_path::String; tokenizer_path::String="")
println("Loading checkpoint from $checkpoint_path ...")
data = JLD2.load(checkpoint_path)
hp = data["hyperparams"]
vocab_size = Int(hp["vocab_size"])
n_embd = Int(hp["n_embd"])
block_size = Int(hp["block_size"])
n_layer = Int(hp["n_layer"])
n_head = Int(hp["n_head"])
n_kv_head = Int(get(hp, "n_kv_head", hp["n_head"]))
dropout_val = Float64(get(hp, "dropout", 0.0))
model = GPT(;
vocab_size = vocab_size,
n_embd = n_embd,
block_size = block_size,
n_layer = n_layer,
n_head = n_head,
n_kv_head = n_kv_head,
dropout = 0.0 # No dropout at inference
)
# Load weights component-by-component
ms = data["model_state"]
Flux.loadmodel!(model.wte, ms[:wte])
Flux.loadmodel!(model.drop, ms[:drop])
Flux.loadmodel!(model.blocks, ms[:blocks])
Flux.loadmodel!(model.ln_f, ms[:ln_f])
# Set to test mode (disables dropout)
Flux.testmode!(model)
step = get(data, "step", 0)
best_val = get(data, "best_val_loss", Inf)
println(" Model loaded: vocab=$vocab_size, embd=$n_embd, layers=$n_layer, " *
"heads=$(n_head)Q/$(n_kv_head)KV, block=$block_size")
println(" Step=$step, best_val=$(round(best_val, digits=4))")
# Load tokenizer
encode_fn = nothing
decode_fn = nothing
if !isempty(tokenizer_path) && isfile(tokenizer_path)
println(" Loading BPE tokenizer from $tokenizer_path")
bpe = load_bpe_tokenizer(tokenizer_path)
tok_vocab_size = length(bpe.vocab)
if tok_vocab_size != vocab_size
@warn "Vocab mismatch! Model expects vocab_size=$vocab_size but tokenizer has $tok_vocab_size tokens. " *
"Token IDs above $vocab_size will be clamped."
end
encode_fn = function(s)
ids = encode_bpe(bpe, s)
return [clamp(id, 1, vocab_size) for id in ids]
end
decode_fn = ids -> decode_bpe(bpe, ids)
println(" BPE tokenizer loaded: $(tok_vocab_size) tokens (model vocab: $vocab_size)")
else
# Character-level fallback
chars = vcat(collect('a':'z'), [' ', '.'])
stoi = Dict(c => i for (i, c) in enumerate(chars))
itos = Dict(i => c for (i, c) in enumerate(chars))
encode_fn = s -> [get(stoi, c, 1) for c in s]
decode_fn = ids -> join(get(itos, id, '?') for id in ids)
println(" No tokenizer.json found, using character-level fallback ($(length(chars)) chars)")
end
return (;
model, vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head,
step, best_val, encode_fn, decode_fn
)
end