Spaces:
Running
Running
| #= | |
| model.jl — Self-contained inference engine for MonarchSLM | |
| Implements the Monarch Mixer architecture (sub-quadratic sequence mixing using | |
| structured matrices) with RMSNorm, SwiGLU, and weight-tied output. | |
| No Lux dependency — parameters loaded directly from JLD2. CPU-only inference. | |
| Architecture per block: | |
| MonarchSequenceMixer (8 heads × MonarchMatrix + CausalDepthwiseConv + LearnedGate) | |
| → RMSNorm pre-norm + residual | |
| SwiGLU FFN → RMSNorm pre-norm + residual | |
| References: | |
| Monarch Mixer (Dao et al., 2023): Sub-quadratic GEMM-based architecture | |
| =# | |
| using NNlib | |
| using NNlib: batched_mul | |
| using Statistics | |
| using Random | |
| using JSON3 | |
| using TOML | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Model configuration | |
| # ═══════════════════════════════════════════════════════════════════ | |
| struct ModelConfig | |
| arch::String | |
| embed_dim::Int | |
| n_layers::Int | |
| n_monarch_heads::Int | |
| conv_kernel_size::Int | |
| context_length::Int | |
| vocab_size::Int | |
| weight_tying::Bool | |
| bias::Bool | |
| end | |
| function load_config_toml(path::String; vocab_size::Int=0) | |
| raw = TOML.parsefile(path) | |
| m = get(raw, "model", Dict()) | |
| return ModelConfig( | |
| get(m, "arch", "monarch"), | |
| get(m, "embed_dim", 256), | |
| get(m, "n_layers", 8), | |
| get(m, "n_monarch_heads", 8), | |
| get(m, "conv_kernel_size", 4), | |
| get(m, "context_length", 256), | |
| vocab_size, | |
| get(m, "weight_tying", true), | |
| get(m, "bias", false), | |
| ) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Character-level tokenizer | |
| # ═══════════════════════════════════════════════════════════════════ | |
| struct CharTokenizer | |
| char_to_idx::Dict{Char, Int} | |
| idx_to_char::Vector{Char} | |
| vocab_size::Int | |
| end | |
| function load_char_vocab_json(path::String) | |
| raw = JSON3.read(read(path, String)) | |
| chars = Char[only(String(s)) for s in raw] | |
| char_to_idx = Dict(c => i for (i, c) in enumerate(chars)) | |
| return CharTokenizer(char_to_idx, chars, length(chars)) | |
| end | |
| function encode(t::CharTokenizer, text::String) | |
| indices = Int[] | |
| sizehint!(indices, length(text)) | |
| for c in text | |
| idx = get(t.char_to_idx, c, nothing) | |
| idx !== nothing && push!(indices, idx) | |
| end | |
| return indices | |
| end | |
| function decode(t::CharTokenizer, indices::AbstractVector{<:Integer}) | |
| buf = IOBuffer() | |
| for idx in indices | |
| if 1 <= idx <= t.vocab_size | |
| write(buf, t.idx_to_char[idx]) | |
| end | |
| end | |
| return String(take!(buf)) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # BPE Tokenizer (GPT-2 style) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| struct BPETokenizer | |
| encoder::Dict{String, Int} | |
| decoder::Dict{Int, String} | |
| merges::Vector{Tuple{String, String}} | |
| merge_ranks::Dict{Tuple{String, String}, Int} | |
| byte_to_unicode::Dict{UInt8, Char} | |
| unicode_to_byte::Dict{Char, UInt8} | |
| vocab_size::Int | |
| pat::Regex | |
| end | |
| function load_bpe_tokenizer(vocab_path::String, merges_path::String) | |
| encoder = JSON3.read(read(vocab_path, String), Dict{String, Int}) | |
| decoder = Dict{Int, String}(v => k for (k, v) in encoder) | |
| merge_lines = readlines(merges_path) | |
| start = startswith(first(merge_lines), "#") ? 2 : 1 | |
| merges = Tuple{String, String}[] | |
| for line in merge_lines[start:end] | |
| parts = split(strip(line)) | |
| length(parts) == 2 && push!(merges, (String(parts[1]), String(parts[2]))) | |
| end | |
| merge_ranks = Dict{Tuple{String, String}, Int}(m => i for (i, m) in enumerate(merges)) | |
| b2u = _build_byte_to_unicode() | |
| u2b = Dict{Char, UInt8}(v => k for (k, v) in b2u) | |
| pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" | |
| return BPETokenizer(encoder, decoder, merges, merge_ranks, b2u, u2b, | |
| length(encoder), pat) | |
| end | |
| function encode(t::BPETokenizer, text::String) | |
| tokens = Int[] | |
| for m in eachmatch(t.pat, text) | |
| word = m.match | |
| encoded_chars = [string(t.byte_to_unicode[b]) for b in Vector{UInt8}(word)] | |
| bpe_tokens = _bpe_encode_word(encoded_chars, t.merge_ranks) | |
| for tok in bpe_tokens | |
| id = get(t.encoder, tok, nothing) | |
| id !== nothing && push!(tokens, id + 1) | |
| end | |
| end | |
| return tokens | |
| end | |
| function decode(t::BPETokenizer, ids::AbstractVector{<:Integer}) | |
| token_strs = [get(t.decoder, id - 1, "") for id in ids] | |
| joined = join(token_strs) | |
| out = UInt8[] | |
| sizehint!(out, length(joined)) | |
| for c in joined | |
| b = get(t.unicode_to_byte, c, nothing) | |
| if b !== nothing | |
| push!(out, b) | |
| else | |
| append!(out, codeunits(string(c))) | |
| end | |
| end | |
| return String(out) | |
| end | |
| function _bpe_encode_word(symbols::Vector{String}, merge_ranks::Dict{Tuple{String, String}, Int}) | |
| while length(symbols) > 1 | |
| best_pair = nothing | |
| best_rank = typemax(Int) | |
| for i in 1:length(symbols)-1 | |
| pair = (symbols[i], symbols[i+1]) | |
| rank = get(merge_ranks, pair, typemax(Int)) | |
| if rank < best_rank | |
| best_rank = rank | |
| best_pair = pair | |
| end | |
| end | |
| best_rank == typemax(Int) && break | |
| new_symbols = String[] | |
| i = 1 | |
| while i <= length(symbols) | |
| if i < length(symbols) && symbols[i] == best_pair[1] && symbols[i+1] == best_pair[2] | |
| push!(new_symbols, best_pair[1] * best_pair[2]) | |
| i += 2 | |
| else | |
| push!(new_symbols, symbols[i]) | |
| i += 1 | |
| end | |
| end | |
| symbols = new_symbols | |
| end | |
| return symbols | |
| end | |
| function _build_byte_to_unicode() | |
| bs = UInt8[] | |
| append!(bs, UInt8('!'):UInt8('~')) | |
| append!(bs, UInt8('¡'):UInt8('¬')) | |
| append!(bs, UInt8('®'):UInt8('ÿ')) | |
| cs = Int[Int(b) for b in bs] | |
| n = 0 | |
| for b in 0x00:0xff | |
| if !(b in bs) | |
| push!(bs, b) | |
| push!(cs, 256 + n) | |
| n += 1 | |
| end | |
| end | |
| return Dict{UInt8, Char}(b => Char(c) for (b, c) in zip(bs, cs)) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Unified tokenizer interface | |
| # ═══════════════════════════════════════════════════════════════════ | |
| const Tokenizer = Union{CharTokenizer, BPETokenizer} | |
| tokenizer_vocab_size(t::CharTokenizer) = t.vocab_size | |
| tokenizer_vocab_size(t::BPETokenizer) = t.vocab_size | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Causal mask (multiplicative 0/1 for Monarch) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function make_causal_mask(seq_len::Int) | |
| return Float32[j <= i ? 1.0f0 : 0.0f0 for i in 1:seq_len, j in 1:seq_len] | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Layer primitives (shared with transformer) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function rmsnorm_forward(x, weight; eps=1.0f-6) | |
| rms = sqrt.(mean(x .^ 2; dims=1) .+ eps) | |
| return weight .* (x ./ rms) | |
| end | |
| function swiglu_forward(x, ps) | |
| D = size(x, 1) | |
| x_flat = reshape(x, D, :) | |
| gate = ps.w1 * x_flat | |
| val = ps.v * x_flat | |
| hidden = NNlib.swish.(gate) .* val | |
| out = ps.w2 * hidden | |
| return reshape(out, D, size(x)[2:end]...) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Monarch Matrix realization | |
| # ═══════════════════════════════════════════════════════════════════ | |
| """ | |
| monarch_realize(L1, L2, p) -> Matrix{Float32} | |
| Materialize T×T Monarch matrix: M = Pᵀ · BlockDiag(L1) · P · BlockDiag(L2) | |
| where L1, L2 are (p, p, p) block-diagonal factors and T = p². | |
| """ | |
| function monarch_realize(L1, L2, p::Int) | |
| T = p * p | |
| # Start with identity matrix | |
| I_T = Float32[i == j ? 1.0f0 : 0.0f0 for i in 1:T, j in 1:T] | |
| # Reshape columns: (T, T) → (p, p, T) | |
| x = reshape(I_T, p, p, T) | |
| # Apply L2 block-diagonal: for each block k, multiply L2[:,:,k] @ x[:,k,:] | |
| x = permutedims(x, (1, 3, 2)) # (p, T, p) | |
| x = batched_mul(L2, x) # (p, p, p) × (p, T, p) → (p, T, p) | |
| x = permutedims(x, (1, 3, 2)) # (p, p, T) | |
| # Permutation P: transpose the p×p grid | |
| x = permutedims(x, (2, 1, 3)) | |
| # Apply L1 block-diagonal | |
| x = permutedims(x, (1, 3, 2)) # (p, T, p) | |
| x = batched_mul(L1, x) # (p, T, p) | |
| x = permutedims(x, (1, 3, 2)) # (p, p, T) | |
| # Undo permutation | |
| x = permutedims(x, (2, 1, 3)) | |
| return reshape(x, T, T) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Causal Depthwise Conv1d | |
| # ═══════════════════════════════════════════════════════════════════ | |
| """ | |
| causal_depthwise_conv1d(x, kernel) -> Array | |
| x: (D, T, B), kernel: (K, D) | |
| Causal convolution: pad K-1 zeros on the left, sum over kernel taps. | |
| """ | |
| function causal_depthwise_conv1d(x, kernel) | |
| D, T, B = size(x) | |
| K = size(kernel, 1) | |
| # Causal pad: K-1 zeros on the left | |
| pad = zeros(Float32, D, K - 1, B) | |
| x_padded = cat(pad, x; dims=2) # (D, T+K-1, B) | |
| # Sum over kernel taps | |
| out = sum(1:K) do k | |
| reshape(kernel[k:k, :], D, 1, 1) .* x_padded[:, k:k+T-1, :] | |
| end | |
| return out | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Pre-compute inference caches (Monarch matrices + causal mask) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| """ | |
| precompute_inference_caches(config, ps) -> NamedTuple | |
| Pre-realize all Monarch matrices and apply causal mask once at startup. | |
| Avoids recomputing them on every forward pass during generation. | |
| """ | |
| function precompute_inference_caches(config::ModelConfig, ps) | |
| p = isqrt(config.context_length) | |
| mask = make_causal_mask(config.context_length) | |
| # Pre-realize all Monarch matrices: monarch_ms[layer][head] = masked T×T matrix | |
| monarch_ms = Vector{Vector{Matrix{Float32}}}(undef, config.n_layers) | |
| for i in 1:config.n_layers | |
| name = Symbol("block_$i") | |
| bp = getproperty(ps.blocks, name) | |
| layer_ms = Vector{Matrix{Float32}}(undef, config.n_monarch_heads) | |
| for j in 1:config.n_monarch_heads | |
| head_name = Symbol("head_$j") | |
| ps_m = getproperty(bp.seq_mixer.monarchs, head_name) | |
| M = monarch_realize(ps_m.L1, ps_m.L2, p) .* mask | |
| layer_ms[j] = M | |
| end | |
| monarch_ms[i] = layer_ms | |
| end | |
| return (; mask, monarch_ms) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Monarch Sequence Mixer forward pass (uses cached matrices) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function monarch_sequence_mixer_forward(x, ps, n_heads::Int, monarch_ms_layer) | |
| D, T, B = size(x) | |
| H = n_heads | |
| HD = D ÷ H | |
| # 1. Causal depthwise conv for local context | |
| conv_out = causal_depthwise_conv1d(x, ps.conv.kernel) | |
| # 2. Multi-head Monarch mixing (pre-realized matrices) | |
| monarch_slices = map(1:H) do i | |
| # Slice cached matrix to actual sequence length | |
| M_t = monarch_ms_layer[i][1:T, 1:T] | |
| # Extract this head's channel slice: (HD, T, B) | |
| ch_start = (i - 1) * HD + 1 | |
| ch_end = i * HD | |
| x_slice = x[ch_start:ch_end, :, :] | |
| # Matmul: (T, T) × (T, HD*B) → (T, HD*B) | |
| x_flat = reshape(permutedims(x_slice, (2, 1, 3)), T, HD * B) | |
| y_flat = M_t * x_flat | |
| # Reshape back: (T, HD*B) → (T, HD, B) → (HD, T, B) | |
| permutedims(reshape(y_flat, T, HD, B), (2, 1, 3)) | |
| end | |
| # Concatenate heads along channel dimension | |
| monarch_out = cat(monarch_slices...; dims=1) | |
| # 3. Combine conv (local) + Monarch (global), then gate | |
| combined = conv_out .+ monarch_out | |
| gate = NNlib.sigmoid_fast.(ps.gate.weight) | |
| gated = gate .* combined | |
| return gated | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Full model forward pass (uses cached data) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function model_forward(config::ModelConfig, ps, x, caches) | |
| T = size(x, 1) # x: (seq_len, batch) of integer token IDs | |
| # Token embedding: (seq_len, batch) → (embed_dim, seq_len, batch) | |
| h = ps.tok_emb.weight[:, x] | |
| # Monarch blocks | |
| for i in 1:config.n_layers | |
| name = Symbol("block_$i") | |
| bp = getproperty(ps.blocks, name) | |
| # Pre-norm sequence mixing + residual | |
| normed = rmsnorm_forward(h, bp.ln1.weight) | |
| mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer, | |
| config.n_monarch_heads, | |
| caches.monarch_ms[i]) | |
| h = h .+ mixed | |
| # Pre-norm FFN + residual | |
| normed2 = rmsnorm_forward(h, bp.ln2.weight) | |
| ffn_out = swiglu_forward(normed2, bp.ffn) | |
| h = h .+ ffn_out | |
| end | |
| # Final norm | |
| h = rmsnorm_forward(h, ps.ln_f.weight) | |
| # Output projection | |
| D, T_out, B = size(h) | |
| h_flat = reshape(h, D, T_out * B) | |
| if config.weight_tying | |
| logits = ps.tok_emb.weight' * h_flat | |
| else | |
| logits = ps.head.weight * h_flat | |
| end | |
| return reshape(logits, :, T_out, B) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Sampling helpers | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function top_k_filter(logits::AbstractVector, k::Int) | |
| k = min(k, length(logits)) | |
| sorted = sort(Array(logits); rev=true) | |
| threshold = sorted[k] | |
| return map(l -> l >= threshold ? l : typemin(eltype(logits)), logits) | |
| end | |
| function top_p_filter(logits::AbstractVector, p::Float64) | |
| sorted_indices = sortperm(Array(logits); rev=true) | |
| sorted_logits = logits[sorted_indices] | |
| probs = NNlib.softmax(sorted_logits) | |
| cumprobs = cumsum(Array(probs)) | |
| cutoff = something(findfirst(>=(p), cumprobs), length(probs)) | |
| result = fill(typemin(eltype(logits)), length(logits)) | |
| for i in 1:cutoff | |
| result[sorted_indices[i]] = logits[sorted_indices[i]] | |
| end | |
| return result | |
| end | |
| function sample_categorical(probs::AbstractVector) | |
| r = rand(Float32) | |
| cumulative = 0.0f0 | |
| for i in eachindex(probs) | |
| cumulative += probs[i] | |
| r <= cumulative && return i | |
| end | |
| return length(probs) | |
| end | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # Text generation with streaming callback | |
| # ═══════════════════════════════════════════════════════════════════ | |
| function generate_streaming(config::ModelConfig, ps, | |
| tokenizer::Tokenizer, prompt::String; | |
| max_tokens::Int=200, | |
| temperature::Float64=0.8, | |
| top_k::Int=0, | |
| top_p::Float64=1.0, | |
| on_token=nothing, | |
| caches=nothing) | |
| tokens = encode(tokenizer, prompt) | |
| if isempty(tokens) | |
| tokens = [rand(1:tokenizer_vocab_size(tokenizer))] | |
| end | |
| # Use provided caches or compute them once | |
| if caches === nothing | |
| caches = precompute_inference_caches(config, ps) | |
| end | |
| generated = String[] | |
| for _ in 1:max_tokens | |
| ctx = if length(tokens) > config.context_length | |
| tokens[end-config.context_length+1:end] | |
| else | |
| copy(tokens) | |
| end | |
| x = reshape(ctx, :, 1) | |
| logits = model_forward(config, ps, x, caches) | |
| next_logits = Vector{Float32}(logits[:, end, 1]) | |
| if temperature != 1.0 | |
| next_logits ./= Float32(temperature) | |
| end | |
| if top_k > 0 | |
| next_logits = top_k_filter(next_logits, top_k) | |
| end | |
| if top_p < 1.0 | |
| next_logits = top_p_filter(next_logits, top_p) | |
| end | |
| probs = NNlib.softmax(next_logits) | |
| next_token = sample_categorical(probs) | |
| push!(tokens, next_token) | |
| token_str = decode(tokenizer, [next_token]) | |
| push!(generated, token_str) | |
| on_token !== nothing && on_token(token_str) | |
| end | |
| return join(generated) | |
| end | |