#= model.jl — Self-contained inference engine for MonarchSLM Implements the Monarch Mixer architecture (sub-quadratic sequence mixing using structured matrices) with RMSNorm, SwiGLU, and weight-tied output. No Lux dependency — parameters loaded directly from JLD2. CPU-only inference. Architecture per block: MonarchSequenceMixer (8 heads × MonarchMatrix + CausalDepthwiseConv + LearnedGate) → RMSNorm pre-norm + residual SwiGLU FFN → RMSNorm pre-norm + residual References: Monarch Mixer (Dao et al., 2023): Sub-quadratic GEMM-based architecture =# using NNlib using NNlib: batched_mul using Statistics using Random using JSON3 using TOML # ═══════════════════════════════════════════════════════════════════ # Model configuration # ═══════════════════════════════════════════════════════════════════ struct ModelConfig arch::String embed_dim::Int n_layers::Int n_monarch_heads::Int conv_kernel_size::Int context_length::Int vocab_size::Int weight_tying::Bool bias::Bool end function load_config_toml(path::String; vocab_size::Int=0) raw = TOML.parsefile(path) m = get(raw, "model", Dict()) return ModelConfig( get(m, "arch", "monarch"), get(m, "embed_dim", 256), get(m, "n_layers", 8), get(m, "n_monarch_heads", 8), get(m, "conv_kernel_size", 4), get(m, "context_length", 256), vocab_size, get(m, "weight_tying", true), get(m, "bias", false), ) end # ═══════════════════════════════════════════════════════════════════ # Character-level tokenizer # ═══════════════════════════════════════════════════════════════════ struct CharTokenizer char_to_idx::Dict{Char, Int} idx_to_char::Vector{Char} vocab_size::Int end function load_char_vocab_json(path::String) raw = JSON3.read(read(path, String)) chars = Char[only(String(s)) for s in raw] char_to_idx = Dict(c => i for (i, c) in enumerate(chars)) return CharTokenizer(char_to_idx, chars, length(chars)) end function encode(t::CharTokenizer, text::String) indices = Int[] sizehint!(indices, length(text)) for c in text idx = get(t.char_to_idx, c, nothing) idx !== nothing && push!(indices, idx) end return indices end function decode(t::CharTokenizer, indices::AbstractVector{<:Integer}) buf = IOBuffer() for idx in indices if 1 <= idx <= t.vocab_size write(buf, t.idx_to_char[idx]) end end return String(take!(buf)) end # ═══════════════════════════════════════════════════════════════════ # BPE Tokenizer (GPT-2 style) # ═══════════════════════════════════════════════════════════════════ struct BPETokenizer encoder::Dict{String, Int} decoder::Dict{Int, String} merges::Vector{Tuple{String, String}} merge_ranks::Dict{Tuple{String, String}, Int} byte_to_unicode::Dict{UInt8, Char} unicode_to_byte::Dict{Char, UInt8} vocab_size::Int pat::Regex end function load_bpe_tokenizer(vocab_path::String, merges_path::String) encoder = JSON3.read(read(vocab_path, String), Dict{String, Int}) decoder = Dict{Int, String}(v => k for (k, v) in encoder) merge_lines = readlines(merges_path) start = startswith(first(merge_lines), "#") ? 2 : 1 merges = Tuple{String, String}[] for line in merge_lines[start:end] parts = split(strip(line)) length(parts) == 2 && push!(merges, (String(parts[1]), String(parts[2]))) end merge_ranks = Dict{Tuple{String, String}, Int}(m => i for (i, m) in enumerate(merges)) b2u = _build_byte_to_unicode() u2b = Dict{Char, UInt8}(v => k for (k, v) in b2u) pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" return BPETokenizer(encoder, decoder, merges, merge_ranks, b2u, u2b, length(encoder), pat) end function encode(t::BPETokenizer, text::String) tokens = Int[] for m in eachmatch(t.pat, text) word = m.match encoded_chars = [string(t.byte_to_unicode[b]) for b in Vector{UInt8}(word)] bpe_tokens = _bpe_encode_word(encoded_chars, t.merge_ranks) for tok in bpe_tokens id = get(t.encoder, tok, nothing) id !== nothing && push!(tokens, id + 1) end end return tokens end function decode(t::BPETokenizer, ids::AbstractVector{<:Integer}) token_strs = [get(t.decoder, id - 1, "") for id in ids] joined = join(token_strs) out = UInt8[] sizehint!(out, length(joined)) for c in joined b = get(t.unicode_to_byte, c, nothing) if b !== nothing push!(out, b) else append!(out, codeunits(string(c))) end end return String(out) end function _bpe_encode_word(symbols::Vector{String}, merge_ranks::Dict{Tuple{String, String}, Int}) while length(symbols) > 1 best_pair = nothing best_rank = typemax(Int) for i in 1:length(symbols)-1 pair = (symbols[i], symbols[i+1]) rank = get(merge_ranks, pair, typemax(Int)) if rank < best_rank best_rank = rank best_pair = pair end end best_rank == typemax(Int) && break new_symbols = String[] i = 1 while i <= length(symbols) if i < length(symbols) && symbols[i] == best_pair[1] && symbols[i+1] == best_pair[2] push!(new_symbols, best_pair[1] * best_pair[2]) i += 2 else push!(new_symbols, symbols[i]) i += 1 end end symbols = new_symbols end return symbols end function _build_byte_to_unicode() bs = UInt8[] append!(bs, UInt8('!'):UInt8('~')) append!(bs, UInt8('¡'):UInt8('¬')) append!(bs, UInt8('®'):UInt8('ÿ')) cs = Int[Int(b) for b in bs] n = 0 for b in 0x00:0xff if !(b in bs) push!(bs, b) push!(cs, 256 + n) n += 1 end end return Dict{UInt8, Char}(b => Char(c) for (b, c) in zip(bs, cs)) end # ═══════════════════════════════════════════════════════════════════ # Unified tokenizer interface # ═══════════════════════════════════════════════════════════════════ const Tokenizer = Union{CharTokenizer, BPETokenizer} tokenizer_vocab_size(t::CharTokenizer) = t.vocab_size tokenizer_vocab_size(t::BPETokenizer) = t.vocab_size # ═══════════════════════════════════════════════════════════════════ # Causal mask (multiplicative 0/1 for Monarch) # ═══════════════════════════════════════════════════════════════════ function make_causal_mask(seq_len::Int) return Float32[j <= i ? 1.0f0 : 0.0f0 for i in 1:seq_len, j in 1:seq_len] end # ═══════════════════════════════════════════════════════════════════ # Layer primitives (shared with transformer) # ═══════════════════════════════════════════════════════════════════ function rmsnorm_forward(x, weight; eps=1.0f-6) rms = sqrt.(mean(x .^ 2; dims=1) .+ eps) return weight .* (x ./ rms) end function swiglu_forward(x, ps) D = size(x, 1) x_flat = reshape(x, D, :) gate = ps.w1 * x_flat val = ps.v * x_flat hidden = NNlib.swish.(gate) .* val out = ps.w2 * hidden return reshape(out, D, size(x)[2:end]...) end # ═══════════════════════════════════════════════════════════════════ # Monarch Matrix realization # ═══════════════════════════════════════════════════════════════════ """ monarch_realize(L1, L2, p) -> Matrix{Float32} Materialize T×T Monarch matrix: M = Pᵀ · BlockDiag(L1) · P · BlockDiag(L2) where L1, L2 are (p, p, p) block-diagonal factors and T = p². """ function monarch_realize(L1, L2, p::Int) T = p * p # Start with identity matrix I_T = Float32[i == j ? 1.0f0 : 0.0f0 for i in 1:T, j in 1:T] # Reshape columns: (T, T) → (p, p, T) x = reshape(I_T, p, p, T) # Apply L2 block-diagonal: for each block k, multiply L2[:,:,k] @ x[:,k,:] x = permutedims(x, (1, 3, 2)) # (p, T, p) x = batched_mul(L2, x) # (p, p, p) × (p, T, p) → (p, T, p) x = permutedims(x, (1, 3, 2)) # (p, p, T) # Permutation P: transpose the p×p grid x = permutedims(x, (2, 1, 3)) # Apply L1 block-diagonal x = permutedims(x, (1, 3, 2)) # (p, T, p) x = batched_mul(L1, x) # (p, T, p) x = permutedims(x, (1, 3, 2)) # (p, p, T) # Undo permutation x = permutedims(x, (2, 1, 3)) return reshape(x, T, T) end # ═══════════════════════════════════════════════════════════════════ # Causal Depthwise Conv1d # ═══════════════════════════════════════════════════════════════════ """ causal_depthwise_conv1d(x, kernel) -> Array x: (D, T, B), kernel: (K, D) Causal convolution: pad K-1 zeros on the left, sum over kernel taps. """ function causal_depthwise_conv1d(x, kernel) D, T, B = size(x) K = size(kernel, 1) # Causal pad: K-1 zeros on the left pad = zeros(Float32, D, K - 1, B) x_padded = cat(pad, x; dims=2) # (D, T+K-1, B) # Sum over kernel taps out = sum(1:K) do k reshape(kernel[k:k, :], D, 1, 1) .* x_padded[:, k:k+T-1, :] end return out end # ═══════════════════════════════════════════════════════════════════ # Pre-compute inference caches (Monarch matrices + causal mask) # ═══════════════════════════════════════════════════════════════════ """ precompute_inference_caches(config, ps) -> NamedTuple Pre-realize all Monarch matrices and apply causal mask once at startup. Avoids recomputing them on every forward pass during generation. """ function precompute_inference_caches(config::ModelConfig, ps) p = isqrt(config.context_length) mask = make_causal_mask(config.context_length) # Pre-realize all Monarch matrices: monarch_ms[layer][head] = masked T×T matrix monarch_ms = Vector{Vector{Matrix{Float32}}}(undef, config.n_layers) for i in 1:config.n_layers name = Symbol("block_$i") bp = getproperty(ps.blocks, name) layer_ms = Vector{Matrix{Float32}}(undef, config.n_monarch_heads) for j in 1:config.n_monarch_heads head_name = Symbol("head_$j") ps_m = getproperty(bp.seq_mixer.monarchs, head_name) M = monarch_realize(ps_m.L1, ps_m.L2, p) .* mask layer_ms[j] = M end monarch_ms[i] = layer_ms end return (; mask, monarch_ms) end # ═══════════════════════════════════════════════════════════════════ # Monarch Sequence Mixer forward pass (uses cached matrices) # ═══════════════════════════════════════════════════════════════════ function monarch_sequence_mixer_forward(x, ps, n_heads::Int, monarch_ms_layer) D, T, B = size(x) H = n_heads HD = D ÷ H # 1. Causal depthwise conv for local context conv_out = causal_depthwise_conv1d(x, ps.conv.kernel) # 2. Multi-head Monarch mixing (pre-realized matrices) monarch_slices = map(1:H) do i # Slice cached matrix to actual sequence length M_t = monarch_ms_layer[i][1:T, 1:T] # Extract this head's channel slice: (HD, T, B) ch_start = (i - 1) * HD + 1 ch_end = i * HD x_slice = x[ch_start:ch_end, :, :] # Matmul: (T, T) × (T, HD*B) → (T, HD*B) x_flat = reshape(permutedims(x_slice, (2, 1, 3)), T, HD * B) y_flat = M_t * x_flat # Reshape back: (T, HD*B) → (T, HD, B) → (HD, T, B) permutedims(reshape(y_flat, T, HD, B), (2, 1, 3)) end # Concatenate heads along channel dimension monarch_out = cat(monarch_slices...; dims=1) # 3. Combine conv (local) + Monarch (global), then gate combined = conv_out .+ monarch_out gate = NNlib.sigmoid_fast.(ps.gate.weight) gated = gate .* combined return gated end # ═══════════════════════════════════════════════════════════════════ # Full model forward pass (uses cached data) # ═══════════════════════════════════════════════════════════════════ function model_forward(config::ModelConfig, ps, x, caches) T = size(x, 1) # x: (seq_len, batch) of integer token IDs # Token embedding: (seq_len, batch) → (embed_dim, seq_len, batch) h = ps.tok_emb.weight[:, x] # Monarch blocks for i in 1:config.n_layers name = Symbol("block_$i") bp = getproperty(ps.blocks, name) # Pre-norm sequence mixing + residual normed = rmsnorm_forward(h, bp.ln1.weight) mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer, config.n_monarch_heads, caches.monarch_ms[i]) h = h .+ mixed # Pre-norm FFN + residual normed2 = rmsnorm_forward(h, bp.ln2.weight) ffn_out = swiglu_forward(normed2, bp.ffn) h = h .+ ffn_out end # Final norm h = rmsnorm_forward(h, ps.ln_f.weight) # Output projection D, T_out, B = size(h) h_flat = reshape(h, D, T_out * B) if config.weight_tying logits = ps.tok_emb.weight' * h_flat else logits = ps.head.weight * h_flat end return reshape(logits, :, T_out, B) end # ═══════════════════════════════════════════════════════════════════ # Sampling helpers # ═══════════════════════════════════════════════════════════════════ function top_k_filter(logits::AbstractVector, k::Int) k = min(k, length(logits)) sorted = sort(Array(logits); rev=true) threshold = sorted[k] return map(l -> l >= threshold ? l : typemin(eltype(logits)), logits) end function top_p_filter(logits::AbstractVector, p::Float64) sorted_indices = sortperm(Array(logits); rev=true) sorted_logits = logits[sorted_indices] probs = NNlib.softmax(sorted_logits) cumprobs = cumsum(Array(probs)) cutoff = something(findfirst(>=(p), cumprobs), length(probs)) result = fill(typemin(eltype(logits)), length(logits)) for i in 1:cutoff result[sorted_indices[i]] = logits[sorted_indices[i]] end return result end function sample_categorical(probs::AbstractVector) r = rand(Float32) cumulative = 0.0f0 for i in eachindex(probs) cumulative += probs[i] r <= cumulative && return i end return length(probs) end # ═══════════════════════════════════════════════════════════════════ # Text generation with streaming callback # ═══════════════════════════════════════════════════════════════════ function generate_streaming(config::ModelConfig, ps, tokenizer::Tokenizer, prompt::String; max_tokens::Int=200, temperature::Float64=0.8, top_k::Int=0, top_p::Float64=1.0, on_token=nothing, caches=nothing) tokens = encode(tokenizer, prompt) if isempty(tokens) tokens = [rand(1:tokenizer_vocab_size(tokenizer))] end # Use provided caches or compute them once if caches === nothing caches = precompute_inference_caches(config, ps) end generated = String[] for _ in 1:max_tokens ctx = if length(tokens) > config.context_length tokens[end-config.context_length+1:end] else copy(tokens) end x = reshape(ctx, :, 1) logits = model_forward(config, ps, x, caches) next_logits = Vector{Float32}(logits[:, end, 1]) if temperature != 1.0 next_logits ./= Float32(temperature) end if top_k > 0 next_logits = top_k_filter(next_logits, top_k) end if top_p < 1.0 next_logits = top_p_filter(next_logits, top_p) end probs = NNlib.softmax(next_logits) next_token = sample_categorical(probs) push!(tokens, next_token) token_str = decode(tokenizer, [next_token]) push!(generated, token_str) on_token !== nothing && on_token(token_str) end return join(generated) end