Spaces:
Sleeping
Sleeping
| #= | |
| checkpoint.jl β Load Flux model checkpoints for JuliaFluxGPT | |
| Loads JLD2 checkpoints saved by the juliaflux_v2 training notebook. | |
| Supports BPE tokenizer (tokenizer.json format) with character-level fallback. | |
| NOTE: The GPT struct no longer has TiedDense β weight tying is done in the | |
| forward pass. This simplifies checkpoint loading: we load all components | |
| normally and skip any lm_head key in the checkpoint (it's redundant since | |
| the output projection uses wte.weight directly). | |
| =# | |
| include("model.jl") | |
| using JLD2 | |
| using JSON3 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BPE Tokenizer (loaded from tokenizer.json β HuggingFace format) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| struct BPETokenizer | |
| vocab::Dict{String, Int} | |
| id_to_token::Dict{Int, String} | |
| merges::Vector{Tuple{String, String}} | |
| merge_rank::Dict{Tuple{String, String}, Int} | |
| byte_to_unicode::Dict{UInt8, String} | |
| unicode_to_byte::Dict{Char, UInt8} | |
| word_cache::Dict{String, Vector{Int}} | |
| gpt2_pattern::Regex | |
| end | |
| function build_byte_to_unicode() | |
| bs = UInt8[] | |
| cs = Char[] | |
| for r in [0x21:0x7e, 0xa1:0xac, 0xae:0xff] | |
| for b in r | |
| push!(bs, b) | |
| push!(cs, Char(b)) | |
| end | |
| end | |
| n = 0 | |
| for b in 0x00:0xff | |
| if b β bs | |
| push!(bs, b) | |
| push!(cs, Char(256 + n)) | |
| n += 1 | |
| end | |
| end | |
| b2u = Dict(bs[i] => string(cs[i]) for i in eachindex(bs)) | |
| u2b = Dict(v[1] => k for (k, v) in b2u) | |
| return b2u, u2b | |
| end | |
| function load_bpe_tokenizer(path::String) | |
| tok_json = JSON3.read(read(path, String)) | |
| vocab = Dict{String, Int}() | |
| for (tok_str, id) in pairs(tok_json.model.vocab) | |
| vocab[string(tok_str)] = Int(id) + 1 # +1 for Julia 1-indexing | |
| end | |
| merges = Tuple{String, String}[] | |
| for merge_entry in tok_json.model.merges | |
| if merge_entry isa AbstractVector && length(merge_entry) >= 2 | |
| push!(merges, (String(merge_entry[1]), String(merge_entry[2]))) | |
| else | |
| parts = split(string(merge_entry), " ", limit=2) | |
| if length(parts) == 2 | |
| push!(merges, (String(parts[1]), String(parts[2]))) | |
| end | |
| end | |
| end | |
| id_to_token = Dict{Int, String}(id => tok for (tok, id) in vocab) | |
| merge_rank = Dict{Tuple{String, String}, Int}( | |
| (a, b) => i for (i, (a, b)) in enumerate(merges) | |
| ) | |
| b2u, u2b = build_byte_to_unicode() | |
| gpt2_pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" | |
| BPETokenizer(vocab, id_to_token, merges, merge_rank, b2u, u2b, | |
| Dict{String, Vector{Int}}(), gpt2_pat) | |
| end | |
| function bpe_encode_word(tok::BPETokenizer, word::Vector{String}) | |
| tokens = copy(word) | |
| while length(tokens) >= 2 | |
| best_rank = typemax(Int) | |
| best_pair = ("", "") | |
| for i in 1:length(tokens)-1 | |
| rank = get(tok.merge_rank, (tokens[i], tokens[i+1]), typemax(Int)) | |
| if rank < best_rank | |
| best_rank = rank | |
| best_pair = (tokens[i], tokens[i+1]) | |
| end | |
| end | |
| best_rank == typemax(Int) && break | |
| a, b = best_pair | |
| new_tokens = String[] | |
| i = 1 | |
| while i <= length(tokens) | |
| if i < length(tokens) && tokens[i] == a && tokens[i+1] == b | |
| push!(new_tokens, a * b) | |
| i += 2 | |
| else | |
| push!(new_tokens, tokens[i]) | |
| i += 1 | |
| end | |
| end | |
| tokens = new_tokens | |
| end | |
| return tokens | |
| end | |
| function encode_bpe(tok::BPETokenizer, s::String) | |
| ids = Int[] | |
| for m in eachmatch(tok.gpt2_pattern, s) | |
| word = m.match | |
| cached = get(tok.word_cache, word, nothing) | |
| if cached !== nothing | |
| append!(ids, cached) | |
| else | |
| word_bytes = Vector{UInt8}(word) | |
| chars = [tok.byte_to_unicode[b] for b in word_bytes] | |
| tokens = bpe_encode_word(tok, chars) | |
| word_ids = Int[] | |
| for t in tokens | |
| id = get(tok.vocab, t, nothing) | |
| id !== nothing && push!(word_ids, id) | |
| end | |
| tok.word_cache[word] = word_ids | |
| append!(ids, word_ids) | |
| end | |
| end | |
| return ids | |
| end | |
| function decode_bpe(tok::BPETokenizer, ids::Vector{Int}) | |
| text = join(get(tok.id_to_token, id, "") for id in ids) | |
| bytes = UInt8[tok.unicode_to_byte[c] for c in text if haskey(tok.unicode_to_byte, c)] | |
| return String(bytes) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Checkpoint loading | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function load_flux_checkpoint(checkpoint_path::String; tokenizer_path::String="") | |
| println("Loading checkpoint from $checkpoint_path ...") | |
| data = JLD2.load(checkpoint_path) | |
| hp = data["hyperparams"] | |
| vocab_size = Int(hp["vocab_size"]) | |
| n_embd = Int(hp["n_embd"]) | |
| block_size = Int(hp["block_size"]) | |
| n_layer = Int(hp["n_layer"]) | |
| n_head = Int(hp["n_head"]) | |
| n_kv_head = Int(get(hp, "n_kv_head", hp["n_head"])) | |
| dropout_val = Float64(get(hp, "dropout", 0.0)) | |
| model = GPT(; | |
| vocab_size = vocab_size, | |
| n_embd = n_embd, | |
| block_size = block_size, | |
| n_layer = n_layer, | |
| n_head = n_head, | |
| n_kv_head = n_kv_head, | |
| dropout = 0.0 # No dropout at inference | |
| ) | |
| # Load weights component-by-component | |
| ms = data["model_state"] | |
| Flux.loadmodel!(model.wte, ms[:wte]) | |
| Flux.loadmodel!(model.drop, ms[:drop]) | |
| Flux.loadmodel!(model.blocks, ms[:blocks]) | |
| Flux.loadmodel!(model.ln_f, ms[:ln_f]) | |
| # Set to test mode (disables dropout) | |
| Flux.testmode!(model) | |
| step = get(data, "step", 0) | |
| best_val = get(data, "best_val_loss", Inf) | |
| println(" Model loaded: vocab=$vocab_size, embd=$n_embd, layers=$n_layer, " * | |
| "heads=$(n_head)Q/$(n_kv_head)KV, block=$block_size") | |
| println(" Step=$step, best_val=$(round(best_val, digits=4))") | |
| # Load tokenizer | |
| encode_fn = nothing | |
| decode_fn = nothing | |
| if !isempty(tokenizer_path) && isfile(tokenizer_path) | |
| println(" Loading BPE tokenizer from $tokenizer_path") | |
| bpe = load_bpe_tokenizer(tokenizer_path) | |
| tok_vocab_size = length(bpe.vocab) | |
| if tok_vocab_size != vocab_size | |
| "Vocab mismatch! Model expects vocab_size=$vocab_size but tokenizer has $tok_vocab_size tokens. " * | |
| "Token IDs above $vocab_size will be clamped." | |
| end | |
| encode_fn = function(s) | |
| ids = encode_bpe(bpe, s) | |
| return [clamp(id, 1, vocab_size) for id in ids] | |
| end | |
| decode_fn = ids -> decode_bpe(bpe, ids) | |
| println(" BPE tokenizer loaded: $(tok_vocab_size) tokens (model vocab: $vocab_size)") | |
| else | |
| # Character-level fallback | |
| chars = vcat(collect('a':'z'), [' ', '.']) | |
| stoi = Dict(c => i for (i, c) in enumerate(chars)) | |
| itos = Dict(i => c for (i, c) in enumerate(chars)) | |
| encode_fn = s -> [get(stoi, c, 1) for c in s] | |
| decode_fn = ids -> join(get(itos, id, '?') for id in ids) | |
| println(" No tokenizer.json found, using character-level fallback ($(length(chars)) chars)") | |
| end | |
| return (; | |
| model, vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, | |
| step, best_val, encode_fn, decode_fn | |
| ) | |
| end | |