#= model.jl — Self-contained inference engine for JuliaSLM Implements the full Lux-trained JuliaGPT architecture (RoPE, RMSNorm, SwiGLU, weight-tied decoder-only transformer) using only NNlib primitives. No Lux dependency required — parameters are loaded directly from JLD2. CPU-only inference. Supports both character-level and BPE tokenizers. =# using NNlib using NNlib: batched_mul using Statistics using Random using JSON3 using TOML # ═══════════════════════════════════════════════════════════════════ # Model configuration # ═══════════════════════════════════════════════════════════════════ struct ModelConfig embed_dim::Int n_layers::Int n_heads::Int head_dim::Int context_length::Int vocab_size::Int weight_tying::Bool bias::Bool end function load_config_toml(path::String; vocab_size::Int=0) raw = TOML.parsefile(path) m = get(raw, "model", Dict()) return ModelConfig( get(m, "embed_dim", 256), get(m, "n_layers", 6), get(m, "n_heads", 4), get(m, "head_dim", 64), get(m, "context_length", 256), vocab_size, get(m, "weight_tying", true), get(m, "bias", false), ) end # ═══════════════════════════════════════════════════════════════════ # Character-level tokenizer # ═══════════════════════════════════════════════════════════════════ struct CharTokenizer char_to_idx::Dict{Char, Int} idx_to_char::Vector{Char} vocab_size::Int end function load_char_vocab_json(path::String) raw = JSON3.read(read(path, String)) chars = Char[only(String(s)) for s in raw] char_to_idx = Dict(c => i for (i, c) in enumerate(chars)) return CharTokenizer(char_to_idx, chars, length(chars)) end function encode(t::CharTokenizer, text::String) indices = Int[] sizehint!(indices, length(text)) for c in text idx = get(t.char_to_idx, c, nothing) idx !== nothing && push!(indices, idx) end return indices end function decode(t::CharTokenizer, indices::AbstractVector{<:Integer}) buf = IOBuffer() for idx in indices if 1 <= idx <= t.vocab_size write(buf, t.idx_to_char[idx]) end end return String(take!(buf)) end # ═══════════════════════════════════════════════════════════════════ # BPE Tokenizer (GPT-2 style) # ═══════════════════════════════════════════════════════════════════ struct BPETokenizer encoder::Dict{String, Int} # token string → id (0-indexed from file) decoder::Dict{Int, String} # id → token string (0-indexed) merges::Vector{Tuple{String, String}} merge_ranks::Dict{Tuple{String, String}, Int} byte_to_unicode::Dict{UInt8, Char} unicode_to_byte::Dict{Char, UInt8} vocab_size::Int pat::Regex end function load_bpe_tokenizer(vocab_path::String, merges_path::String) encoder = JSON3.read(read(vocab_path, String), Dict{String, Int}) decoder = Dict{Int, String}(v => k for (k, v) in encoder) merge_lines = readlines(merges_path) start = startswith(first(merge_lines), "#") ? 2 : 1 merges = Tuple{String, String}[] for line in merge_lines[start:end] parts = split(strip(line)) length(parts) == 2 && push!(merges, (String(parts[1]), String(parts[2]))) end merge_ranks = Dict{Tuple{String, String}, Int}(m => i for (i, m) in enumerate(merges)) b2u = _build_byte_to_unicode() u2b = Dict{Char, UInt8}(v => k for (k, v) in b2u) pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" return BPETokenizer(encoder, decoder, merges, merge_ranks, b2u, u2b, length(encoder), pat) end function encode(t::BPETokenizer, text::String) tokens = Int[] for m in eachmatch(t.pat, text) word = m.match encoded_chars = [string(t.byte_to_unicode[b]) for b in Vector{UInt8}(word)] bpe_tokens = _bpe_encode_word(encoded_chars, t.merge_ranks) for tok in bpe_tokens id = get(t.encoder, tok, nothing) id !== nothing && push!(tokens, id + 1) # +1 for Julia 1-indexing end end return tokens end function decode(t::BPETokenizer, ids::AbstractVector{<:Integer}) token_strs = [get(t.decoder, id - 1, "") for id in ids] # -1 to undo Julia 1-indexing joined = join(token_strs) out = UInt8[] sizehint!(out, length(joined)) for c in joined b = get(t.unicode_to_byte, c, nothing) if b !== nothing push!(out, b) else append!(out, codeunits(string(c))) end end return String(out) end function _bpe_encode_word(symbols::Vector{String}, merge_ranks::Dict{Tuple{String, String}, Int}) while length(symbols) > 1 best_pair = nothing best_rank = typemax(Int) for i in 1:length(symbols)-1 pair = (symbols[i], symbols[i+1]) rank = get(merge_ranks, pair, typemax(Int)) if rank < best_rank best_rank = rank best_pair = pair end end best_rank == typemax(Int) && break new_symbols = String[] i = 1 while i <= length(symbols) if i < length(symbols) && symbols[i] == best_pair[1] && symbols[i+1] == best_pair[2] push!(new_symbols, best_pair[1] * best_pair[2]) i += 2 else push!(new_symbols, symbols[i]) i += 1 end end symbols = new_symbols end return symbols end function _build_byte_to_unicode() bs = UInt8[] append!(bs, UInt8('!'):UInt8('~')) append!(bs, UInt8('¡'):UInt8('¬')) append!(bs, UInt8('®'):UInt8('ÿ')) cs = Int[Int(b) for b in bs] n = 0 for b in 0x00:0xff if !(b in bs) push!(bs, b) push!(cs, 256 + n) n += 1 end end return Dict{UInt8, Char}(b => Char(c) for (b, c) in zip(bs, cs)) end # ═══════════════════════════════════════════════════════════════════ # Unified tokenizer interface # ═══════════════════════════════════════════════════════════════════ const Tokenizer = Union{CharTokenizer, BPETokenizer} tokenizer_vocab_size(t::CharTokenizer) = t.vocab_size tokenizer_vocab_size(t::BPETokenizer) = t.vocab_size # ═══════════════════════════════════════════════════════════════════ # Causal mask # ═══════════════════════════════════════════════════════════════════ function make_causal_mask(seq_len::Int) return Float32[j <= i ? 0.0f0 : typemin(Float32) for i in 1:seq_len, j in 1:seq_len] end # ═══════════════════════════════════════════════════════════════════ # Rotary Positional Embedding # ═══════════════════════════════════════════════════════════════════ function compute_rope_caches(head_dim::Int, max_seq_len::Int) freqs = Float32.(1.0 ./ (10000.0 .^ ((0:2:(head_dim-1)) ./ head_dim))) positions = Float32.(0:(max_seq_len-1)) angles = freqs * positions' return cos.(angles), sin.(angles) end function apply_rotary_emb(x, cos_cache, sin_cache, seq_len) half = size(x, 1) ÷ 2 x1 = x[1:half, :, :, :] x2 = x[half+1:end, :, :, :] c = cos_cache[:, 1:seq_len] s = sin_cache[:, 1:seq_len] o1 = x1 .* c .- x2 .* s o2 = x1 .* s .+ x2 .* c return vcat(o1, o2) end # ═══════════════════════════════════════════════════════════════════ # Layer primitives # ═══════════════════════════════════════════════════════════════════ function rmsnorm_forward(x, weight; eps=1.0f-6) rms = sqrt.(mean(x .^ 2; dims=1) .+ eps) return weight .* (x ./ rms) end function swiglu_forward(x, ps) D = size(x, 1) x_flat = reshape(x, D, :) gate = ps.w1 * x_flat val = ps.v * x_flat hidden = NNlib.swish.(gate) .* val out = ps.w2 * hidden return reshape(out, D, size(x)[2:end]...) end function self_attention_forward(x, ps, n_heads, head_dim, rope_cos, rope_sin, mask) D, T, B = size(x) H = n_heads HD = head_dim x_flat = reshape(x, D, T * B) q = reshape(ps.wq * x_flat, HD, T, H, B) k = reshape(ps.wk * x_flat, HD, T, H, B) v = reshape(ps.wv * x_flat, HD, T, H, B) q = apply_rotary_emb(q, rope_cos, rope_sin, T) k = apply_rotary_emb(k, rope_cos, rope_sin, T) scale = Float32(1.0 / sqrt(Float64(HD))) q_r = reshape(q, HD, T, H * B) k_r = reshape(k, HD, T, H * B) attn = batched_mul(permutedims(q_r, (2, 1, 3)), k_r) .* scale attn = attn .+ mask attn = NNlib.softmax(attn; dims=2) v_r = reshape(v, HD, T, H * B) out = batched_mul(v_r, permutedims(attn, (2, 1, 3))) out = reshape(out, HD * H, T, B) result = reshape(ps.wo * reshape(out, HD * H, T * B), D, T, B) return result end # ═══════════════════════════════════════════════════════════════════ # Full model forward pass (uses pre-computed causal mask) # ═══════════════════════════════════════════════════════════════════ function model_forward(config::ModelConfig, ps, rope_cos, rope_sin, x, causal_mask) T = size(x, 1) # x: (seq_len, batch) of integer token IDs # Token embedding: (seq_len, batch) → (embed_dim, seq_len, batch) h = ps.tok_emb.weight[:, x] # Slice pre-computed causal mask to actual sequence length mask = causal_mask[1:T, 1:T] # Transformer blocks for i in 1:config.n_layers name = Symbol("block_$i") bp = getproperty(ps.blocks, name) # Pre-norm attention + residual normed = rmsnorm_forward(h, bp.ln1.weight) attn_out = self_attention_forward(normed, bp.attn, config.n_heads, config.head_dim, rope_cos, rope_sin, mask) h = h .+ attn_out # Pre-norm FFN + residual normed2 = rmsnorm_forward(h, bp.ln2.weight) ffn_out = swiglu_forward(normed2, bp.ffn) h = h .+ ffn_out end # Final norm h = rmsnorm_forward(h, ps.ln_f.weight) # Output projection D, T_out, B = size(h) h_flat = reshape(h, D, T_out * B) if config.weight_tying logits = ps.tok_emb.weight' * h_flat else logits = ps.head.weight * h_flat end return reshape(logits, :, T_out, B) end # ═══════════════════════════════════════════════════════════════════ # Sampling helpers # ═══════════════════════════════════════════════════════════════════ function top_k_filter(logits::AbstractVector, k::Int) k = min(k, length(logits)) sorted = sort(Array(logits); rev=true) threshold = sorted[k] return map(l -> l >= threshold ? l : typemin(eltype(logits)), logits) end function top_p_filter(logits::AbstractVector, p::Float64) sorted_indices = sortperm(Array(logits); rev=true) sorted_logits = logits[sorted_indices] probs = NNlib.softmax(sorted_logits) cumprobs = cumsum(Array(probs)) cutoff = something(findfirst(>=(p), cumprobs), length(probs)) result = fill(typemin(eltype(logits)), length(logits)) for i in 1:cutoff result[sorted_indices[i]] = logits[sorted_indices[i]] end return result end function sample_categorical(probs::AbstractVector) r = rand(Float32) cumulative = 0.0f0 for i in eachindex(probs) cumulative += probs[i] r <= cumulative && return i end return length(probs) end # ═══════════════════════════════════════════════════════════════════ # Text generation with streaming callback # ═══════════════════════════════════════════════════════════════════ function generate_streaming(config::ModelConfig, ps, rope_cos, rope_sin, tokenizer::Tokenizer, prompt::String; max_tokens::Int=200, temperature::Float64=0.8, top_k::Int=0, top_p::Float64=1.0, on_token=nothing, causal_mask=nothing) tokens = encode(tokenizer, prompt) if isempty(tokens) tokens = [rand(1:tokenizer_vocab_size(tokenizer))] end # Use provided mask or compute once if causal_mask === nothing causal_mask = make_causal_mask(config.context_length) end generated = String[] for _ in 1:max_tokens ctx = if length(tokens) > config.context_length tokens[end-config.context_length+1:end] else copy(tokens) end x = reshape(ctx, :, 1) logits = model_forward(config, ps, rope_cos, rope_sin, x, causal_mask) next_logits = Vector{Float32}(logits[:, end, 1]) # Temperature scaling if temperature != 1.0 next_logits ./= Float32(temperature) end # Top-k filtering if top_k > 0 next_logits = top_k_filter(next_logits, top_k) end # Top-p (nucleus) filtering if top_p < 1.0 next_logits = top_p_filter(next_logits, top_p) end probs = NNlib.softmax(next_logits) next_token = sample_categorical(probs) push!(tokens, next_token) token_str = decode(tokenizer, [next_token]) push!(generated, token_str) on_token !== nothing && on_token(token_str) end return join(generated) end